Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # dpt head implementation for DUST3R | |
| # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; | |
| # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True | |
| # the forward function also takes as input a dictionnary img_info with key "height" and "width" | |
| # for PixelwiseTask, the output will be of dimension B x num_channels x H x W | |
| # -------------------------------------------------------- | |
| from einops import rearrange | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| from utils.dust3r.heads.postprocess import postprocess | |
| # import utils.dust3r.utils.path_to_croco # noqa: F401 | |
| from utils.dust3r.dpt_block import DPTOutputAdapter # noqa | |
| from pdb import set_trace as st | |
| class DPTOutputAdapter_fix(DPTOutputAdapter): | |
| """ | |
| Adapt croco's DPTOutputAdapter implementation for dust3r: | |
| remove duplicated weigths, and fix forward for dust3r | |
| """ | |
| def init(self, dim_tokens_enc=768): | |
| super().init(dim_tokens_enc) | |
| # these are duplicated weights | |
| del self.act_1_postprocess | |
| del self.act_2_postprocess | |
| del self.act_3_postprocess | |
| del self.act_4_postprocess | |
| def forward(self, encoder_tokens: List[torch.Tensor], image_size=None): | |
| assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' | |
| # H, W = input_info['image_size'] | |
| image_size = self.image_size if image_size is None else image_size | |
| H, W = image_size | |
| # Number of patches in height and width | |
| N_H = H // (self.stride_level * self.P_H) | |
| N_W = W // (self.stride_level * self.P_W) | |
| # Hook decoder onto 4 layers from specified ViT layers | |
| layers = [encoder_tokens[hook] for hook in self.hooks] | |
| # Extract only task-relevant tokens and ignore global tokens. | |
| layers = [self.adapt_tokens(l) for l in layers] | |
| # Reshape tokens to spatial representation | |
| layers = [ | |
| rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) | |
| for l in layers | |
| ] | |
| # st() | |
| layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] | |
| # Project layers to chosen feature dim | |
| layers = [ | |
| self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers) | |
| ] | |
| # Fuse layers using refinement stages | |
| path_4 = self.scratch.refinenet4( | |
| layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]] | |
| path_3 = self.scratch.refinenet3(path_4, layers[2]) | |
| path_2 = self.scratch.refinenet2(path_3, layers[1]) | |
| path_1 = self.scratch.refinenet1(path_2, layers[0]) | |
| # Output head | |
| out = self.head(path_1) | |
| return out | |
| class PixelwiseTaskWithDPT(nn.Module): | |
| """ DPT module for dust3r, can return 3D points + confidence for all pixels""" | |
| def __init__(self, | |
| *, | |
| n_cls_token=0, | |
| hooks_idx=None, | |
| dim_tokens=None, | |
| output_width_ratio=1, | |
| num_channels=1, | |
| postprocess=None, | |
| depth_mode=None, | |
| conf_mode=None, | |
| **kwargs): | |
| super(PixelwiseTaskWithDPT, self).__init__() | |
| self.return_all_layers = True # backbone needs to return all layers | |
| self.postprocess = postprocess | |
| self.depth_mode = depth_mode | |
| self.conf_mode = conf_mode | |
| assert n_cls_token == 0, "Not implemented" | |
| dpt_args = dict(output_width_ratio=output_width_ratio, | |
| num_channels=num_channels, | |
| **kwargs) | |
| if hooks_idx is not None: | |
| dpt_args.update(hooks=hooks_idx) | |
| self.dpt = DPTOutputAdapter_fix(**dpt_args) | |
| dpt_init_args = {} if dim_tokens is None else { | |
| 'dim_tokens_enc': dim_tokens | |
| } | |
| self.dpt.init(**dpt_init_args) | |
| # ! remove unused param | |
| del self.dpt.scratch.refinenet4.resConfUnit1 | |
| def forward(self, x, img_info): | |
| out = self.dpt(x, image_size=(img_info[0], img_info[1])) | |
| if self.postprocess: | |
| out = self.postprocess(out, self.depth_mode, self.conf_mode) | |
| return out | |
| def create_dpt_head(net, has_conf=False): | |
| """ | |
| return PixelwiseTaskWithDPT for given net params | |
| """ | |
| assert net.dec_depth > 9 | |
| l2 = net.dec_depth | |
| feature_dim = 256 | |
| last_dim = feature_dim // 2 | |
| out_nchan = 3 | |
| ed = net.enc_embed_dim | |
| dd = net.dec_embed_dim | |
| return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, | |
| feature_dim=feature_dim, | |
| last_dim=last_dim, | |
| hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], | |
| dim_tokens=[ed, dd, dd, dd], | |
| postprocess=postprocess, | |
| # postprocess=None, | |
| depth_mode=net.depth_mode, | |
| conf_mode=net.conf_mode, | |
| head_type='regression') | |
| # def create_dpt_head_ln3diff(net, has_conf=False): | |
| def create_dpt_head_ln3diff(out_nchan, feature_dim, l2, dec_embed_dim, | |
| patch_size=2, has_conf=False): | |
| """ | |
| return PixelwiseTaskWithDPT for given net params | |
| """ | |
| # assert net.dec_depth > 9 | |
| # l2 = net.dec_depth | |
| # feature_dim = 256 | |
| last_dim = feature_dim // 2 | |
| # out_nchan = 3 | |
| # ed = net.enc_embed_dim | |
| # dd = net.dec_embed_dim | |
| dd = dec_embed_dim | |
| return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, | |
| feature_dim=feature_dim, | |
| last_dim=last_dim, | |
| patch_size=patch_size, | |
| hooks_idx=[(l2 * 1 // 4)-1, (l2 * 2 // 4)-1, (l2 * 3 // 4)-1, l2-1], | |
| # dim_tokens=[ed, dd, dd, dd], | |
| dim_tokens=[dd, dd, dd, dd], | |
| # postprocess=postprocess, | |
| postprocess=None, | |
| # depth_mode=net.depth_mode, | |
| # conf_mode=net.conf_mode, | |
| head_type='regression_gs') | |