Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------ | |
| # Copyright 2023 Haotian Liu | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ------------------------------------------------------------------------ | |
| # Modified from LLaVA (https://github.com/haotian-liu/LLaVA) and S^2(https://github.com/bfshi/scaling_on_scales) | |
| # Copyright 2024 Jiachen Li | |
| # ------------------------------------------------------------------------ | |
| import torch | |
| import torch.nn as nn | |
| from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig | |
| import torch.nn.functional as F | |
| from transformers.activations import ACT2FN | |
| import math | |
| from einops import rearrange | |
| from .clip import CLIPVisionTransformer | |
| from .clip_smoe import CLIPSMoEVisionTransformer | |
| class CLIPVisionTower(nn.Module): | |
| def __init__(self, vision_tower, args, delay_load=False): | |
| super().__init__() | |
| self.vision_tower_name = vision_tower | |
| self.select_layer = args.mm_vision_select_layer | |
| self.clip_smoe = args.clip_smoe | |
| self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') | |
| self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) | |
| self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) | |
| self.scales = args.scales | |
| if args.clip_smoe: | |
| self.vision_model = CLIPSMoEVisionTransformer(self.cfg_only, num_experts=args.num_experts, num_selected=args.num_selected) | |
| else: | |
| self.vision_model = CLIPVisionTransformer(self.cfg_only) | |
| self.is_loaded = True | |
| def feature_select(self, image_features): | |
| #image_features = image_forward_outs.hidden_states[self.select_layer] | |
| if self.select_feature == 'patch': | |
| image_features = image_features[:, 1:] | |
| elif self.select_feature == 'cls_patch': | |
| image_features = image_features | |
| else: | |
| raise ValueError(f'Unexpected select feature: {self.select_feature}') | |
| return image_features | |
| def split_chessboard(self, x, num_split): | |
| """ | |
| x: b * c * h * w | |
| Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension | |
| """ | |
| B, C, H, W = x.shape | |
| assert H % num_split == 0 and W % num_split == 0 | |
| h, w = H // num_split, W // num_split | |
| x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)], dim=0) | |
| return x_split | |
| def merge_chessboard(self, x, num_split): | |
| """ | |
| x: b * c * h * w | |
| Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square. | |
| (inverse of split_chessboard) | |
| """ | |
| B, C, H, W = x.shape | |
| assert B % (num_split**2) == 0 | |
| b = B // (num_split**2) | |
| x_merge = torch.cat([torch.cat([x[(i*num_split + j)*b:(i*num_split + j + 1)*b] for j in range(num_split)], dim=-1) for i in range(num_split)], dim=-2) | |
| return x_merge | |
| def forward(self, images): | |
| if type(images) is list: | |
| image_features = [] | |
| for image in images: | |
| image_forward_out = self.vision_model(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) | |
| image_feature = self.feature_select(image_forward_out).to(image.dtype) | |
| image_features.append(image_feature) | |
| else: | |
| input_size = images.shape[3] | |
| img_sizes = [int(input_size * scale) for scale in self.scales] | |
| num_splits = [math.ceil(size / input_size) for size in img_sizes] | |
| image_pyramids = [images] | |
| for i, (size, num_split) in enumerate(zip(img_sizes, num_splits)): | |
| if i > 0: | |
| x = F.interpolate(images.to(torch.float32), size=size, mode='bicubic').to(images.dtype) | |
| x = self.split_chessboard(x, num_split=num_split) | |
| image_pyramids.append(x) | |
| if self.clip_smoe: | |
| image_features = [] | |
| balance_losses = [] | |
| router_z_losses = [] | |
| for i, (x, num_split) in enumerate(zip(image_pyramids, num_splits)): | |
| out_x, balance_loss, router_z_loss = self.vision_model(x) | |
| out_x = self.feature_select(out_x) | |
| if i > 0: | |
| out_x = rearrange(out_x, 'b (h w) c -> b c h w', h=int(out_x.shape[1] ** 0.5), w=int(out_x.shape[1] ** 0.5)) | |
| out_x = self.merge_chessboard(out_x, num_split=num_split) | |
| out_x = F.interpolate(out_x.to(torch.float32), size=int(image_features[0].shape[1] ** 0.5), mode='area').to(x.dtype) | |
| out_x = rearrange(out_x, 'b c h w -> b (h w) c') | |
| image_features.append(out_x) | |
| balance_losses.append(balance_loss) | |
| router_z_losses.append(router_z_loss) | |
| image_features = torch.cat(image_features, dim=-1) | |
| return image_features, torch.stack(balance_losses).mean(), torch.stack(router_z_losses).mean() | |
| else: | |
| image_features = [] | |
| for i, (x, num_split) in enumerate(zip(image_pyramids, num_splits)): | |
| out_x = self.vision_model(x) | |
| out_x = self.feature_select(out_x) | |
| if i > 0: | |
| out_x = rearrange(out_x, 'b (h w) c -> b c h w', h=int(out_x.shape[1] ** 0.5), w=int(out_x.shape[1] ** 0.5)) | |
| out_x = self.merge_chessboard(out_x, num_split=num_split) | |
| out_x = F.interpolate(out_x.to(torch.float32), size=int(image_features[0].shape[1] ** 0.5), mode='area').to(x.dtype) | |
| out_x = rearrange(out_x, 'b c h w -> b (h w) c') | |
| image_features.append(out_x) | |
| image_features = torch.cat(image_features, dim=-1) | |
| return image_features, None, None | |
| def dummy_feature(self): | |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
| def dtype(self): | |
| return self.vision_model.dtype | |
| def device(self): | |
| return self.vision_model.device | |
| def config(self): | |
| if self.is_loaded: | |
| return self.vision_model.config | |
| else: | |
| return self.cfg_only | |
| def hidden_size(self): | |
| return self.config.hidden_size | |
| def num_patches_per_side(self): | |
| return self.config.image_size // self.config.patch_size | |
| def num_patches(self): | |
| return (self.config.image_size // self.config.patch_size) ** 2 | |