Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| from modelscope.msdatasets import MsDataset | |
| class Interpolate(nn.Module): | |
| def __init__( | |
| self, | |
| size=None, | |
| scale_factor=None, | |
| mode="bilinear", | |
| align_corners=False, | |
| ): | |
| super(Interpolate, self).__init__() | |
| self.size = size | |
| self.scale_factor = scale_factor | |
| self.mode = mode | |
| self.align_corners = align_corners | |
| def forward(self, x): | |
| return F.interpolate( | |
| x, | |
| size=self.size, | |
| scale_factor=self.scale_factor, | |
| mode=self.mode, | |
| align_corners=self.align_corners, | |
| ) | |
| class t_EvalNet: | |
| def __init__( | |
| self, | |
| backbone: str, | |
| cls_num: int, | |
| ori_T: int, | |
| imgnet_ver="v1", | |
| weight_path="", | |
| ): | |
| if not hasattr(models, backbone): | |
| raise ValueError(f"Unsupported model {backbone}.") | |
| self.imgnet_ver = imgnet_ver | |
| self.type, self.weight_url, self.input_size = self._model_info(backbone) | |
| self.model: torch.nn.Module = eval("models.%s()" % backbone) | |
| self.ori_T = ori_T | |
| if self.type == "vit": | |
| self.hidden_dim = self.model.hidden_dim | |
| self.class_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) | |
| elif self.type == "swin_transformer": | |
| self.hidden_dim = 768 | |
| self.cls_num = cls_num | |
| self._set_classifier() | |
| checkpoint = ( | |
| torch.load(weight_path) | |
| if torch.cuda.is_available() | |
| else torch.load(weight_path, map_location="cpu") | |
| ) | |
| self.model.load_state_dict(checkpoint["model"], False) | |
| self.classifier.load_state_dict(checkpoint["classifier"], False) | |
| if torch.cuda.is_available(): | |
| self.model = self.model.cuda() | |
| self.classifier = self.classifier.cuda() | |
| self.model.eval() | |
| def _get_backbone(self, backbone_ver, backbone_list): | |
| for backbone_info in backbone_list: | |
| if backbone_ver == backbone_info["ver"]: | |
| return backbone_info | |
| raise ValueError("[Backbone not found] Please check if --model is correct!") | |
| def _model_info(self, backbone: str): | |
| backbone_list = MsDataset.load( | |
| "monetjoe/cv_backbones", | |
| split=self.imgnet_ver, | |
| cache_dir="./__pycache__", | |
| ) | |
| backbone_info = self._get_backbone(backbone, backbone_list) | |
| return ( | |
| str(backbone_info["type"]), | |
| str(backbone_info["url"]), | |
| int(backbone_info["input_size"]), | |
| ) | |
| def _create_classifier(self): | |
| original_T_size = self.ori_T | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, None)) # F -> 1 | |
| upsample_module = nn.Sequential( # nn.AdaptiveAvgPool2d((1, None)), # F -> 1 | |
| nn.ConvTranspose2d( | |
| self.hidden_dim, 256, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) | |
| ), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(256), | |
| nn.ConvTranspose2d( | |
| 256, 128, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) | |
| ), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(128), | |
| nn.ConvTranspose2d( | |
| 128, 64, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) | |
| ), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(64), | |
| nn.ConvTranspose2d( | |
| 64, 32, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1) | |
| ), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(32), # input for Interp: [bsz, C, 1, T] | |
| Interpolate( | |
| size=(1, original_T_size), mode="bilinear", align_corners=False | |
| ), # classifier | |
| nn.Conv2d(32, 32, kernel_size=(1, 1)), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(32), | |
| nn.Conv2d(32, self.cls_num, kernel_size=(1, 1)), | |
| ) | |
| return upsample_module | |
| def _set_classifier(self): #### set custom classifier #### | |
| if self.type == "vit" or self.type == "swin_transformer": | |
| self.classifier = self._create_classifier() | |
| def get_input_size(self): | |
| return self.input_size | |
| def forward(self, x: torch.Tensor): | |
| if torch.cuda.is_available(): | |
| x = x.cuda() | |
| if self.type == "vit": | |
| x = self.model._process_input(x) | |
| batch_class_token = self.class_token.expand(x.size(0), -1, -1).cuda() | |
| x = torch.cat([batch_class_token, x], dim=1) | |
| x = self.model.encoder(x) | |
| x = x[:, 1:].permute(0, 2, 1) | |
| x = x.unsqueeze(2) | |
| x = self.classifier(x).squeeze() # x shape: [bsz, hidden_dim, 1, seq_len] | |
| return x | |
| elif self.type == "swin_transformer": | |
| x = self.model.features(x) # [B, H, W, C] | |
| x = x.permute(0, 3, 1, 2) | |
| x = self.avgpool(x) # [B, C, 1, W] | |
| x = self.classifier(x).squeeze() | |
| return x | |
| return None | |