Spaces:
Running
Running
| # Copyright (c) EPFL VILAB. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # -------------------------------------------------------- | |
| # Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv and MAE code bases | |
| # https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
| # https://github.com/facebookresearch/deit | |
| # https://github.com/facebookresearch/dino | |
| # https://github.com/facebookresearch/moco-v3 | |
| # https://github.com/microsoft/unilm/tree/master/beit | |
| # https://github.com/BUPT-PRIV/MAE-priv | |
| # https://github.com/facebookresearch/mae | |
| # -------------------------------------------------------- | |
| import itertools | |
| import math | |
| from collections import OrderedDict | |
| from functools import partial | |
| from typing import Dict, List, Optional, Union | |
| import torch | |
| from einops import rearrange, repeat | |
| from torch import nn | |
| from torch.distributions.dirichlet import Dirichlet | |
| from utils.registry import register_model | |
| from .multimae_utils import Block, trunc_normal_ | |
| __all__ = [ | |
| 'pretrain_multimae_base', | |
| 'pretrain_multimae_large', | |
| 'multivit_base', | |
| 'multivit_large', | |
| ] | |
| class MultiMAE(nn.Module): | |
| """MultiMAE: Multi-task Multi-modal Masked Autoencoder | |
| This module performs masking in its forward pass. | |
| The MultiViT module defined below inherits from this module and performs a regular forward pass, | |
| and should be used instead for downstream tasks | |
| :param input_adapters: Dictionary of task -> input adapters | |
| :param output_adapters: Optional dictionary of task -> output adapters | |
| :param num_global_tokens: Number of additional global tokens to add (like cls tokens), default is 1 | |
| :param dim_tokens: Dimension of encoder tokens | |
| :param depth: Depth of encoder | |
| :param num_heads: Number of attention heads | |
| :param mlp_ratio: MLP hidden dim ratio | |
| :param qkv_bias: Set to False to disable bias | |
| :param drop_rate: Dropout after MLPs and Attention | |
| :param attn_drop_rate: Attention matrix drop rate | |
| :param drop_path_rate: DropPath drop rate | |
| :param norm_layer: Type of normalization layer | |
| """ | |
| def __init__(self, | |
| input_adapters: Dict[str, nn.Module], | |
| output_adapters: Optional[Dict[str, nn.Module]], | |
| num_global_tokens: int = 1, | |
| dim_tokens: int = 768, | |
| depth: int = 12, | |
| num_heads: int = 12, | |
| mlp_ratio: float = 4.0, | |
| qkv_bias: bool = True, | |
| drop_rate: float = 0.0, | |
| attn_drop_rate: float = 0.0, | |
| drop_path_rate: float = 0.0, | |
| norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6)): | |
| super().__init__() | |
| # Initialize input and output adapters | |
| for adapter in input_adapters.values(): | |
| adapter.init(dim_tokens=dim_tokens) | |
| self.input_adapters = nn.ModuleDict(input_adapters) | |
| if output_adapters is not None: | |
| for adapter in output_adapters.values(): | |
| adapter.init(dim_tokens_enc=dim_tokens) | |
| self.output_adapters = nn.ModuleDict(output_adapters) | |
| else: | |
| self.output_adapters = None | |
| # Additional learnable tokens that can be used by encoder to process/store global information | |
| self.num_global_tokens = num_global_tokens | |
| self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, dim_tokens)) | |
| trunc_normal_(self.global_tokens, std=0.02) | |
| # Transformer encoder | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule | |
| self.encoder = nn.Sequential(*[ | |
| Block(dim=dim_tokens, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, | |
| drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) | |
| for i in range(depth) | |
| ]) | |
| self.apply(self._init_weights) | |
| for name, m in self.named_modules(): | |
| if isinstance(m, nn.Linear): | |
| if 'qkv' in name: | |
| # treat the weights of Q, K, V separately | |
| val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) | |
| nn.init.uniform_(m.weight, -val, val) | |
| elif 'kv' in name: | |
| # treat the weights of K, V separately | |
| val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1])) | |
| nn.init.uniform_(m.weight, -val, val) | |
| if isinstance(m, nn.Conv2d): | |
| if '.proj' in name: | |
| # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d) | |
| w = m.weight.data | |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def get_num_layers(self): | |
| return len(self.encoder) | |
| def no_weight_decay(self): | |
| no_wd_set = {'global_tokens'} | |
| for task, adapter in self.input_adapters.items(): | |
| if hasattr(adapter, 'no_weight_decay'): | |
| to_skip = adapter.no_weight_decay() | |
| to_skip = set([f'input_adapters.{task}.{name}' for name in to_skip]) | |
| no_wd_set = no_wd_set | to_skip | |
| for task, adapter in self.output_adapters.items(): | |
| if hasattr(adapter, 'no_weight_decay'): | |
| to_skip = adapter.no_weight_decay() | |
| to_skip = set([f'output_adapters.{task}.{name}' for name in to_skip]) | |
| no_wd_set = no_wd_set | to_skip | |
| return no_wd_set | |
| def sample_alphas(self, B: int, n_tasks: int, alphas: float = 1.0, eps: float = 1e-5): | |
| """ | |
| Sample alphas for Dirichlet sampling such that tasks are first uniformly chosen and then Dirichlet sampling | |
| is performed over the chosen ones. | |
| :param B: Batch size | |
| :param n_tasks: Number of input tasks | |
| :param alphas: Float or list to multiply task choices {0,1} by | |
| :param eps: Small constant since Dirichlet alphas need to be positive | |
| """ | |
| valid_task_choices = torch.Tensor([list(i) for i in itertools.product([0, 1], repeat=n_tasks)][1:]) | |
| rand_per_sample_choice = torch.randint(0, len(valid_task_choices), (B,)) | |
| alphas_tensor = torch.index_select(valid_task_choices, 0, rand_per_sample_choice) | |
| alphas_tensor = alphas_tensor * torch.tensor(alphas) + eps | |
| return alphas_tensor | |
| def generate_random_masks(self, | |
| input_tokens: Dict[str, torch.Tensor], | |
| num_encoded_tokens: int, | |
| alphas: Union[float, List[float]] = 1.0, | |
| sample_tasks_uniformly: bool = False) : | |
| """ | |
| Sample a total of num_encoded_tokens from different tasks using Dirichlet sampling. | |
| :param input_tokens: Dictionary of tensors to sample num_encoded_tokens from | |
| :param num_encoded_tokens: Number of tokens to select | |
| :param alphas: Dirichlet distribution parameter alpha. Lower alpha = harder, | |
| less uniform sampling. Can be float or list of floats. | |
| :param sample_tasks_uniformly: Set to True to first sample 1-n_tasks uniformly at random | |
| for each sample in the batch. Dirichlet sampling is then done over selected subsets. | |
| """ | |
| B = list(input_tokens.values())[0].shape[0] | |
| device = list(input_tokens.values())[0].device | |
| alphas = [alphas] * len(input_tokens) if isinstance(alphas, float) else alphas | |
| if sample_tasks_uniformly: | |
| alphas = self.sample_alphas(B, len(input_tokens), alphas=alphas) | |
| task_sampling_dist = Dirichlet(alphas).sample().to(device) | |
| else: | |
| task_sampling_dist = Dirichlet(torch.Tensor(alphas)).sample((B,)).to(device) | |
| samples_per_task = (task_sampling_dist * num_encoded_tokens).round().long() | |
| task_masks = [] | |
| num_tokens_per_task = [task_tokens.shape[1] for task_tokens in input_tokens.values()] | |
| for i, num_tokens in enumerate(num_tokens_per_task): | |
| # Use noise to shuffle arange | |
| noise = torch.rand(B, num_tokens, device=device) # noise in [0, 1] | |
| ids_arange_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove | |
| mask = torch.arange(num_tokens, device=device).unsqueeze(0).expand(B, -1) | |
| mask = torch.gather(mask, dim=1, index=ids_arange_shuffle) | |
| # 0 is keep (unmasked), 1 is remove (masked) | |
| mask = torch.where(mask < samples_per_task[:, i].unsqueeze(1), 0, 1) | |
| task_masks.append(mask) | |
| mask_all = torch.cat(task_masks, dim=1) | |
| ids_shuffle = torch.argsort(mask_all + torch.rand_like(mask_all.float()), dim=1) | |
| ids_restore = torch.argsort(ids_shuffle, dim=1) | |
| ids_keep = ids_shuffle[:, :num_encoded_tokens] | |
| # Update binary mask to adjust for task rounding | |
| mask_all = torch.ones_like(mask_all) | |
| mask_all[:, :num_encoded_tokens] = 0 | |
| # Unshuffle to get the binary mask | |
| mask_all = torch.gather(mask_all, dim=1, index=ids_restore) | |
| # Split to get task masks | |
| task_masks = torch.split(mask_all, num_tokens_per_task, dim=1) | |
| # Convert to dict | |
| task_masks = {domain: mask for domain, mask in zip(input_tokens.keys(), task_masks)} | |
| return task_masks, ids_keep, ids_restore | |
| def make_mask(N_H, N_W, xy_idxs, full_tasks=[], indicate_visible=True, flatten=True, device='cuda'): | |
| """ | |
| Creates masks for each task, given lists of un-masked x,y coordinates. | |
| """ | |
| xy_idxs = { | |
| k: torch.LongTensor(v) | |
| for k, v in xy_idxs.items() | |
| } | |
| task_masks = { | |
| k: torch.ones(N_H, N_W).to(device) | |
| for k in xy_idxs.keys() | |
| } | |
| for k in xy_idxs.keys(): | |
| if len(xy_idxs[k]) > 0: | |
| task_masks[k][xy_idxs[k][:, 1], xy_idxs[k][:, 0]] = 0 | |
| for task in full_tasks: | |
| task_masks[task][:] = 0 | |
| if not indicate_visible: | |
| task_masks = {k: 1 - v for k, v in task_masks.items()} | |
| if flatten: | |
| task_masks = {k: v.flatten().unsqueeze(0) for k, v in task_masks.items()} | |
| return task_masks | |
| def generate_input_info(self, input_task_tokens, image_size): | |
| input_info = OrderedDict() | |
| i = 0 | |
| input_info['tasks'] = {} | |
| for domain, tensor in input_task_tokens.items(): | |
| num_tokens = tensor.shape[1] | |
| d = { | |
| 'num_tokens': num_tokens, | |
| 'has_2d_posemb': True, # TODO: Modify when adding non-2D tasks | |
| 'start_idx': i, | |
| 'end_idx': i + num_tokens, | |
| } | |
| i += num_tokens | |
| input_info['tasks'][domain] = d | |
| input_info['image_size'] = image_size | |
| input_info['num_task_tokens'] = i | |
| input_info['num_global_tokens'] = self.num_global_tokens | |
| return input_info | |
| def forward(self, | |
| x: Union[Dict[str, torch.Tensor], torch.Tensor], | |
| mask_inputs: bool = True, | |
| task_masks: Dict[str, torch.Tensor] = None, | |
| num_encoded_tokens: int = 128, | |
| alphas: Union[float, List[float]] = 1.0, | |
| sample_tasks_uniformly: bool = False, | |
| fp32_output_adapters: List[str] = []): | |
| """ | |
| Forward pass through input adapters, transformer encoder and output adapters. | |
| If specified, will randomly drop input tokens. | |
| :param x: Input tensor or dictionary of tensors | |
| :param mask_inputs: Set to True to enable random masking of input patches | |
| :param task_masks: Optional dictionary of task->mask pairs. | |
| :param num_encoded_tokens: Number of tokens to randomly select for encoder. | |
| Only used if mask_inputs is True. | |
| :param alphas: Dirichlet distribution parameter alpha for task sampling. | |
| Higher alpha = harder, less uniform sampling. Can be float or list of floats. | |
| :param sample_tasks_uniformly: Set to True if tasks should be uniformly presampled, | |
| before Dirichlet sampling decides share of masked tokens between them. | |
| :param fp32_output_adapters: List of task identifiers to force output adapters to | |
| run with mixed precision turned off for stability reasons. | |
| """ | |
| ## Processing input modalities | |
| # If input x is a Tensor, assume it's RGB | |
| x = {'rgb': x} if isinstance(x, torch.Tensor) else x | |
| # Need image size for tokens->image reconstruction | |
| # We assume that at least one of rgb or semseg is given as input before masking | |
| if 'rgb' in x: | |
| B, C, H, W = x['rgb'].shape | |
| elif 'semseg' in x: | |
| B, H, W = x['semseg'].shape | |
| H *= self.input_adapters['semseg'].stride_level | |
| W *= self.input_adapters['semseg'].stride_level | |
| else: | |
| B, C, H, W = list(x.values())[0].shape # TODO: Deal with case where not all have same shape | |
| # Encode selected inputs to tokens | |
| input_task_tokens = { | |
| domain: self.input_adapters[domain](tensor) | |
| for domain, tensor in x.items() | |
| if domain in self.input_adapters | |
| } | |
| input_info = self.generate_input_info(input_task_tokens=input_task_tokens, image_size=(H, W)) | |
| # Select random subset of tokens from the chosen input tasks and concatenate them | |
| if mask_inputs: | |
| num_encoded_tokens = num_encoded_tokens if num_encoded_tokens is not None else self.num_encoded_tokens | |
| else: | |
| num_encoded_tokens = sum([tensor.shape[1] for tensor in input_task_tokens.values()]) | |
| ## Generating masks | |
| if task_masks is None: | |
| task_masks, ids_keep, ids_restore = self.generate_random_masks( | |
| input_task_tokens, | |
| num_encoded_tokens, | |
| alphas=alphas, | |
| sample_tasks_uniformly=sample_tasks_uniformly | |
| ) | |
| else: | |
| mask_all = torch.cat([task_masks[task] for task in input_task_tokens.keys()], dim=1) | |
| ids_shuffle = torch.argsort(mask_all, dim=1) | |
| ids_restore = torch.argsort(ids_shuffle, dim=1) | |
| ids_keep = ids_shuffle[:, :(mask_all == 0).sum()] | |
| input_tokens = torch.cat([task_tokens for task_tokens in input_task_tokens.values()], dim=1) | |
| # Apply mask | |
| input_tokens = torch.gather(input_tokens, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, input_tokens.shape[2])) | |
| # Add global tokens to input tokens | |
| global_tokens = repeat(self.global_tokens, '() n d -> b n d', b=B) | |
| input_tokens = torch.cat([input_tokens, global_tokens], dim=1) | |
| ## Transformer forward pass | |
| encoder_tokens = self.encoder(input_tokens) | |
| ## Output decoders | |
| if self.output_adapters is None: | |
| return encoder_tokens, task_masks | |
| # Decode tokens for each task using task-specific output adapters | |
| preds = { | |
| domain: self.output_adapters[domain]( | |
| encoder_tokens=encoder_tokens, | |
| input_info=input_info, | |
| ids_keep=ids_keep, | |
| ids_restore=ids_restore, | |
| ) | |
| for domain in self.output_adapters | |
| if domain not in fp32_output_adapters | |
| } | |
| # Force running selected output adapters in fp32 mode | |
| with torch.cuda.amp.autocast(enabled=False): | |
| for domain in fp32_output_adapters: | |
| if domain not in self.output_adapters: | |
| continue | |
| preds[domain] = self.output_adapters[domain]( | |
| encoder_tokens=encoder_tokens.float(), | |
| input_info=input_info, | |
| ids_keep=ids_keep, | |
| ids_restore=ids_restore, | |
| ) | |
| return preds, task_masks | |
| def pretrain_multimae_base( | |
| input_adapters: Dict[str, nn.Module], | |
| output_adapters: Optional[Dict[str, nn.Module]], | |
| **kwargs): | |
| model = MultiMAE( | |
| input_adapters=input_adapters, | |
| output_adapters=output_adapters, | |
| dim_tokens=768, | |
| depth=12, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| def pretrain_multimae_large( | |
| input_adapters: Dict[str, nn.Module], | |
| output_adapters: Optional[Dict[str, nn.Module]], | |
| **kwargs): | |
| model = MultiMAE( | |
| input_adapters=input_adapters, | |
| output_adapters=output_adapters, | |
| dim_tokens=1024, | |
| depth=24, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| class MultiViT(MultiMAE): | |
| """MultiViT: Multi-modal Vision Transformer | |
| This is MultiMAE without masking and with a simplified / faster forward pass | |
| :param input_adapters: Dictionary of task -> input adapters | |
| :param output_adapters: Optional dictionary of task -> output adapters | |
| :param num_global_tokens: Number of additional global tokens to add (like cls tokens), default is 1 | |
| :param dim_tokens: Dimension of encoder tokens | |
| :param depth: Depth of encoder | |
| :param num_heads: Number of attention heads | |
| :param mlp_ratio: MLP hidden dim ratio | |
| :param qkv_bias: Set to False to disable bias | |
| :param drop_rate: Dropout after MLPs and Attention | |
| :param attn_drop_rate: Attention matrix drop rate | |
| :param drop_path_rate: DropPath drop rate | |
| :param norm_layer: Type of normalization layer | |
| """ | |
| def process_input(self, x): | |
| # If input x is a Tensor, assume it's RGB | |
| x = {'rgb': x} if isinstance(x, torch.Tensor) else x | |
| # Need image size for tokens->image reconstruction | |
| if 'rgb' in x: | |
| B, _, H, W = x['rgb'].shape | |
| elif 'semseg' in x: | |
| B, H, W = x['semseg'].shape | |
| H *= self.input_adapters['semseg'].stride_level | |
| W *= self.input_adapters['semseg'].stride_level | |
| else: | |
| B, _, H, W = list(x.values())[0].shape # TODO: Deal with case where not all have same shape | |
| # Encode selected inputs to tokens | |
| input_task_tokens = { | |
| domain: self.input_adapters[domain](tensor) | |
| for domain, tensor in x.items() | |
| if domain in self.input_adapters | |
| } | |
| input_info = self.generate_input_info(input_task_tokens=input_task_tokens, image_size=(H, W)) | |
| input_tokens = torch.cat([task_tokens for task_tokens in input_task_tokens.values()], dim=1) | |
| # Add global tokens to input tokens | |
| global_tokens = repeat(self.global_tokens, '() n d -> b n d', b=B) | |
| input_tokens = torch.cat([input_tokens, global_tokens], dim=1) | |
| return input_tokens, input_info | |
| def forward(self, x: Union[Dict[str, torch.Tensor], torch.Tensor], return_all_layers=False, **kwargs): | |
| """ | |
| Forward pass through input adapters, transformer encoder and output adapters. | |
| :param x: Input tensor or dictionary of tensors | |
| :param return_all_layers: Set to True to return all transformer layers | |
| """ | |
| input_tokens, input_info = self.process_input(x) | |
| # Pass tokens through Transformer | |
| if not return_all_layers: | |
| encoder_tokens = self.encoder(input_tokens) | |
| else: | |
| # Optionally access every intermediate layer | |
| encoder_tokens = [] | |
| tokens = input_tokens | |
| for block in self.encoder: | |
| tokens = block(tokens) | |
| encoder_tokens.append(tokens) | |
| if self.output_adapters is None: | |
| return encoder_tokens | |
| # Decode tokens for each task using task-specific output adapters | |
| preds = { | |
| domain: self.output_adapters[domain]( | |
| encoder_tokens=encoder_tokens, | |
| input_info=input_info, | |
| ) | |
| for domain in self.output_adapters | |
| } | |
| return preds | |
| def multivit_base( | |
| input_adapters: Dict[str, nn.Module], | |
| output_adapters: Optional[Dict[str, nn.Module]], | |
| **kwargs): | |
| model = MultiViT( | |
| input_adapters=input_adapters, | |
| output_adapters=output_adapters, | |
| dim_tokens=768, | |
| depth=12, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| def multivit_large( | |
| input_adapters: Dict[str, nn.Module], | |
| output_adapters: Optional[Dict[str, nn.Module]], | |
| **kwargs): | |
| model = MultiViT( | |
| input_adapters=input_adapters, | |
| output_adapters=output_adapters, | |
| dim_tokens=1024, | |
| depth=24, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |