Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import timm | |
| from torchvision.transforms import Normalize | |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD | |
| import os | |
| class IndentityMapping(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x, resize=True): | |
| b, c, h, w = x.shape | |
| x = x.reshape(b, c, h*w).transpose(1, 2) | |
| return x | |
| class DINOv2(nn.Module): | |
| def __init__(self, weight_path:str, base_patch_size=16): | |
| super(DINOv2, self).__init__() | |
| directory = os.path.dirname(weight_path) | |
| weight_path = os.path.basename(weight_path) | |
| self.encoder = torch.hub.load( | |
| directory, | |
| weight_path, | |
| source="local", | |
| skip_validation=True | |
| ) | |
| self.encoder = self.encoder.to(torch.bfloat16) | |
| self.pos_embed = copy.deepcopy(self.encoder.pos_embed) | |
| self.encoder.head = torch.nn.Identity() | |
| self.patch_size = self.encoder.patch_embed.patch_size | |
| self.precomputed_pos_embed = dict() | |
| self.base_patch_size = base_patch_size | |
| self.encoder.compile() | |
| def forward(self, x, resize=True): | |
| b, c, h, w = x.shape | |
| x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) | |
| if resize: | |
| x = torch.nn.functional.interpolate(x, (int(14*h/self.base_patch_size), int(14*w/self.base_patch_size)), mode='bicubic') | |
| feature = self.encoder.forward_features(x)['x_norm_patchtokens'] | |
| feature = feature.to(torch.bfloat16) | |
| return feature | |
| from transformers import CLIPModel, CLIPTokenizer | |
| class CLIP(nn.Module): | |
| def __init__(self, weight_path:str): | |
| super(CLIP, self).__init__() | |
| self.model = CLIPModel.from_pretrained(weight_path).to(torch.bfloat16) | |
| self.tokenizer = CLIPTokenizer.from_pretrained(weight_path) | |
| self.height = self.model.config.vision_config.image_size | |
| self.width = self.model.config.vision_config.image_size | |
| self.model.vision_model.compile() | |
| self.model.text_model.compile() | |
| def forward(self, x, text, resize=True): | |
| tokens = self.tokenizer(text, truncation=True, return_tensors='pt', padding="max_length", max_length=self.tokenizer.model_max_length).input_ids.cuda() | |
| text_output = self.model.text_model(input_ids=tokens).last_hidden_state | |
| text_output = self.model.text_projection(text_output) | |
| text_output = torch.nn.functional.normalize(text_output, dim=-1, p=2) | |
| if resize: | |
| x = torch.nn.functional.interpolate(x, (self.height, self.width), mode='bicubic') | |
| x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x) | |
| vision_output = self.model.vision_model(x).last_hidden_state[:, 1:] | |
| vision_output = self.model.visual_projection(vision_output) | |
| vision_output = torch.nn.functional.normalize(vision_output, dim=-1, p=2) | |
| output = torch.bmm(vision_output, text_output.transpose(1, 2)) | |
| return output | |
| from transformers import SiglipModel, GemmaTokenizer, SiglipTokenizer | |
| class SigLIP(nn.Module): | |
| def __init__(self, weight_path:str): | |
| super(SigLIP, self).__init__() | |
| if "siglip2" in weight_path: | |
| self.tokenizer = GemmaTokenizer.from_pretrained(weight_path) | |
| else: | |
| self.tokenizer = SiglipTokenizer.from_pretrained(weight_path) | |
| self.model = SiglipModel.from_pretrained(weight_path).to(torch.bfloat16) | |
| self.mean = 0.5 | |
| self.std = 0.5 | |
| self.model.vision_model.compile() | |
| self.model.text_model.compile() | |
| def forward(self, x, text, resize=True): | |
| tokens = self.tokenizer(text, truncation=True, return_tensors='pt', padding="max_length", max_length=64).input_ids.cuda() | |
| text_output = self.model.text_model(input_ids=tokens).last_hidden_state | |
| text_output = torch.nn.functional.normalize(text_output, dim=-1, p=2) | |
| if resize: | |
| x = torch.nn.functional.interpolate(x, (self.height, self.width), mode='bicubic') | |
| x = (x - self.mean)/self.std | |
| vision_output = self.model.vision_model(x).last_hidden_state | |
| vision_output = torch.nn.functional.normalize(vision_output, dim=-1, p=2) | |
| output = torch.bmm(vision_output, text_output.transpose(1, 2)) | |
| return output | |
| from transformers import SiglipVisionModel | |
| class SigLIPVision(nn.Module): | |
| def __init__(self, weight_path:str, base_patch_size=16): | |
| super(SigLIPVision, self).__init__() | |
| self.model = SiglipVisionModel.from_pretrained(weight_path).to(torch.bfloat16) | |
| self.height = self.model.config.image_size | |
| self.width = self.model.config.image_size | |
| self.patch_size = self.model.config.patch_size | |
| self.base_patch_size = base_patch_size | |
| self.model.compile() | |
| self.mean = 0.5 | |
| self.std = 0.5 | |
| def forward(self, x, resize=True): | |
| if resize: | |
| h, w = x.shape[-2:] | |
| new_h = int(self.patch_size * h / self.base_patch_size) | |
| new_w = int(self.patch_size * w / self.base_patch_size) | |
| x = torch.nn.functional.interpolate(x, (new_h, new_w), mode='bicubic') | |
| x = (x - self.mean)/self.std | |
| vision_output = self.model.vision_model(x).last_hidden_state | |
| return vision_output |