Spaces:
Paused
Paused
| """Model-related code and constants.""" | |
| import dataclasses | |
| import os | |
| import re | |
| import PIL.Image | |
| # pylint: disable=g-bad-import-order | |
| import gradio_helpers | |
| import paligemma_bv | |
| ORGANIZATION = 'google' | |
| BASE_MODELS = [ | |
| ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'), | |
| ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'), | |
| ] | |
| MODELS = { | |
| **{ | |
| model_name: ( | |
| f'{ORGANIZATION}/{repo}', | |
| f'{model_name}.bf16.npz', | |
| 'bfloat16', # Model repo revision. | |
| ) | |
| for repo, model_name in BASE_MODELS | |
| }, | |
| } | |
| MODELS_INFO = { | |
| 'paligemma-3b-mix-224': ( | |
| 'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output ' | |
| 'text sequences on a mixture of downstream academic datasets. The models are available in float32, ' | |
| 'bfloat16 and float16 format for research purposes only.' | |
| ), | |
| 'paligemma-3b-mix-448': ( | |
| 'JAX/FLAX PaliGemma 3B weights, finetuned with 448x448 input images and 512 token input/output ' | |
| 'text sequences on a mixture of downstream academic datasets. The models are available in float32, ' | |
| 'bfloat16 and float16 format for research purposes only.' | |
| ), | |
| } | |
| MODELS_RES_SEQ = { | |
| 'paligemma-3b-mix-224': (224, 256), | |
| 'paligemma-3b-mix-448': (448, 512), | |
| } | |
| # "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM. | |
| # Below value should be smaller than "available RAM - one model". | |
| # A single bf16 is about 5860 MB. | |
| MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9) | |
| config = paligemma_bv.PaligemmaConfig( | |
| ckpt='', # will be set below | |
| res=224, | |
| text_len=64, | |
| tokenizer='gemma(tokensets=("loc", "seg"))', | |
| vocab_size=256_000 + 1024 + 128, | |
| ) | |
| def get_cached_model( | |
| model_name: str, | |
| ) -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]: | |
| """Returns model and params, using RAM cache.""" | |
| res, seq = MODELS_RES_SEQ[model_name] | |
| model_path = gradio_helpers.get_paths()[model_name] | |
| config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq) | |
| model, params_cpu = gradio_helpers.get_memory_cache( | |
| config_, | |
| lambda: paligemma_bv.load_model(config_), | |
| max_cache_size_bytes=MAX_RAM_CACHE, | |
| ) | |
| return model, params_cpu | |
| def generate( | |
| model_name: str, sampler: str, image: PIL.Image.Image, prompt: str | |
| ) -> str: | |
| """Generates output with specified `model_name`, `sampler`.""" | |
| model, params_cpu = get_cached_model(model_name) | |
| batch = model.shard_batch(model.prepare_batch([image], [prompt])) | |
| with gradio_helpers.timed('sharding'): | |
| params = model.shard_params(params_cpu) | |
| with gradio_helpers.timed('computation', start_message=True): | |
| tokens = model.predict(params, batch, sampler=sampler) | |
| return model.tokenizer.to_str(tokens[0]) | |