Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # Neighborhood Attention Transformer | |
| # Licensed under The MIT License | |
| # Written by Ali Hassani | |
| # -------------------------------------------------------- | |
| # Modified by Jitesh Jain | |
| import torch | |
| import torch.nn as nn | |
| from timm.models.layers import DropPath | |
| from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec | |
| from natten import NeighborhoodAttention2D as NeighborhoodAttention | |
| class ConvTokenizer(nn.Module): | |
| def __init__(self, in_chans=3, embed_dim=96, norm_layer=None): | |
| super().__init__() | |
| self.proj = nn.Sequential( | |
| nn.Conv2d(in_chans, embed_dim // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), | |
| nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), | |
| ) | |
| if norm_layer is not None: | |
| self.norm = norm_layer(embed_dim) | |
| else: | |
| self.norm = None | |
| def forward(self, x): | |
| x = self.proj(x).permute(0, 2, 3, 1) | |
| if self.norm is not None: | |
| x = self.norm(x) | |
| return x | |
| class ConvDownsampler(nn.Module): | |
| def __init__(self, dim, norm_layer=nn.LayerNorm): | |
| super().__init__() | |
| self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) | |
| self.norm = norm_layer(2 * dim) | |
| def forward(self, x): | |
| x = self.reduction(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) | |
| x = self.norm(x) | |
| return x | |
| class Mlp(nn.Module): | |
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.fc1 = nn.Linear(in_features, hidden_features) | |
| self.act = act_layer() | |
| self.fc2 = nn.Linear(hidden_features, out_features) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| x = self.drop(x) | |
| return x | |
| class NATLayer(nn.Module): | |
| def __init__(self, dim, num_heads, kernel_size=7, dilation=None, | |
| mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., | |
| act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale=None): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.mlp_ratio = mlp_ratio | |
| self.norm1 = norm_layer(dim) | |
| self.attn = NeighborhoodAttention( | |
| dim, kernel_size=kernel_size, dilation=dilation, num_heads=num_heads, | |
| qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) | |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) | |
| self.layer_scale = False | |
| if layer_scale is not None and type(layer_scale) in [int, float]: | |
| self.layer_scale = True | |
| self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True) | |
| self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True) | |
| def forward(self, x): | |
| if not self.layer_scale: | |
| shortcut = x | |
| x = self.norm1(x) | |
| x = self.attn(x) | |
| x = shortcut + self.drop_path(x) | |
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |
| return x | |
| shortcut = x | |
| x = self.norm1(x) | |
| x = self.attn(x) | |
| x = shortcut + self.drop_path(self.gamma1 * x) | |
| x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) | |
| return x | |
| class NATBlock(nn.Module): | |
| def __init__(self, dim, depth, num_heads, kernel_size, dilations=None, | |
| downsample=True, | |
| mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., | |
| drop_path=0., norm_layer=nn.LayerNorm, layer_scale=None): | |
| super().__init__() | |
| self.dim = dim | |
| self.depth = depth | |
| self.blocks = nn.ModuleList([ | |
| NATLayer(dim=dim, | |
| num_heads=num_heads, | |
| kernel_size=kernel_size, | |
| dilation=None if dilations is None else dilations[i], | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, qk_scale=qk_scale, | |
| drop=drop, attn_drop=attn_drop, | |
| drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, | |
| norm_layer=norm_layer, | |
| layer_scale=layer_scale) | |
| for i in range(depth)]) | |
| self.downsample = None if not downsample else ConvDownsampler(dim=dim, norm_layer=norm_layer) | |
| def forward(self, x): | |
| for blk in self.blocks: | |
| x = blk(x) | |
| if self.downsample is None: | |
| return x, x | |
| return self.downsample(x), x | |
| class DiNAT(nn.Module): | |
| def __init__(self, | |
| embed_dim, | |
| mlp_ratio, | |
| depths, | |
| num_heads, | |
| drop_path_rate=0.2, | |
| in_chans=3, | |
| kernel_size=7, | |
| dilations=None, | |
| out_indices=(0, 1, 2, 3), | |
| qkv_bias=True, | |
| qk_scale=None, | |
| drop_rate=0., | |
| attn_drop_rate=0., | |
| norm_layer=nn.LayerNorm, | |
| frozen_stages=-1, | |
| layer_scale=None, | |
| **kwargs): | |
| super().__init__() | |
| self.num_levels = len(depths) | |
| self.embed_dim = embed_dim | |
| self.num_features = [int(embed_dim * 2 ** i) for i in range(self.num_levels)] | |
| self.mlp_ratio = mlp_ratio | |
| self.patch_embed = ConvTokenizer(in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer) | |
| self.pos_drop = nn.Dropout(p=drop_rate) | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] | |
| self.levels = nn.ModuleList() | |
| for i in range(self.num_levels): | |
| level = NATBlock(dim=int(embed_dim * 2 ** i), | |
| depth=depths[i], | |
| num_heads=num_heads[i], | |
| kernel_size=kernel_size, | |
| dilations=None if dilations is None else dilations[i], | |
| mlp_ratio=self.mlp_ratio, | |
| qkv_bias=qkv_bias, qk_scale=qk_scale, | |
| drop=drop_rate, attn_drop=attn_drop_rate, | |
| drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], | |
| norm_layer=norm_layer, | |
| downsample=(i < self.num_levels - 1), | |
| layer_scale=layer_scale) | |
| self.levels.append(level) | |
| # add a norm layer for each output | |
| self.out_indices = out_indices | |
| for i_layer in self.out_indices: | |
| layer = norm_layer(self.num_features[i_layer]) | |
| layer_name = f'norm{i_layer}' | |
| self.add_module(layer_name, layer) | |
| self.frozen_stages = frozen_stages | |
| def _freeze_stages(self): | |
| if self.frozen_stages >= 0: | |
| self.patch_embed.eval() | |
| for param in self.patch_embed.parameters(): | |
| param.requires_grad = False | |
| if self.frozen_stages >= 2: | |
| for i in range(0, self.frozen_stages - 1): | |
| m = self.network[i] | |
| m.eval() | |
| for param in m.parameters(): | |
| param.requires_grad = False | |
| def train(self, mode=True): | |
| super(DiNAT, self).train(mode) | |
| self._freeze_stages() | |
| def forward_embeddings(self, x): | |
| x = self.patch_embed(x) | |
| return x | |
| def forward_tokens(self, x): | |
| outs = {} | |
| for idx, level in enumerate(self.levels): | |
| x, xo = level(x) | |
| if idx in self.out_indices: | |
| norm_layer = getattr(self, f'norm{idx}') | |
| x_out = norm_layer(xo) | |
| outs["res{}".format(idx + 2)] = x_out.permute(0, 3, 1, 2).contiguous() | |
| return outs | |
| def forward(self, x): | |
| x = self.forward_embeddings(x) | |
| return self.forward_tokens(x) | |
| class D2DiNAT(DiNAT, Backbone): | |
| def __init__(self, cfg, input_shape): | |
| embed_dim = cfg.MODEL.DiNAT.EMBED_DIM | |
| mlp_ratio = cfg.MODEL.DiNAT.MLP_RATIO | |
| depths = cfg.MODEL.DiNAT.DEPTHS | |
| num_heads = cfg.MODEL.DiNAT.NUM_HEADS | |
| drop_path_rate = cfg.MODEL.DiNAT.DROP_PATH_RATE | |
| kernel_size = cfg.MODEL.DiNAT.KERNEL_SIZE | |
| out_indices = cfg.MODEL.DiNAT.OUT_INDICES | |
| dilations = cfg.MODEL.DiNAT.DILATIONS | |
| super().__init__( | |
| embed_dim=embed_dim, | |
| mlp_ratio=mlp_ratio, | |
| depths=depths, | |
| num_heads=num_heads, | |
| drop_path_rate=drop_path_rate, | |
| kernel_size=kernel_size, | |
| out_indices=out_indices, | |
| dilations=dilations, | |
| ) | |
| self._out_features = cfg.MODEL.DiNAT.OUT_FEATURES | |
| self._out_feature_strides = { | |
| "res2": 4, | |
| "res3": 8, | |
| "res4": 16, | |
| "res5": 32, | |
| } | |
| self._out_feature_channels = { | |
| "res2": self.num_features[0], | |
| "res3": self.num_features[1], | |
| "res4": self.num_features[2], | |
| "res5": self.num_features[3], | |
| } | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. | |
| Returns: | |
| dict[str->Tensor]: names and the corresponding features | |
| """ | |
| assert ( | |
| x.dim() == 4 | |
| ), f"DiNAT takes an input of shape (N, C, H, W). Got {x.shape} instead!" | |
| outputs = {} | |
| y = super().forward(x) | |
| for k in y.keys(): | |
| if k in self._out_features: | |
| outputs[k] = y[k] | |
| return outputs | |
| def output_shape(self): | |
| return { | |
| name: ShapeSpec( | |
| channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] | |
| ) | |
| for name in self._out_features | |
| } | |
| def size_divisibility(self): | |
| return 32 | |