Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| """Wraps `big_vision` PaliGemma model for easy use in demo.""" | |
| from collections.abc import Callable | |
| import dataclasses | |
| from typing import Any | |
| import jax | |
| import jax.numpy as jnp | |
| import ml_collections | |
| import numpy as np | |
| import PIL.Image | |
| from big_vision import sharding | |
| from big_vision import utils | |
| from big_vision.models.proj.paligemma import paligemma | |
| from big_vision.pp import builder as pp_builder | |
| from big_vision.pp import ops_general # pylint: disable=unused-import | |
| from big_vision.pp import ops_image # pylint: disable=unused-import | |
| from big_vision.pp import ops_text # pylint: disable=unused-import | |
| from big_vision.pp import tokenizer | |
| from big_vision.pp.proj.paligemma import ops as ops_paligemma # pylint: disable=unused-import | |
| from big_vision.trainers.proj.paligemma import predict_fns | |
| mesh = jax.sharding.Mesh(jax.devices(), 'data') | |
| def _recover_bf16(x): | |
| if x.dtype == np.dtype('V2'): | |
| x = x.view('bfloat16') | |
| return x | |
| def _load( | |
| path, tokenizer_spec='gemma(tokensets=("loc", "seg"))', vocab_size=257_152 | |
| ): | |
| """Loads model, params, decode functions and tokenizer.""" | |
| tok = tokenizer.get_tokenizer(tokenizer_spec) | |
| config = ml_collections.FrozenConfigDict(dict( | |
| llm_model='proj.paligemma.gemma_bv', | |
| llm=dict(vocab_size=vocab_size, variant='gemma_2b'), | |
| img=dict(variant='So400m/14', pool_type='none', scan=True), | |
| )) | |
| model = paligemma.Model(**config) | |
| decode = predict_fns.get_all(model)['decode'] | |
| beam_decode = predict_fns.get_all(model)['beam_decode'] | |
| params_cpu = paligemma.load(None, path, config) | |
| # Some numpy versions don't load bfloat16 correctly: | |
| params_cpu = jax.tree.map(_recover_bf16, params_cpu) | |
| return model, params_cpu, decode, beam_decode, tok | |
| def _shard_params(params_cpu): | |
| """Shards `params_cpu` with fsdp strategy on all available devices.""" | |
| params_sharding = sharding.infer_sharding( | |
| params_cpu, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh | |
| ) | |
| params = jax.tree.map(utils.reshard, params_cpu, params_sharding) | |
| return params | |
| def _pil2np(img): | |
| """Accepts `PIL.Image` or `np.ndarray` and returns `np.ndarray`.""" | |
| if isinstance(img, PIL.Image.Image): | |
| img = np.array(img) | |
| img = img[..., :3] | |
| if img.ndim == 2: | |
| img = img[..., None] | |
| if img.shape[-1] == 1: | |
| img = np.repeat(img, 3, axis=-1) | |
| return img | |
| def _prepare_batch( | |
| images, | |
| prefixes, | |
| *, | |
| res=224, | |
| tokenizer_spec='gemma(tokensets=("loc", "seg"))', | |
| suffixes=None, | |
| text_len=64, | |
| ): | |
| """Returns non-sharded batch.""" | |
| pp_fn = pp_builder.get_preprocess_fn('|'.join([ | |
| f'resize({res}, antialias=True)|value_range(-1, 1)', | |
| f"tok(key='prefix', bos='yes', model='{tokenizer_spec}')", | |
| f"tok(key='septok', text='\\n', model='{tokenizer_spec}')", | |
| f"tok(key='suffix', model='{tokenizer_spec}')", | |
| 'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])', # pylint: disable=line-too-long | |
| f'tolen({text_len}, pad_value=0, key="text")', | |
| f'tolen({text_len}, pad_value=1, key="mask_ar")', | |
| f'tolen({text_len}, pad_value=0, key="mask_input")', | |
| 'keep("image", "text", "mask_ar", "mask_input")', | |
| ]), log_data=False) | |
| assert not isinstance(prefixes, str), f'expected batch: {prefixes}' | |
| assert ( | |
| isinstance(images, (list, tuple)) or images.ndim == 4 | |
| ), f'expected batch: {images.shape}' | |
| if suffixes is None: | |
| suffixes = [''] * len(prefixes) | |
| assert len(prefixes) == len(suffixes) == len(images) | |
| examples = [{'_mask': True, **pp_fn({ | |
| 'image': np.asarray(_pil2np(image)), | |
| 'prefix': np.array(prefix), | |
| 'suffix': np.array(suffix), | |
| })} for image, prefix, suffix in zip(images, prefixes, suffixes)] | |
| batch = jax.tree_map(lambda *xs: np.stack(xs), *examples) | |
| return batch | |
| def _shard_batch(batch, n=None): | |
| """Shards `batch` with fsdp strategy on all available devices.""" | |
| if n is None: | |
| n = jax.local_device_count() | |
| def pad(x): | |
| return jnp.pad(x, [(0, -len(x) % n)] + [(0, 0)] * (x.ndim - 1)) | |
| batch = {k: pad(v) for k, v in batch.items()} | |
| data_sharding = jax.sharding.NamedSharding( | |
| mesh, jax.sharding.PartitionSpec('data') | |
| ) | |
| batch_on_device = utils.reshard(batch, data_sharding) | |
| return batch_on_device | |
| class PaligemmaConfig: | |
| """Desribes a `big_vision` PaliGemma model.""" | |
| ckpt: str | |
| res: int | |
| text_len: int | |
| tokenizer: str | |
| vocab_size: int | |
| class PaliGemmaModel: | |
| """Wraps a `big_vision` PaliGemma model.""" | |
| config: PaligemmaConfig | |
| tokenizer: tokenizer.Tokenizer | |
| decode: Callable[..., Any] | |
| beam_decode: Callable[..., Any] | |
| def shard_batch(cls, batch): | |
| return _shard_batch(batch) | |
| def shard_params(cls, params_cpu): | |
| return _shard_params(params_cpu) | |
| def prepare_batch(self, images, texts, suffixes=None): | |
| return _prepare_batch( | |
| images=images, | |
| prefixes=texts, | |
| suffixes=suffixes, | |
| res=self.config.res, | |
| tokenizer_spec=self.config.tokenizer, | |
| text_len=self.config.text_len, | |
| ) | |
| def predict( | |
| self, | |
| params, | |
| batch, | |
| devices=None, | |
| max_decode_len=128, | |
| sampler='greedy', | |
| **kw, | |
| ): | |
| """Returns tokens.""" | |
| if devices is None: | |
| devices = jax.devices() | |
| if sampler == 'beam': | |
| decode = self.beam_decode | |
| else: | |
| decode = self.decode | |
| kw['sampler'] = sampler | |
| return decode( | |
| {'params': params}, | |
| batch=batch, | |
| devices=devices, | |
| eos_token=self.tokenizer.eos_token, | |
| max_decode_len=max_decode_len, | |
| **kw, | |
| ) | |
| ParamsCpu = Any | |
| def load_model(config: PaligemmaConfig) -> tuple[PaliGemmaModel, ParamsCpu]: | |
| """Loads model from config.""" | |
| model, params_cpu, decode, beam_decode, tok = _load( | |
| path=config.ckpt, | |
| tokenizer_spec=config.tokenizer, | |
| vocab_size=config.vocab_size, | |
| ) | |
| del model | |
| return PaliGemmaModel( | |
| config=config, tokenizer=tok, decode=decode, beam_decode=beam_decode, | |
| ), params_cpu | |
