Spaces:
Running
Running
| ''' | |
| DFormerv2: Geometry Self-Attention for RGBD Semantic Segmentation | |
| Code: https://github.com/VCIP-RGBD/DFormer | |
| Author: yinbow | |
| Email: bowenyin@mail.nankai.edu.cn | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| ''' | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint as checkpoint | |
| import math | |
| from timm.models.layers import DropPath, trunc_normal_ | |
| from typing import List | |
| from mmengine.runner.checkpoint import load_state_dict | |
| from mmengine.runner.checkpoint import load_checkpoint | |
| from typing import Tuple | |
| import sys | |
| import os | |
| from collections import OrderedDict | |
| class LayerNorm2d(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim, eps=1e-6) | |
| def forward(self, x: torch.Tensor): | |
| ''' | |
| input shape (b c h w) | |
| ''' | |
| x = x.permute(0, 2, 3, 1).contiguous() #(b h w c) | |
| x = self.norm(x) #(b h w c) | |
| x = x.permute(0, 3, 1, 2).contiguous() | |
| return x | |
| class PatchEmbed(nn.Module): | |
| """ | |
| Image to Patch Embedding | |
| """ | |
| def __init__(self, in_chans=3, embed_dim=96, norm_layer=None): | |
| super().__init__() | |
| self.in_chans = in_chans | |
| self.embed_dim = embed_dim | |
| self.proj = nn.Sequential( | |
| nn.Conv2d(in_chans, embed_dim//2, 3, 2, 1), | |
| nn.SyncBatchNorm(embed_dim//2), | |
| nn.GELU(), | |
| nn.Conv2d(embed_dim//2, embed_dim//2, 3, 1, 1), | |
| nn.SyncBatchNorm(embed_dim//2), | |
| nn.GELU(), | |
| nn.Conv2d(embed_dim//2, embed_dim, 3, 2, 1), | |
| nn.SyncBatchNorm(embed_dim), | |
| nn.GELU(), | |
| nn.Conv2d(embed_dim, embed_dim, 3, 1, 1), | |
| nn.SyncBatchNorm(embed_dim) | |
| ) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| x = self.proj(x).permute(0, 2, 3, 1) | |
| return x | |
| class DWConv2d(nn.Module): | |
| def __init__(self, dim, kernel_size, stride, padding): | |
| super().__init__() | |
| self.dwconv = nn.Conv2d(dim, dim, kernel_size, stride, padding, groups=dim) | |
| def forward(self, x: torch.Tensor): | |
| ''' | |
| input (b h w c) | |
| ''' | |
| x = x.permute(0, 3, 1, 2) | |
| x = self.dwconv(x) | |
| x = x.permute(0, 2, 3, 1) | |
| return x | |
| class PatchMerging(nn.Module): | |
| """ | |
| Patch Merging Layer. | |
| """ | |
| def __init__(self, dim, out_dim, norm_layer=nn.LayerNorm): | |
| super().__init__() | |
| self.dim = dim | |
| self.reduction = nn.Conv2d(dim, out_dim, 3, 2, 1) | |
| self.norm = nn.SyncBatchNorm(out_dim) | |
| def forward(self, x): | |
| ''' | |
| x: B H W C | |
| ''' | |
| x = x.permute(0, 3, 1, 2).contiguous() #(b c h w) | |
| x = self.reduction(x) #(b oc oh ow) | |
| x = self.norm(x) | |
| x = x.permute(0, 2, 3, 1) #(b oh ow oc) | |
| return x | |
| def angle_transform(x, sin, cos): | |
| x1 = x[:, :, :, :, ::2] | |
| x2 = x[:, :, :, :, 1::2] | |
| return (x * cos) + (torch.stack([-x2, x1], dim=-1).flatten(-2) * sin) | |
| class GeoPriorGen(nn.Module): | |
| def __init__(self, embed_dim, num_heads, initial_value, heads_range): | |
| super().__init__() | |
| angle = 1.0 / (10000 ** torch.linspace(0, 1, embed_dim // num_heads // 2)) | |
| angle = angle.unsqueeze(-1).repeat(1, 2).flatten() | |
| self.weight = nn.Parameter(torch.ones(2,1,1,1), requires_grad=True) | |
| decay = torch.log(1 - 2 ** (-initial_value - heads_range * torch.arange(num_heads, dtype=torch.float) / num_heads)) | |
| self.register_buffer('angle', angle) | |
| self.register_buffer('decay', decay) | |
| def generate_depth_decay(self, H: int, W: int, depth_grid): | |
| ''' | |
| generate 2d decay mask, the result is (HW)*(HW) | |
| H, W are the numbers of patches at each column and row | |
| ''' | |
| B,_,H,W = depth_grid.shape | |
| grid_d = depth_grid.reshape(B, H*W, 1) | |
| mask_d = grid_d[:, :, None, :] - grid_d[:, None, :, :] | |
| mask_d = (mask_d.abs()).sum(dim=-1) | |
| mask_d = mask_d.unsqueeze(1) * self.decay[None, :, None, None] | |
| return mask_d | |
| def generate_pos_decay(self, H: int, W: int): | |
| ''' | |
| generate 2d decay mask, the result is (HW)*(HW) | |
| H, W are the numbers of patches at each column and row | |
| ''' | |
| index_h = torch.arange(H).to(self.decay) | |
| index_w = torch.arange(W).to(self.decay) | |
| grid = torch.meshgrid([index_h, index_w]) | |
| grid = torch.stack(grid, dim=-1).reshape(H*W, 2) | |
| mask = grid[:, None, :] - grid[None, :, :] | |
| mask = (mask.abs()).sum(dim=-1) | |
| mask = mask * self.decay[:, None, None] | |
| return mask | |
| def generate_1d_depth_decay(self, H, W, depth_grid): | |
| ''' | |
| generate 1d depth decay mask, the result is l*l | |
| ''' | |
| mask = depth_grid[:, :, :, :, None] - depth_grid[:, :, :, None, :] | |
| mask = mask.abs() | |
| mask = mask * self.decay[:, None, None, None] | |
| assert mask.shape[2:] == (W,H,H) | |
| return mask | |
| def generate_1d_decay(self, l: int): | |
| ''' | |
| generate 1d decay mask, the result is l*l | |
| ''' | |
| index = torch.arange(l).to(self.decay) | |
| mask = index[:, None] - index[None, :] | |
| mask = mask.abs() | |
| mask = mask * self.decay[:, None, None] | |
| return mask | |
| def forward(self, HW_tuple: Tuple[int], depth_map, split_or_not=False): | |
| ''' | |
| depth_map: depth patches | |
| HW_tuple: (H, W) | |
| H * W == l | |
| ''' | |
| depth_map = F.interpolate(depth_map, size=HW_tuple,mode='bilinear',align_corners=False) | |
| if split_or_not: | |
| index = torch.arange(HW_tuple[0]*HW_tuple[1]).to(self.decay) | |
| sin = torch.sin(index[:, None] * self.angle[None, :]) | |
| sin = sin.reshape(HW_tuple[0], HW_tuple[1], -1) | |
| cos = torch.cos(index[:, None] * self.angle[None, :]) | |
| cos = cos.reshape(HW_tuple[0], HW_tuple[1], -1) | |
| mask_d_h = self.generate_1d_depth_decay(HW_tuple[0], HW_tuple[1], depth_map.transpose(-2,-1)) | |
| mask_d_w = self.generate_1d_depth_decay(HW_tuple[1], HW_tuple[0], depth_map) | |
| mask_h = self.generate_1d_decay(HW_tuple[0]) | |
| mask_w = self.generate_1d_decay(HW_tuple[1]) | |
| mask_h = self.weight[0]*mask_h.unsqueeze(0).unsqueeze(2) + self.weight[1]*mask_d_h | |
| mask_w = self.weight[0]*mask_w.unsqueeze(0).unsqueeze(2) + self.weight[1]*mask_d_w | |
| geo_prior = ((sin, cos), (mask_h, mask_w)) | |
| else: | |
| index = torch.arange(HW_tuple[0]*HW_tuple[1]).to(self.decay) | |
| sin = torch.sin(index[:, None] * self.angle[None, :]) | |
| sin = sin.reshape(HW_tuple[0], HW_tuple[1], -1) | |
| cos = torch.cos(index[:, None] * self.angle[None, :]) | |
| cos = cos.reshape(HW_tuple[0], HW_tuple[1], -1) | |
| mask = self.generate_pos_decay(HW_tuple[0], HW_tuple[1]) | |
| mask_d = self.generate_depth_decay(HW_tuple[0], HW_tuple[1], depth_map) | |
| mask = (self.weight[0]*mask+self.weight[1]*mask_d) | |
| geo_prior = ((sin, cos), mask) | |
| return geo_prior | |
| class Decomposed_GSA(nn.Module): | |
| def __init__(self, embed_dim, num_heads, value_factor=1): | |
| super().__init__() | |
| self.factor = value_factor | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = self.embed_dim * self.factor // num_heads | |
| self.key_dim = self.embed_dim // num_heads | |
| self.scaling = self.key_dim ** -0.5 | |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim * self.factor, bias=True) | |
| self.lepe = DWConv2d(embed_dim, 5, 1, 2) | |
| self.out_proj = nn.Linear(embed_dim*self.factor, embed_dim, bias=True) | |
| self.reset_parameters() | |
| def forward(self, x: torch.Tensor, rel_pos, split_or_not=False): | |
| bsz, h, w, _ = x.size() | |
| (sin, cos), (mask_h, mask_w) = rel_pos | |
| q = self.q_proj(x) | |
| k = self.k_proj(x) | |
| v = self.v_proj(x) | |
| lepe = self.lepe(v) | |
| k = k * self.scaling | |
| q = q.view(bsz, h, w, self.num_heads, self.key_dim).permute(0, 3, 1, 2, 4) #(b n h w d1) | |
| k = k.view(bsz, h, w, self.num_heads, self.key_dim).permute(0, 3, 1, 2, 4) #(b n h w d1) | |
| qr = angle_transform(q, sin, cos) | |
| kr = angle_transform(k, sin, cos) | |
| qr_w = qr.transpose(1, 2) | |
| kr_w = kr.transpose(1, 2) | |
| v = v.reshape(bsz, h, w, self.num_heads, -1).permute(0, 1, 3, 2, 4) | |
| qk_mat_w = qr_w @ kr_w.transpose(-1, -2) | |
| qk_mat_w = qk_mat_w + mask_w.transpose(1,2) | |
| qk_mat_w = torch.softmax(qk_mat_w, -1) | |
| v = torch.matmul(qk_mat_w, v) | |
| qr_h = qr.permute(0, 3, 1, 2, 4) | |
| kr_h = kr.permute(0, 3, 1, 2, 4) | |
| v = v.permute(0, 3, 2, 1, 4) | |
| qk_mat_h = qr_h @ kr_h.transpose(-1, -2) | |
| qk_mat_h = qk_mat_h + mask_h.transpose(1,2) | |
| qk_mat_h = torch.softmax(qk_mat_h, -1) | |
| output = torch.matmul(qk_mat_h, v) | |
| output = output.permute(0, 3, 1, 2, 4).flatten(-2, -1) | |
| output = output + lepe | |
| output = self.out_proj(output) | |
| return output | |
| def reset_parameters(self): | |
| nn.init.xavier_normal_(self.q_proj.weight, gain=2 ** -2.5) | |
| nn.init.xavier_normal_(self.k_proj.weight, gain=2 ** -2.5) | |
| nn.init.xavier_normal_(self.v_proj.weight, gain=2 ** -2.5) | |
| nn.init.xavier_normal_(self.out_proj.weight) | |
| nn.init.constant_(self.out_proj.bias, 0.0) | |
| class Full_GSA(nn.Module): | |
| def __init__(self, embed_dim, num_heads, value_factor=1): | |
| super().__init__() | |
| self.factor = value_factor | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = self.embed_dim * self.factor // num_heads | |
| self.key_dim = self.embed_dim // num_heads | |
| self.scaling = self.key_dim ** -0.5 | |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim * self.factor, bias=True) | |
| self.lepe = DWConv2d(embed_dim, 5, 1, 2) | |
| self.out_proj = nn.Linear(embed_dim*self.factor, embed_dim, bias=True) | |
| self.reset_parameters() | |
| def forward(self, x: torch.Tensor, rel_pos, split_or_not=False): | |
| ''' | |
| x: (b h w c) | |
| rel_pos: mask: (n l l) | |
| ''' | |
| bsz, h, w, _ = x.size() | |
| (sin, cos), mask = rel_pos | |
| assert h*w == mask.size(3) | |
| q = self.q_proj(x) | |
| k = self.k_proj(x) | |
| v = self.v_proj(x) | |
| lepe = self.lepe(v) | |
| k = k * self.scaling | |
| q = q.view(bsz, h, w, self.num_heads, -1).permute(0, 3, 1, 2, 4) | |
| k = k.view(bsz, h, w, self.num_heads, -1).permute(0, 3, 1, 2, 4) | |
| qr = angle_transform(q, sin, cos) | |
| kr = angle_transform(k, sin, cos) | |
| qr = qr.flatten(2, 3) | |
| kr = kr.flatten(2, 3) | |
| vr = v.reshape(bsz, h, w, self.num_heads, -1).permute(0, 3, 1, 2, 4) | |
| vr = vr.flatten(2, 3) | |
| qk_mat = qr @ kr.transpose(-1, -2) | |
| qk_mat = qk_mat + mask | |
| qk_mat = torch.softmax(qk_mat, -1) | |
| output = torch.matmul(qk_mat, vr) | |
| output = output.transpose(1, 2).reshape(bsz, h, w, -1) | |
| output = output + lepe | |
| output = self.out_proj(output) | |
| return output | |
| def reset_parameters(self): | |
| nn.init.xavier_normal_(self.q_proj.weight, gain=2 ** -2.5) | |
| nn.init.xavier_normal_(self.k_proj.weight, gain=2 ** -2.5) | |
| nn.init.xavier_normal_(self.v_proj.weight, gain=2 ** -2.5) | |
| nn.init.xavier_normal_(self.out_proj.weight) | |
| nn.init.constant_(self.out_proj.bias, 0.0) | |
| class FeedForwardNetwork(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim, | |
| ffn_dim, | |
| activation_fn=F.gelu, | |
| dropout=0.0, | |
| activation_dropout=0.0, | |
| layernorm_eps=1e-6, | |
| subln=False, | |
| subconv=True | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.activation_fn = activation_fn | |
| self.activation_dropout_module = torch.nn.Dropout(activation_dropout) | |
| self.dropout_module = torch.nn.Dropout(dropout) | |
| self.fc1 = nn.Linear(self.embed_dim, ffn_dim) | |
| self.fc2 = nn.Linear(ffn_dim, self.embed_dim) | |
| self.ffn_layernorm = nn.LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None | |
| self.dwconv = DWConv2d(ffn_dim, 3, 1, 1) if subconv else None | |
| def reset_parameters(self): | |
| self.fc1.reset_parameters() | |
| self.fc2.reset_parameters() | |
| if self.ffn_layernorm is not None: | |
| self.ffn_layernorm.reset_parameters() | |
| def forward(self, x: torch.Tensor): | |
| ''' | |
| input shape: (b h w c) | |
| ''' | |
| x = self.fc1(x) | |
| x = self.activation_fn(x) | |
| x = self.activation_dropout_module(x) | |
| residual = x | |
| if self.dwconv is not None: | |
| x = self.dwconv(x) | |
| if self.ffn_layernorm is not None: | |
| x = self.ffn_layernorm(x) | |
| x = x + residual | |
| x = self.fc2(x) | |
| x = self.dropout_module(x) | |
| return x | |
| class RGBD_Block(nn.Module): | |
| def __init__(self, split_or_not: str, embed_dim: int, num_heads: int, ffn_dim: int, drop_path=0., layerscale=False, layer_init_values=1e-5, init_value=2, heads_range=4): | |
| super().__init__() | |
| self.layerscale = layerscale | |
| self.embed_dim = embed_dim | |
| self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=1e-6) | |
| self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=1e-6) | |
| if split_or_not: | |
| self.Attention = Decomposed_GSA(embed_dim, num_heads) | |
| else: | |
| self.Attention = Full_GSA(embed_dim, num_heads) | |
| self.drop_path = DropPath(drop_path) | |
| # FFN | |
| self.ffn = FeedForwardNetwork(embed_dim, ffn_dim) | |
| self.cnn_pos_encode = DWConv2d(embed_dim, 3, 1, 1) | |
| # the function to generate the geometry prior for the current block | |
| self.Geo = GeoPriorGen(embed_dim, num_heads, init_value, heads_range) | |
| if layerscale: | |
| self.gamma_1 = nn.Parameter(layer_init_values * torch.ones(1, 1, 1, embed_dim),requires_grad=True) | |
| self.gamma_2 = nn.Parameter(layer_init_values * torch.ones(1, 1, 1, embed_dim),requires_grad=True) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| x_e: torch.Tensor, | |
| split_or_not=False | |
| ): | |
| x = x + self.cnn_pos_encode(x) | |
| b, h, w, d = x.size() | |
| geo_prior = self.Geo((h, w), x_e, split_or_not=split_or_not) | |
| if self.layerscale: | |
| x = x + self.drop_path(self.gamma_1 * self.Attention(self.layer_norm1(x), geo_prior, split_or_not)) | |
| x = x + self.drop_path(self.gamma_2 * self.ffn(self.layer_norm2(x))) | |
| else: | |
| x = x + self.drop_path(self.Attention(self.layer_norm1(x), geo_prior, split_or_not)) | |
| x = x + self.drop_path(self.ffn(self.layer_norm2(x))) | |
| return x | |
| class BasicLayer(nn.Module): | |
| """ | |
| A basic RGB-D layer in DFormerv2. | |
| """ | |
| def __init__(self, embed_dim, out_dim, depth, num_heads, | |
| init_value: float, heads_range: float, | |
| ffn_dim=96., drop_path=0., norm_layer=nn.LayerNorm, split_or_not=False, | |
| downsample: PatchMerging=None, use_checkpoint=False, | |
| layerscale=False, layer_init_values=1e-5): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.depth = depth | |
| self.use_checkpoint = use_checkpoint | |
| self.split_or_not = split_or_not | |
| # build blocks | |
| self.blocks = nn.ModuleList([ | |
| RGBD_Block(split_or_not, embed_dim, num_heads, ffn_dim, | |
| drop_path[i] if isinstance(drop_path, list) else drop_path, layerscale, layer_init_values, init_value=init_value, heads_range=heads_range) | |
| for i in range(depth)]) | |
| # patch merging layer | |
| if downsample is not None: | |
| self.downsample = downsample(dim=embed_dim, out_dim=out_dim, norm_layer=norm_layer) | |
| else: | |
| self.downsample = None | |
| def forward(self, x, x_e): | |
| b, h, w, d = x.size() | |
| for blk in self.blocks: | |
| if self.use_checkpoint: | |
| x = checkpoint.checkpoint(blk, x=x, x_e=x_e, split_or_not=self.split_or_not) | |
| else: | |
| x = blk(x, x_e, split_or_not=self.split_or_not) | |
| if self.downsample is not None: | |
| x_down = self.downsample(x) | |
| return x, x_down | |
| else: | |
| return x, x | |
| class dformerv2(nn.Module): | |
| def __init__(self, out_indices=(0, 1, 2, 3), | |
| embed_dims=[64, 128, 256, 512], depths=[2, 2, 8, 2], num_heads=[4, 4, 8, 16], | |
| init_values=[2, 2, 2, 2], heads_ranges=[4, 4, 6, 6], mlp_ratios=[4, 4, 3, 3], drop_path_rate=0.1, norm_layer=nn.LayerNorm, | |
| patch_norm=True, use_checkpoint=False, projection=1024, norm_cfg = None, | |
| layerscales=[False, False, False, False], layer_init_values=1e-6, norm_eval=True): | |
| super().__init__() | |
| self.out_indices = out_indices | |
| self.num_layers = len(depths) | |
| self.embed_dim = embed_dims[0] | |
| self.patch_norm = patch_norm | |
| self.num_features = embed_dims[-1] | |
| self.mlp_ratios = mlp_ratios | |
| self.norm_eval = norm_eval | |
| # patch embedding | |
| self.patch_embed = PatchEmbed(in_chans=3, embed_dim=embed_dims[0], | |
| norm_layer=norm_layer if self.patch_norm else None) | |
| # drop path rate | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] | |
| # build layers | |
| self.layers = nn.ModuleList() | |
| for i_layer in range(self.num_layers): | |
| layer = BasicLayer( | |
| embed_dim=embed_dims[i_layer], | |
| out_dim=embed_dims[i_layer+1] if (i_layer < self.num_layers - 1) else None, | |
| depth=depths[i_layer], | |
| num_heads=num_heads[i_layer], | |
| init_value=init_values[i_layer], | |
| heads_range=heads_ranges[i_layer], | |
| ffn_dim=int(mlp_ratios[i_layer]*embed_dims[i_layer]), | |
| drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | |
| norm_layer=norm_layer, | |
| split_or_not=(i_layer!=3), | |
| downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, | |
| use_checkpoint=use_checkpoint, | |
| layerscale=layerscales[i_layer], | |
| layer_init_values=layer_init_values | |
| ) | |
| self.layers.append(layer) | |
| self.extra_norms = nn.ModuleList() | |
| for i in range(3): | |
| self.extra_norms.append(nn.LayerNorm(embed_dims[i+1])) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| try: | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| except: | |
| pass | |
| def init_weights(self, pretrained=None): | |
| """Initialize the weights in backbone. | |
| Args: | |
| pretrained (str, optional): Path to pre-trained weights. | |
| Defaults to None. | |
| """ | |
| def _init_weights(m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| if isinstance(pretrained, str): | |
| self.apply(_init_weights) | |
| # logger = get_root_logger() | |
| _state_dict = torch.load(pretrained) | |
| if 'model' in _state_dict.keys(): | |
| _state_dict=_state_dict['model'] | |
| if 'state_dict' in _state_dict.keys(): | |
| _state_dict=_state_dict['state_dict'] | |
| state_dict = OrderedDict() | |
| for k, v in _state_dict.items(): | |
| if k.startswith('backbone.'): | |
| state_dict[k[9:]] = v | |
| else: | |
| state_dict[k] = v | |
| print('load '+pretrained) | |
| load_state_dict(self, state_dict, strict=False) | |
| # load_checkpoint(self, pretrained, strict=False) | |
| # load_checkpoint(self, pretrained, strict=False, logger=logger) | |
| elif pretrained is None: | |
| self.apply(_init_weights) | |
| else: | |
| raise TypeError('pretrained must be a str or None') | |
| def no_weight_decay(self): | |
| return {'absolute_pos_embed'} | |
| def no_weight_decay_keywords(self): | |
| return {'relative_position_bias_table'} | |
| def forward(self, x, x_e): | |
| # rgb input | |
| x = self.patch_embed(x) | |
| # depth input | |
| x_e = x_e[:,0,:,:].unsqueeze(1) | |
| outs = [] | |
| for i in range(self.num_layers): | |
| layer = self.layers[i] | |
| x_out, x = layer(x, x_e) | |
| if i in self.out_indices: | |
| if i != 0: | |
| x_out = self.extra_norms[i-1](x_out) | |
| out = x_out.permute(0, 3, 1, 2).contiguous() | |
| outs.append(out) | |
| return tuple(outs) | |
| def train(self, mode=True): | |
| """Convert the model into training mode while keep normalization layer | |
| freezed.""" | |
| super().train(mode) | |
| if mode and self.norm_eval: | |
| for m in self.modules(): | |
| # trick: eval have effect on BatchNorm only | |
| if isinstance(m, nn.BatchNorm2d): | |
| m.eval() | |
| def DFormerv2_S(pretrained=False, **kwargs): | |
| model = dformerv2(embed_dims=[64, 128, 256, 512], depths=[3, 4, 18, 4], num_heads=[4, 4, 8, 16], | |
| heads_ranges=[4, 4, 6, 6], **kwargs) | |
| return model | |
| def DFormerv2_B(pretrained=False, **kwargs): | |
| model = dformerv2(embed_dims=[80, 160, 320, 512], depths=[4, 8, 25, 8], num_heads=[5, 5, 10, 16], | |
| heads_ranges=[5, 5, 6, 6], | |
| layerscales=[False, False, True, True], | |
| layer_init_values=1e-6, **kwargs) | |
| return model | |
| def DFormerv2_L(pretrained=False, **kwargs): | |
| model = dformerv2(embed_dims=[112, 224, 448, 640], depths=[4, 8, 25, 8], num_heads=[7, 7, 14, 20], | |
| heads_ranges=[6, 6, 6, 6], | |
| layerscales=[False, False, True, True], | |
| layer_init_values=1e-6, **kwargs) | |
| return model | |