Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel | |
| # from tinyllava.utils.data_utils import get_value_from_kwargs | |
| def get_value_from_kwargs(kwargs, name): | |
| if name in kwargs: | |
| return kwargs.pop(name) | |
| else: | |
| return None | |
| class VisionTower(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self._vision_tower = None | |
| self._image_processor = None | |
| self.config = cfg | |
| def load_model(self, vision_tower_name, **kwargs): | |
| self._load_model(vision_tower_name, **kwargs) | |
| self._vision_tower.requires_grad_(False) | |
| def _load_model(self, vision_tower_name, **kwargs): | |
| pretrained_vision_tower_path = get_value_from_kwargs(kwargs, 'pretrained_vision_tower_path') | |
| if isinstance(self._vision_tower, PreTrainedModel): # hf model | |
| if pretrained_vision_tower_path is not None: | |
| vision_tower_name = pretrained_vision_tower_path | |
| self._vision_tower = self._vision_tower.from_pretrained(vision_tower_name, **kwargs) | |
| else: # nn.Module | |
| if pretrained_vision_tower_path is not None: | |
| vision_tower_weights = torch.load(os.path.join(pretrained_vision_tower_path, 'pytorch_model.bin'), map_location='cpu') | |
| def get_w(weights, keyword): | |
| return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} | |
| self._vision_tower.load_state_dict(vision_tower_weights) | |
| print("Loading vision tower from ", vision_tower_name) | |
| def forward(self, x, **kwargs): | |
| x = x.to(torch.float32) | |
| image_features = self._vision_tower(x, output_hidden_states=True) | |
| image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)] | |
| if kwargs.get('vision_feature_select_strategy', 'patch') == 'patch': | |
| image_features = image_features[:, 1:] | |
| elif kwargs.get('vision_feature_select_strategy', 'patch') == 'cls_patch': | |
| image_features = image_features | |
| else: | |
| raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}") | |
| return image_features | |
| def vision_tower(self): | |
| return self._vision_tower | |
| def vision_tower(self, vision_tower): | |
| self._vision_tower = vision_tower | |