Spaces:
Running
Running
| # Copyright (c) EPFL VILAB. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # -------------------------------------------------------- | |
| # Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv and MAE code bases | |
| # https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
| # https://github.com/facebookresearch/deit | |
| # https://github.com/facebookresearch/dino | |
| # https://github.com/facebookresearch/moco-v3 | |
| # https://github.com/microsoft/unilm/tree/master/beit | |
| # https://github.com/BUPT-PRIV/MAE-priv | |
| # https://github.com/facebookresearch/mae | |
| # -------------------------------------------------------- | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from .multimae_utils import build_2d_sincos_posemb, pair, trunc_normal_ | |
| class PatchedInputAdapter(nn.Module): | |
| """Adapter for spatial inputs, like images or feature maps. | |
| Creates tokens from patches over the image. | |
| :param num_channels: Number of input channels of the image/feature map | |
| :param stride_level: Stride level compared to the full-sized image. | |
| E.g. 4 for 1/4th the size of the image. | |
| :param patch_size_full: Int or tuple of the patch size over the full image size. | |
| Patch size for smaller inputs will be computed accordingly. | |
| :param dim_tokens: Dimension of output tokens. Can be set using init method. | |
| :param sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings | |
| :param learnable_pos_emb: Set to True to learn positional embeddings instead | |
| :param image_size: Default image size. Used to initialize size of positional embeddings. | |
| """ | |
| def __init__(self, | |
| num_channels: int, | |
| stride_level: int, | |
| patch_size_full: Union[int, Tuple[int,int]], | |
| dim_tokens: Optional[int] = None, | |
| sincos_pos_emb: bool = True, | |
| learnable_pos_emb: bool = False, | |
| image_size: Union[int, Tuple[int]] = 224): | |
| super().__init__() | |
| self.num_channels = num_channels | |
| self.stride_level = stride_level | |
| self.patch_size_full = pair(patch_size_full) | |
| self.dim_tokens = dim_tokens | |
| self.sincos_pos_emb = sincos_pos_emb | |
| self.learnable_pos_emb = learnable_pos_emb | |
| self.image_size = pair(image_size) | |
| self.num_patches = (self.image_size[0] // patch_size_full) * (self.image_size[1] // patch_size_full) | |
| # Actual patch height and width, taking into account stride of input | |
| self.P_H = max(1, self.patch_size_full[0] // stride_level) | |
| self.P_W = max(1, self.patch_size_full[1] // stride_level) | |
| if self.dim_tokens is not None: | |
| self.init(dim_tokens=dim_tokens) | |
| def init(self, dim_tokens: int = 768): | |
| """ | |
| Initialize parts of encoder that are dependent on dimension of tokens. | |
| Should be called when setting up MultiMAE. | |
| :param dim_tokens: Dimension of tokens | |
| """ | |
| self.dim_tokens = dim_tokens | |
| # Task embedding identifying from which task a given token comes from | |
| # Fixed-size positional embeddings. Can be interpolated to different input sizes | |
| h_posemb = self.image_size[0] // (self.stride_level * self.P_H) | |
| w_posemb = self.image_size[1] // (self.stride_level * self.P_W) | |
| if self.sincos_pos_emb: | |
| self.pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens) | |
| self.pos_emb = nn.Parameter(self.pos_emb, requires_grad=self.learnable_pos_emb) | |
| else: | |
| self.pos_emb = nn.Parameter(torch.zeros(1, self.dim_tokens, h_posemb, w_posemb)) | |
| trunc_normal_(self.pos_emb, std=0.02) | |
| # Image -> tokens projection | |
| self.proj = nn.Conv2d( | |
| in_channels=self.num_channels, out_channels=self.dim_tokens, | |
| kernel_size=(self.P_H, self.P_W), stride=(self.P_H, self.P_W) | |
| ) | |
| def no_weight_decay(self): | |
| return {'pos_emb'} | |
| def forward(self, x): | |
| """ | |
| Forward pass through input adapter, transforming image to sequence of tokens. | |
| Adds task and positional encodings. | |
| :param x: Input image tensor | |
| """ | |
| B, C, H, W = x.shape | |
| assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first' | |
| assert (H % self.P_H == 0) and (W % self.P_W == 0), f'Image sizes {H}x{W} must be divisible by patch sizes {self.P_H}x{self.P_W}' | |
| N_H, N_W = H // self.P_H, W // self.P_W # Number of patches in height and width | |
| # Create patches [B, C, H, W] -> [B, (H*W), C] | |
| x_patch = rearrange(self.proj(x), 'b d nh nw -> b (nh nw) d') | |
| # Create positional embedding | |
| x_pos_emb = F.interpolate(self.pos_emb, size=(N_H, N_W), mode='bicubic', align_corners=False) | |
| x_pos_emb = rearrange(x_pos_emb, 'b d nh nw -> b (nh nw) d') | |
| # Add patches and positional embeddings | |
| x = x_patch + x_pos_emb | |
| return x | |
| class SemSegInputAdapter(nn.Module): | |
| """ | |
| Adapter for spatial inputs, like images or feature maps. | |
| Creates tokens from patches over the image. | |
| :param num_classes: Number of input semantic classes | |
| :param stride_level: Stride level compared to the full-sized image. | |
| E.g. 4 for 1/4th the size of the image. | |
| :param patch_size_full: Int or tuple of the patch size over the full image size. | |
| Patch size for smaller inputs will be computed accordingly. | |
| :param dim_tokens: Dimension of output tokens. Can be set using init method. | |
| :param sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings | |
| :param learnable_pos_emb: Set to True to learn positional embeddings instead | |
| :param image_size: Default image size. Used to initialize size of positional embeddings. | |
| :param dim_class_emb: Dimension of learned class embedding | |
| :param interpolate_class_emb: Set to True to average pool class embeddings of each patch | |
| :param emb_padding_idx: Padding index (e.g. image border), default is None | |
| """ | |
| def __init__(self, | |
| num_classes: int, | |
| stride_level: int, | |
| patch_size_full: Union[int, Tuple[int, int]], | |
| dim_tokens: Optional[int] = None, | |
| sincos_pos_emb: int = True, | |
| learnable_pos_emb: int = False, | |
| image_size: Union[int, Tuple[int]] = 224, | |
| dim_class_emb: int = 64, | |
| interpolate_class_emb: bool = False, | |
| emb_padding_idx: int = None | |
| ): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.stride_level = stride_level | |
| self.patch_size_full = pair(patch_size_full) | |
| self.dim_tokens = dim_tokens | |
| self.sincos_pos_emb = sincos_pos_emb | |
| self.learnable_pos_emb = learnable_pos_emb | |
| self.image_size = pair(image_size) | |
| self.dim_class_emb = dim_class_emb | |
| self.interpolate_class_emb = interpolate_class_emb | |
| self.emb_padding_idx = emb_padding_idx | |
| if self.emb_padding_idx is not None: | |
| self.num_classes += 1 | |
| # Actual patch height and width, taking into account stride of input | |
| self.P_H = max(1, self.patch_size_full[0] // stride_level) | |
| self.P_W = max(1, self.patch_size_full[1] // stride_level) | |
| if self.dim_tokens is not None: | |
| self.init(dim_tokens=dim_tokens) | |
| def init(self, dim_tokens: int = 768): | |
| ''' | |
| Initialize parts of encoder that are dependent on dimension of tokens. | |
| Should be called when setting up MultiMAE. | |
| :param dim_tokens: Dimension of tokens | |
| ''' | |
| self.dim_tokens = dim_tokens | |
| # Task embedding identifying from which task a given token comes from | |
| # Fixed-size positional embeddings. Can be interpolated to different input sizes | |
| h_posemb = self.image_size[0] // (self.stride_level * self.P_H) | |
| w_posemb = self.image_size[1] // (self.stride_level * self.P_W) | |
| if self.sincos_pos_emb: | |
| self.pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens) | |
| self.pos_emb = nn.Parameter(self.pos_emb, requires_grad=self.learnable_pos_emb) | |
| else: | |
| self.pos_emb = nn.Parameter(torch.zeros(1, self.dim_tokens, h_posemb, w_posemb)) | |
| trunc_normal_(self.pos_emb, std=0.02) | |
| # Image -> tokens projection | |
| self.class_emb = nn.Embedding(num_embeddings=self.num_classes, embedding_dim=self.dim_class_emb, padding_idx=self.emb_padding_idx) | |
| trunc_normal_(self.class_emb.weight, std=0.02) | |
| if self.interpolate_class_emb: | |
| self.proj = nn.Sequential( | |
| nn.Upsample(scale_factor=(1 / self.P_H, 1 / self.P_W), | |
| mode='bilinear'), # Actually a downsample operation | |
| nn.Conv2d(in_channels=self.dim_class_emb, out_channels=self.dim_tokens, | |
| kernel_size=1, stride=1), | |
| ) | |
| else: | |
| self.proj = nn.Conv2d( | |
| in_channels=self.dim_class_emb, out_channels=self.dim_tokens, | |
| kernel_size=(self.P_H, self.P_W), stride=(self.P_H, self.P_W) | |
| ) | |
| def no_weight_decay(self): | |
| return {'pos_emb', 'class_emb'} | |
| def forward(self, x): | |
| ''' | |
| Forward pass through input adapter, transforming image to sequence of tokens. | |
| Adds task and positional encodings. | |
| :param x: Input image tensor | |
| ''' | |
| B, H, W = x.shape | |
| assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first' | |
| assert (H % self.P_H == 0) and ( | |
| W % self.P_W == 0), f'Image sizes {H}x{W} must be divisible by patch sizes {self.P_H}x{self.P_W}' | |
| N_H, N_W = H // self.P_H, W // self.P_W # Number of patches in height and width | |
| # Map to embedding | |
| x = rearrange(self.class_emb(x), 'b nh nw c -> b c nh nw') | |
| # Create patches [B, C, H, W] -> [B, (H*W), C] | |
| x_patch = rearrange(self.proj(x), 'b d nh nw -> b (nh nw) d') | |
| # Create positional embedding | |
| x_pos_emb = F.interpolate(self.pos_emb, size=(N_H, N_W), mode='bilinear') | |
| x_pos_emb = rearrange(x_pos_emb, 'b d nh nw -> b (nh nw) d') | |
| # Add patches and positional embeddings | |
| x = x_patch + x_pos_emb | |
| return x | |