Grounded-Segment-Anything
/
grounded-sam-osx
/transformer_utils
/mmpose
/models
/backbones
/tcformer.py
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import (build_norm_layer, constant_init, normal_init, | |
| trunc_normal_init) | |
| from mmcv.runner import _load_checkpoint, load_state_dict | |
| from ...utils import get_root_logger | |
| from ..builder import BACKBONES | |
| from ..utils import (PatchEmbed, TCFormerDynamicBlock, TCFormerRegularBlock, | |
| TokenConv, cluster_dpc_knn, merge_tokens, | |
| tcformer_convert, token2map) | |
| class CTM(nn.Module): | |
| """Clustering-based Token Merging module in TCFormer. | |
| Args: | |
| sample_ratio (float): The sample ratio of tokens. | |
| embed_dim (int): Input token feature dimension. | |
| dim_out (int): Output token feature dimension. | |
| k (int): number of the nearest neighbor used i DPC-knn algorithm. | |
| """ | |
| def __init__(self, sample_ratio, embed_dim, dim_out, k=5): | |
| super().__init__() | |
| self.sample_ratio = sample_ratio | |
| self.dim_out = dim_out | |
| self.conv = TokenConv( | |
| in_channels=embed_dim, | |
| out_channels=dim_out, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1) | |
| self.norm = nn.LayerNorm(self.dim_out) | |
| self.score = nn.Linear(self.dim_out, 1) | |
| self.k = k | |
| def forward(self, token_dict): | |
| token_dict = token_dict.copy() | |
| x = self.conv(token_dict) | |
| x = self.norm(x) | |
| token_score = self.score(x) | |
| token_weight = token_score.exp() | |
| token_dict['x'] = x | |
| B, N, C = x.shape | |
| token_dict['token_score'] = token_score | |
| cluster_num = max(math.ceil(N * self.sample_ratio), 1) | |
| idx_cluster, cluster_num = cluster_dpc_knn(token_dict, cluster_num, | |
| self.k) | |
| down_dict = merge_tokens(token_dict, idx_cluster, cluster_num, | |
| token_weight) | |
| H, W = token_dict['map_size'] | |
| H = math.floor((H - 1) / 2 + 1) | |
| W = math.floor((W - 1) / 2 + 1) | |
| down_dict['map_size'] = [H, W] | |
| return down_dict, token_dict | |
| class TCFormer(nn.Module): | |
| """Token Clustering Transformer (TCFormer) | |
| Implementation of `Not All Tokens Are Equal: Human-centric Visual | |
| Analysis via Token Clustering Transformer | |
| <https://arxiv.org/abs/2204.08680>` | |
| Args: | |
| in_channels (int): Number of input channels. Default: 3. | |
| embed_dims (list[int]): Embedding dimension. Default: | |
| [64, 128, 256, 512]. | |
| num_heads (Sequence[int]): The attention heads of each transformer | |
| encode layer. Default: [1, 2, 5, 8]. | |
| mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the | |
| embedding dim of each transformer block. | |
| qkv_bias (bool): Enable bias for qkv if True. Default: True. | |
| qk_scale (float | None, optional): Override default qk scale of | |
| head_dim ** -0.5 if set. Default: None. | |
| drop_rate (float): Probability of an element to be zeroed. | |
| Default 0.0. | |
| attn_drop_rate (float): The drop out rate for attention layer. | |
| Default 0.0. | |
| drop_path_rate (float): stochastic depth rate. Default 0. | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Default: dict(type='LN', eps=1e-6). | |
| num_layers (Sequence[int]): The layer number of each transformer encode | |
| layer. Default: [3, 4, 6, 3]. | |
| sr_ratios (Sequence[int]): The spatial reduction rate of each | |
| transformer block. Default: [8, 4, 2, 1]. | |
| num_stages (int): The num of stages. Default: 4. | |
| pretrained (str, optional): model pretrained path. Default: None. | |
| k (int): number of the nearest neighbor used for local density. | |
| sample_ratios (list[float]): The sample ratios of CTM modules. | |
| Default: [0.25, 0.25, 0.25] | |
| return_map (bool): If True, transfer dynamic tokens to feature map at | |
| last. Default: False | |
| convert_weights (bool): The flag indicates whether the | |
| pre-trained model is from the original repo. We may need | |
| to convert some keys to make it compatible. | |
| Default: True. | |
| """ | |
| def __init__(self, | |
| in_channels=3, | |
| embed_dims=[64, 128, 256, 512], | |
| num_heads=[1, 2, 4, 8], | |
| mlp_ratios=[4, 4, 4, 4], | |
| qkv_bias=True, | |
| qk_scale=None, | |
| drop_rate=0., | |
| attn_drop_rate=0., | |
| drop_path_rate=0., | |
| norm_cfg=dict(type='LN', eps=1e-6), | |
| num_layers=[3, 4, 6, 3], | |
| sr_ratios=[8, 4, 2, 1], | |
| num_stages=4, | |
| pretrained=None, | |
| k=5, | |
| sample_ratios=[0.25, 0.25, 0.25], | |
| return_map=False, | |
| convert_weights=True): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| self.num_stages = num_stages | |
| self.grid_stride = sr_ratios[0] | |
| self.embed_dims = embed_dims | |
| self.sr_ratios = sr_ratios | |
| self.mlp_ratios = mlp_ratios | |
| self.sample_ratios = sample_ratios | |
| self.return_map = return_map | |
| self.convert_weights = convert_weights | |
| # stochastic depth decay rule | |
| dpr = [ | |
| x.item() | |
| for x in torch.linspace(0, drop_path_rate, sum(num_layers)) | |
| ] | |
| cur = 0 | |
| # In stage 1, use the standard transformer blocks | |
| for i in range(1): | |
| patch_embed = PatchEmbed( | |
| in_channels=in_channels if i == 0 else embed_dims[i - 1], | |
| embed_dims=embed_dims[i], | |
| kernel_size=7, | |
| stride=4, | |
| padding=3, | |
| bias=True, | |
| norm_cfg=dict(type='LN', eps=1e-6)) | |
| block = nn.ModuleList([ | |
| TCFormerRegularBlock( | |
| dim=embed_dims[i], | |
| num_heads=num_heads[i], | |
| mlp_ratio=mlp_ratios[i], | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| drop=drop_rate, | |
| attn_drop=attn_drop_rate, | |
| drop_path=dpr[cur + j], | |
| norm_cfg=norm_cfg, | |
| sr_ratio=sr_ratios[i]) for j in range(num_layers[i]) | |
| ]) | |
| norm = build_norm_layer(norm_cfg, embed_dims[i])[1] | |
| cur += num_layers[i] | |
| setattr(self, f'patch_embed{i + 1}', patch_embed) | |
| setattr(self, f'block{i + 1}', block) | |
| setattr(self, f'norm{i + 1}', norm) | |
| # In stage 2~4, use TCFormerDynamicBlock for dynamic tokens | |
| for i in range(1, num_stages): | |
| ctm = CTM(sample_ratios[i - 1], embed_dims[i - 1], embed_dims[i], | |
| k) | |
| block = nn.ModuleList([ | |
| TCFormerDynamicBlock( | |
| dim=embed_dims[i], | |
| num_heads=num_heads[i], | |
| mlp_ratio=mlp_ratios[i], | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| drop=drop_rate, | |
| attn_drop=attn_drop_rate, | |
| drop_path=dpr[cur + j], | |
| norm_cfg=norm_cfg, | |
| sr_ratio=sr_ratios[i]) for j in range(num_layers[i]) | |
| ]) | |
| norm = build_norm_layer(norm_cfg, embed_dims[i])[1] | |
| cur += num_layers[i] | |
| setattr(self, f'ctm{i}', ctm) | |
| setattr(self, f'block{i + 1}', block) | |
| setattr(self, f'norm{i + 1}', norm) | |
| self.init_weights(pretrained) | |
| def init_weights(self, pretrained=None): | |
| if isinstance(pretrained, str): | |
| logger = get_root_logger() | |
| checkpoint = _load_checkpoint( | |
| pretrained, logger=logger, map_location='cpu') | |
| logger.warning(f'Load pre-trained model for ' | |
| f'{self.__class__.__name__} from original repo') | |
| if 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| elif 'model' in checkpoint: | |
| state_dict = checkpoint['model'] | |
| else: | |
| state_dict = checkpoint | |
| if self.convert_weights: | |
| # We need to convert pre-trained weights to match this | |
| # implementation. | |
| state_dict = tcformer_convert(state_dict) | |
| load_state_dict(self, state_dict, strict=False, logger=logger) | |
| elif pretrained is None: | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_init(m, std=.02, bias=0.) | |
| elif isinstance(m, nn.LayerNorm): | |
| constant_init(m, 1.0) | |
| elif isinstance(m, nn.Conv2d): | |
| fan_out = m.kernel_size[0] * m.kernel_size[ | |
| 1] * m.out_channels | |
| fan_out //= m.groups | |
| normal_init(m, 0, math.sqrt(2.0 / fan_out)) | |
| else: | |
| raise TypeError('pretrained must be a str or None') | |
| def forward(self, x): | |
| outs = [] | |
| i = 0 | |
| patch_embed = getattr(self, f'patch_embed{i + 1}') | |
| block = getattr(self, f'block{i + 1}') | |
| norm = getattr(self, f'norm{i + 1}') | |
| x, (H, W) = patch_embed(x) | |
| for blk in block: | |
| x = blk(x, H, W) | |
| x = norm(x) | |
| # init token dict | |
| B, N, _ = x.shape | |
| device = x.device | |
| idx_token = torch.arange(N)[None, :].repeat(B, 1).to(device) | |
| agg_weight = x.new_ones(B, N, 1) | |
| token_dict = { | |
| 'x': x, | |
| 'token_num': N, | |
| 'map_size': [H, W], | |
| 'init_grid_size': [H, W], | |
| 'idx_token': idx_token, | |
| 'agg_weight': agg_weight | |
| } | |
| outs.append(token_dict.copy()) | |
| # stage 2~4 | |
| for i in range(1, self.num_stages): | |
| ctm = getattr(self, f'ctm{i}') | |
| block = getattr(self, f'block{i + 1}') | |
| norm = getattr(self, f'norm{i + 1}') | |
| token_dict = ctm(token_dict) # down sample | |
| for j, blk in enumerate(block): | |
| token_dict = blk(token_dict) | |
| token_dict['x'] = norm(token_dict['x']) | |
| outs.append(token_dict) | |
| if self.return_map: | |
| outs = [token2map(token_dict) for token_dict in outs] | |
| return outs | |