Spaces:
Runtime error
Runtime error
| import fvcore.nn.weight_init as weight_init | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .msdeformattn import PositionEmbeddingSine, _get_clones, _get_activation_fn | |
| from lib.model_zoo.common.get_model import get_model, register | |
| ########## | |
| # helper # | |
| ########## | |
| def with_pos_embed(x, pos): | |
| return x if pos is None else x + pos | |
| ############## | |
| # One Former # | |
| ############## | |
| class Transformer(nn.Module): | |
| def __init__(self, | |
| d_model=512, | |
| nhead=8, | |
| num_encoder_layers=6, | |
| num_decoder_layers=6, | |
| dim_feedforward=2048, | |
| dropout=0.1, | |
| activation="relu", | |
| normalize_before=False, | |
| return_intermediate_dec=False,): | |
| super().__init__() | |
| encoder_layer = TransformerEncoderLayer( | |
| d_model, nhead, dim_feedforward, dropout, activation, normalize_before) | |
| encoder_norm = nn.LayerNorm(d_model) if normalize_before else None | |
| self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) | |
| decoder_layer = TransformerDecoderLayer( | |
| d_model, nhead, dim_feedforward, dropout, activation, normalize_before) | |
| decoder_norm = nn.LayerNorm(d_model) | |
| self.decoder = TransformerDecoder( | |
| decoder_layer, | |
| num_decoder_layers, | |
| decoder_norm, | |
| return_intermediate=return_intermediate_dec,) | |
| self._reset_parameters() | |
| self.d_model = d_model | |
| self.nhead = nhead | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def forward(self, src, mask, query_embed, pos_embed, task_token=None): | |
| # flatten NxCxHxW to HWxNxC | |
| bs, c, h, w = src.shape | |
| src = src.flatten(2).permute(2, 0, 1) | |
| pos_embed = pos_embed.flatten(2).permute(2, 0, 1) | |
| query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) | |
| if mask is not None: | |
| mask = mask.flatten(1) | |
| if task_token is None: | |
| tgt = torch.zeros_like(query_embed) | |
| else: | |
| tgt = task_token.repeat(query_embed.shape[0], 1, 1) | |
| memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) # src = memory | |
| hs = self.decoder( | |
| tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed | |
| ) | |
| return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) | |
| class TransformerEncoder(nn.Module): | |
| def __init__(self, encoder_layer, num_layers, norm=None): | |
| super().__init__() | |
| self.layers = _get_clones(encoder_layer, num_layers) | |
| self.num_layers = num_layers | |
| self.norm = norm | |
| def forward(self, src, mask=None, src_key_padding_mask=None, pos=None,): | |
| output = src | |
| for layer in self.layers: | |
| output = layer( | |
| output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos | |
| ) | |
| if self.norm is not None: | |
| output = self.norm(output) | |
| return output | |
| class TransformerDecoder(nn.Module): | |
| def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): | |
| super().__init__() | |
| self.layers = _get_clones(decoder_layer, num_layers) | |
| self.num_layers = num_layers | |
| self.norm = norm | |
| self.return_intermediate = return_intermediate | |
| def forward( | |
| self, | |
| tgt, | |
| memory, | |
| tgt_mask=None, | |
| memory_mask=None, | |
| tgt_key_padding_mask=None, | |
| memory_key_padding_mask=None, | |
| pos=None, | |
| query_pos=None,): | |
| output = tgt | |
| intermediate = [] | |
| for layer in self.layers: | |
| output = layer( | |
| output, | |
| memory, | |
| tgt_mask=tgt_mask, | |
| memory_mask=memory_mask, | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| memory_key_padding_mask=memory_key_padding_mask, | |
| pos=pos, | |
| query_pos=query_pos, | |
| ) | |
| if self.return_intermediate: | |
| intermediate.append(self.norm(output)) | |
| if self.norm is not None: | |
| output = self.norm(output) | |
| if self.return_intermediate: | |
| intermediate.pop() | |
| intermediate.append(output) | |
| if self.return_intermediate: | |
| return torch.stack(intermediate) | |
| return output.unsqueeze(0) | |
| class TransformerEncoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model, | |
| nhead, | |
| dim_feedforward=2048, | |
| dropout=0.1, | |
| activation="relu", | |
| normalize_before=False, ): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
| # Implementation of Feedforward model | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| def with_pos_embed(self, x, pos): | |
| return x if pos is None else x + pos | |
| def forward_post( | |
| self, | |
| src, | |
| src_mask = None, | |
| src_key_padding_mask = None, | |
| pos = None,): | |
| q = k = self.with_pos_embed(src, pos) | |
| src2 = self.self_attn( | |
| q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask | |
| )[0] | |
| src = src + self.dropout1(src2) | |
| src = self.norm1(src) | |
| src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) | |
| src = src + self.dropout2(src2) | |
| src = self.norm2(src) | |
| return src | |
| def forward_pre( | |
| self, | |
| src, | |
| src_mask = None, | |
| src_key_padding_mask = None, | |
| pos = None,): | |
| src2 = self.norm1(src) | |
| q = k = self.with_pos_embed(src2, pos) | |
| src2 = self.self_attn( | |
| q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask | |
| )[0] | |
| src = src + self.dropout1(src2) | |
| src2 = self.norm2(src) | |
| src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) | |
| src = src + self.dropout2(src2) | |
| return src | |
| def forward( | |
| self, | |
| src, | |
| src_mask = None, | |
| src_key_padding_mask = None, | |
| pos = None,): | |
| if self.normalize_before: | |
| return self.forward_pre(src, src_mask, src_key_padding_mask, pos) | |
| return self.forward_post(src, src_mask, src_key_padding_mask, pos) | |
| class TransformerDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model, | |
| nhead, | |
| dim_feedforward=2048, | |
| dropout=0.1, | |
| activation="relu", | |
| normalize_before=False,): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
| self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
| # Implementation of Feedforward model | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.dropout3 = nn.Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| def with_pos_embed(self, x, pos): | |
| return x if pos is None else x + pos | |
| def forward_post( | |
| self, | |
| tgt, | |
| memory, | |
| tgt_mask = None, | |
| memory_mask = None, | |
| tgt_key_padding_mask = None, | |
| memory_key_padding_mask = None, | |
| pos = None, | |
| query_pos = None,): | |
| q = k = self.with_pos_embed(tgt, query_pos) | |
| tgt2 = self.self_attn( | |
| q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] | |
| tgt = tgt + self.dropout1(tgt2) | |
| tgt = self.norm1(tgt) | |
| tgt2 = self.multihead_attn( | |
| query=self.with_pos_embed(tgt, query_pos), | |
| key=self.with_pos_embed(memory, pos), | |
| value=memory, | |
| attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask,)[0] | |
| tgt = tgt + self.dropout2(tgt2) | |
| tgt = self.norm2(tgt) | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
| tgt = tgt + self.dropout3(tgt2) | |
| tgt = self.norm3(tgt) | |
| return tgt | |
| def forward_pre( | |
| self, | |
| tgt, | |
| memory, | |
| tgt_mask = None, | |
| memory_mask = None, | |
| tgt_key_padding_mask = None, | |
| memory_key_padding_mask = None, | |
| pos = None, | |
| query_pos = None,): | |
| tgt2 = self.norm1(tgt) | |
| q = k = self.with_pos_embed(tgt2, query_pos) | |
| tgt2 = self.self_attn( | |
| q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask | |
| )[0] | |
| tgt = tgt + self.dropout1(tgt2) | |
| tgt2 = self.norm2(tgt) | |
| tgt2 = self.multihead_attn( | |
| query=self.with_pos_embed(tgt2, query_pos), | |
| key=self.with_pos_embed(memory, pos), | |
| value=memory, | |
| attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask, | |
| )[0] | |
| tgt = tgt + self.dropout2(tgt2) | |
| tgt2 = self.norm3(tgt) | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) | |
| tgt = tgt + self.dropout3(tgt2) | |
| return tgt | |
| def forward( | |
| self, | |
| tgt, | |
| memory, | |
| tgt_mask = None, | |
| memory_mask = None, | |
| tgt_key_padding_mask = None, | |
| memory_key_padding_mask = None, | |
| pos = None, | |
| query_pos = None, ): | |
| if self.normalize_before: | |
| return self.forward_pre( | |
| tgt, | |
| memory, | |
| tgt_mask, | |
| memory_mask, | |
| tgt_key_padding_mask, | |
| memory_key_padding_mask, | |
| pos, | |
| query_pos,) | |
| return self.forward_post( | |
| tgt, | |
| memory, | |
| tgt_mask, | |
| memory_mask, | |
| tgt_key_padding_mask, | |
| memory_key_padding_mask, | |
| pos, | |
| query_pos,) | |
| class SelfAttentionLayer(nn.Module): | |
| def __init__(self, d_model, nhead, dropout=0.0, | |
| activation="relu", normalize_before=False): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def with_pos_embed(self, tensor, pos): | |
| return tensor if pos is None else tensor + pos | |
| def forward_post(self, tgt, | |
| tgt_mask = None, | |
| tgt_key_padding_mask = None, | |
| query_pos = None): | |
| q = k = self.with_pos_embed(tgt, query_pos).transpose(0 ,1) | |
| tgt2 = self.self_attn(q, k, value=tgt.transpose(0 ,1), attn_mask=tgt_mask, | |
| key_padding_mask=tgt_key_padding_mask)[0] | |
| tgt = tgt + self.dropout(tgt2.transpose(0 ,1)) | |
| tgt = self.norm(tgt) | |
| return tgt | |
| def forward_pre(self, tgt, | |
| tgt_mask = None, | |
| tgt_key_padding_mask = None, | |
| query_pos = None): | |
| tgt2 = self.norm(tgt) | |
| q = k = self.with_pos_embed(tgt2, query_pos) | |
| tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, | |
| key_padding_mask=tgt_key_padding_mask)[0] | |
| tgt = tgt + self.dropout(tgt2) | |
| return tgt | |
| def forward(self, tgt, | |
| tgt_mask = None, | |
| tgt_key_padding_mask = None, | |
| query_pos = None): | |
| if self.normalize_before: | |
| return self.forward_pre(tgt, tgt_mask, | |
| tgt_key_padding_mask, query_pos) | |
| return self.forward_post(tgt, tgt_mask, | |
| tgt_key_padding_mask, query_pos) | |
| class CrossAttentionLayer(nn.Module): | |
| def __init__(self, d_model, nhead, dropout=0.0, | |
| activation="relu", normalize_before=False): | |
| super().__init__() | |
| self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def with_pos_embed(self, tensor, pos): | |
| return tensor if pos is None else tensor + pos | |
| def forward_post(self, tgt, memory, | |
| memory_mask = None, | |
| memory_key_padding_mask = None, | |
| pos = None, | |
| query_pos = None): | |
| tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos).transpose(0, 1), | |
| key=self.with_pos_embed(memory, pos).transpose(0, 1), | |
| value=memory.transpose(0, 1), attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask)[0] | |
| tgt = tgt + self.dropout(tgt2.transpose(0, 1)) | |
| tgt = self.norm(tgt) | |
| return tgt | |
| def forward_pre(self, tgt, memory, | |
| memory_mask = None, | |
| memory_key_padding_mask = None, | |
| pos = None, | |
| query_pos = None): | |
| tgt2 = self.norm(tgt) | |
| tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), | |
| key=self.with_pos_embed(memory, pos), | |
| value=memory, attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask)[0] | |
| tgt = tgt + self.dropout(tgt2) | |
| return tgt | |
| def forward(self, tgt, memory, | |
| memory_mask = None, | |
| memory_key_padding_mask = None, | |
| pos = None, | |
| query_pos = None): | |
| if self.normalize_before: | |
| return self.forward_pre(tgt, memory, memory_mask, | |
| memory_key_padding_mask, pos, query_pos) | |
| return self.forward_post(tgt, memory, memory_mask, | |
| memory_key_padding_mask, pos, query_pos) | |
| class FFNLayer(nn.Module): | |
| def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, | |
| activation="relu", normalize_before=False): | |
| super().__init__() | |
| # Implementation of Feedforward model | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def with_pos_embed(self, tensor, pos): | |
| return tensor if pos is None else tensor + pos | |
| def forward_post(self, tgt): | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
| tgt = tgt + self.dropout(tgt2) | |
| tgt = self.norm(tgt) | |
| return tgt | |
| def forward_pre(self, tgt): | |
| tgt2 = self.norm(tgt) | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) | |
| tgt = tgt + self.dropout(tgt2) | |
| return tgt | |
| def forward(self, tgt): | |
| if self.normalize_before: | |
| return self.forward_pre(tgt) | |
| return self.forward_post(tgt) | |
| class MLP(nn.Module): | |
| """ Very simple multi-layer perceptron (also called FFN)""" | |
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| h = [hidden_dim] * (num_layers - 1) | |
| self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
| return x | |
| class Seet_OneFormer_TDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| mask_classification, | |
| num_classes, | |
| hidden_dim, | |
| num_queries, | |
| nheads, | |
| dropout, | |
| dim_feedforward, | |
| enc_layers, | |
| is_train, | |
| dec_layers, | |
| class_dec_layers, | |
| pre_norm, | |
| mask_dim, | |
| enforce_input_project, | |
| use_task_norm,): | |
| super().__init__() | |
| assert mask_classification, "Only support mask classification model" | |
| self.mask_classification = mask_classification | |
| self.is_train = is_train | |
| self.use_task_norm = use_task_norm | |
| # positional encoding | |
| N_steps = hidden_dim // 2 | |
| self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) | |
| self.class_transformer = Transformer( | |
| d_model=hidden_dim, | |
| dropout=dropout, | |
| nhead=nheads, | |
| dim_feedforward=dim_feedforward, | |
| num_encoder_layers=enc_layers, | |
| num_decoder_layers=class_dec_layers, | |
| normalize_before=pre_norm, | |
| return_intermediate_dec=False, | |
| ) | |
| # define Transformer decoder here | |
| self.num_heads = nheads | |
| self.num_layers = dec_layers | |
| self.transformer_self_attention_layers = nn.ModuleList() | |
| self.transformer_cross_attention_layers = nn.ModuleList() | |
| self.transformer_ffn_layers = nn.ModuleList() | |
| for _ in range(self.num_layers): | |
| self.transformer_self_attention_layers.append( | |
| SelfAttentionLayer( | |
| d_model=hidden_dim, | |
| nhead=nheads, | |
| dropout=0.0, | |
| normalize_before=pre_norm, | |
| ) | |
| ) | |
| self.transformer_cross_attention_layers.append( | |
| CrossAttentionLayer( | |
| d_model=hidden_dim, | |
| nhead=nheads, | |
| dropout=0.0, | |
| normalize_before=pre_norm, | |
| ) | |
| ) | |
| self.transformer_ffn_layers.append( | |
| FFNLayer( | |
| d_model=hidden_dim, | |
| dim_feedforward=dim_feedforward, | |
| dropout=0.0, | |
| normalize_before=pre_norm, | |
| ) | |
| ) | |
| self.decoder_norm = nn.LayerNorm(hidden_dim) | |
| self.num_queries = num_queries | |
| # learnable query p.e. | |
| self.query_embed = nn.Embedding(num_queries, hidden_dim) | |
| # level embedding (we always use 3 scales) | |
| self.num_feature_levels = 3 | |
| self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) | |
| self.input_proj = nn.ModuleList() | |
| for _ in range(self.num_feature_levels): | |
| if in_channels != hidden_dim or enforce_input_project: | |
| self.input_proj.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1)) | |
| weight_init.c2_xavier_fill(self.input_proj[-1]) | |
| else: | |
| self.input_proj.append(nn.Sequential()) | |
| self.class_input_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1) | |
| weight_init.c2_xavier_fill(self.class_input_proj) | |
| # output FFNs | |
| if self.mask_classification: | |
| self.class_embed = nn.Linear(hidden_dim, num_classes + 1) | |
| self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) | |
| def forward(self, x, mask_features, tasks): | |
| # x is a list of multi-scale feature | |
| assert len(x) == self.num_feature_levels | |
| src = [] | |
| pos = [] | |
| size_list = [] | |
| for i in range(self.num_feature_levels): | |
| size_list.append(x[i].shape[-2:]) | |
| pos.append(self.pe_layer(x[i], None).flatten(2)) | |
| src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) | |
| pos[-1] = pos[-1].transpose(1, 2) | |
| src[-1] = src[-1].transpose(1, 2) | |
| bs, _, _ = src[0].shape | |
| query_embed = self.query_embed.weight.unsqueeze(0).repeat(bs, 1, 1) | |
| tasks = tasks.unsqueeze(0) | |
| if self.use_task_norm: | |
| tasks = self.decoder_norm(tasks) | |
| feats = self.pe_layer(mask_features, None) | |
| out_t, _ = self.class_transformer( | |
| feats, None, | |
| self.query_embed.weight[:-1], | |
| self.class_input_proj(mask_features), | |
| tasks if self.use_task_norm else None) | |
| out_t = out_t[0] | |
| out = torch.cat([out_t, tasks], dim=1) | |
| output = out.clone() | |
| predictions_class = [] | |
| predictions_mask = [] | |
| # prediction heads on learnable query features | |
| outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads( | |
| output, mask_features, attn_mask_target_size=size_list[0]) | |
| predictions_class.append(outputs_class) | |
| predictions_mask.append(outputs_mask) | |
| for i in range(self.num_layers): | |
| level_index = i % self.num_feature_levels | |
| attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False | |
| output = self.transformer_cross_attention_layers[i]( | |
| output, src[level_index], | |
| memory_mask=attn_mask, | |
| memory_key_padding_mask=None, | |
| pos=pos[level_index], query_pos=query_embed, ) | |
| output = self.transformer_self_attention_layers[i]( | |
| output, tgt_mask=None, | |
| tgt_key_padding_mask=None, | |
| query_pos=query_embed, ) | |
| # FFN | |
| output = self.transformer_ffn_layers[i](output) | |
| outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads( | |
| output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels]) | |
| predictions_class.append(outputs_class) | |
| predictions_mask.append(outputs_mask) | |
| assert len(predictions_class) == self.num_layers + 1 | |
| out = { | |
| 'pred_logits': predictions_class[-1], | |
| 'pred_masks': predictions_mask[-1],} | |
| return out | |
| def forward_prediction_heads(self, output, mask_features, attn_mask_target_size): | |
| decoder_output = self.decoder_norm(output) | |
| outputs_class = self.class_embed(decoder_output) | |
| mask_embed = self.mask_embed(decoder_output) | |
| outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) | |
| attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) | |
| attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() | |
| attn_mask = attn_mask.detach() | |
| return outputs_class, outputs_mask, attn_mask | |