Spaces:
Running
Running
| # Copyright (c) EPFL VILAB. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # -------------------------------------------------------- | |
| # Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv MAE, DPT and ConvNeXt code bases | |
| # https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
| # https://github.com/facebookresearch/deit | |
| # https://github.com/facebookresearch/dino | |
| # https://github.com/facebookresearch/moco-v3 | |
| # https://github.com/microsoft/unilm/tree/master/beit | |
| # https://github.com/BUPT-PRIV/MAE-priv | |
| # https://github.com/facebookresearch/mae | |
| # https://github.com/isl-org/DPT | |
| # https://github.com/facebookresearch/ConvNeXt | |
| # -------------------------------------------------------- | |
| from functools import partial | |
| from typing import Dict, Iterable, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from .multimae_utils import (Block, CrossAttention, Mlp, | |
| build_2d_sincos_posemb, pair, trunc_normal_) | |
| from .output_adapter_utils import (ConvNeXtBlock, Interpolate, | |
| make_fusion_block, make_scratch) | |
| class SpatialOutputAdapter(nn.Module): | |
| """Cross-attention adapter for spatial outputs, like images or feature maps. | |
| :param num_channels: Number of input channels of the image/feature map | |
| :param stride_level: Stride level compared to the full-sized image. | |
| E.g. 4 for 1/4th the size of the image. | |
| :param patch_size_full: Int or tuple of the patch size over the full image size. | |
| Patch size for smaller inputs will be computed accordingly. | |
| :param dim_tokens_enc: Dimension of tokens coming from encoder. Can be set using init method. | |
| :param dim_tokens: Dimension of decoder tokens | |
| :param depth: Number of additional (full self-attention) transformer layers after initial cross attention and MLP | |
| :param learnable_pos_emb: Set to True to learn positional embeddings instead | |
| :param image_size: Default image size. Used to initialize size of positional embeddings. | |
| :param mlp_ratio: MLP hidden dim ratio | |
| :param num_heads: Number of attention heads | |
| :param qkv_bias: Set to True to enable bias | |
| :param drop_rate: Probability of dropping attention layer outputs | |
| :param attn_drop_rate: Probability of dropping attention matrix elements | |
| :param drop_path_rate: DropPath drop rate | |
| :param norm_layer: Type of normalization layer | |
| :param use_task_queries: When set to True, adds task specific tokens from encoder (if available) | |
| to the corresponding query entries | |
| :param task: Task for which encoder tokens are added to the queries of the decoder (e.g. RGB if decoder is used for RGB) | |
| :param context_tasks: Tasks / modalities from the encoder. Used to create learned embeddings for each task. | |
| :param use_xattn: When set to True, attend to the tokens from the encoder through a cross-attention layer | |
| """ | |
| def __init__(self, | |
| num_channels: int, | |
| stride_level: int, | |
| patch_size_full: Union[int, Tuple[int, int]], | |
| dim_tokens_enc: Optional[int] = None, | |
| dim_tokens: int = 256, | |
| depth: int = 0, | |
| learnable_pos_emb: int = False, | |
| image_size: Union[int, Tuple[int]] = 224, | |
| mlp_ratio: int = 4.0, | |
| num_heads: int = 8, | |
| qkv_bias: bool = True, | |
| drop_rate: float = 0.0, | |
| attn_drop_rate: float = 0.0, | |
| drop_path_rate: float = 0.0, | |
| norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), | |
| use_task_queries: bool = True, | |
| task: Optional[str] = None, | |
| context_tasks: Optional[list] = None, | |
| use_xattn: bool = True | |
| ): | |
| super().__init__() | |
| self.num_channels = num_channels | |
| self.stride_level = stride_level | |
| self.patch_size_full = pair(patch_size_full) | |
| self.dim_tokens_enc = dim_tokens_enc | |
| self.dim_tokens = dim_tokens | |
| self.learnable_pos_emb = learnable_pos_emb | |
| self.image_size = pair(image_size) | |
| self.use_task_queries = use_task_queries | |
| self.task = task | |
| self.use_xattn = use_xattn | |
| # Actual patch height and width, taking into account stride of input | |
| self.P_H = max(1, self.patch_size_full[0] // stride_level) | |
| self.P_W = max(1, self.patch_size_full[1] // stride_level) | |
| if context_tasks is not None: | |
| self.task_embeddings = nn.ParameterDict( | |
| {task: nn.Parameter(torch.zeros(1, 1, self.dim_tokens)) for task in context_tasks}) | |
| for embedding in self.task_embeddings.values(): | |
| trunc_normal_(embedding, std=0.02) | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, self.dim_tokens)) | |
| # Fixed-size positional embeddings. Can be interpolated to different input sizes | |
| h_posemb = self.image_size[0] // (self.stride_level * self.P_H) | |
| w_posemb = self.image_size[1] // (self.stride_level * self.P_W) | |
| if not self.learnable_pos_emb: | |
| self.pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens) | |
| self.pos_emb = nn.Parameter(self.pos_emb, requires_grad=False) | |
| else: | |
| self.pos_emb = nn.Parameter(torch.zeros(1, h_posemb, w_posemb, self.dim_tokens)) | |
| trunc_normal_(self.pos_emb, std=0.02) | |
| # One cross attention layer followed by MLP block, an optional transformer, and an output projection | |
| if self.use_xattn: | |
| self.decoder = CrossAttention( | |
| dim=self.dim_tokens, num_heads=num_heads, qkv_bias=qkv_bias, | |
| attn_drop=attn_drop_rate, proj_drop=drop_rate) | |
| self.context_norm = norm_layer(self.dim_tokens) | |
| self.query_norm = norm_layer(self.dim_tokens) | |
| self.out_norm = norm_layer(self.dim_tokens) | |
| mlp_hidden_dim = int(self.dim_tokens * mlp_ratio) | |
| self.mlp = Mlp(in_features=self.dim_tokens, hidden_features=mlp_hidden_dim) | |
| # Optional full self-attention transformer layers | |
| if depth > 0: | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule | |
| self.decoder_transformer = nn.Sequential(*[ | |
| Block(dim=self.dim_tokens, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, | |
| attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) | |
| for i in range(depth) | |
| ]) | |
| else: | |
| self.decoder_transformer = nn.Identity() | |
| self.dim_patch = self.num_channels * self.P_H * self.P_W | |
| self.out_proj = nn.Linear(self.dim_tokens, self.dim_patch) | |
| if self.dim_tokens_enc is not None: | |
| self.init(dim_tokens_enc=dim_tokens_enc) | |
| def init(self, dim_tokens_enc: int = 768): | |
| ''' | |
| Initialize parts of decoder that are dependent on dimension of encoder tokens. | |
| Should be called when setting up MultiMAE. | |
| :param dim_tokens_enc: Dimension of tokens coming from encoder | |
| ''' | |
| self.dim_tokens_enc = dim_tokens_enc | |
| # Projection of encoder tokens to the patch dimension | |
| self.proj_context = nn.Linear(self.dim_tokens_enc, self.dim_tokens) | |
| def no_weight_decay(self): | |
| return {'pos_emb', 'mask_token', 'task_embeddings'} | |
| def generate_context_embeddings(self, input_info, | |
| bs: int, | |
| size: Tuple[int, int], | |
| device: Optional[torch.device] = None): | |
| context_embeddings = [] | |
| for task, info in input_info["tasks"].items(): | |
| if self.task_embeddings is not None and task in self.task_embeddings: | |
| task_emb = repeat(self.task_embeddings[task], '() () d -> b n d', b=bs, n=info['num_tokens']) | |
| else: | |
| task_emb = torch.zeros((bs, info['num_tokens'], self.dim_tokens), device=device) | |
| if info['has_2d_posemb']: | |
| pos_emb = F.interpolate(self.pos_emb, size=size, mode='bilinear', align_corners=False) | |
| pos_emb = rearrange(pos_emb, 'b d nh nw -> b (nh nw) d') | |
| assert info['num_tokens'] == pos_emb.shape[1] | |
| task_emb = task_emb + pos_emb | |
| context_embeddings.append(task_emb) | |
| context_embeddings = torch.cat(context_embeddings, dim=1) | |
| return context_embeddings | |
| def get_queries_and_context(self, context_tokens, input_info, ids_keep, ids_restore): | |
| B = context_tokens.shape[0] | |
| H, W = input_info['image_size'] | |
| # Number of patches in height and width | |
| N_H = H // (self.stride_level * self.P_H) | |
| N_W = W // (self.stride_level * self.P_W) | |
| if 'num_global_tokens' in input_info: | |
| context_tokens_without_global = context_tokens[:, :-input_info['num_global_tokens']] | |
| else: | |
| context_tokens_without_global = context_tokens | |
| # Add mask tokens | |
| mask_tokens = repeat(self.mask_token, '() () d -> b n d', b=B, | |
| n=input_info['num_task_tokens'] - context_tokens_without_global.shape[1]) | |
| context_with_mask = torch.cat([context_tokens_without_global, mask_tokens], dim=1) | |
| # Unshuffle context_with_mask | |
| context_with_mask = torch.gather(context_with_mask, dim=1, | |
| index=ids_restore.unsqueeze(-1).repeat(1, 1, context_with_mask.shape[2])) | |
| # Generate context_emb and add them to context | |
| context_emb = self.generate_context_embeddings(input_info=input_info, bs=B, size=(N_H, N_W), | |
| device=context_tokens.device) | |
| context_with_mask = context_with_mask + context_emb | |
| # Generate queries | |
| if self.use_task_queries and self.task in input_info['tasks']: | |
| start_idx = input_info['tasks'][self.task]['start_idx'] | |
| end_idx = input_info['tasks'][self.task]['end_idx'] | |
| queries = context_with_mask[:, start_idx:end_idx] | |
| else: | |
| queries = repeat(self.mask_token, '() () d -> b n d', b=B, n=N_H * N_W) | |
| queries_pos_emb = F.interpolate(self.pos_emb, size=(N_H, N_W), mode='bilinear', align_corners=False) | |
| queries_pos_emb = rearrange(queries_pos_emb, 'b d nh nw -> b (nh nw) d') | |
| queries = queries + queries_pos_emb | |
| if self.task_embeddings is not None and self.task in self.task_embeddings: | |
| queries_task_emb = repeat(self.task_embeddings[self.task], '() () d -> b n d', b=B, n=N_H * N_W) | |
| queries = queries + queries_task_emb | |
| # Unshuffle context and keep only initial context (yes, again) | |
| context_tokens_without_global = torch.gather(context_with_mask, dim=1, | |
| index=ids_keep.unsqueeze(-1).repeat(1, 1, context_with_mask.shape[2])) | |
| # Add back global tokens | |
| if 'num_global_tokens' in input_info: | |
| context_tokens = torch.cat( | |
| [context_tokens_without_global, context_tokens[:, -input_info['num_global_tokens']:]], dim=1) | |
| else: | |
| context_tokens = context_tokens_without_global | |
| return queries, context_tokens | |
| def forward(self, | |
| encoder_tokens: torch.Tensor, | |
| input_info: Dict, | |
| ids_keep: torch.Tensor, | |
| ids_restore: torch.Tensor, | |
| ): | |
| """ | |
| Forward pass taking output tokens from encoder and optionally a subset of them corresponding | |
| to this output adapter's task (needs an additional mask describing position of these tokens in the queries). | |
| :param encoder_tokens: Output of encoder | |
| :param input_info: Dictionary with information about the input modalities | |
| :param ids_keep: IDs of unmasked tokens (tokens given to the encoder) | |
| :param ids_restore: IDs to unshuffle tokens | |
| """ | |
| assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' | |
| H, W = input_info['image_size'] | |
| # Number of patches in height and width | |
| N_H = H // (self.stride_level * self.P_H) | |
| N_W = W // (self.stride_level * self.P_W) | |
| # Project encoder tokens to decoder tokens | |
| context_tokens = self.proj_context(encoder_tokens) | |
| # Get queries and context | |
| queries, context_tokens = self.get_queries_and_context(context_tokens, input_info, ids_keep, ids_restore) | |
| # Perform cross attention of queries to context tokens, followed by an MLP | |
| if self.use_xattn: | |
| x = self.decoder(self.query_norm(queries), self.context_norm(context_tokens)) | |
| x = x + self.mlp(self.out_norm(x)) | |
| else: | |
| x = queries | |
| # Optional transformer layers if depth > 0 | |
| x = self.decoder_transformer(x) | |
| # Project each token to (C * P_H * P_W) | |
| x = self.out_proj(x) | |
| # Reshape sequence of patches into image | |
| x = rearrange( | |
| x, 'b (nh nw) (c ph pw) -> b c (nh ph) (nw pw)', | |
| nh=N_H, nw=N_W, ph=self.P_H, pw=self.P_W, c=self.num_channels | |
| ) | |
| return x | |
| class LinearOutputAdapter(nn.Module): | |
| """ | |
| Linear output adapter. | |
| :param num_classes: Number of classes | |
| :param dim_tokens_enc: Dimension of tokens from the encoder | |
| :param use_mean_pooling: When set to True, uses mean pooling before linear classification head. | |
| Otherwise, use last token (usually the global token) | |
| :param norm_layer: Normalization layer | |
| :param init_scale: Initialization scale for linear classification head | |
| """ | |
| def __init__(self, | |
| num_classes: int, | |
| dim_tokens_enc: Optional[int] = None, | |
| use_mean_pooling: bool = True, | |
| norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), | |
| init_scale: float = 1.0): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.dim_tokens_enc = dim_tokens_enc | |
| self.use_mean_pooling = use_mean_pooling | |
| self.norm_layer = norm_layer | |
| self.init_scale = init_scale | |
| if self.dim_tokens_enc is not None: | |
| self.init(dim_tokens_enc=dim_tokens_enc) | |
| def init(self, dim_tokens_enc: int = 768): | |
| """ | |
| Initialize parts of decoder that are dependent on dimension of encoder tokens. | |
| Should be called when setting up MultiMAE. | |
| :param dim_tokens_enc: Dimension of tokens coming from encoder | |
| """ | |
| self.dim_tokens_enc = dim_tokens_enc | |
| self.norm = self.norm_layer(self.dim_tokens_enc) | |
| self.head = nn.Linear(dim_tokens_enc, self.num_classes) if self.num_classes > 0 else nn.Identity() | |
| self.apply(self._init_weights) | |
| self.head.weight.data.mul_(self.init_scale) | |
| self.head.bias.data.mul_(self.init_scale) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def get_classifier(self): | |
| return self.head | |
| def reset_classifier(self, num_classes, global_pool=''): | |
| self.num_classes = num_classes | |
| self.init(dim_tokens_enc=self.dim_tokens_enc) | |
| def forward(self, | |
| encoder_tokens: torch.Tensor, | |
| **kwargs): | |
| if self.use_mean_pooling: | |
| x = encoder_tokens.mean(1) | |
| else: | |
| # Global token is added at the end | |
| x = encoder_tokens[:, -1] | |
| x = self.head(self.norm(x)) | |
| return x | |
| class SegmenterMaskTransformerAdapter(nn.Module): | |
| """Output adapter inspired by the Segmenter-Mask architecture | |
| This head is the implementation of `Segmenter: <https://arxiv.org/abs/2105.05633>`_. | |
| :param num_classes: Number of classes | |
| :param depth: Depth of decoder | |
| :param num_heads: Number of attention heads | |
| :param embed_dim: Dimension of decoder tokens | |
| :param mlp_ratio: MLP hidden dim ratio | |
| :param drop_path_rate: DropPath drop rate | |
| :param drop_rate: Dropout after MLPs and Attention | |
| :param attn_drop_rate: Attention matrix drop rate | |
| :param qkv_bias: Set to False to disable bias | |
| :param main_tasks: Tasks to use for the adapter. Only tokens coming from these tasks are kept. | |
| :param patch_size: Size of patches | |
| :param norm_layer: Type of normalization layer | |
| """ | |
| def __init__( | |
| self, | |
| num_classes, | |
| depth: int = 2, | |
| num_heads: int = 12, | |
| embed_dim: int = 768, | |
| mlp_ratio=4, | |
| drop_path_rate=0.1, | |
| drop_rate=0.0, | |
| attn_drop_rate=0.0, | |
| qkv_bias=True, | |
| main_tasks: str = ('rgb',), | |
| patch_size: int = 16, | |
| norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.main_tasks = main_tasks | |
| self.patch_size = patch_size | |
| self.embed_dim = embed_dim | |
| self.num_classes = num_classes | |
| self.cls_emb = nn.Parameter(torch.zeros(1, num_classes, embed_dim)) | |
| trunc_normal_(self.cls_emb, std=0.02) | |
| self.patch_proj = nn.Linear(embed_dim, embed_dim, bias=False) | |
| self.classes_proj = nn.Linear(embed_dim, embed_dim, bias=False) | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] | |
| self.blocks = nn.ModuleList([ | |
| Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, | |
| attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) | |
| for i in range(depth) | |
| ]) | |
| self.decoder_norm = norm_layer(embed_dim) | |
| self.mask_norm = norm_layer(num_classes) | |
| self.apply(self._init_weights) | |
| def init(self, dim_tokens_enc: int = 768): | |
| """ | |
| Initialize parts of decoder that are dependent on dimension of encoder tokens. | |
| Should be called when setting up MultiMAE. | |
| :param dim_tokens_enc: Dimension of tokens coming from encoder | |
| """ | |
| self.in_channels = dim_tokens_enc * len(self.main_tasks) | |
| # Projection of encoder tokens to the patch dimension | |
| self.proj_dec = nn.Linear(self.in_channels, self.embed_dim) | |
| self._init_weights(self.proj_dec) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def adapt_tokens(self, encoder_tokens, input_info): | |
| # Adapt tokens | |
| x = [] | |
| for task in self.main_tasks: | |
| start_idx = input_info['tasks'][task]['start_idx'] | |
| end_idx = input_info['tasks'][task]['end_idx'] | |
| x.append(encoder_tokens[:, start_idx:end_idx]) | |
| x = torch.cat(x, dim=-1) | |
| return x | |
| def forward(self, encoder_tokens: torch.Tensor, input_info: Dict): | |
| H, W = input_info['image_size'] | |
| N_H, N_W = H // self.patch_size, W // self.patch_size | |
| x = self.adapt_tokens(encoder_tokens, input_info) | |
| x = self.proj_dec(x) | |
| cls_emb = self.cls_emb.expand(x.shape[0], -1, -1) | |
| x = torch.cat((x, cls_emb), 1) | |
| for blk in self.blocks: | |
| x = blk(x) | |
| x = self.decoder_norm(x) | |
| patches = self.patch_proj(x[:, :-self.num_classes]) | |
| cls_seg_feat = self.classes_proj(x[:, -self.num_classes:]) | |
| patches = F.normalize(patches, dim=2, p=2) | |
| cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2) | |
| masks = patches @ cls_seg_feat.transpose(1, 2) | |
| masks = self.mask_norm(masks) | |
| masks = rearrange(masks, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) | |
| # Interpolate to semseg res | |
| masks = F.interpolate(masks, size=(H, W), mode="bilinear") | |
| return masks | |
| class ConvNeXtAdapter(nn.Module): | |
| """Output adapter with ConvNext blocks for semantic segmentation | |
| :param num_classes: Number of classes | |
| :param num_heads: Number of attention heads | |
| :param embed_dim: Token dimension after projection, and before reshaping operation. | |
| :param preds_per_patch: Increases size of feature map by reshaping each patch Each patch gets reshaped | |
| from embed_dim x 1 x 1 to (embed_dim / preds_per_patch) x (preds_per_patch ** 0.5) x (preds_per_patch ** 0.5) | |
| :param main_tasks: Tasks to use for the adapter. Only tokens coming from these tasks are kept. | |
| :param patch_size: Size of patches | |
| :param depth: Number of ConvNeXt blocks | |
| :interpolate_mode: Interpolation mode for final upsampling | |
| """ | |
| def __init__( | |
| self, | |
| num_classes, | |
| embed_dim: int = 6144, | |
| preds_per_patch: int = 16, | |
| main_tasks: Iterable[str] = ('rgb',), | |
| patch_size: int = 16, | |
| depth: int = 4, | |
| interpolate_mode: str = 'bilinear', | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.main_tasks = main_tasks | |
| self.patch_size = patch_size | |
| self.embed_dim = embed_dim | |
| self.preds_per_patch = preds_per_patch | |
| self.class_dim = embed_dim // preds_per_patch | |
| self.num_classes = num_classes | |
| self.interpolate_mode = interpolate_mode | |
| self.blocks = nn.Sequential(*[ | |
| ConvNeXtBlock(dim=self.class_dim) | |
| for _ in range(depth) | |
| ]) | |
| self.final_layer = nn.Conv2d(self.class_dim, self.num_classes, 1) | |
| self.apply(self._init_weights) | |
| def init(self, dim_tokens_enc: int = 768): | |
| """ | |
| Initialize parts of decoder that are dependent on dimension of encoder tokens. | |
| Should be called when setting up MultiMAE. | |
| :param dim_tokens_enc: Dimension of tokens coming from encoder | |
| """ | |
| self.in_channels = dim_tokens_enc * len(self.main_tasks) | |
| # Projection of encoder tokens to the patch dimension | |
| self.proj_dec = nn.Linear(self.in_channels, self.embed_dim) | |
| self._init_weights(self.proj_dec) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def adapt_tokens(self, encoder_tokens, input_info): | |
| # Adapt tokens | |
| x = [] | |
| for task in self.main_tasks: | |
| start_idx = input_info['tasks'][task]['start_idx'] | |
| end_idx = input_info['tasks'][task]['end_idx'] | |
| x.append(encoder_tokens[:, start_idx:end_idx]) | |
| x = torch.cat(x, dim=-1) | |
| return x | |
| def forward(self, encoder_tokens: torch.Tensor, input_info: Dict): | |
| H, W = input_info['image_size'] | |
| N_H, N_W = H // self.patch_size, W // self.patch_size | |
| x = self.adapt_tokens(encoder_tokens, input_info) | |
| x = self.proj_dec(x) | |
| x = rearrange(x, "b n (p c) -> b (n p) c", n=N_H * N_W, p=self.preds_per_patch, c=self.class_dim) | |
| x = rearrange(x, "b (nh nw ph pw) c -> b c (nh ph) (nw pw)", | |
| nh=N_H, nw=N_W, | |
| ph=int(self.preds_per_patch ** 0.5), | |
| pw=int(self.preds_per_patch ** 0.5)) | |
| x = self.blocks(x) | |
| x = self.final_layer(x) | |
| # Interpolate to semseg res | |
| x = F.interpolate(x, size=(H, W), mode=self.interpolate_mode) | |
| return x | |
| class DPTOutputAdapter(nn.Module): | |
| """DPT output adapter. | |
| :param num_classes: Number of output channels | |
| :param stride_level: tride level compared to the full-sized image. | |
| E.g. 4 for 1/4th the size of the image. | |
| :param patch_size_full: Int or tuple of the patch size over the full image size. | |
| Patch size for smaller inputs will be computed accordingly. | |
| :param hooks: Index of intermediate layers | |
| :param layer_dims: Dimension of intermediate layers | |
| :param feature_dim: Feature dimension | |
| :param use_bn: If set to True, activates batch norm | |
| :param dim_tokens_enc: Dimension of tokens coming from encoder | |
| """ | |
| def __init__(self, | |
| num_classes: int = 3, | |
| stride_level: int = 1, | |
| patch_size: Union[int, Tuple[int, int]] = 16, | |
| main_tasks: Iterable[str] = ('rgb',), | |
| hooks: List[int] = [2, 5, 8, 11], | |
| layer_dims: List[int] = [96, 192, 384, 768], | |
| feature_dim: int = 256, | |
| use_bn: bool = False, | |
| dim_tokens_enc: Optional[int] = None, | |
| head_type: str = 'regression', | |
| **kwargs): | |
| super().__init__() | |
| self.num_channels = num_classes | |
| self.stride_level = stride_level | |
| self.patch_size = pair(patch_size) | |
| self.main_tasks = main_tasks | |
| self.hooks = hooks | |
| self.layer_dims = layer_dims | |
| self.feature_dim = feature_dim | |
| self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None | |
| self.head_type = head_type | |
| # Actual patch height and width, taking into account stride of input | |
| self.P_H = max(1, self.patch_size[0] // stride_level) | |
| self.P_W = max(1, self.patch_size[1] // stride_level) | |
| self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False) | |
| self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn) | |
| self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn) | |
| self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn) | |
| self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn) | |
| if self.head_type == 'regression': | |
| # The "DPTDepthModel" head | |
| self.head = nn.Sequential( | |
| nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1), | |
| Interpolate(scale_factor=2, mode="bilinear", align_corners=True), | |
| nn.Conv2d(feature_dim // 2, 32, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(True), | |
| nn.Conv2d(32, self.num_channels, kernel_size=1, stride=1, padding=0) | |
| ) | |
| elif self.head_type == 'semseg': | |
| # The "DPTSegmentationModel" head | |
| self.head = nn.Sequential( | |
| nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(), | |
| nn.ReLU(True), | |
| nn.Dropout(0.1, False), | |
| nn.Conv2d(feature_dim, self.num_channels, kernel_size=1), | |
| Interpolate(scale_factor=2, mode="bilinear", align_corners=True), | |
| ) | |
| else: | |
| raise ValueError('DPT head_type must be "regression" or "semseg".') | |
| if self.dim_tokens_enc is not None: | |
| self.init(dim_tokens_enc=dim_tokens_enc) | |
| def init(self, dim_tokens_enc: int = 768): | |
| """ | |
| Initialize parts of decoder that are dependent on dimension of encoder tokens. | |
| Should be called when setting up MultiMAE. | |
| :param dim_tokens_enc: Dimension of tokens coming from encoder | |
| """ | |
| self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) | |
| # Set up activation postprocessing layers | |
| self.act_1_postprocess = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=self.dim_tokens_enc, | |
| out_channels=self.layer_dims[0], | |
| kernel_size=1, stride=1, padding=0, | |
| ), | |
| nn.ConvTranspose2d( | |
| in_channels=self.layer_dims[0], | |
| out_channels=self.layer_dims[0], | |
| kernel_size=4, stride=4, padding=0, | |
| bias=True, dilation=1, groups=1, | |
| ) | |
| ) | |
| self.act_2_postprocess = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=self.dim_tokens_enc, | |
| out_channels=self.layer_dims[1], | |
| kernel_size=1, stride=1, padding=0, | |
| ), | |
| nn.ConvTranspose2d( | |
| in_channels=self.layer_dims[1], | |
| out_channels=self.layer_dims[1], | |
| kernel_size=2, stride=2, padding=0, | |
| bias=True, dilation=1, groups=1, | |
| ) | |
| ) | |
| self.act_3_postprocess = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=self.dim_tokens_enc, | |
| out_channels=self.layer_dims[2], | |
| kernel_size=1, stride=1, padding=0, | |
| ) | |
| ) | |
| self.act_4_postprocess = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=self.dim_tokens_enc, | |
| out_channels=self.layer_dims[3], | |
| kernel_size=1, stride=1, padding=0, | |
| ), | |
| nn.Conv2d( | |
| in_channels=self.layer_dims[3], | |
| out_channels=self.layer_dims[3], | |
| kernel_size=3, stride=2, padding=1, | |
| ) | |
| ) | |
| self.act_postprocess = nn.ModuleList([ | |
| self.act_1_postprocess, | |
| self.act_2_postprocess, | |
| self.act_3_postprocess, | |
| self.act_4_postprocess | |
| ]) | |
| def adapt_tokens(self, encoder_tokens, input_info): | |
| # Adapt tokens | |
| x = [] | |
| for task in self.main_tasks: | |
| start_idx = input_info['tasks'][task]['start_idx'] | |
| end_idx = input_info['tasks'][task]['end_idx'] | |
| x.append(encoder_tokens[:, start_idx:end_idx]) | |
| x = torch.cat(x, dim=-1) | |
| return x | |
| def forward(self, encoder_tokens: List[torch.Tensor], input_info: Dict): | |
| assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' | |
| H, W = input_info['image_size'] | |
| # Number of patches in height and width | |
| N_H = H // (self.stride_level * self.P_H) | |
| N_W = W // (self.stride_level * self.P_W) | |
| # Hook decoder onto 4 layers from specified ViT layers | |
| layers = [encoder_tokens[hook] for hook in self.hooks] | |
| # Extract only task-relevant tokens and ignore global tokens. | |
| layers = [self.adapt_tokens(l, input_info) for l in layers] | |
| # Reshape tokens to spatial representation | |
| layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] | |
| # Postprocess activations | |
| layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] | |
| # Project layers to chosen feature dim | |
| layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] | |
| # Fuse layers using refinement stages | |
| path_4 = self.scratch.refinenet4(layers[3]) | |
| path_3 = self.scratch.refinenet3(path_4, layers[2]) | |
| path_2 = self.scratch.refinenet2(path_3, layers[1]) | |
| path_1 = self.scratch.refinenet1(path_2, layers[0]) | |
| # Output head | |
| out = self.head(path_1) | |
| return out | |