Spaces:
Running
on
Zero
Running
on
Zero
| # Modified from PyTorch nn.Transformer | |
| from typing import List, Callable | |
| import torch | |
| from torch import Tensor | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tracker.model.channel_attn import CAResBlock | |
| class SelfAttention(nn.Module): | |
| def __init__(self, | |
| dim: int, | |
| nhead: int, | |
| dropout: float = 0.0, | |
| batch_first: bool = True, | |
| add_pe_to_qkv: List[bool] = [True, True, False]): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first) | |
| self.norm = nn.LayerNorm(dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.add_pe_to_qkv = add_pe_to_qkv | |
| def forward(self, | |
| x: torch.Tensor, | |
| pe: torch.Tensor, | |
| attn_mask: bool = None, | |
| key_padding_mask: bool = None) -> torch.Tensor: | |
| x = self.norm(x) | |
| if any(self.add_pe_to_qkv): | |
| x_with_pe = x + pe | |
| q = x_with_pe if self.add_pe_to_qkv[0] else x | |
| k = x_with_pe if self.add_pe_to_qkv[1] else x | |
| v = x_with_pe if self.add_pe_to_qkv[2] else x | |
| else: | |
| q = k = v = x | |
| r = x | |
| x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] | |
| return r + self.dropout(x) | |
| # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention | |
| class CrossAttention(nn.Module): | |
| def __init__(self, | |
| dim: int, | |
| nhead: int, | |
| dropout: float = 0.0, | |
| batch_first: bool = True, | |
| add_pe_to_qkv: List[bool] = [True, True, False], | |
| residual: bool = True, | |
| norm: bool = True): | |
| super().__init__() | |
| self.cross_attn = nn.MultiheadAttention(dim, | |
| nhead, | |
| dropout=dropout, | |
| batch_first=batch_first) | |
| if norm: | |
| self.norm = nn.LayerNorm(dim) | |
| else: | |
| self.norm = nn.Identity() | |
| self.dropout = nn.Dropout(dropout) | |
| self.add_pe_to_qkv = add_pe_to_qkv | |
| self.residual = residual | |
| def forward(self, | |
| x: torch.Tensor, | |
| mem: torch.Tensor, | |
| x_pe: torch.Tensor, | |
| mem_pe: torch.Tensor, | |
| attn_mask: bool = None, | |
| *, | |
| need_weights: bool = False) -> (torch.Tensor, torch.Tensor): | |
| x = self.norm(x) | |
| if self.add_pe_to_qkv[0]: | |
| q = x + x_pe | |
| else: | |
| q = x | |
| if any(self.add_pe_to_qkv[1:]): | |
| mem_with_pe = mem + mem_pe | |
| k = mem_with_pe if self.add_pe_to_qkv[1] else mem | |
| v = mem_with_pe if self.add_pe_to_qkv[2] else mem | |
| else: | |
| k = v = mem | |
| r = x | |
| x, weights = self.cross_attn(q, | |
| k, | |
| v, | |
| attn_mask=attn_mask, | |
| need_weights=need_weights, | |
| average_attn_weights=False) | |
| if self.residual: | |
| return r + self.dropout(x), weights | |
| else: | |
| return self.dropout(x), weights | |
| class FFN(nn.Module): | |
| def __init__(self, dim_in: int, dim_ff: int, activation=F.relu): | |
| super().__init__() | |
| self.linear1 = nn.Linear(dim_in, dim_ff) | |
| self.linear2 = nn.Linear(dim_ff, dim_in) | |
| self.norm = nn.LayerNorm(dim_in) | |
| if isinstance(activation, str): | |
| self.activation = _get_activation_fn(activation) | |
| else: | |
| self.activation = activation | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| r = x | |
| x = self.norm(x) | |
| x = self.linear2(self.activation(self.linear1(x))) | |
| x = r + x | |
| return x | |
| class PixelFFN(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| self.dim = dim | |
| self.conv = CAResBlock(dim, dim) | |
| def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor: | |
| # pixel: batch_size * num_objects * dim * H * W | |
| # pixel_flat: (batch_size*num_objects) * (H*W) * dim | |
| bs, num_objects, _, h, w = pixel.shape | |
| pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim) | |
| pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous() | |
| x = self.conv(pixel_flat) | |
| x = x.view(bs, num_objects, self.dim, h, w) | |
| return x | |
| class OutputFFN(nn.Module): | |
| def __init__(self, dim_in: int, dim_out: int, activation=F.relu): | |
| super().__init__() | |
| self.linear1 = nn.Linear(dim_in, dim_out) | |
| self.linear2 = nn.Linear(dim_out, dim_out) | |
| if isinstance(activation, str): | |
| self.activation = _get_activation_fn(activation) | |
| else: | |
| self.activation = activation | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.linear2(self.activation(self.linear1(x))) | |
| return x | |
| def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: | |
| if activation == "relu": | |
| return F.relu | |
| elif activation == "gelu": | |
| return F.gelu | |
| raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) | |