Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Dict, Optional | |
| from omegaconf import DictConfig | |
| import torch | |
| import torch.nn as nn | |
| from tracker.model.group_modules import GConv2d | |
| from tracker.utils.tensor_utils import aggregate | |
| from tracker.model.transformer.positional_encoding import PositionalEncoding | |
| from tracker.model.transformer.transformer_layers import * | |
| class QueryTransformerBlock(nn.Module): | |
| def __init__(self, model_cfg: DictConfig): | |
| super().__init__() | |
| this_cfg = model_cfg.object_transformer | |
| self.embed_dim = this_cfg.embed_dim | |
| self.num_heads = this_cfg.num_heads | |
| self.num_queries = this_cfg.num_queries | |
| self.ff_dim = this_cfg.ff_dim | |
| self.read_from_pixel = CrossAttention(self.embed_dim, | |
| self.num_heads, | |
| add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv) | |
| self.self_attn = SelfAttention(self.embed_dim, | |
| self.num_heads, | |
| add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv) | |
| self.ffn = FFN(self.embed_dim, self.ff_dim) | |
| self.read_from_query = CrossAttention(self.embed_dim, | |
| self.num_heads, | |
| add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv, | |
| norm=this_cfg.read_from_query.output_norm) | |
| self.pixel_ffn = PixelFFN(self.embed_dim) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| pixel: torch.Tensor, | |
| query_pe: torch.Tensor, | |
| pixel_pe: torch.Tensor, | |
| attn_mask: torch.Tensor, | |
| need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): | |
| # x: (bs*num_objects)*num_queries*embed_dim | |
| # pixel: bs*num_objects*C*H*W | |
| # query_pe: (bs*num_objects)*num_queries*embed_dim | |
| # pixel_pe: (bs*num_objects)*(H*W)*C | |
| # attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W) | |
| # bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C | |
| pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() | |
| x, q_weights = self.read_from_pixel(x, | |
| pixel_flat, | |
| query_pe, | |
| pixel_pe, | |
| attn_mask=attn_mask, | |
| need_weights=need_weights) | |
| x = self.self_attn(x, query_pe) | |
| x = self.ffn(x) | |
| pixel_flat, p_weights = self.read_from_query(pixel_flat, | |
| x, | |
| pixel_pe, | |
| query_pe, | |
| need_weights=need_weights) | |
| pixel = self.pixel_ffn(pixel, pixel_flat) | |
| if need_weights: | |
| bs, num_objects, _, h, w = pixel.shape | |
| q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w) | |
| p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads, | |
| self.num_queries, h, w) | |
| return x, pixel, q_weights, p_weights | |
| class QueryTransformer(nn.Module): | |
| def __init__(self, model_cfg: DictConfig): | |
| super().__init__() | |
| this_cfg = model_cfg.object_transformer | |
| self.value_dim = model_cfg.value_dim | |
| self.embed_dim = this_cfg.embed_dim | |
| self.num_heads = this_cfg.num_heads | |
| self.num_queries = this_cfg.num_queries | |
| # query initialization and embedding | |
| self.query_init = nn.Embedding(self.num_queries, self.embed_dim) | |
| self.query_emb = nn.Embedding(self.num_queries, self.embed_dim) | |
| # projection from object summaries to query initialization and embedding | |
| self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim) | |
| self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim) | |
| self.pixel_pe_scale = model_cfg.pixel_pe_scale | |
| self.pixel_pe_temperature = model_cfg.pixel_pe_temperature | |
| self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) | |
| self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) | |
| self.spatial_pe = PositionalEncoding(self.embed_dim, | |
| scale=self.pixel_pe_scale, | |
| temperature=self.pixel_pe_temperature, | |
| channel_last=False, | |
| transpose_output=True) | |
| # transformer blocks | |
| self.num_blocks = this_cfg.num_blocks | |
| self.blocks = nn.ModuleList( | |
| QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks)) | |
| self.mask_pred = nn.ModuleList( | |
| nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1)) | |
| for _ in range(self.num_blocks + 1)) | |
| self.act = nn.ReLU(inplace=True) | |
| def forward(self, | |
| pixel: torch.Tensor, | |
| obj_summaries: torch.Tensor, | |
| selector: Optional[torch.Tensor] = None, | |
| need_weights: bool = False) -> (torch.Tensor, Dict[str, torch.Tensor]): | |
| # pixel: B*num_objects*embed_dim*H*W | |
| # obj_summaries: B*num_objects*T*num_queries*embed_dim | |
| T = obj_summaries.shape[2] | |
| bs, num_objects, _, H, W = pixel.shape | |
| # normalize object values | |
| # the last channel is the cumulative area of the object | |
| obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries, | |
| self.embed_dim + 1) | |
| # sum over time | |
| # during inference, T=1 as we already did streaming average in memory_manager | |
| obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1) | |
| obj_area = obj_summaries[:, :, :, -1:].sum(dim=1) | |
| obj_values = obj_sums / (obj_area + 1e-4) | |
| obj_init = self.summary_to_query_init(obj_values) | |
| obj_emb = self.summary_to_query_emb(obj_values) | |
| # positional embeddings for object queries | |
| query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init | |
| query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb | |
| # positional embeddings for pixel features | |
| pixel_init = self.pixel_init_proj(pixel) | |
| pixel_emb = self.pixel_emb_proj(pixel) | |
| pixel_pe = self.spatial_pe(pixel.flatten(0, 1)) | |
| pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() | |
| pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb | |
| pixel = pixel_init | |
| # run the transformer | |
| aux_features = {'logits': []} | |
| # first aux output | |
| aux_logits = self.mask_pred[0](pixel).squeeze(2) | |
| attn_mask = self._get_aux_mask(aux_logits, selector) | |
| aux_features['logits'].append(aux_logits) | |
| for i in range(self.num_blocks): | |
| query, pixel, q_weights, p_weights = self.blocks[i](query, | |
| pixel, | |
| query_emb, | |
| pixel_pe, | |
| attn_mask, | |
| need_weights=need_weights) | |
| if self.training or i <= self.num_blocks - 1 or need_weights: | |
| aux_logits = self.mask_pred[i + 1](pixel).squeeze(2) | |
| attn_mask = self._get_aux_mask(aux_logits, selector) | |
| aux_features['logits'].append(aux_logits) | |
| aux_features['q_weights'] = q_weights # last layer only | |
| aux_features['p_weights'] = p_weights # last layer only | |
| if self.training: | |
| # no need to save all heads | |
| aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads, | |
| self.num_queries, H, W)[:, :, 0] | |
| return pixel, aux_features | |
| def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor: | |
| # logits: batch_size*num_objects*H*W | |
| # selector: batch_size*num_objects*1*1 | |
| # returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W) | |
| # where True means the attention is blocked | |
| if selector is None: | |
| prob = logits.sigmoid() | |
| else: | |
| prob = logits.sigmoid() * selector | |
| logits = aggregate(prob, dim=1) | |
| is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0]) | |
| foreground_mask = is_foreground.bool().flatten(start_dim=2) | |
| inv_foreground_mask = ~foreground_mask | |
| inv_background_mask = foreground_mask | |
| aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat( | |
| 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) | |
| aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat( | |
| 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) | |
| aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1) | |
| aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False | |
| return aux_mask |