Spaces:
Runtime error
Runtime error
| import os | |
| import PIL.Image | |
| import torch | |
| from huggingface_hub import login | |
| from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import functools | |
| import spaces | |
| hf_token = os.getenv("HF_TOKEN") | |
| login(token=hf_token, add_to_git_credential=True) | |
| class PaliGemmaModel: | |
| def __init__(self): | |
| self.model_id = "google/paligemma-3b-mix-448" | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device) | |
| self.processor = PaliGemmaProcessor.from_pretrained(self.model_id) | |
| def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str: | |
| inputs = self.processor(text=text, images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to the correct device | |
| with torch.inference_mode(): | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False | |
| ) | |
| result = self.processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| return result[0][len(text):].lstrip("\n") | |
| class VAEModel: | |
| def __init__(self, model_path: str): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.params = self._get_params(model_path) | |
| def _get_params(self, checkpoint_path): | |
| """Converts PyTorch checkpoint to Flax params.""" | |
| checkpoint = dict(np.load(checkpoint_path)) | |
| def transp(kernel): | |
| return np.transpose(kernel, (2, 3, 1, 0)) | |
| def conv(name): | |
| return { | |
| 'bias': checkpoint[name + '.bias'], | |
| 'kernel': transp(checkpoint[name + '.weight']), | |
| } | |
| def resblock(name): | |
| return { | |
| 'Conv_0': conv(name + '.0'), | |
| 'Conv_1': conv(name + '.2'), | |
| 'Conv_2': conv(name + '.4'), | |
| } | |
| return { | |
| '_embeddings': checkpoint['_vq_vae._embedding'], | |
| 'Conv_0': conv('decoder.0'), | |
| 'ResBlock_0': resblock('decoder.2.net'), | |
| 'ResBlock_1': resblock('decoder.3.net'), | |
| 'ConvTranspose_0': conv('decoder.4'), | |
| 'ConvTranspose_1': conv('decoder.6'), | |
| 'ConvTranspose_2': conv('decoder.8'), | |
| 'ConvTranspose_3': conv('decoder.10'), | |
| 'Conv_1': conv('decoder.12'), | |
| } | |
| def reconstruct_masks(self, codebook_indices): | |
| quantized = self._quantized_values_from_codebook_indices(codebook_indices) | |
| return self._decoder().apply({'params': self.params}, quantized) | |
| def _quantized_values_from_codebook_indices(self, codebook_indices): | |
| batch_size, num_tokens = codebook_indices.shape | |
| assert num_tokens == 16, codebook_indices.shape | |
| unused_num_embeddings, embedding_dim = self.params['_embeddings'].shape | |
| encodings = jnp.take(self.params['_embeddings'], codebook_indices.reshape((-1)), axis=0) | |
| encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) | |
| return encodings | |
| def _decoder(self): | |
| class ResBlock(nn.Module): | |
| features: int | |
| def __call__(self, x): | |
| original_x = x | |
| x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) | |
| x = nn.relu(x) | |
| x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) | |
| x = nn.relu(x) | |
| x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x) | |
| return x + original_x | |
| class Decoder(nn.Module): | |
| """Upscales quantized vectors to mask.""" | |
| def __call__(self, x): | |
| num_res_blocks = 2 | |
| dim = 128 | |
| num_upsample_layers = 4 | |
| x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x) | |
| x = nn.relu(x) | |
| for _ in range(num_res_blocks): | |
| x = ResBlock(features=dim)(x) | |
| for _ in range(num_upsample_layers): | |
| x = nn.ConvTranspose( | |
| features=dim, | |
| kernel_size=(4, 4), | |
| strides=(2, 2), | |
| padding=2, | |
| transpose_kernel=True, | |
| )(x) | |
| x = nn.relu(x) | |
| dim //= 2 | |
| x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x) | |
| return x | |
| return jax.jit(Decoder().apply, backend='cpu') | |