Spaces:
Running
Running
| # Copyright (C) 2024 Apple Inc. All Rights Reserved. | |
| # Depth Pro: Sharp Monocular Metric Depth in Less Than a Second | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Mapping, Optional, Tuple, Union | |
| import torch | |
| from torch import nn | |
| from torchvision.transforms import ( | |
| Compose, | |
| ConvertImageDtype, | |
| Lambda, | |
| Normalize, | |
| ToTensor, | |
| ) | |
| from .network.decoder import MultiresConvDecoder | |
| from .network.encoder import DepthProEncoder | |
| from .network.fov import FOVNetwork | |
| from .network.vit_factory import VIT_CONFIG_DICT, ViTPreset, create_vit | |
| class DepthProConfig: | |
| """Configuration for DepthPro.""" | |
| patch_encoder_preset: ViTPreset | |
| image_encoder_preset: ViTPreset | |
| decoder_features: int | |
| checkpoint_uri: Optional[str] = None | |
| fov_encoder_preset: Optional[ViTPreset] = None | |
| use_fov_head: bool = True | |
| DEFAULT_MONODEPTH_CONFIG_DICT = DepthProConfig( | |
| patch_encoder_preset="dinov2l16_384", | |
| image_encoder_preset="dinov2l16_384", | |
| checkpoint_uri="./checkpoints/depth_pro.pt", | |
| decoder_features=256, | |
| use_fov_head=True, | |
| fov_encoder_preset="dinov2l16_384", | |
| ) | |
| def create_backbone_model( | |
| preset: ViTPreset | |
| ) -> Tuple[nn.Module, ViTPreset]: | |
| """Create and load a backbone model given a config. | |
| Args: | |
| ---- | |
| preset: A backbone preset to load pre-defind configs. | |
| Returns: | |
| ------- | |
| A Torch module and the associated config. | |
| """ | |
| if preset in VIT_CONFIG_DICT: | |
| config = VIT_CONFIG_DICT[preset] | |
| model = create_vit(preset=preset, use_pretrained=False) | |
| else: | |
| raise KeyError(f"Preset {preset} not found.") | |
| return model, config | |
| def create_model_and_transforms( | |
| config: DepthProConfig = DEFAULT_MONODEPTH_CONFIG_DICT, | |
| device: torch.device = torch.device("cpu"), | |
| precision: torch.dtype = torch.float32, | |
| ) -> Tuple[DepthPro, Compose]: | |
| """Create a DepthPro model and load weights from `config.checkpoint_uri`. | |
| Args: | |
| ---- | |
| config: The configuration for the DPT model architecture. | |
| device: The optional Torch device to load the model onto, default runs on "cpu". | |
| precision: The optional precision used for the model, default is FP32. | |
| Returns: | |
| ------- | |
| The Torch DepthPro model and associated Transform. | |
| """ | |
| patch_encoder, patch_encoder_config = create_backbone_model( | |
| preset=config.patch_encoder_preset | |
| ) | |
| image_encoder, _ = create_backbone_model( | |
| preset=config.image_encoder_preset | |
| ) | |
| fov_encoder = None | |
| if config.use_fov_head and config.fov_encoder_preset is not None: | |
| fov_encoder, _ = create_backbone_model(preset=config.fov_encoder_preset) | |
| dims_encoder = patch_encoder_config.encoder_feature_dims | |
| hook_block_ids = patch_encoder_config.encoder_feature_layer_ids | |
| encoder = DepthProEncoder( | |
| dims_encoder=dims_encoder, | |
| patch_encoder=patch_encoder, | |
| image_encoder=image_encoder, | |
| hook_block_ids=hook_block_ids, | |
| decoder_features=config.decoder_features, | |
| ) | |
| decoder = MultiresConvDecoder( | |
| dims_encoder=[config.decoder_features] + list(encoder.dims_encoder), | |
| dim_decoder=config.decoder_features, | |
| ) | |
| model = DepthPro( | |
| encoder=encoder, | |
| decoder=decoder, | |
| last_dims=(32, 1), | |
| use_fov_head=config.use_fov_head, | |
| fov_encoder=fov_encoder, | |
| ).to(device) | |
| if precision == torch.half: | |
| model.half() | |
| transform = Compose( | |
| [ | |
| ToTensor(), | |
| Lambda(lambda x: x.to(device)), | |
| Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ConvertImageDtype(precision), | |
| ] | |
| ) | |
| if config.checkpoint_uri is not None: | |
| state_dict = torch.load(config.checkpoint_uri, map_location="cpu") | |
| missing_keys, unexpected_keys = model.load_state_dict( | |
| state_dict=state_dict, strict=True | |
| ) | |
| if len(unexpected_keys) != 0: | |
| raise KeyError( | |
| f"Found unexpected keys when loading monodepth: {unexpected_keys}" | |
| ) | |
| # fc_norm is only for the classification head, | |
| # which we would not use. We only use the encoding. | |
| missing_keys = [key for key in missing_keys if "fc_norm" not in key] | |
| if len(missing_keys) != 0: | |
| raise KeyError(f"Keys are missing when loading monodepth: {missing_keys}") | |
| return model, transform | |
| class DepthPro(nn.Module): | |
| """DepthPro network.""" | |
| def __init__( | |
| self, | |
| encoder: DepthProEncoder, | |
| decoder: MultiresConvDecoder, | |
| last_dims: tuple[int, int], | |
| use_fov_head: bool = True, | |
| fov_encoder: Optional[nn.Module] = None, | |
| ): | |
| """Initialize DepthPro. | |
| Args: | |
| ---- | |
| encoder: The DepthProEncoder backbone. | |
| decoder: The MultiresConvDecoder decoder. | |
| last_dims: The dimension for the last convolution layers. | |
| use_fov_head: Whether to use the field-of-view head. | |
| fov_encoder: A separate encoder for the field of view. | |
| """ | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| dim_decoder = decoder.dim_decoder | |
| self.head = nn.Sequential( | |
| nn.Conv2d( | |
| dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1 | |
| ), | |
| nn.ConvTranspose2d( | |
| in_channels=dim_decoder // 2, | |
| out_channels=dim_decoder // 2, | |
| kernel_size=2, | |
| stride=2, | |
| padding=0, | |
| bias=True, | |
| ), | |
| nn.Conv2d( | |
| dim_decoder // 2, | |
| last_dims[0], | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| ), | |
| nn.ReLU(True), | |
| nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0), | |
| nn.ReLU(), | |
| ) | |
| # Set the final convoultion layer's bias to be 0. | |
| self.head[4].bias.data.fill_(0) | |
| # Set the FOV estimation head. | |
| if use_fov_head: | |
| self.fov = FOVNetwork(num_features=dim_decoder, fov_encoder=fov_encoder) | |
| def img_size(self) -> int: | |
| """Return the internal image size of the network.""" | |
| return self.encoder.img_size | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """Decode by projection and fusion of multi-resolution encodings. | |
| Args: | |
| ---- | |
| x (torch.Tensor): Input image. | |
| Returns: | |
| ------- | |
| The canonical inverse depth map [m] and the optional estimated field of view [deg]. | |
| """ | |
| _, _, H, W = x.shape | |
| assert H == self.img_size and W == self.img_size | |
| encodings = self.encoder(x) | |
| features, features_0 = self.decoder(encodings) | |
| canonical_inverse_depth = self.head(features) | |
| fov_deg = None | |
| if hasattr(self, "fov"): | |
| fov_deg = self.fov.forward(x, features_0.detach()) | |
| return canonical_inverse_depth, fov_deg | |
| def infer( | |
| self, | |
| x: torch.Tensor, | |
| f_px: Optional[Union[float, torch.Tensor]] = None, | |
| interpolation_mode="bilinear", | |
| ) -> Mapping[str, torch.Tensor]: | |
| """Infer depth and fov for a given image. | |
| If the image is not at network resolution, it is resized to 1536x1536 and | |
| the estimated depth is resized to the original image resolution. | |
| Note: if the focal length is given, the estimated value is ignored and the provided | |
| focal length is use to generate the metric depth values. | |
| Args: | |
| ---- | |
| x (torch.Tensor): Input image | |
| f_px (torch.Tensor): Optional focal length in pixels corresponding to `x`. | |
| interpolation_mode (str): Interpolation function for downsampling/upsampling. | |
| Returns: | |
| ------- | |
| Tensor dictionary (torch.Tensor): depth [m], focallength [pixels]. | |
| """ | |
| if len(x.shape) == 3: | |
| x = x.unsqueeze(0) | |
| _, _, H, W = x.shape | |
| resize = H != self.img_size or W != self.img_size | |
| if resize: | |
| x = nn.functional.interpolate( | |
| x, | |
| size=(self.img_size, self.img_size), | |
| mode=interpolation_mode, | |
| align_corners=False, | |
| ) | |
| canonical_inverse_depth, fov_deg = self.forward(x) | |
| if f_px is None: | |
| f_px = 0.5 * W / torch.tan(0.5 * torch.deg2rad(fov_deg.to(torch.float))) | |
| inverse_depth = canonical_inverse_depth * (W / f_px) | |
| f_px = f_px.squeeze() | |
| if resize: | |
| inverse_depth = nn.functional.interpolate( | |
| inverse_depth, size=(H, W), mode=interpolation_mode, align_corners=False | |
| ) | |
| depth = 1.0 / torch.clamp(inverse_depth, min=1e-4, max=1e4) | |
| return { | |
| "depth": depth.squeeze(), | |
| "focallength_px": f_px, | |
| } | |