Spaces:
Build error
Build error
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import math | |
| from os.path import join | |
| import numpy as np | |
| import copy | |
| from functools import partial | |
| import torch | |
| from torch import nn | |
| import torch.utils.model_zoo as model_zoo | |
| import torch.nn.functional as F | |
| import fvcore.nn.weight_init as weight_init | |
| from detectron2.modeling.backbone import FPN | |
| from detectron2.modeling.backbone.build import BACKBONE_REGISTRY | |
| from detectron2.layers.batch_norm import get_norm, FrozenBatchNorm2d | |
| from detectron2.modeling.backbone import Backbone | |
| from timm import create_model | |
| from timm.models.helpers import build_model_with_cfg | |
| from timm.models.registry import register_model | |
| from timm.models.resnet import ResNet, Bottleneck | |
| from timm.models.resnet import default_cfgs as default_cfgs_resnet | |
| class CustomResNet(ResNet): | |
| def __init__(self, **kwargs): | |
| self.out_indices = kwargs.pop('out_indices') | |
| super().__init__(**kwargs) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.act1(x) | |
| x = self.maxpool(x) | |
| ret = [x] | |
| x = self.layer1(x) | |
| ret.append(x) | |
| x = self.layer2(x) | |
| ret.append(x) | |
| x = self.layer3(x) | |
| ret.append(x) | |
| x = self.layer4(x) | |
| ret.append(x) | |
| return [ret[i] for i in self.out_indices] | |
| def load_pretrained(self, cached_file): | |
| data = torch.load(cached_file, map_location='cpu') | |
| if 'state_dict' in data: | |
| self.load_state_dict(data['state_dict']) | |
| else: | |
| self.load_state_dict(data) | |
| model_params = { | |
| 'resnet50': dict(block=Bottleneck, layers=[3, 4, 6, 3]), | |
| 'resnet50_in21k': dict(block=Bottleneck, layers=[3, 4, 6, 3]), | |
| } | |
| def create_timm_resnet(variant, out_indices, pretrained=False, **kwargs): | |
| params = model_params[variant] | |
| default_cfgs_resnet['resnet50_in21k'] = \ | |
| copy.deepcopy(default_cfgs_resnet['resnet50']) | |
| default_cfgs_resnet['resnet50_in21k']['url'] = \ | |
| 'https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth' | |
| default_cfgs_resnet['resnet50_in21k']['num_classes'] = 11221 | |
| return build_model_with_cfg( | |
| CustomResNet, variant, pretrained, | |
| default_cfg=default_cfgs_resnet[variant], | |
| out_indices=out_indices, | |
| pretrained_custom_load=True, | |
| **params, | |
| **kwargs) | |
| class LastLevelP6P7_P5(nn.Module): | |
| """ | |
| """ | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.num_levels = 2 | |
| self.in_feature = "p5" | |
| self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) | |
| self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) | |
| for module in [self.p6, self.p7]: | |
| weight_init.c2_xavier_fill(module) | |
| def forward(self, c5): | |
| p6 = self.p6(c5) | |
| p7 = self.p7(F.relu(p6)) | |
| return [p6, p7] | |
| def freeze_module(x): | |
| """ | |
| """ | |
| for p in x.parameters(): | |
| p.requires_grad = False | |
| FrozenBatchNorm2d.convert_frozen_batchnorm(x) | |
| return x | |
| class TIMM(Backbone): | |
| def __init__(self, base_name, out_levels, freeze_at=0, norm='FrozenBN'): | |
| super().__init__() | |
| out_indices = [x - 1 for x in out_levels] | |
| if 'resnet' in base_name: | |
| self.base = create_timm_resnet( | |
| base_name, out_indices=out_indices, | |
| pretrained=False) | |
| elif 'eff' in base_name: | |
| self.base = create_model( | |
| base_name, features_only=True, | |
| out_indices=out_indices, pretrained=True) | |
| else: | |
| assert 0, base_name | |
| feature_info = [dict(num_chs=f['num_chs'], reduction=f['reduction']) \ | |
| for i, f in enumerate(self.base.feature_info)] | |
| self._out_features = ['layer{}'.format(x) for x in out_levels] | |
| self._out_feature_channels = { | |
| 'layer{}'.format(l): feature_info[l - 1]['num_chs'] for l in out_levels} | |
| self._out_feature_strides = { | |
| 'layer{}'.format(l): feature_info[l - 1]['reduction'] for l in out_levels} | |
| self._size_divisibility = max(self._out_feature_strides.values()) | |
| if 'resnet' in base_name: | |
| self.freeze(freeze_at) | |
| if norm == 'FrozenBN': | |
| self = FrozenBatchNorm2d.convert_frozen_batchnorm(self) | |
| def freeze(self, freeze_at=0): | |
| """ | |
| """ | |
| if freeze_at >= 1: | |
| print('Frezing', self.base.conv1) | |
| self.base.conv1 = freeze_module(self.base.conv1) | |
| if freeze_at >= 2: | |
| print('Frezing', self.base.layer1) | |
| self.base.layer1 = freeze_module(self.base.layer1) | |
| def forward(self, x): | |
| features = self.base(x) | |
| ret = {k: v for k, v in zip(self._out_features, features)} | |
| return ret | |
| def size_divisibility(self): | |
| return self._size_divisibility | |
| def build_timm_backbone(cfg, input_shape): | |
| model = TIMM( | |
| cfg.MODEL.TIMM.BASE_NAME, | |
| cfg.MODEL.TIMM.OUT_LEVELS, | |
| freeze_at=cfg.MODEL.TIMM.FREEZE_AT, | |
| norm=cfg.MODEL.TIMM.NORM, | |
| ) | |
| return model | |
| def build_p67_timm_fpn_backbone(cfg, input_shape): | |
| """ | |
| """ | |
| bottom_up = build_timm_backbone(cfg, input_shape) | |
| in_features = cfg.MODEL.FPN.IN_FEATURES | |
| out_channels = cfg.MODEL.FPN.OUT_CHANNELS | |
| backbone = FPN( | |
| bottom_up=bottom_up, | |
| in_features=in_features, | |
| out_channels=out_channels, | |
| norm=cfg.MODEL.FPN.NORM, | |
| top_block=LastLevelP6P7_P5(out_channels, out_channels), | |
| fuse_type=cfg.MODEL.FPN.FUSE_TYPE, | |
| ) | |
| return backbone | |
| def build_p35_timm_fpn_backbone(cfg, input_shape): | |
| """ | |
| """ | |
| bottom_up = build_timm_backbone(cfg, input_shape) | |
| in_features = cfg.MODEL.FPN.IN_FEATURES | |
| out_channels = cfg.MODEL.FPN.OUT_CHANNELS | |
| backbone = FPN( | |
| bottom_up=bottom_up, | |
| in_features=in_features, | |
| out_channels=out_channels, | |
| norm=cfg.MODEL.FPN.NORM, | |
| top_block=None, | |
| fuse_type=cfg.MODEL.FPN.FUSE_TYPE, | |
| ) | |
| return backbone |