Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from functools import partial | |
| from lib.model_zoo.common.get_model import register | |
| import torch.nn.functional as F | |
| symbol = 'clip' | |
| class AbstractEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def encode(self, *args, **kwargs): | |
| raise NotImplementedError | |
| from transformers import CLIPTokenizer, CLIPTextModel | |
| def disabled_train(self, mode=True): | |
| """Overwrite model.train with this function to make sure train/eval mode | |
| does not change anymore.""" | |
| return self | |
| ############### | |
| # for vd next # | |
| ############### | |
| from transformers import CLIPModel | |
| class CLIPTextContextEncoder(AbstractEncoder): | |
| def __init__(self, | |
| version="openai/clip-vit-large-patch14", | |
| max_length=77, | |
| fp16=False, ): | |
| super().__init__() | |
| self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
| self.model = CLIPModel.from_pretrained(version) | |
| self.max_length = max_length | |
| self.fp16 = fp16 | |
| self.freeze() | |
| def get_device(self): | |
| # A trick to get device | |
| return self.model.text_projection.weight.device | |
| def freeze(self): | |
| self.model = self.model.eval() | |
| self.train = disabled_train | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def encode(self, text): | |
| batch_encoding = self.tokenizer( | |
| text, truncation=True, max_length=self.max_length, return_length=True, | |
| return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
| tokens = batch_encoding["input_ids"].to(self.get_device()) | |
| outputs = self.model.text_model(input_ids=tokens) | |
| z = self.model.text_projection(outputs.last_hidden_state) | |
| z_pooled = self.model.text_projection(outputs.pooler_output) | |
| z = z / torch.norm(z_pooled.unsqueeze(1), dim=-1, keepdim=True) | |
| return z | |
| from transformers import CLIPProcessor | |
| class CLIPImageContextEncoder(AbstractEncoder): | |
| def __init__(self, | |
| version="openai/clip-vit-large-patch14", | |
| fp16=False, ): | |
| super().__init__() | |
| self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
| self.processor = CLIPProcessor.from_pretrained(version) | |
| self.model = CLIPModel.from_pretrained(version) | |
| self.fp16 = fp16 | |
| self.freeze() | |
| def get_device(self): | |
| # A trick to get device | |
| return self.model.text_projection.weight.device | |
| def freeze(self): | |
| self.model = self.model.eval() | |
| self.train = disabled_train | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def _encode(self, images): | |
| if isinstance(images, torch.Tensor): | |
| import torchvision.transforms as tvtrans | |
| images = [tvtrans.ToPILImage()(i) for i in images] | |
| inputs = self.processor(images=images, return_tensors="pt") | |
| pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values'] | |
| pixels = pixels.to(self.get_device()) | |
| outputs = self.model.vision_model(pixel_values=pixels) | |
| z = outputs.last_hidden_state | |
| z = self.model.vision_model.post_layernorm(z) | |
| z = self.model.visual_projection(z) | |
| z_pooled = z[:, 0:1] | |
| z = z / torch.norm(z_pooled, dim=-1, keepdim=True) | |
| return z | |
| def _encode_wmask(self, images, masks): | |
| assert isinstance(masks, torch.Tensor) | |
| assert (len(masks.shape)==4) and (masks.shape[1]==1) | |
| masks = torch.clamp(masks, 0, 1) | |
| masks = masks.float() | |
| masks = F.interpolate(masks, [224, 224], mode='bilinear') | |
| if masks.sum() == masks.numel(): | |
| return self._encode(images) | |
| device = images.device | |
| dtype = images.dtype | |
| gscale = masks.mean(axis=[1, 2, 3], keepdim=True).flatten(2) | |
| vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size | |
| vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride | |
| mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float() | |
| vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2) | |
| vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size) | |
| vtoken_mask = torch.concat([gscale, vtoken_mask], axis=1) | |
| import types | |
| def customized_embedding_forward(self, pixel_values): | |
| batch_size = pixel_values.shape[0] | |
| patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] | |
| patch_embeds = patch_embeds.flatten(2).transpose(1, 2) | |
| class_embeds = self.class_embedding.expand(batch_size, 1, -1) | |
| embeddings = torch.cat([class_embeds, patch_embeds], dim=1) | |
| embeddings = embeddings + self.position_embedding(self.position_ids) | |
| embeddings = embeddings*vtoken_mask.to(embeddings.dtype) | |
| return embeddings | |
| old_forward = self.model.vision_model.embeddings.forward | |
| self.model.vision_model.embeddings.forward = types.MethodType( | |
| customized_embedding_forward, self.model.vision_model.embeddings) | |
| z = self._encode(images) | |
| self.model.vision_model.embeddings.forward = old_forward | |
| z = z * vtoken_mask.to(dtype) | |
| return z | |
| def encode(self, images, masks=None): | |
| if masks is None: | |
| return self._encode(images) | |
| else: | |
| return self._encode_wmask(images, masks) | |