Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import copy | |
| from .seecoder_utils import with_pos_embed | |
| from lib.model_zoo.common.get_model import get_model, register | |
| symbol = 'seecoder' | |
| ########### | |
| # helpers # | |
| ########### | |
| def _get_clones(module, N): | |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
| def _get_activation_fn(activation): | |
| """Return an activation function given a string""" | |
| if activation == "relu": | |
| return F.relu | |
| if activation == "gelu": | |
| return F.gelu | |
| if activation == "glu": | |
| return F.glu | |
| raise RuntimeError(f"activation should be relu/gelu, not {activation}.") | |
| def c2_xavier_fill(module): | |
| # Caffe2 implementation of XavierFill in fact | |
| nn.init.kaiming_uniform_(module.weight, a=1) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| def with_pos_embed(x, pos): | |
| return x if pos is None else x + pos | |
| ########### | |
| # Modules # | |
| ########### | |
| class Conv2d_Convenience(nn.Conv2d): | |
| def __init__(self, *args, **kwargs): | |
| norm = kwargs.pop("norm", None) | |
| activation = kwargs.pop("activation", None) | |
| super().__init__(*args, **kwargs) | |
| self.norm = norm | |
| self.activation = activation | |
| def forward(self, x): | |
| x = F.conv2d( | |
| x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) | |
| if self.norm is not None: | |
| x = self.norm(x) | |
| if self.activation is not None: | |
| x = self.activation(x) | |
| return x | |
| class DecoderLayer(nn.Module): | |
| def __init__(self, | |
| dim=256, | |
| feedforward_dim=1024, | |
| dropout=0.1, | |
| activation="relu", | |
| n_heads=8,): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.linear1 = nn.Linear(dim, feedforward_dim) | |
| self.activation = _get_activation_fn(activation) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(feedforward_dim, dim) | |
| self.dropout3 = nn.Dropout(dropout) | |
| self.norm2 = nn.LayerNorm(dim) | |
| def forward(self, x): | |
| h = x | |
| h1 = self.self_attn(x, x, x, attn_mask=None)[0] | |
| h = h + self.dropout1(h1) | |
| h = self.norm1(h) | |
| h2 = self.linear2(self.dropout2(self.activation(self.linear1(h)))) | |
| h = h + self.dropout3(h2) | |
| h = self.norm2(h) | |
| return h | |
| class DecoderLayerStacked(nn.Module): | |
| def __init__(self, layer, num_layers, norm=None): | |
| super().__init__() | |
| self.layers = _get_clones(layer, num_layers) | |
| self.num_layers = num_layers | |
| self.norm = norm | |
| def forward(self, x): | |
| h = x | |
| for _, layer in enumerate(self.layers): | |
| h = layer(h) | |
| if self.norm is not None: | |
| h = self.norm(h) | |
| return h | |
| class SelfAttentionLayer(nn.Module): | |
| def __init__(self, channels, nhead, dropout=0.0, | |
| activation="relu", normalize_before=False): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(channels, nhead, dropout=dropout) | |
| self.norm = nn.LayerNorm(channels) | |
| self.dropout = nn.Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def forward_post(self, | |
| qkv, | |
| qk_pos = None, | |
| mask = None,): | |
| h = qkv | |
| qk = with_pos_embed(qkv, qk_pos).transpose(0, 1) | |
| v = qkv.transpose(0, 1) | |
| h1 = self.self_attn(qk, qk, v, attn_mask=mask)[0] | |
| h1 = h1.transpose(0, 1) | |
| h = h + self.dropout(h1) | |
| h = self.norm(h) | |
| return h | |
| def forward_pre(self, tgt, | |
| tgt_mask = None, | |
| tgt_key_padding_mask = None, | |
| query_pos = None): | |
| # deprecated | |
| assert False | |
| tgt2 = self.norm(tgt) | |
| q = k = self.with_pos_embed(tgt2, query_pos) | |
| tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, | |
| key_padding_mask=tgt_key_padding_mask)[0] | |
| tgt = tgt + self.dropout(tgt2) | |
| return tgt | |
| def forward(self, *args, **kwargs): | |
| if self.normalize_before: | |
| return self.forward_pre(*args, **kwargs) | |
| return self.forward_post(*args, **kwargs) | |
| class CrossAttentionLayer(nn.Module): | |
| def __init__(self, channels, nhead, dropout=0.0, | |
| activation="relu", normalize_before=False): | |
| super().__init__() | |
| self.multihead_attn = nn.MultiheadAttention(channels, nhead, dropout=dropout) | |
| self.norm = nn.LayerNorm(channels) | |
| self.dropout = nn.Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def forward_post(self, | |
| q, | |
| kv, | |
| q_pos = None, | |
| k_pos = None, | |
| mask = None,): | |
| h = q | |
| q = with_pos_embed(q, q_pos).transpose(0, 1) | |
| k = with_pos_embed(kv, k_pos).transpose(0, 1) | |
| v = kv.transpose(0, 1) | |
| h1 = self.multihead_attn(q, k, v, attn_mask=mask)[0] | |
| h1 = h1.transpose(0, 1) | |
| h = h + self.dropout(h1) | |
| h = self.norm(h) | |
| return h | |
| def forward_pre(self, tgt, memory, | |
| memory_mask = None, | |
| memory_key_padding_mask = None, | |
| pos = None, | |
| query_pos = None): | |
| # Deprecated | |
| assert False | |
| tgt2 = self.norm(tgt) | |
| tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), | |
| key=self.with_pos_embed(memory, pos), | |
| value=memory, attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask)[0] | |
| tgt = tgt + self.dropout(tgt2) | |
| return tgt | |
| def forward(self, *args, **kwargs): | |
| if self.normalize_before: | |
| return self.forward_pre(*args, **kwargs) | |
| return self.forward_post(*args, **kwargs) | |
| class FeedForwardLayer(nn.Module): | |
| def __init__(self, channels, hidden_channels=2048, dropout=0.0, | |
| activation="relu", normalize_before=False): | |
| super().__init__() | |
| self.linear1 = nn.Linear(channels, hidden_channels) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(hidden_channels, channels) | |
| self.norm = nn.LayerNorm(channels) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def forward_post(self, x): | |
| h = x | |
| h1 = self.linear2(self.dropout(self.activation(self.linear1(h)))) | |
| h = h + self.dropout(h1) | |
| h = self.norm(h) | |
| return h | |
| def forward_pre(self, x): | |
| xn = self.norm(x) | |
| h = x | |
| h1 = self.linear2(self.dropout(self.activation(self.linear1(xn)))) | |
| h = h + self.dropout(h1) | |
| return h | |
| def forward(self, *args, **kwargs): | |
| if self.normalize_before: | |
| return self.forward_pre(*args, **kwargs) | |
| return self.forward_post(*args, **kwargs) | |
| class MLP(nn.Module): | |
| def __init__(self, in_channels, channels, out_channels, num_layers): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| h = [channels] * (num_layers - 1) | |
| self.layers = nn.ModuleList( | |
| nn.Linear(n, k) | |
| for n, k in zip([in_channels]+h, h+[out_channels])) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
| return x | |
| class PPE_MLP(nn.Module): | |
| def __init__(self, freq_num=20, freq_max=None, out_channel=768, mlp_layer=3): | |
| import math | |
| super().__init__() | |
| self.freq_num = freq_num | |
| self.freq_max = freq_max | |
| self.out_channel = out_channel | |
| self.mlp_layer = mlp_layer | |
| self.twopi = 2 * math.pi | |
| mlp = [] | |
| in_channel = freq_num*4 | |
| for idx in range(mlp_layer): | |
| linear = nn.Linear(in_channel, out_channel, bias=True) | |
| nn.init.xavier_normal_(linear.weight) | |
| nn.init.constant_(linear.bias, 0) | |
| mlp.append(linear) | |
| if idx != mlp_layer-1: | |
| mlp.append(nn.SiLU()) | |
| in_channel = out_channel | |
| self.mlp = nn.Sequential(*mlp) | |
| nn.init.constant_(self.mlp[-1].weight, 0) | |
| def forward(self, x, mask=None): | |
| assert mask is None, "Mask not implemented" | |
| h, w = x.shape[-2:] | |
| minlen = min(h, w) | |
| h_embed, w_embed = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') | |
| if self.training: | |
| import numpy.random as npr | |
| pertube_h, pertube_w = npr.uniform(-0.5, 0.5), npr.uniform(-0.5, 0.5) | |
| else: | |
| pertube_h, pertube_w = 0, 0 | |
| h_embed = (h_embed+0.5 - h/2 + pertube_h) / (minlen) * self.twopi | |
| w_embed = (w_embed+0.5 - w/2 + pertube_w) / (minlen) * self.twopi | |
| h_embed, w_embed = h_embed.to(x.device).to(x.dtype), w_embed.to(x.device).to(x.dtype) | |
| dim_t = torch.linspace(0, 1, self.freq_num, dtype=torch.float32, device=x.device) | |
| freq_max = self.freq_max if self.freq_max is not None else minlen/2 | |
| dim_t = freq_max ** dim_t.to(x.dtype) | |
| pos_h = h_embed[:, :, None] * dim_t | |
| pos_w = w_embed[:, :, None] * dim_t | |
| pos = torch.cat((pos_h.sin(), pos_h.cos(), pos_w.sin(), pos_w.cos()), dim=-1) | |
| pos = self.mlp(pos) | |
| pos = pos.permute(2, 0, 1)[None] | |
| return pos | |
| def __repr__(self, _repr_indent=4): | |
| head = "Positional encoding " + self.__class__.__name__ | |
| body = [ | |
| "num_pos_feats: {}".format(self.num_pos_feats), | |
| "temperature: {}".format(self.temperature), | |
| "normalize: {}".format(self.normalize), | |
| "scale: {}".format(self.scale), | |
| ] | |
| # _repr_indent = 4 | |
| lines = [head] + [" " * _repr_indent + line for line in body] | |
| return "\n".join(lines) | |
| ########### | |
| # Decoder # | |
| ########### | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| inchannels, | |
| trans_input_tags, | |
| trans_num_layers, | |
| trans_dim, | |
| trans_nheads, | |
| trans_dropout, | |
| trans_feedforward_dim,): | |
| super().__init__() | |
| trans_inchannels = { | |
| k: v for k, v in inchannels.items() if k in trans_input_tags} | |
| fpn_inchannels = { | |
| k: v for k, v in inchannels.items() if k not in trans_input_tags} | |
| self.trans_tags = sorted(list(trans_inchannels.keys())) | |
| self.fpn_tags = sorted(list(fpn_inchannels.keys())) | |
| self.all_tags = sorted(list(inchannels.keys())) | |
| if len(self.trans_tags)==0: | |
| assert False # Not allowed | |
| self.num_trans_lvls = len(self.trans_tags) | |
| self.inproj_layers = nn.ModuleDict() | |
| for tagi in self.trans_tags: | |
| layeri = nn.Sequential( | |
| nn.Conv2d(trans_inchannels[tagi], trans_dim, kernel_size=1), | |
| nn.GroupNorm(32, trans_dim),) | |
| nn.init.xavier_uniform_(layeri[0].weight, gain=1) | |
| nn.init.constant_(layeri[0].bias, 0) | |
| self.inproj_layers[tagi] = layeri | |
| tlayer = DecoderLayer( | |
| dim = trans_dim, | |
| n_heads = trans_nheads, | |
| dropout = trans_dropout, | |
| feedforward_dim = trans_feedforward_dim, | |
| activation = 'relu',) | |
| self.transformer = DecoderLayerStacked(tlayer, trans_num_layers) | |
| for p in self.transformer.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| self.level_embed = nn.Parameter(torch.Tensor(len(self.trans_tags), trans_dim)) | |
| nn.init.normal_(self.level_embed) | |
| self.lateral_layers = nn.ModuleDict() | |
| self.output_layers = nn.ModuleDict() | |
| for tagi in self.all_tags: | |
| lateral_conv = Conv2d_Convenience( | |
| inchannels[tagi], trans_dim, kernel_size=1, | |
| bias=False, norm=nn.GroupNorm(32, trans_dim)) | |
| c2_xavier_fill(lateral_conv) | |
| self.lateral_layers[tagi] = lateral_conv | |
| for tagi in self.fpn_tags: | |
| output_conv = Conv2d_Convenience( | |
| trans_dim, trans_dim, kernel_size=3, stride=1, padding=1, | |
| bias=False, norm=nn.GroupNorm(32, trans_dim), activation=F.relu,) | |
| c2_xavier_fill(output_conv) | |
| self.output_layers[tagi] = output_conv | |
| def forward(self, features): | |
| x = [] | |
| spatial_shapes = {} | |
| for idx, tagi in enumerate(self.trans_tags[::-1]): | |
| xi = features[tagi] | |
| xi = self.inproj_layers[tagi](xi) | |
| bs, _, h, w = xi.shape | |
| spatial_shapes[tagi] = (h, w) | |
| xi = xi.flatten(2).transpose(1, 2) + self.level_embed[idx].view(1, 1, -1) | |
| x.append(xi) | |
| x_length = [xi.shape[1] for xi in x] | |
| x_concat = torch.cat(x, 1) | |
| y_concat = self.transformer(x_concat) | |
| y = torch.split(y_concat, x_length, dim=1) | |
| out = {} | |
| for idx, tagi in enumerate(self.trans_tags[::-1]): | |
| h, w = spatial_shapes[tagi] | |
| yi = y[idx].transpose(1, 2).view(bs, -1, h, w) | |
| out[tagi] = yi | |
| for idx, tagi in enumerate(self.all_tags[::-1]): | |
| lconv = self.lateral_layers[tagi] | |
| if tagi in self.trans_tags: | |
| out[tagi] = out[tagi] + lconv(features[tagi]) | |
| tag_save = tagi | |
| else: | |
| oconv = self.output_layers[tagi] | |
| h = lconv(features[tagi]) | |
| oprev = out[tag_save] | |
| h = h + F.interpolate(oconv(oprev), size=h.shape[-2:], mode="bilinear", align_corners=False) | |
| out[tagi] = h | |
| return out | |
| ##################### | |
| # Query Transformer # | |
| ##################### | |
| class QueryTransformer(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| hidden_dim, | |
| num_queries = [8, 144], | |
| nheads = 8, | |
| num_layers = 9, | |
| feedforward_dim = 2048, | |
| mask_dim = 256, | |
| pre_norm = False, | |
| num_feature_levels = 3, | |
| enforce_input_project = False, | |
| with_fea2d_pos = True): | |
| super().__init__() | |
| if with_fea2d_pos: | |
| self.pe_layer = PPE_MLP(freq_num=20, freq_max=None, out_channel=hidden_dim, mlp_layer=3) | |
| else: | |
| self.pe_layer = None | |
| if in_channels!=hidden_dim or enforce_input_project: | |
| self.input_proj = nn.ModuleList() | |
| for _ in range(num_feature_levels): | |
| self.input_proj.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1)) | |
| c2_xavier_fill(self.input_proj[-1]) | |
| else: | |
| self.input_proj = None | |
| self.num_heads = nheads | |
| self.num_layers = num_layers | |
| self.transformer_selfatt_layers = nn.ModuleList() | |
| self.transformer_crossatt_layers = nn.ModuleList() | |
| self.transformer_feedforward_layers = nn.ModuleList() | |
| for _ in range(self.num_layers): | |
| self.transformer_selfatt_layers.append( | |
| SelfAttentionLayer( | |
| channels=hidden_dim, | |
| nhead=nheads, | |
| dropout=0.0, | |
| normalize_before=pre_norm, )) | |
| self.transformer_crossatt_layers.append( | |
| CrossAttentionLayer( | |
| channels=hidden_dim, | |
| nhead=nheads, | |
| dropout=0.0, | |
| normalize_before=pre_norm, )) | |
| self.transformer_feedforward_layers.append( | |
| FeedForwardLayer( | |
| channels=hidden_dim, | |
| hidden_channels=feedforward_dim, | |
| dropout=0.0, | |
| normalize_before=pre_norm, )) | |
| self.num_queries = num_queries | |
| num_gq, num_lq = self.num_queries | |
| self.init_query = nn.Embedding(num_gq+num_lq, hidden_dim) | |
| self.query_pos_embedding = nn.Embedding(num_gq+num_lq, hidden_dim) | |
| self.num_feature_levels = num_feature_levels | |
| self.level_embed = nn.Embedding(num_feature_levels, hidden_dim) | |
| def forward(self, x): | |
| # x is a list of multi-scale feature | |
| assert len(x) == self.num_feature_levels | |
| fea2d = [] | |
| fea2d_pos = [] | |
| size_list = [] | |
| for i in range(self.num_feature_levels): | |
| size_list.append(x[i].shape[-2:]) | |
| if self.pe_layer is not None: | |
| pi = self.pe_layer(x[i], None).flatten(2) | |
| pi = pi.transpose(1, 2) | |
| else: | |
| pi = None | |
| xi = self.input_proj[i](x[i]) if self.input_proj is not None else x[i] | |
| xi = xi.flatten(2) + self.level_embed.weight[i][None, :, None] | |
| xi = xi.transpose(1, 2) | |
| fea2d.append(xi) | |
| fea2d_pos.append(pi) | |
| bs, _, _ = fea2d[0].shape | |
| num_gq, num_lq = self.num_queries | |
| gquery = self.init_query.weight[:num_gq].unsqueeze(0).repeat(bs, 1, 1) | |
| lquery = self.init_query.weight[num_gq:].unsqueeze(0).repeat(bs, 1, 1) | |
| gquery_pos = self.query_pos_embedding.weight[:num_gq].unsqueeze(0).repeat(bs, 1, 1) | |
| lquery_pos = self.query_pos_embedding.weight[num_gq:].unsqueeze(0).repeat(bs, 1, 1) | |
| for i in range(self.num_layers): | |
| level_index = i % self.num_feature_levels | |
| qout = self.transformer_crossatt_layers[i]( | |
| q = lquery, | |
| kv = fea2d[level_index], | |
| q_pos = lquery_pos, | |
| k_pos = fea2d_pos[level_index], | |
| mask = None,) | |
| lquery = qout | |
| qout = self.transformer_selfatt_layers[i]( | |
| qkv = torch.cat([gquery, lquery], dim=1), | |
| qk_pos = torch.cat([gquery_pos, lquery_pos], dim=1),) | |
| qout = self.transformer_feedforward_layers[i](qout) | |
| gquery = qout[:, :num_gq] | |
| lquery = qout[:, num_gq:] | |
| output = torch.cat([gquery, lquery], dim=1) | |
| return output | |
| ################## | |
| # Main structure # | |
| ################## | |
| class SemanticExtractionEncoder(nn.Module): | |
| def __init__(self, | |
| imencoder_cfg, | |
| imdecoder_cfg, | |
| qtransformer_cfg): | |
| super().__init__() | |
| self.imencoder = get_model()(imencoder_cfg) | |
| self.imdecoder = get_model()(imdecoder_cfg) | |
| self.qtransformer = get_model()(qtransformer_cfg) | |
| def forward(self, x): | |
| fea = self.imencoder(x) | |
| hs = {'res3' : fea['res3'], | |
| 'res4' : fea['res4'], | |
| 'res5' : fea['res5'], } | |
| hs = self.imdecoder(hs) | |
| hs = [hs['res3'], hs['res4'], hs['res5']] | |
| q = self.qtransformer(hs) | |
| return q | |
| def encode(self, x): | |
| return self(x) | |