Spaces:
Sleeping
Sleeping
Delete External Models to prevent HF tags
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- mapanything/models/external/README.md +0 -5
- mapanything/models/external/__init__.py +0 -0
- mapanything/models/external/anycalib/__init__.py +0 -95
- mapanything/models/external/dinov2/__init__.py +0 -6
- mapanything/models/external/dinov2/hub/__init__.py +0 -4
- mapanything/models/external/dinov2/hub/backbones.py +0 -183
- mapanything/models/external/dinov2/hub/utils.py +0 -42
- mapanything/models/external/dinov2/layers/__init__.py +0 -14
- mapanything/models/external/dinov2/layers/attention.py +0 -90
- mapanything/models/external/dinov2/layers/block.py +0 -290
- mapanything/models/external/dinov2/layers/dino_head.py +0 -67
- mapanything/models/external/dinov2/layers/drop_path.py +0 -36
- mapanything/models/external/dinov2/layers/layer_scale.py +0 -26
- mapanything/models/external/dinov2/layers/mlp.py +0 -40
- mapanything/models/external/dinov2/layers/patch_embed.py +0 -100
- mapanything/models/external/dinov2/layers/swiglu_ffn.py +0 -71
- mapanything/models/external/dinov2/models/__init__.py +0 -44
- mapanything/models/external/dinov2/models/vision_transformer.py +0 -448
- mapanything/models/external/dinov2/utils/__init__.py +0 -4
- mapanything/models/external/dinov2/utils/cluster.py +0 -102
- mapanything/models/external/dinov2/utils/config.py +0 -74
- mapanything/models/external/dinov2/utils/dtype.py +0 -38
- mapanything/models/external/dinov2/utils/param_groups.py +0 -122
- mapanything/models/external/dinov2/utils/utils.py +0 -105
- mapanything/models/external/dust3r/__init__.py +0 -217
- mapanything/models/external/mast3r/__init__.py +0 -191
- mapanything/models/external/moge/__init__.py +0 -114
- mapanything/models/external/moge/models/modules.py +0 -467
- mapanything/models/external/moge/models/utils.py +0 -477
- mapanything/models/external/moge/models/v1.py +0 -595
- mapanything/models/external/moge/models/v2.py +0 -379
- mapanything/models/external/must3r/__init__.py +0 -283
- mapanything/models/external/pi3/__init__.py +0 -119
- mapanything/models/external/pi3/layers/__init__.py +0 -0
- mapanything/models/external/pi3/layers/attention.py +0 -429
- mapanything/models/external/pi3/layers/block.py +0 -448
- mapanything/models/external/pi3/layers/camera_head.py +0 -106
- mapanything/models/external/pi3/layers/pos_embed.py +0 -190
- mapanything/models/external/pi3/layers/transformer_head.py +0 -98
- mapanything/models/external/pi3/models/__init__.py +0 -0
- mapanything/models/external/pi3/models/pi3.py +0 -251
- mapanything/models/external/pow3r/__init__.py +0 -860
- mapanything/models/external/vggt/__init__.py +0 -186
- mapanything/models/external/vggt/heads/__init__.py +0 -0
- mapanything/models/external/vggt/heads/camera_head.py +0 -167
- mapanything/models/external/vggt/heads/dpt_head.py +0 -600
- mapanything/models/external/vggt/heads/head_act.py +0 -127
- mapanything/models/external/vggt/heads/track_head.py +0 -118
- mapanything/models/external/vggt/heads/track_modules/__init__.py +0 -5
- mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py +0 -242
mapanything/models/external/README.md
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
# External Model Code for Benchmarking & Re-Training
|
| 2 |
-
|
| 3 |
-
This directory contains external model code that we use to train and benchmark external models fairly. These libraries are not part of the core MapAnything codebase and are included for only benchmarking purposes. The code in this directory is licensed under the same license as the source code from which it was derived, unless otherwise specified.
|
| 4 |
-
|
| 5 |
-
The open-source Apache 2.0 License of MapAnything does not apply to these libraries.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/__init__.py
DELETED
|
File without changes
|
mapanything/models/external/anycalib/__init__.py
DELETED
|
@@ -1,95 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Inference wrapper for AnyCalib
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from anycalib import AnyCalib
|
| 7 |
-
|
| 8 |
-
from mapanything.utils.geometry import get_rays_in_camera_frame
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class AnyCalibWrapper(torch.nn.Module):
|
| 12 |
-
def __init__(
|
| 13 |
-
self,
|
| 14 |
-
name,
|
| 15 |
-
model_id="anycalib_pinhole",
|
| 16 |
-
**kwargs,
|
| 17 |
-
):
|
| 18 |
-
super().__init__()
|
| 19 |
-
self.name = name
|
| 20 |
-
self.model_id = model_id
|
| 21 |
-
|
| 22 |
-
# Initialize the model
|
| 23 |
-
self.model = AnyCalib(model_id=self.model_id)
|
| 24 |
-
|
| 25 |
-
def forward(self, views):
|
| 26 |
-
"""
|
| 27 |
-
Forward pass wrapper for AnyCalib.
|
| 28 |
-
|
| 29 |
-
Assumption:
|
| 30 |
-
- The number of input views is 1.
|
| 31 |
-
- The output camera model is pinhole (fx, fy, cx, cy).
|
| 32 |
-
This can be relaxed by not hardcoding the cam_id.
|
| 33 |
-
|
| 34 |
-
Args:
|
| 35 |
-
views (List[dict]): List of dictionaries containing the input views' images and instance information.
|
| 36 |
-
Length of the list should be 1.
|
| 37 |
-
Each dictionary should contain the following keys:
|
| 38 |
-
"img" (tensor): Image tensor of shape (B, C, H, W).
|
| 39 |
-
"data_norm_type" (list): ["identity"]
|
| 40 |
-
|
| 41 |
-
Returns:
|
| 42 |
-
List[dict]: A list containing the final outputs for the single view. Length of the list will be 1.
|
| 43 |
-
"""
|
| 44 |
-
# Check that the number of input views is 1
|
| 45 |
-
assert len(views) == 1, "AnyCalib only supports 1 input view."
|
| 46 |
-
|
| 47 |
-
# Get input shape of the images and batch size per view
|
| 48 |
-
_, _, height, width = views[0]["img"].shape
|
| 49 |
-
|
| 50 |
-
# Check the data norm type
|
| 51 |
-
# AnyCalib expects a normalized image but without the DINOv2 mean and std applied ("identity")
|
| 52 |
-
data_norm_type = views[0]["data_norm_type"][0]
|
| 53 |
-
assert data_norm_type == "identity", (
|
| 54 |
-
"AnyCalib expects a normalized image but without the DINOv2 mean and std applied"
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
# Run AnyCalib inference
|
| 58 |
-
# Corresponding batched output dictionary:
|
| 59 |
-
# {
|
| 60 |
-
# "intrinsics": List[(D_i,) tensors] for each camera model "i" at the original input resolution,
|
| 61 |
-
# "fov_field": (B, N, 2) tensor with the regressed FoV field by the network. N≈320^2 (resolution close to the one seen during training),
|
| 62 |
-
# "tangent_coords": alias for "fov_field",
|
| 63 |
-
# "rays": (B, N, 3) tensor with the corresponding (via the exponential map) ray directions in the camera frame (x right, y down, z forward),
|
| 64 |
-
# "pred_size": (H, W) tuple with the image size used by the network. It can be used e.g. for resizing the FoV/ray fields to the original image size.
|
| 65 |
-
# }
|
| 66 |
-
# For "pinhole" camera model, the intrinsics are (fx, fy, cx, cy).
|
| 67 |
-
model_outputs = self.model.predict(views[0]["img"], cam_id="pinhole")
|
| 68 |
-
|
| 69 |
-
# Convert the list of intrinsics to a tensor
|
| 70 |
-
intrinsics = []
|
| 71 |
-
for intrinsics_per_sample in model_outputs["intrinsics"]:
|
| 72 |
-
pred_fx, pred_fy, pred_cx, pred_cy = intrinsics_per_sample
|
| 73 |
-
intrinsics_per_sample = torch.tensor(
|
| 74 |
-
[
|
| 75 |
-
[pred_fx, 0, pred_cx],
|
| 76 |
-
[0, pred_fy, pred_cy],
|
| 77 |
-
[0, 0, 1],
|
| 78 |
-
],
|
| 79 |
-
device=views[0]["img"].device,
|
| 80 |
-
)
|
| 81 |
-
intrinsics.append(intrinsics_per_sample)
|
| 82 |
-
|
| 83 |
-
# Convert the list of intrinsics to a tensor of size (batch_size_per_view, 3, 3)
|
| 84 |
-
intrinsics = torch.stack(intrinsics)
|
| 85 |
-
|
| 86 |
-
# Get the ray directions
|
| 87 |
-
with torch.autocast("cuda", enabled=False):
|
| 88 |
-
_, ray_directions = get_rays_in_camera_frame(
|
| 89 |
-
intrinsics, height, width, normalize_to_unit_sphere=True
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
# Return the output in MapAnything format
|
| 93 |
-
res = [{"ray_directions": ray_directions, "intrinsics": intrinsics}]
|
| 94 |
-
|
| 95 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/__init__.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
__version__ = "0.0.1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/hub/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/hub/backbones.py
DELETED
|
@@ -1,183 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from enum import Enum
|
| 7 |
-
from typing import Union
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
from mapanything.models.external.dinov2.hub.utils import (
|
| 12 |
-
_DINOV2_BASE_URL,
|
| 13 |
-
_make_dinov2_model_name,
|
| 14 |
-
)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class Weights(Enum):
|
| 18 |
-
LVD142M = "LVD142M"
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def _make_dinov2_model(
|
| 22 |
-
*,
|
| 23 |
-
arch_name: str = "vit_large",
|
| 24 |
-
img_size: int = 518,
|
| 25 |
-
patch_size: int = 14,
|
| 26 |
-
init_values: float = 1.0,
|
| 27 |
-
ffn_layer: str = "mlp",
|
| 28 |
-
block_chunks: int = 0,
|
| 29 |
-
num_register_tokens: int = 0,
|
| 30 |
-
interpolate_antialias: bool = False,
|
| 31 |
-
interpolate_offset: float = 0.1,
|
| 32 |
-
pretrained: bool = True,
|
| 33 |
-
weights: Union[Weights, str] = Weights.LVD142M,
|
| 34 |
-
**kwargs,
|
| 35 |
-
):
|
| 36 |
-
from ..models import vision_transformer as vits
|
| 37 |
-
|
| 38 |
-
if isinstance(weights, str):
|
| 39 |
-
try:
|
| 40 |
-
weights = Weights[weights]
|
| 41 |
-
except KeyError:
|
| 42 |
-
raise AssertionError(f"Unsupported weights: {weights}")
|
| 43 |
-
|
| 44 |
-
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 45 |
-
vit_kwargs = dict(
|
| 46 |
-
img_size=img_size,
|
| 47 |
-
patch_size=patch_size,
|
| 48 |
-
init_values=init_values,
|
| 49 |
-
ffn_layer=ffn_layer,
|
| 50 |
-
block_chunks=block_chunks,
|
| 51 |
-
num_register_tokens=num_register_tokens,
|
| 52 |
-
interpolate_antialias=interpolate_antialias,
|
| 53 |
-
interpolate_offset=interpolate_offset,
|
| 54 |
-
)
|
| 55 |
-
vit_kwargs.update(**kwargs)
|
| 56 |
-
model = vits.__dict__[arch_name](**vit_kwargs)
|
| 57 |
-
|
| 58 |
-
if pretrained:
|
| 59 |
-
model_full_name = _make_dinov2_model_name(
|
| 60 |
-
arch_name, patch_size, num_register_tokens
|
| 61 |
-
)
|
| 62 |
-
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
| 63 |
-
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 64 |
-
model.load_state_dict(state_dict, strict=True)
|
| 65 |
-
|
| 66 |
-
return model
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def dinov2_vits14(
|
| 70 |
-
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 71 |
-
):
|
| 72 |
-
"""
|
| 73 |
-
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 74 |
-
"""
|
| 75 |
-
return _make_dinov2_model(
|
| 76 |
-
arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def dinov2_vitb14(
|
| 81 |
-
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 82 |
-
):
|
| 83 |
-
"""
|
| 84 |
-
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 85 |
-
"""
|
| 86 |
-
return _make_dinov2_model(
|
| 87 |
-
arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def dinov2_vitl14(
|
| 92 |
-
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 93 |
-
):
|
| 94 |
-
"""
|
| 95 |
-
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 96 |
-
"""
|
| 97 |
-
return _make_dinov2_model(
|
| 98 |
-
arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def dinov2_vitg14(
|
| 103 |
-
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 104 |
-
):
|
| 105 |
-
"""
|
| 106 |
-
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 107 |
-
"""
|
| 108 |
-
return _make_dinov2_model(
|
| 109 |
-
arch_name="vit_giant2",
|
| 110 |
-
ffn_layer="swiglufused",
|
| 111 |
-
weights=weights,
|
| 112 |
-
pretrained=pretrained,
|
| 113 |
-
**kwargs,
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def dinov2_vits14_reg(
|
| 118 |
-
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 119 |
-
):
|
| 120 |
-
"""
|
| 121 |
-
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 122 |
-
"""
|
| 123 |
-
return _make_dinov2_model(
|
| 124 |
-
arch_name="vit_small",
|
| 125 |
-
pretrained=pretrained,
|
| 126 |
-
weights=weights,
|
| 127 |
-
num_register_tokens=4,
|
| 128 |
-
interpolate_antialias=True,
|
| 129 |
-
interpolate_offset=0.0,
|
| 130 |
-
**kwargs,
|
| 131 |
-
)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def dinov2_vitb14_reg(
|
| 135 |
-
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 136 |
-
):
|
| 137 |
-
"""
|
| 138 |
-
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 139 |
-
"""
|
| 140 |
-
return _make_dinov2_model(
|
| 141 |
-
arch_name="vit_base",
|
| 142 |
-
pretrained=pretrained,
|
| 143 |
-
weights=weights,
|
| 144 |
-
num_register_tokens=4,
|
| 145 |
-
interpolate_antialias=True,
|
| 146 |
-
interpolate_offset=0.0,
|
| 147 |
-
**kwargs,
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def dinov2_vitl14_reg(
|
| 152 |
-
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 153 |
-
):
|
| 154 |
-
"""
|
| 155 |
-
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 156 |
-
"""
|
| 157 |
-
return _make_dinov2_model(
|
| 158 |
-
arch_name="vit_large",
|
| 159 |
-
pretrained=pretrained,
|
| 160 |
-
weights=weights,
|
| 161 |
-
num_register_tokens=4,
|
| 162 |
-
interpolate_antialias=True,
|
| 163 |
-
interpolate_offset=0.0,
|
| 164 |
-
**kwargs,
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
def dinov2_vitg14_reg(
|
| 169 |
-
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 170 |
-
):
|
| 171 |
-
"""
|
| 172 |
-
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 173 |
-
"""
|
| 174 |
-
return _make_dinov2_model(
|
| 175 |
-
arch_name="vit_giant2",
|
| 176 |
-
ffn_layer="swiglufused",
|
| 177 |
-
weights=weights,
|
| 178 |
-
pretrained=pretrained,
|
| 179 |
-
num_register_tokens=4,
|
| 180 |
-
interpolate_antialias=True,
|
| 181 |
-
interpolate_offset=0.0,
|
| 182 |
-
**kwargs,
|
| 183 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/hub/utils.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import itertools
|
| 7 |
-
import math
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
import torch.nn.functional as F
|
| 12 |
-
|
| 13 |
-
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def _make_dinov2_model_name(
|
| 17 |
-
arch_name: str, patch_size: int, num_register_tokens: int = 0
|
| 18 |
-
) -> str:
|
| 19 |
-
compact_arch_name = arch_name.replace("_", "")[:4]
|
| 20 |
-
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
| 21 |
-
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class CenterPadding(nn.Module):
|
| 25 |
-
def __init__(self, multiple):
|
| 26 |
-
super().__init__()
|
| 27 |
-
self.multiple = multiple
|
| 28 |
-
|
| 29 |
-
def _get_pad(self, size):
|
| 30 |
-
new_size = math.ceil(size / self.multiple) * self.multiple
|
| 31 |
-
pad_size = new_size - size
|
| 32 |
-
pad_size_left = pad_size // 2
|
| 33 |
-
pad_size_right = pad_size - pad_size_left
|
| 34 |
-
return pad_size_left, pad_size_right
|
| 35 |
-
|
| 36 |
-
@torch.inference_mode()
|
| 37 |
-
def forward(self, x):
|
| 38 |
-
pads = list(
|
| 39 |
-
itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])
|
| 40 |
-
)
|
| 41 |
-
output = F.pad(x, pads)
|
| 42 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/layers/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
from mapanything.models.external.dinov2.layers.dino_head import DINOHead # noqa
|
| 7 |
-
from mapanything.models.external.dinov2.layers.mlp import Mlp # noqa
|
| 8 |
-
from mapanything.models.external.dinov2.layers.patch_embed import PatchEmbed # noqa
|
| 9 |
-
from mapanything.models.external.dinov2.layers.swiglu_ffn import (
|
| 10 |
-
SwiGLUFFN, # noqa
|
| 11 |
-
SwiGLUFFNFused, # noqa
|
| 12 |
-
)
|
| 13 |
-
from mapanything.models.external.dinov2.layers.block import NestedTensorBlock # noqa
|
| 14 |
-
from mapanything.models.external.dinov2.layers.attention import MemEffAttention # noqa
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/layers/attention.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
# References:
|
| 7 |
-
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
-
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
-
|
| 10 |
-
import logging
|
| 11 |
-
import os
|
| 12 |
-
|
| 13 |
-
from torch import nn, Tensor
|
| 14 |
-
|
| 15 |
-
logger = logging.getLogger("dinov2")
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 19 |
-
try:
|
| 20 |
-
if XFORMERS_ENABLED:
|
| 21 |
-
from xformers.ops import memory_efficient_attention, unbind
|
| 22 |
-
|
| 23 |
-
XFORMERS_AVAILABLE = True
|
| 24 |
-
# warnings.warn("xFormers is available (Attention)")
|
| 25 |
-
else:
|
| 26 |
-
# warnings.warn("xFormers is disabled (Attention)")
|
| 27 |
-
raise ImportError
|
| 28 |
-
except ImportError:
|
| 29 |
-
XFORMERS_AVAILABLE = False
|
| 30 |
-
# warnings.warn("xFormers is not available (Attention)")
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class Attention(nn.Module):
|
| 34 |
-
def __init__(
|
| 35 |
-
self,
|
| 36 |
-
dim: int,
|
| 37 |
-
num_heads: int = 8,
|
| 38 |
-
qkv_bias: bool = False,
|
| 39 |
-
proj_bias: bool = True,
|
| 40 |
-
attn_drop: float = 0.0,
|
| 41 |
-
proj_drop: float = 0.0,
|
| 42 |
-
) -> None:
|
| 43 |
-
super().__init__()
|
| 44 |
-
self.num_heads = num_heads
|
| 45 |
-
head_dim = dim // num_heads
|
| 46 |
-
self.scale = head_dim**-0.5
|
| 47 |
-
|
| 48 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 49 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
| 50 |
-
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 51 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
| 52 |
-
|
| 53 |
-
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 54 |
-
B, N, C = x.shape
|
| 55 |
-
qkv = (
|
| 56 |
-
self.qkv(x)
|
| 57 |
-
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 58 |
-
.permute(2, 0, 3, 1, 4)
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 62 |
-
attn = q @ k.transpose(-2, -1)
|
| 63 |
-
|
| 64 |
-
attn = attn.softmax(dim=-1)
|
| 65 |
-
attn = self.attn_drop(attn)
|
| 66 |
-
|
| 67 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 68 |
-
x = self.proj(x)
|
| 69 |
-
x = self.proj_drop(x)
|
| 70 |
-
return x
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class MemEffAttention(Attention):
|
| 74 |
-
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 75 |
-
if not XFORMERS_AVAILABLE:
|
| 76 |
-
if attn_bias is not None:
|
| 77 |
-
raise AssertionError("xFormers is required for using nested tensors")
|
| 78 |
-
return super().forward(x)
|
| 79 |
-
|
| 80 |
-
B, N, C = x.shape
|
| 81 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 82 |
-
|
| 83 |
-
q, k, v = unbind(qkv, 2)
|
| 84 |
-
|
| 85 |
-
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 86 |
-
x = x.reshape([B, N, C])
|
| 87 |
-
|
| 88 |
-
x = self.proj(x)
|
| 89 |
-
x = self.proj_drop(x)
|
| 90 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/layers/block.py
DELETED
|
@@ -1,290 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
# References:
|
| 7 |
-
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
-
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
-
|
| 10 |
-
import logging
|
| 11 |
-
import os
|
| 12 |
-
from typing import Any, Callable, Dict, List, Tuple
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
from torch import nn, Tensor
|
| 16 |
-
|
| 17 |
-
from mapanything.models.external.dinov2.layers.attention import (
|
| 18 |
-
Attention,
|
| 19 |
-
MemEffAttention,
|
| 20 |
-
)
|
| 21 |
-
from mapanything.models.external.dinov2.layers.drop_path import DropPath
|
| 22 |
-
from mapanything.models.external.dinov2.layers.layer_scale import LayerScale
|
| 23 |
-
from mapanything.models.external.dinov2.layers.mlp import Mlp
|
| 24 |
-
|
| 25 |
-
logger = logging.getLogger("dinov2")
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 29 |
-
try:
|
| 30 |
-
if XFORMERS_ENABLED:
|
| 31 |
-
from xformers.ops import fmha, index_select_cat, scaled_index_add
|
| 32 |
-
|
| 33 |
-
XFORMERS_AVAILABLE = True
|
| 34 |
-
# warnings.warn("xFormers is available (Block)")
|
| 35 |
-
else:
|
| 36 |
-
# warnings.warn("xFormers is disabled (Block)")
|
| 37 |
-
raise ImportError
|
| 38 |
-
except ImportError:
|
| 39 |
-
XFORMERS_AVAILABLE = False
|
| 40 |
-
# warnings.warn("xFormers is not available (Block)")
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class Block(nn.Module):
|
| 44 |
-
def __init__(
|
| 45 |
-
self,
|
| 46 |
-
dim: int,
|
| 47 |
-
num_heads: int,
|
| 48 |
-
mlp_ratio: float = 4.0,
|
| 49 |
-
qkv_bias: bool = False,
|
| 50 |
-
proj_bias: bool = True,
|
| 51 |
-
ffn_bias: bool = True,
|
| 52 |
-
drop: float = 0.0,
|
| 53 |
-
attn_drop: float = 0.0,
|
| 54 |
-
init_values=None,
|
| 55 |
-
drop_path: float = 0.0,
|
| 56 |
-
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 57 |
-
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 58 |
-
attn_class: Callable[..., nn.Module] = Attention,
|
| 59 |
-
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 60 |
-
) -> None:
|
| 61 |
-
super().__init__()
|
| 62 |
-
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 63 |
-
self.norm1 = norm_layer(dim)
|
| 64 |
-
self.attn = attn_class(
|
| 65 |
-
dim,
|
| 66 |
-
num_heads=num_heads,
|
| 67 |
-
qkv_bias=qkv_bias,
|
| 68 |
-
proj_bias=proj_bias,
|
| 69 |
-
attn_drop=attn_drop,
|
| 70 |
-
proj_drop=drop,
|
| 71 |
-
)
|
| 72 |
-
self.ls1 = (
|
| 73 |
-
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 74 |
-
)
|
| 75 |
-
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 76 |
-
|
| 77 |
-
self.norm2 = norm_layer(dim)
|
| 78 |
-
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 79 |
-
self.mlp = ffn_layer(
|
| 80 |
-
in_features=dim,
|
| 81 |
-
hidden_features=mlp_hidden_dim,
|
| 82 |
-
act_layer=act_layer,
|
| 83 |
-
drop=drop,
|
| 84 |
-
bias=ffn_bias,
|
| 85 |
-
)
|
| 86 |
-
self.ls2 = (
|
| 87 |
-
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 88 |
-
)
|
| 89 |
-
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 90 |
-
|
| 91 |
-
self.sample_drop_ratio = drop_path
|
| 92 |
-
|
| 93 |
-
def forward(self, x: Tensor) -> Tensor:
|
| 94 |
-
def attn_residual_func(x: Tensor) -> Tensor:
|
| 95 |
-
return self.ls1(self.attn(self.norm1(x)))
|
| 96 |
-
|
| 97 |
-
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 98 |
-
return self.ls2(self.mlp(self.norm2(x)))
|
| 99 |
-
|
| 100 |
-
if self.training and self.sample_drop_ratio > 0.1:
|
| 101 |
-
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 102 |
-
x = drop_add_residual_stochastic_depth(
|
| 103 |
-
x,
|
| 104 |
-
residual_func=attn_residual_func,
|
| 105 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
| 106 |
-
)
|
| 107 |
-
x = drop_add_residual_stochastic_depth(
|
| 108 |
-
x,
|
| 109 |
-
residual_func=ffn_residual_func,
|
| 110 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
| 111 |
-
)
|
| 112 |
-
elif self.training and self.sample_drop_ratio > 0.0:
|
| 113 |
-
x = x + self.drop_path1(attn_residual_func(x))
|
| 114 |
-
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 115 |
-
else:
|
| 116 |
-
x = x + attn_residual_func(x)
|
| 117 |
-
x = x + ffn_residual_func(x)
|
| 118 |
-
return x
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def drop_add_residual_stochastic_depth(
|
| 122 |
-
x: Tensor,
|
| 123 |
-
residual_func: Callable[[Tensor], Tensor],
|
| 124 |
-
sample_drop_ratio: float = 0.0,
|
| 125 |
-
) -> Tensor:
|
| 126 |
-
# 1) extract subset using permutation
|
| 127 |
-
b, n, d = x.shape
|
| 128 |
-
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 129 |
-
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 130 |
-
x_subset = x[brange]
|
| 131 |
-
|
| 132 |
-
# 2) apply residual_func to get residual
|
| 133 |
-
residual = residual_func(x_subset)
|
| 134 |
-
|
| 135 |
-
x_flat = x.flatten(1)
|
| 136 |
-
residual = residual.flatten(1)
|
| 137 |
-
|
| 138 |
-
residual_scale_factor = b / sample_subset_size
|
| 139 |
-
|
| 140 |
-
# 3) add the residual
|
| 141 |
-
x_plus_residual = torch.index_add(
|
| 142 |
-
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| 143 |
-
)
|
| 144 |
-
return x_plus_residual.view_as(x)
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 148 |
-
b, n, d = x.shape
|
| 149 |
-
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 150 |
-
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 151 |
-
residual_scale_factor = b / sample_subset_size
|
| 152 |
-
return brange, residual_scale_factor
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 156 |
-
if scaling_vector is None:
|
| 157 |
-
x_flat = x.flatten(1)
|
| 158 |
-
residual = residual.flatten(1)
|
| 159 |
-
x_plus_residual = torch.index_add(
|
| 160 |
-
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| 161 |
-
)
|
| 162 |
-
else:
|
| 163 |
-
x_plus_residual = scaled_index_add(
|
| 164 |
-
x,
|
| 165 |
-
brange,
|
| 166 |
-
residual.to(dtype=x.dtype),
|
| 167 |
-
scaling=scaling_vector,
|
| 168 |
-
alpha=residual_scale_factor,
|
| 169 |
-
)
|
| 170 |
-
return x_plus_residual
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def get_attn_bias_and_cat(x_list, branges=None):
|
| 177 |
-
"""
|
| 178 |
-
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 179 |
-
"""
|
| 180 |
-
batch_sizes = (
|
| 181 |
-
[b.shape[0] for b in branges]
|
| 182 |
-
if branges is not None
|
| 183 |
-
else [x.shape[0] for x in x_list]
|
| 184 |
-
)
|
| 185 |
-
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 186 |
-
if all_shapes not in attn_bias_cache.keys():
|
| 187 |
-
seqlens = []
|
| 188 |
-
for b, x in zip(batch_sizes, x_list):
|
| 189 |
-
for _ in range(b):
|
| 190 |
-
seqlens.append(x.shape[1])
|
| 191 |
-
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 192 |
-
attn_bias._batch_sizes = batch_sizes
|
| 193 |
-
attn_bias_cache[all_shapes] = attn_bias
|
| 194 |
-
|
| 195 |
-
if branges is not None:
|
| 196 |
-
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
|
| 197 |
-
1, -1, x_list[0].shape[-1]
|
| 198 |
-
)
|
| 199 |
-
else:
|
| 200 |
-
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 201 |
-
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 202 |
-
|
| 203 |
-
return attn_bias_cache[all_shapes], cat_tensors
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
def drop_add_residual_stochastic_depth_list(
|
| 207 |
-
x_list: List[Tensor],
|
| 208 |
-
residual_func: Callable[[Tensor, Any], Tensor],
|
| 209 |
-
sample_drop_ratio: float = 0.0,
|
| 210 |
-
scaling_vector=None,
|
| 211 |
-
) -> Tensor:
|
| 212 |
-
# 1) generate random set of indices for dropping samples in the batch
|
| 213 |
-
branges_scales = [
|
| 214 |
-
get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
|
| 215 |
-
]
|
| 216 |
-
branges = [s[0] for s in branges_scales]
|
| 217 |
-
residual_scale_factors = [s[1] for s in branges_scales]
|
| 218 |
-
|
| 219 |
-
# 2) get attention bias and index+concat the tensors
|
| 220 |
-
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 221 |
-
|
| 222 |
-
# 3) apply residual_func to get residual, and split the result
|
| 223 |
-
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 224 |
-
|
| 225 |
-
outputs = []
|
| 226 |
-
for x, brange, residual, residual_scale_factor in zip(
|
| 227 |
-
x_list, branges, residual_list, residual_scale_factors
|
| 228 |
-
):
|
| 229 |
-
outputs.append(
|
| 230 |
-
add_residual(
|
| 231 |
-
x, brange, residual, residual_scale_factor, scaling_vector
|
| 232 |
-
).view_as(x)
|
| 233 |
-
)
|
| 234 |
-
return outputs
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
class NestedTensorBlock(Block):
|
| 238 |
-
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 239 |
-
"""
|
| 240 |
-
x_list contains a list of tensors to nest together and run
|
| 241 |
-
"""
|
| 242 |
-
assert isinstance(self.attn, MemEffAttention)
|
| 243 |
-
|
| 244 |
-
if self.training and self.sample_drop_ratio > 0.0:
|
| 245 |
-
|
| 246 |
-
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 247 |
-
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 248 |
-
|
| 249 |
-
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 250 |
-
return self.mlp(self.norm2(x))
|
| 251 |
-
|
| 252 |
-
x_list = drop_add_residual_stochastic_depth_list(
|
| 253 |
-
x_list,
|
| 254 |
-
residual_func=attn_residual_func,
|
| 255 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
| 256 |
-
scaling_vector=self.ls1.gamma
|
| 257 |
-
if isinstance(self.ls1, LayerScale)
|
| 258 |
-
else None,
|
| 259 |
-
)
|
| 260 |
-
x_list = drop_add_residual_stochastic_depth_list(
|
| 261 |
-
x_list,
|
| 262 |
-
residual_func=ffn_residual_func,
|
| 263 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
| 264 |
-
scaling_vector=self.ls2.gamma
|
| 265 |
-
if isinstance(self.ls1, LayerScale)
|
| 266 |
-
else None,
|
| 267 |
-
)
|
| 268 |
-
return x_list
|
| 269 |
-
else:
|
| 270 |
-
|
| 271 |
-
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 272 |
-
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 273 |
-
|
| 274 |
-
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 275 |
-
return self.ls2(self.mlp(self.norm2(x)))
|
| 276 |
-
|
| 277 |
-
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 278 |
-
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 279 |
-
x = x + ffn_residual_func(x)
|
| 280 |
-
return attn_bias.split(x)
|
| 281 |
-
|
| 282 |
-
def forward(self, x_or_x_list):
|
| 283 |
-
if isinstance(x_or_x_list, Tensor):
|
| 284 |
-
return super().forward(x_or_x_list)
|
| 285 |
-
elif isinstance(x_or_x_list, list):
|
| 286 |
-
if not XFORMERS_AVAILABLE:
|
| 287 |
-
raise AssertionError("xFormers is required for using nested tensors")
|
| 288 |
-
return self.forward_nested(x_or_x_list)
|
| 289 |
-
else:
|
| 290 |
-
raise AssertionError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/layers/dino_head.py
DELETED
|
@@ -1,67 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
from torch.nn.init import trunc_normal_
|
| 9 |
-
from torch.nn.utils import weight_norm
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class DINOHead(nn.Module):
|
| 13 |
-
def __init__(
|
| 14 |
-
self,
|
| 15 |
-
in_dim,
|
| 16 |
-
out_dim,
|
| 17 |
-
use_bn=False,
|
| 18 |
-
nlayers=3,
|
| 19 |
-
hidden_dim=2048,
|
| 20 |
-
bottleneck_dim=256,
|
| 21 |
-
mlp_bias=True,
|
| 22 |
-
):
|
| 23 |
-
super().__init__()
|
| 24 |
-
nlayers = max(nlayers, 1)
|
| 25 |
-
self.mlp = _build_mlp(
|
| 26 |
-
nlayers,
|
| 27 |
-
in_dim,
|
| 28 |
-
bottleneck_dim,
|
| 29 |
-
hidden_dim=hidden_dim,
|
| 30 |
-
use_bn=use_bn,
|
| 31 |
-
bias=mlp_bias,
|
| 32 |
-
)
|
| 33 |
-
self.apply(self._init_weights)
|
| 34 |
-
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 35 |
-
self.last_layer.weight_g.data.fill_(1)
|
| 36 |
-
|
| 37 |
-
def _init_weights(self, m):
|
| 38 |
-
if isinstance(m, nn.Linear):
|
| 39 |
-
trunc_normal_(m.weight, std=0.02)
|
| 40 |
-
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 41 |
-
nn.init.constant_(m.bias, 0)
|
| 42 |
-
|
| 43 |
-
def forward(self, x):
|
| 44 |
-
x = self.mlp(x)
|
| 45 |
-
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
| 46 |
-
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
| 47 |
-
x = self.last_layer(x)
|
| 48 |
-
return x
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def _build_mlp(
|
| 52 |
-
nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
|
| 53 |
-
):
|
| 54 |
-
if nlayers == 1:
|
| 55 |
-
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
| 56 |
-
else:
|
| 57 |
-
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
| 58 |
-
if use_bn:
|
| 59 |
-
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 60 |
-
layers.append(nn.GELU())
|
| 61 |
-
for _ in range(nlayers - 2):
|
| 62 |
-
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
| 63 |
-
if use_bn:
|
| 64 |
-
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 65 |
-
layers.append(nn.GELU())
|
| 66 |
-
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
| 67 |
-
return nn.Sequential(*layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/layers/drop_path.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
# References:
|
| 7 |
-
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
-
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
from torch import nn
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
-
if drop_prob == 0.0 or not training:
|
| 16 |
-
return x
|
| 17 |
-
keep_prob = 1 - drop_prob
|
| 18 |
-
shape = (x.shape[0],) + (1,) * (
|
| 19 |
-
x.ndim - 1
|
| 20 |
-
) # work with diff dim tensors, not just 2D ConvNets
|
| 21 |
-
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 22 |
-
if keep_prob > 0.0:
|
| 23 |
-
random_tensor.div_(keep_prob)
|
| 24 |
-
output = x * random_tensor
|
| 25 |
-
return output
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class DropPath(nn.Module):
|
| 29 |
-
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 30 |
-
|
| 31 |
-
def __init__(self, drop_prob=None):
|
| 32 |
-
super(DropPath, self).__init__()
|
| 33 |
-
self.drop_prob = drop_prob
|
| 34 |
-
|
| 35 |
-
def forward(self, x):
|
| 36 |
-
return drop_path(x, self.drop_prob, self.training)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/layers/layer_scale.py
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
-
|
| 8 |
-
from typing import Union
|
| 9 |
-
|
| 10 |
-
import torch
|
| 11 |
-
from torch import nn, Tensor
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class LayerScale(nn.Module):
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
dim: int,
|
| 18 |
-
init_values: Union[float, Tensor] = 1e-5,
|
| 19 |
-
inplace: bool = False,
|
| 20 |
-
) -> None:
|
| 21 |
-
super().__init__()
|
| 22 |
-
self.inplace = inplace
|
| 23 |
-
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 24 |
-
|
| 25 |
-
def forward(self, x: Tensor) -> Tensor:
|
| 26 |
-
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/layers/mlp.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
# References:
|
| 7 |
-
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
-
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
from typing import Callable, Optional
|
| 12 |
-
|
| 13 |
-
from torch import nn, Tensor
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class Mlp(nn.Module):
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
in_features: int,
|
| 20 |
-
hidden_features: Optional[int] = None,
|
| 21 |
-
out_features: Optional[int] = None,
|
| 22 |
-
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
-
drop: float = 0.0,
|
| 24 |
-
bias: bool = True,
|
| 25 |
-
) -> None:
|
| 26 |
-
super().__init__()
|
| 27 |
-
out_features = out_features or in_features
|
| 28 |
-
hidden_features = hidden_features or in_features
|
| 29 |
-
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
-
self.act = act_layer()
|
| 31 |
-
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
-
self.drop = nn.Dropout(drop)
|
| 33 |
-
|
| 34 |
-
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
-
x = self.fc1(x)
|
| 36 |
-
x = self.act(x)
|
| 37 |
-
x = self.drop(x)
|
| 38 |
-
x = self.fc2(x)
|
| 39 |
-
x = self.drop(x)
|
| 40 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/layers/patch_embed.py
DELETED
|
@@ -1,100 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
# References:
|
| 7 |
-
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
-
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
-
|
| 10 |
-
from typing import Callable, Optional, Tuple, Union
|
| 11 |
-
|
| 12 |
-
import torch.nn as nn
|
| 13 |
-
from torch import Tensor
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def make_2tuple(x):
|
| 17 |
-
if isinstance(x, tuple):
|
| 18 |
-
assert len(x) == 2
|
| 19 |
-
return x
|
| 20 |
-
|
| 21 |
-
assert isinstance(x, int)
|
| 22 |
-
return (x, x)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class PatchEmbed(nn.Module):
|
| 26 |
-
"""
|
| 27 |
-
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
-
|
| 29 |
-
Args:
|
| 30 |
-
img_size: Image size.
|
| 31 |
-
patch_size: Patch token size.
|
| 32 |
-
in_chans: Number of input image channels.
|
| 33 |
-
embed_dim: Number of linear projection output channels.
|
| 34 |
-
norm_layer: Normalization layer.
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
def __init__(
|
| 38 |
-
self,
|
| 39 |
-
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
-
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
-
in_chans: int = 3,
|
| 42 |
-
embed_dim: int = 768,
|
| 43 |
-
norm_layer: Optional[Callable] = None,
|
| 44 |
-
flatten_embedding: bool = True,
|
| 45 |
-
) -> None:
|
| 46 |
-
super().__init__()
|
| 47 |
-
|
| 48 |
-
image_HW = make_2tuple(img_size)
|
| 49 |
-
patch_HW = make_2tuple(patch_size)
|
| 50 |
-
patch_grid_size = (
|
| 51 |
-
image_HW[0] // patch_HW[0],
|
| 52 |
-
image_HW[1] // patch_HW[1],
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
self.img_size = image_HW
|
| 56 |
-
self.patch_size = patch_HW
|
| 57 |
-
self.patches_resolution = patch_grid_size
|
| 58 |
-
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 59 |
-
|
| 60 |
-
self.in_chans = in_chans
|
| 61 |
-
self.embed_dim = embed_dim
|
| 62 |
-
|
| 63 |
-
self.flatten_embedding = flatten_embedding
|
| 64 |
-
|
| 65 |
-
self.proj = nn.Conv2d(
|
| 66 |
-
in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
|
| 67 |
-
)
|
| 68 |
-
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 69 |
-
|
| 70 |
-
def forward(self, x: Tensor) -> Tensor:
|
| 71 |
-
_, _, H, W = x.shape
|
| 72 |
-
patch_H, patch_W = self.patch_size
|
| 73 |
-
|
| 74 |
-
assert H % patch_H == 0, (
|
| 75 |
-
f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 76 |
-
)
|
| 77 |
-
assert W % patch_W == 0, (
|
| 78 |
-
f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
x = self.proj(x) # B C H W
|
| 82 |
-
H, W = x.size(2), x.size(3)
|
| 83 |
-
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 84 |
-
x = self.norm(x)
|
| 85 |
-
if not self.flatten_embedding:
|
| 86 |
-
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 87 |
-
return x
|
| 88 |
-
|
| 89 |
-
def flops(self) -> float:
|
| 90 |
-
Ho, Wo = self.patches_resolution
|
| 91 |
-
flops = (
|
| 92 |
-
Ho
|
| 93 |
-
* Wo
|
| 94 |
-
* self.embed_dim
|
| 95 |
-
* self.in_chans
|
| 96 |
-
* (self.patch_size[0] * self.patch_size[1])
|
| 97 |
-
)
|
| 98 |
-
if self.norm is not None:
|
| 99 |
-
flops += Ho * Wo * self.embed_dim
|
| 100 |
-
return flops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/layers/swiglu_ffn.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
from typing import Callable, Optional
|
| 8 |
-
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
from torch import nn, Tensor
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class SwiGLUFFN(nn.Module):
|
| 14 |
-
def __init__(
|
| 15 |
-
self,
|
| 16 |
-
in_features: int,
|
| 17 |
-
hidden_features: Optional[int] = None,
|
| 18 |
-
out_features: Optional[int] = None,
|
| 19 |
-
act_layer: Callable[..., nn.Module] = None,
|
| 20 |
-
drop: float = 0.0,
|
| 21 |
-
bias: bool = True,
|
| 22 |
-
) -> None:
|
| 23 |
-
super().__init__()
|
| 24 |
-
out_features = out_features or in_features
|
| 25 |
-
hidden_features = hidden_features or in_features
|
| 26 |
-
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 27 |
-
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 28 |
-
|
| 29 |
-
def forward(self, x: Tensor) -> Tensor:
|
| 30 |
-
x12 = self.w12(x)
|
| 31 |
-
x1, x2 = x12.chunk(2, dim=-1)
|
| 32 |
-
hidden = F.silu(x1) * x2
|
| 33 |
-
return self.w3(hidden)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 37 |
-
try:
|
| 38 |
-
if XFORMERS_ENABLED:
|
| 39 |
-
from xformers.ops import SwiGLU
|
| 40 |
-
|
| 41 |
-
XFORMERS_AVAILABLE = True
|
| 42 |
-
# warnings.warn("xFormers is available (SwiGLU)")
|
| 43 |
-
else:
|
| 44 |
-
# warnings.warn("xFormers is disabled (SwiGLU)")
|
| 45 |
-
raise ImportError
|
| 46 |
-
except ImportError:
|
| 47 |
-
SwiGLU = SwiGLUFFN
|
| 48 |
-
XFORMERS_AVAILABLE = False
|
| 49 |
-
|
| 50 |
-
# warnings.warn("xFormers is not available (SwiGLU)")
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
class SwiGLUFFNFused(SwiGLU):
|
| 54 |
-
def __init__(
|
| 55 |
-
self,
|
| 56 |
-
in_features: int,
|
| 57 |
-
hidden_features: Optional[int] = None,
|
| 58 |
-
out_features: Optional[int] = None,
|
| 59 |
-
act_layer: Callable[..., nn.Module] = None,
|
| 60 |
-
drop: float = 0.0,
|
| 61 |
-
bias: bool = True,
|
| 62 |
-
) -> None:
|
| 63 |
-
out_features = out_features or in_features
|
| 64 |
-
hidden_features = hidden_features or in_features
|
| 65 |
-
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 66 |
-
super().__init__(
|
| 67 |
-
in_features=in_features,
|
| 68 |
-
hidden_features=hidden_features,
|
| 69 |
-
out_features=out_features,
|
| 70 |
-
bias=bias,
|
| 71 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/models/__init__.py
DELETED
|
@@ -1,44 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import logging
|
| 7 |
-
|
| 8 |
-
import mapanything.models.external.dinov2.models.vision_transformer as vits
|
| 9 |
-
|
| 10 |
-
logger = logging.getLogger("dinov2")
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def build_model(args, only_teacher=False, img_size=224):
|
| 14 |
-
args.arch = args.arch.removesuffix("_memeff")
|
| 15 |
-
if "vit" in args.arch:
|
| 16 |
-
vit_kwargs = dict(
|
| 17 |
-
img_size=img_size,
|
| 18 |
-
patch_size=args.patch_size,
|
| 19 |
-
init_values=args.layerscale,
|
| 20 |
-
ffn_layer=args.ffn_layer,
|
| 21 |
-
block_chunks=args.block_chunks,
|
| 22 |
-
qkv_bias=args.qkv_bias,
|
| 23 |
-
proj_bias=args.proj_bias,
|
| 24 |
-
ffn_bias=args.ffn_bias,
|
| 25 |
-
num_register_tokens=args.num_register_tokens,
|
| 26 |
-
interpolate_offset=args.interpolate_offset,
|
| 27 |
-
interpolate_antialias=args.interpolate_antialias,
|
| 28 |
-
)
|
| 29 |
-
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
| 30 |
-
if only_teacher:
|
| 31 |
-
return teacher, teacher.embed_dim
|
| 32 |
-
student = vits.__dict__[args.arch](
|
| 33 |
-
**vit_kwargs,
|
| 34 |
-
drop_path_rate=args.drop_path_rate,
|
| 35 |
-
drop_path_uniform=args.drop_path_uniform,
|
| 36 |
-
)
|
| 37 |
-
embed_dim = student.embed_dim
|
| 38 |
-
return student, teacher, embed_dim
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def build_model_from_cfg(cfg, only_teacher=False):
|
| 42 |
-
return build_model(
|
| 43 |
-
cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size
|
| 44 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/models/vision_transformer.py
DELETED
|
@@ -1,448 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
# References:
|
| 7 |
-
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
-
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
-
|
| 10 |
-
import math
|
| 11 |
-
from functools import partial
|
| 12 |
-
from typing import Callable, Sequence, Tuple, Union
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
import torch.nn as nn
|
| 16 |
-
from torch.nn.init import trunc_normal_
|
| 17 |
-
from torch.utils.checkpoint import checkpoint
|
| 18 |
-
|
| 19 |
-
from mapanything.models.external.dinov2.layers import (
|
| 20 |
-
MemEffAttention,
|
| 21 |
-
Mlp,
|
| 22 |
-
NestedTensorBlock as Block,
|
| 23 |
-
PatchEmbed,
|
| 24 |
-
SwiGLUFFNFused,
|
| 25 |
-
)
|
| 26 |
-
from mapanything.models.external.pi3.layers.attention import FlashAttention
|
| 27 |
-
|
| 28 |
-
# logger = logging.getLogger("dinov2")
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def named_apply(
|
| 32 |
-
fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
|
| 33 |
-
) -> nn.Module:
|
| 34 |
-
if not depth_first and include_root:
|
| 35 |
-
fn(module=module, name=name)
|
| 36 |
-
for child_name, child_module in module.named_children():
|
| 37 |
-
child_name = ".".join((name, child_name)) if name else child_name
|
| 38 |
-
named_apply(
|
| 39 |
-
fn=fn,
|
| 40 |
-
module=child_module,
|
| 41 |
-
name=child_name,
|
| 42 |
-
depth_first=depth_first,
|
| 43 |
-
include_root=True,
|
| 44 |
-
)
|
| 45 |
-
if depth_first and include_root:
|
| 46 |
-
fn(module=module, name=name)
|
| 47 |
-
return module
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class BlockChunk(nn.ModuleList):
|
| 51 |
-
def forward(self, x):
|
| 52 |
-
for b in self:
|
| 53 |
-
x = b(x)
|
| 54 |
-
return x
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
class DinoVisionTransformer(nn.Module):
|
| 58 |
-
def __init__(
|
| 59 |
-
self,
|
| 60 |
-
img_size=224,
|
| 61 |
-
patch_size=16,
|
| 62 |
-
in_chans=3,
|
| 63 |
-
embed_dim=768,
|
| 64 |
-
depth=12,
|
| 65 |
-
num_heads=12,
|
| 66 |
-
mlp_ratio=4.0,
|
| 67 |
-
qkv_bias=True,
|
| 68 |
-
ffn_bias=True,
|
| 69 |
-
proj_bias=True,
|
| 70 |
-
drop_path_rate=0.0,
|
| 71 |
-
drop_path_uniform=False,
|
| 72 |
-
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 73 |
-
embed_layer=PatchEmbed,
|
| 74 |
-
act_layer=nn.GELU,
|
| 75 |
-
block_fn=Block,
|
| 76 |
-
ffn_layer="mlp",
|
| 77 |
-
block_chunks=1,
|
| 78 |
-
num_register_tokens=0,
|
| 79 |
-
interpolate_antialias=False,
|
| 80 |
-
interpolate_offset=0.1,
|
| 81 |
-
):
|
| 82 |
-
"""
|
| 83 |
-
Args:
|
| 84 |
-
img_size (int, tuple): input image size
|
| 85 |
-
patch_size (int, tuple): patch size
|
| 86 |
-
in_chans (int): number of input channels
|
| 87 |
-
embed_dim (int): embedding dimension
|
| 88 |
-
depth (int): depth of transformer
|
| 89 |
-
num_heads (int): number of attention heads
|
| 90 |
-
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 91 |
-
qkv_bias (bool): enable bias for qkv if True
|
| 92 |
-
proj_bias (bool): enable bias for proj in attn if True
|
| 93 |
-
ffn_bias (bool): enable bias for ffn if True
|
| 94 |
-
drop_path_rate (float): stochastic depth rate
|
| 95 |
-
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 96 |
-
weight_init (str): weight init scheme
|
| 97 |
-
init_values (float): layer-scale init values
|
| 98 |
-
embed_layer (nn.Module): patch embedding layer
|
| 99 |
-
act_layer (nn.Module): MLP activation layer
|
| 100 |
-
block_fn (nn.Module): transformer block class
|
| 101 |
-
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 102 |
-
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 103 |
-
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 104 |
-
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 105 |
-
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 106 |
-
"""
|
| 107 |
-
super().__init__()
|
| 108 |
-
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 109 |
-
|
| 110 |
-
self.num_features = self.embed_dim = (
|
| 111 |
-
embed_dim # num_features for consistency with other models
|
| 112 |
-
)
|
| 113 |
-
self.num_tokens = 1
|
| 114 |
-
self.n_blocks = depth
|
| 115 |
-
self.num_heads = num_heads
|
| 116 |
-
self.patch_size = patch_size
|
| 117 |
-
self.num_register_tokens = num_register_tokens
|
| 118 |
-
self.interpolate_antialias = interpolate_antialias
|
| 119 |
-
self.interpolate_offset = interpolate_offset
|
| 120 |
-
|
| 121 |
-
self.patch_embed = embed_layer(
|
| 122 |
-
img_size=img_size,
|
| 123 |
-
patch_size=patch_size,
|
| 124 |
-
in_chans=in_chans,
|
| 125 |
-
embed_dim=embed_dim,
|
| 126 |
-
)
|
| 127 |
-
num_patches = self.patch_embed.num_patches
|
| 128 |
-
|
| 129 |
-
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 130 |
-
self.pos_embed = nn.Parameter(
|
| 131 |
-
torch.zeros(1, num_patches + self.num_tokens, embed_dim)
|
| 132 |
-
)
|
| 133 |
-
assert num_register_tokens >= 0
|
| 134 |
-
self.register_tokens = (
|
| 135 |
-
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
|
| 136 |
-
if num_register_tokens
|
| 137 |
-
else None
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
if drop_path_uniform is True:
|
| 141 |
-
dpr = [drop_path_rate] * depth
|
| 142 |
-
else:
|
| 143 |
-
dpr = [
|
| 144 |
-
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 145 |
-
] # stochastic depth decay rule
|
| 146 |
-
|
| 147 |
-
if ffn_layer == "mlp":
|
| 148 |
-
# logger.info("using MLP layer as FFN")
|
| 149 |
-
ffn_layer = Mlp
|
| 150 |
-
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 151 |
-
# logger.info("using SwiGLU layer as FFN")
|
| 152 |
-
ffn_layer = SwiGLUFFNFused
|
| 153 |
-
elif ffn_layer == "identity":
|
| 154 |
-
# logger.info("using Identity layer as FFN")
|
| 155 |
-
|
| 156 |
-
def f(*args, **kwargs):
|
| 157 |
-
return nn.Identity()
|
| 158 |
-
|
| 159 |
-
ffn_layer = f
|
| 160 |
-
else:
|
| 161 |
-
raise NotImplementedError
|
| 162 |
-
|
| 163 |
-
blocks_list = [
|
| 164 |
-
block_fn(
|
| 165 |
-
dim=embed_dim,
|
| 166 |
-
num_heads=num_heads,
|
| 167 |
-
mlp_ratio=mlp_ratio,
|
| 168 |
-
qkv_bias=qkv_bias,
|
| 169 |
-
proj_bias=proj_bias,
|
| 170 |
-
ffn_bias=ffn_bias,
|
| 171 |
-
drop_path=dpr[i],
|
| 172 |
-
norm_layer=norm_layer,
|
| 173 |
-
act_layer=act_layer,
|
| 174 |
-
ffn_layer=ffn_layer,
|
| 175 |
-
init_values=init_values,
|
| 176 |
-
attn_class=FlashAttention,
|
| 177 |
-
)
|
| 178 |
-
for i in range(depth)
|
| 179 |
-
]
|
| 180 |
-
if block_chunks > 0:
|
| 181 |
-
self.chunked_blocks = True
|
| 182 |
-
chunked_blocks = []
|
| 183 |
-
chunksize = depth // block_chunks
|
| 184 |
-
for i in range(0, depth, chunksize):
|
| 185 |
-
# this is to keep the block index consistent if we chunk the block list
|
| 186 |
-
chunked_blocks.append(
|
| 187 |
-
[nn.Identity()] * i + blocks_list[i : i + chunksize]
|
| 188 |
-
)
|
| 189 |
-
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 190 |
-
else:
|
| 191 |
-
self.chunked_blocks = False
|
| 192 |
-
self.blocks = nn.ModuleList(blocks_list)
|
| 193 |
-
|
| 194 |
-
self.norm = norm_layer(embed_dim)
|
| 195 |
-
self.head = nn.Identity()
|
| 196 |
-
|
| 197 |
-
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 198 |
-
|
| 199 |
-
self.init_weights()
|
| 200 |
-
|
| 201 |
-
def init_weights(self):
|
| 202 |
-
trunc_normal_(self.pos_embed, std=0.02)
|
| 203 |
-
nn.init.normal_(self.cls_token, std=1e-6)
|
| 204 |
-
if self.register_tokens is not None:
|
| 205 |
-
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 206 |
-
named_apply(init_weights_vit_timm, self)
|
| 207 |
-
|
| 208 |
-
def interpolate_pos_encoding(self, x, w, h):
|
| 209 |
-
previous_dtype = x.dtype
|
| 210 |
-
npatch = x.shape[1] - 1
|
| 211 |
-
N = self.pos_embed.shape[1] - 1
|
| 212 |
-
if npatch == N and w == h:
|
| 213 |
-
return self.pos_embed
|
| 214 |
-
pos_embed = self.pos_embed.float()
|
| 215 |
-
class_pos_embed = pos_embed[:, 0]
|
| 216 |
-
patch_pos_embed = pos_embed[:, 1:]
|
| 217 |
-
dim = x.shape[-1]
|
| 218 |
-
w0 = w // self.patch_size
|
| 219 |
-
h0 = h // self.patch_size
|
| 220 |
-
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 221 |
-
assert N == M * M
|
| 222 |
-
kwargs = {}
|
| 223 |
-
if self.interpolate_offset:
|
| 224 |
-
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 225 |
-
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 226 |
-
sx = float(w0 + self.interpolate_offset) / M
|
| 227 |
-
sy = float(h0 + self.interpolate_offset) / M
|
| 228 |
-
kwargs["scale_factor"] = (sx, sy)
|
| 229 |
-
else:
|
| 230 |
-
# Simply specify an output size instead of a scale factor
|
| 231 |
-
kwargs["size"] = (w0, h0)
|
| 232 |
-
patch_pos_embed = nn.functional.interpolate(
|
| 233 |
-
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 234 |
-
mode="bicubic",
|
| 235 |
-
antialias=self.interpolate_antialias,
|
| 236 |
-
**kwargs,
|
| 237 |
-
)
|
| 238 |
-
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 239 |
-
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 240 |
-
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
|
| 241 |
-
previous_dtype
|
| 242 |
-
)
|
| 243 |
-
|
| 244 |
-
def prepare_tokens_with_masks(self, x, masks=None):
|
| 245 |
-
B, nc, w, h = x.shape
|
| 246 |
-
x = self.patch_embed(x)
|
| 247 |
-
if masks is not None:
|
| 248 |
-
x = torch.where(
|
| 249 |
-
masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
|
| 250 |
-
)
|
| 251 |
-
|
| 252 |
-
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 253 |
-
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 254 |
-
|
| 255 |
-
if self.register_tokens is not None:
|
| 256 |
-
x = torch.cat(
|
| 257 |
-
(
|
| 258 |
-
x[:, :1],
|
| 259 |
-
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 260 |
-
x[:, 1:],
|
| 261 |
-
),
|
| 262 |
-
dim=1,
|
| 263 |
-
)
|
| 264 |
-
|
| 265 |
-
return x
|
| 266 |
-
|
| 267 |
-
def forward_features_list(self, x_list, masks_list):
|
| 268 |
-
x = [
|
| 269 |
-
self.prepare_tokens_with_masks(x, masks)
|
| 270 |
-
for x, masks in zip(x_list, masks_list)
|
| 271 |
-
]
|
| 272 |
-
for blk in self.blocks:
|
| 273 |
-
if self.training:
|
| 274 |
-
x = checkpoint(blk, x, use_reentrant=False)
|
| 275 |
-
else:
|
| 276 |
-
x = blk(x)
|
| 277 |
-
|
| 278 |
-
all_x = x
|
| 279 |
-
output = []
|
| 280 |
-
for x, masks in zip(all_x, masks_list):
|
| 281 |
-
x_norm = self.norm(x)
|
| 282 |
-
output.append(
|
| 283 |
-
{
|
| 284 |
-
"x_norm_clstoken": x_norm[:, 0],
|
| 285 |
-
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 286 |
-
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 287 |
-
"x_prenorm": x,
|
| 288 |
-
"masks": masks,
|
| 289 |
-
}
|
| 290 |
-
)
|
| 291 |
-
return output
|
| 292 |
-
|
| 293 |
-
def forward_features(self, x, masks=None):
|
| 294 |
-
if isinstance(x, list):
|
| 295 |
-
return self.forward_features_list(x, masks)
|
| 296 |
-
|
| 297 |
-
x = self.prepare_tokens_with_masks(x, masks)
|
| 298 |
-
|
| 299 |
-
for blk in self.blocks:
|
| 300 |
-
if self.training:
|
| 301 |
-
x = checkpoint(blk, x, use_reentrant=False)
|
| 302 |
-
else:
|
| 303 |
-
x = blk(x)
|
| 304 |
-
|
| 305 |
-
x_norm = self.norm(x)
|
| 306 |
-
return {
|
| 307 |
-
"x_norm_clstoken": x_norm[:, 0],
|
| 308 |
-
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 309 |
-
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 310 |
-
"x_prenorm": x,
|
| 311 |
-
"masks": masks,
|
| 312 |
-
}
|
| 313 |
-
|
| 314 |
-
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 315 |
-
x = self.prepare_tokens_with_masks(x)
|
| 316 |
-
# If n is an int, take the n last blocks. If it's a list, take them
|
| 317 |
-
output, total_block_len = [], len(self.blocks)
|
| 318 |
-
blocks_to_take = (
|
| 319 |
-
range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 320 |
-
)
|
| 321 |
-
for i, blk in enumerate(self.blocks):
|
| 322 |
-
x = blk(x)
|
| 323 |
-
if i in blocks_to_take:
|
| 324 |
-
output.append(x)
|
| 325 |
-
assert len(output) == len(blocks_to_take), (
|
| 326 |
-
f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 327 |
-
)
|
| 328 |
-
return output
|
| 329 |
-
|
| 330 |
-
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 331 |
-
x = self.prepare_tokens_with_masks(x)
|
| 332 |
-
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 333 |
-
# If n is an int, take the n last blocks. If it's a list, take them
|
| 334 |
-
blocks_to_take = (
|
| 335 |
-
range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 336 |
-
)
|
| 337 |
-
for block_chunk in self.blocks:
|
| 338 |
-
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 339 |
-
x = blk(x)
|
| 340 |
-
if i in blocks_to_take:
|
| 341 |
-
output.append(x)
|
| 342 |
-
i += 1
|
| 343 |
-
assert len(output) == len(blocks_to_take), (
|
| 344 |
-
f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 345 |
-
)
|
| 346 |
-
return output
|
| 347 |
-
|
| 348 |
-
def get_intermediate_layers(
|
| 349 |
-
self,
|
| 350 |
-
x: torch.Tensor,
|
| 351 |
-
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 352 |
-
reshape: bool = False,
|
| 353 |
-
return_class_token: bool = False,
|
| 354 |
-
norm=True,
|
| 355 |
-
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 356 |
-
if self.chunked_blocks:
|
| 357 |
-
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 358 |
-
else:
|
| 359 |
-
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 360 |
-
if norm:
|
| 361 |
-
outputs = [self.norm(out) for out in outputs]
|
| 362 |
-
class_tokens = [out[:, 0] for out in outputs]
|
| 363 |
-
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
| 364 |
-
if reshape:
|
| 365 |
-
B, _, w, h = x.shape
|
| 366 |
-
outputs = [
|
| 367 |
-
out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
|
| 368 |
-
.permute(0, 3, 1, 2)
|
| 369 |
-
.contiguous()
|
| 370 |
-
for out in outputs
|
| 371 |
-
]
|
| 372 |
-
if return_class_token:
|
| 373 |
-
return tuple(zip(outputs, class_tokens))
|
| 374 |
-
return tuple(outputs)
|
| 375 |
-
|
| 376 |
-
def forward(self, *args, is_training=False, **kwargs):
|
| 377 |
-
ret = self.forward_features(*args, **kwargs)
|
| 378 |
-
if is_training:
|
| 379 |
-
return ret
|
| 380 |
-
else:
|
| 381 |
-
return self.head(ret["x_norm_clstoken"])
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 385 |
-
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 386 |
-
if isinstance(module, nn.Linear):
|
| 387 |
-
trunc_normal_(module.weight, std=0.02)
|
| 388 |
-
if module.bias is not None:
|
| 389 |
-
nn.init.zeros_(module.bias)
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 393 |
-
model = DinoVisionTransformer(
|
| 394 |
-
patch_size=patch_size,
|
| 395 |
-
embed_dim=384,
|
| 396 |
-
depth=12,
|
| 397 |
-
num_heads=6,
|
| 398 |
-
mlp_ratio=4,
|
| 399 |
-
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 400 |
-
num_register_tokens=num_register_tokens,
|
| 401 |
-
**kwargs,
|
| 402 |
-
)
|
| 403 |
-
return model
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 407 |
-
model = DinoVisionTransformer(
|
| 408 |
-
patch_size=patch_size,
|
| 409 |
-
embed_dim=768,
|
| 410 |
-
depth=12,
|
| 411 |
-
num_heads=12,
|
| 412 |
-
mlp_ratio=4,
|
| 413 |
-
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 414 |
-
num_register_tokens=num_register_tokens,
|
| 415 |
-
**kwargs,
|
| 416 |
-
)
|
| 417 |
-
return model
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 421 |
-
model = DinoVisionTransformer(
|
| 422 |
-
patch_size=patch_size,
|
| 423 |
-
embed_dim=1024,
|
| 424 |
-
depth=24,
|
| 425 |
-
num_heads=16,
|
| 426 |
-
mlp_ratio=4,
|
| 427 |
-
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 428 |
-
num_register_tokens=num_register_tokens,
|
| 429 |
-
**kwargs,
|
| 430 |
-
)
|
| 431 |
-
return model
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 435 |
-
"""
|
| 436 |
-
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 437 |
-
"""
|
| 438 |
-
model = DinoVisionTransformer(
|
| 439 |
-
patch_size=patch_size,
|
| 440 |
-
embed_dim=1536,
|
| 441 |
-
depth=40,
|
| 442 |
-
num_heads=24,
|
| 443 |
-
mlp_ratio=4,
|
| 444 |
-
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 445 |
-
num_register_tokens=num_register_tokens,
|
| 446 |
-
**kwargs,
|
| 447 |
-
)
|
| 448 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/utils/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/utils/cluster.py
DELETED
|
@@ -1,102 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
from enum import Enum
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from typing import Any, Dict, Optional
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class ClusterType(Enum):
|
| 13 |
-
AWS = "aws"
|
| 14 |
-
FAIR = "fair"
|
| 15 |
-
RSC = "rsc"
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def _guess_cluster_type() -> ClusterType:
|
| 19 |
-
uname = os.uname()
|
| 20 |
-
if uname.sysname == "Linux":
|
| 21 |
-
if uname.release.endswith("-aws"):
|
| 22 |
-
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
|
| 23 |
-
return ClusterType.AWS
|
| 24 |
-
elif uname.nodename.startswith("rsc"):
|
| 25 |
-
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
|
| 26 |
-
return ClusterType.RSC
|
| 27 |
-
|
| 28 |
-
return ClusterType.FAIR
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def get_cluster_type(
|
| 32 |
-
cluster_type: Optional[ClusterType] = None,
|
| 33 |
-
) -> Optional[ClusterType]:
|
| 34 |
-
if cluster_type is None:
|
| 35 |
-
return _guess_cluster_type()
|
| 36 |
-
|
| 37 |
-
return cluster_type
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| 41 |
-
cluster_type = get_cluster_type(cluster_type)
|
| 42 |
-
if cluster_type is None:
|
| 43 |
-
return None
|
| 44 |
-
|
| 45 |
-
CHECKPOINT_DIRNAMES = {
|
| 46 |
-
ClusterType.AWS: "checkpoints",
|
| 47 |
-
ClusterType.FAIR: "checkpoint",
|
| 48 |
-
ClusterType.RSC: "checkpoint/dino",
|
| 49 |
-
}
|
| 50 |
-
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def get_user_checkpoint_path(
|
| 54 |
-
cluster_type: Optional[ClusterType] = None,
|
| 55 |
-
) -> Optional[Path]:
|
| 56 |
-
checkpoint_path = get_checkpoint_path(cluster_type)
|
| 57 |
-
if checkpoint_path is None:
|
| 58 |
-
return None
|
| 59 |
-
|
| 60 |
-
username = os.environ.get("USER")
|
| 61 |
-
assert username is not None
|
| 62 |
-
return checkpoint_path / username
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
| 66 |
-
cluster_type = get_cluster_type(cluster_type)
|
| 67 |
-
if cluster_type is None:
|
| 68 |
-
return None
|
| 69 |
-
|
| 70 |
-
SLURM_PARTITIONS = {
|
| 71 |
-
ClusterType.AWS: "learnlab",
|
| 72 |
-
ClusterType.FAIR: "learnlab",
|
| 73 |
-
ClusterType.RSC: "learn",
|
| 74 |
-
}
|
| 75 |
-
return SLURM_PARTITIONS[cluster_type]
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def get_slurm_executor_parameters(
|
| 79 |
-
nodes: int,
|
| 80 |
-
num_gpus_per_node: int,
|
| 81 |
-
cluster_type: Optional[ClusterType] = None,
|
| 82 |
-
**kwargs,
|
| 83 |
-
) -> Dict[str, Any]:
|
| 84 |
-
# create default parameters
|
| 85 |
-
params = {
|
| 86 |
-
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
| 87 |
-
"gpus_per_node": num_gpus_per_node,
|
| 88 |
-
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
| 89 |
-
"cpus_per_task": 10,
|
| 90 |
-
"nodes": nodes,
|
| 91 |
-
"slurm_partition": get_slurm_partition(cluster_type),
|
| 92 |
-
}
|
| 93 |
-
# apply cluster-specific adjustments
|
| 94 |
-
cluster_type = get_cluster_type(cluster_type)
|
| 95 |
-
if cluster_type == ClusterType.AWS:
|
| 96 |
-
params["cpus_per_task"] = 12
|
| 97 |
-
del params["mem_gb"]
|
| 98 |
-
elif cluster_type == ClusterType.RSC:
|
| 99 |
-
params["cpus_per_task"] = 12
|
| 100 |
-
# set additional parameters / apply overrides
|
| 101 |
-
params.update(kwargs)
|
| 102 |
-
return params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/utils/config.py
DELETED
|
@@ -1,74 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import logging
|
| 7 |
-
import math
|
| 8 |
-
import os
|
| 9 |
-
|
| 10 |
-
import dinov2.distributed as distributed
|
| 11 |
-
from dinov2.configs import dinov2_default_config
|
| 12 |
-
from dinov2.logging import setup_logging
|
| 13 |
-
from dinov2.utils import utils
|
| 14 |
-
from omegaconf import OmegaConf
|
| 15 |
-
|
| 16 |
-
logger = logging.getLogger("dinov2")
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def apply_scaling_rules_to_cfg(cfg): # to fix
|
| 20 |
-
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
| 21 |
-
base_lr = cfg.optim.base_lr
|
| 22 |
-
cfg.optim.lr = base_lr
|
| 23 |
-
cfg.optim.lr *= math.sqrt(
|
| 24 |
-
cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0
|
| 25 |
-
)
|
| 26 |
-
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
|
| 27 |
-
else:
|
| 28 |
-
raise NotImplementedError
|
| 29 |
-
return cfg
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def write_config(cfg, output_dir, name="config.yaml"):
|
| 33 |
-
logger.info(OmegaConf.to_yaml(cfg))
|
| 34 |
-
saved_cfg_path = os.path.join(output_dir, name)
|
| 35 |
-
with open(saved_cfg_path, "w") as f:
|
| 36 |
-
OmegaConf.save(config=cfg, f=f)
|
| 37 |
-
return saved_cfg_path
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def get_cfg_from_args(args):
|
| 41 |
-
args.output_dir = os.path.abspath(args.output_dir)
|
| 42 |
-
args.opts += [f"train.output_dir={args.output_dir}"]
|
| 43 |
-
default_cfg = OmegaConf.create(dinov2_default_config)
|
| 44 |
-
cfg = OmegaConf.load(args.config_file)
|
| 45 |
-
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
| 46 |
-
return cfg
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def default_setup(args):
|
| 50 |
-
distributed.enable(overwrite=True)
|
| 51 |
-
seed = getattr(args, "seed", 0)
|
| 52 |
-
rank = distributed.get_global_rank()
|
| 53 |
-
|
| 54 |
-
global logger
|
| 55 |
-
setup_logging(output=args.output_dir, level=logging.INFO)
|
| 56 |
-
logger = logging.getLogger("dinov2")
|
| 57 |
-
|
| 58 |
-
utils.fix_random_seeds(seed + rank)
|
| 59 |
-
logger.info("git:\n {}\n".format(utils.get_sha()))
|
| 60 |
-
logger.info(
|
| 61 |
-
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def setup(args):
|
| 66 |
-
"""
|
| 67 |
-
Create configs and perform basic setups.
|
| 68 |
-
"""
|
| 69 |
-
cfg = get_cfg_from_args(args)
|
| 70 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
| 71 |
-
default_setup(args)
|
| 72 |
-
apply_scaling_rules_to_cfg(cfg)
|
| 73 |
-
write_config(cfg, args.output_dir)
|
| 74 |
-
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/utils/dtype.py
DELETED
|
@@ -1,38 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
from typing import Dict, Union
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
import torch
|
| 11 |
-
|
| 12 |
-
TypeSpec = Union[str, np.dtype, torch.dtype]
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
|
| 16 |
-
np.dtype("bool"): torch.bool,
|
| 17 |
-
np.dtype("uint8"): torch.uint8,
|
| 18 |
-
np.dtype("int8"): torch.int8,
|
| 19 |
-
np.dtype("int16"): torch.int16,
|
| 20 |
-
np.dtype("int32"): torch.int32,
|
| 21 |
-
np.dtype("int64"): torch.int64,
|
| 22 |
-
np.dtype("float16"): torch.float16,
|
| 23 |
-
np.dtype("float32"): torch.float32,
|
| 24 |
-
np.dtype("float64"): torch.float64,
|
| 25 |
-
np.dtype("complex64"): torch.complex64,
|
| 26 |
-
np.dtype("complex128"): torch.complex128,
|
| 27 |
-
}
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
|
| 31 |
-
if isinstance(dtype, torch.dtype):
|
| 32 |
-
return dtype
|
| 33 |
-
if isinstance(dtype, str):
|
| 34 |
-
dtype = np.dtype(dtype)
|
| 35 |
-
assert isinstance(dtype, np.dtype), (
|
| 36 |
-
f"Expected an instance of nunpy dtype, got {type(dtype)}"
|
| 37 |
-
)
|
| 38 |
-
return _NUMPY_TO_TORCH_DTYPE[dtype]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/utils/param_groups.py
DELETED
|
@@ -1,122 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import logging
|
| 7 |
-
from collections import defaultdict
|
| 8 |
-
|
| 9 |
-
logger = logging.getLogger("dinov2")
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def get_vit_lr_decay_rate(
|
| 13 |
-
name,
|
| 14 |
-
lr_decay_rate=1.0,
|
| 15 |
-
num_layers=12,
|
| 16 |
-
force_is_backbone=False,
|
| 17 |
-
chunked_blocks=False,
|
| 18 |
-
):
|
| 19 |
-
"""
|
| 20 |
-
Calculate lr decay rate for different ViT blocks.
|
| 21 |
-
Args:
|
| 22 |
-
name (string): parameter name.
|
| 23 |
-
lr_decay_rate (float): base lr decay rate.
|
| 24 |
-
num_layers (int): number of ViT blocks.
|
| 25 |
-
Returns:
|
| 26 |
-
lr decay rate for the given parameter.
|
| 27 |
-
"""
|
| 28 |
-
layer_id = num_layers + 1
|
| 29 |
-
if name.startswith("backbone") or force_is_backbone:
|
| 30 |
-
if (
|
| 31 |
-
".pos_embed" in name
|
| 32 |
-
or ".patch_embed" in name
|
| 33 |
-
or ".mask_token" in name
|
| 34 |
-
or ".cls_token" in name
|
| 35 |
-
or ".register_tokens" in name
|
| 36 |
-
):
|
| 37 |
-
layer_id = 0
|
| 38 |
-
elif force_is_backbone and (
|
| 39 |
-
"pos_embed" in name
|
| 40 |
-
or "patch_embed" in name
|
| 41 |
-
or "mask_token" in name
|
| 42 |
-
or "cls_token" in name
|
| 43 |
-
or "register_tokens" in name
|
| 44 |
-
):
|
| 45 |
-
layer_id = 0
|
| 46 |
-
elif ".blocks." in name and ".residual." not in name:
|
| 47 |
-
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
| 48 |
-
elif chunked_blocks and "blocks." in name and "residual." not in name:
|
| 49 |
-
layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
|
| 50 |
-
elif "blocks." in name and "residual." not in name:
|
| 51 |
-
layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
|
| 52 |
-
|
| 53 |
-
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
|
| 57 |
-
chunked_blocks = False
|
| 58 |
-
if hasattr(model, "n_blocks"):
|
| 59 |
-
logger.info("chunked fsdp")
|
| 60 |
-
n_blocks = model.n_blocks
|
| 61 |
-
chunked_blocks = model.chunked_blocks
|
| 62 |
-
elif hasattr(model, "blocks"):
|
| 63 |
-
logger.info("first code branch")
|
| 64 |
-
n_blocks = len(model.blocks)
|
| 65 |
-
elif hasattr(model, "backbone"):
|
| 66 |
-
logger.info("second code branch")
|
| 67 |
-
n_blocks = len(model.backbone.blocks)
|
| 68 |
-
else:
|
| 69 |
-
logger.info("else code branch")
|
| 70 |
-
n_blocks = 0
|
| 71 |
-
all_param_groups = []
|
| 72 |
-
|
| 73 |
-
for name, param in model.named_parameters():
|
| 74 |
-
name = name.replace("_fsdp_wrapped_module.", "")
|
| 75 |
-
if not param.requires_grad:
|
| 76 |
-
continue
|
| 77 |
-
decay_rate = get_vit_lr_decay_rate(
|
| 78 |
-
name,
|
| 79 |
-
lr_decay_rate,
|
| 80 |
-
num_layers=n_blocks,
|
| 81 |
-
force_is_backbone=n_blocks > 0,
|
| 82 |
-
chunked_blocks=chunked_blocks,
|
| 83 |
-
)
|
| 84 |
-
d = {
|
| 85 |
-
"params": param,
|
| 86 |
-
"is_last_layer": False,
|
| 87 |
-
"lr_multiplier": decay_rate,
|
| 88 |
-
"wd_multiplier": 1.0,
|
| 89 |
-
"name": name,
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
if "last_layer" in name:
|
| 93 |
-
d.update({"is_last_layer": True})
|
| 94 |
-
|
| 95 |
-
if name.endswith(".bias") or "norm" in name or "gamma" in name:
|
| 96 |
-
d.update({"wd_multiplier": 0.0})
|
| 97 |
-
|
| 98 |
-
if "patch_embed" in name:
|
| 99 |
-
d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
|
| 100 |
-
|
| 101 |
-
all_param_groups.append(d)
|
| 102 |
-
logger.info(
|
| 103 |
-
f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}"""
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
return all_param_groups
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def fuse_params_groups(
|
| 110 |
-
all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")
|
| 111 |
-
):
|
| 112 |
-
fused_params_groups = defaultdict(lambda: {"params": []})
|
| 113 |
-
for d in all_params_groups:
|
| 114 |
-
identifier = ""
|
| 115 |
-
for k in keys:
|
| 116 |
-
identifier += k + str(d[k]) + "_"
|
| 117 |
-
|
| 118 |
-
for k in keys:
|
| 119 |
-
fused_params_groups[identifier][k] = d[k]
|
| 120 |
-
fused_params_groups[identifier]["params"].append(d["params"])
|
| 121 |
-
|
| 122 |
-
return fused_params_groups.values()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dinov2/utils/utils.py
DELETED
|
@@ -1,105 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import random
|
| 8 |
-
import subprocess
|
| 9 |
-
from urllib.parse import urlparse
|
| 10 |
-
|
| 11 |
-
import numpy as np
|
| 12 |
-
import torch
|
| 13 |
-
from torch import nn
|
| 14 |
-
|
| 15 |
-
# logger = logging.getLogger("dinov2")
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
|
| 19 |
-
if urlparse(pretrained_weights).scheme: # If it looks like an URL
|
| 20 |
-
state_dict = torch.hub.load_state_dict_from_url(
|
| 21 |
-
pretrained_weights, map_location="cpu"
|
| 22 |
-
)
|
| 23 |
-
else:
|
| 24 |
-
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
| 25 |
-
if checkpoint_key is not None and checkpoint_key in state_dict:
|
| 26 |
-
# logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
|
| 27 |
-
state_dict = state_dict[checkpoint_key]
|
| 28 |
-
# remove `module.` prefix
|
| 29 |
-
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 30 |
-
# remove `backbone.` prefix induced by multicrop wrapper
|
| 31 |
-
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
| 32 |
-
_ = model.load_state_dict(state_dict, strict=False)
|
| 33 |
-
# logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def fix_random_seeds(seed=31):
|
| 37 |
-
"""
|
| 38 |
-
Fix random seeds.
|
| 39 |
-
"""
|
| 40 |
-
torch.manual_seed(seed)
|
| 41 |
-
torch.cuda.manual_seed_all(seed)
|
| 42 |
-
np.random.seed(seed)
|
| 43 |
-
random.seed(seed)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def get_sha():
|
| 47 |
-
cwd = os.path.dirname(os.path.abspath(__file__))
|
| 48 |
-
|
| 49 |
-
def _run(command):
|
| 50 |
-
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
| 51 |
-
|
| 52 |
-
sha = "N/A"
|
| 53 |
-
diff = "clean"
|
| 54 |
-
branch = "N/A"
|
| 55 |
-
try:
|
| 56 |
-
sha = _run(["git", "rev-parse", "HEAD"])
|
| 57 |
-
subprocess.check_output(["git", "diff"], cwd=cwd)
|
| 58 |
-
diff = _run(["git", "diff-index", "HEAD"])
|
| 59 |
-
diff = "has uncommitted changes" if diff else "clean"
|
| 60 |
-
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
| 61 |
-
except Exception:
|
| 62 |
-
pass
|
| 63 |
-
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
| 64 |
-
return message
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
class CosineScheduler(object):
|
| 68 |
-
def __init__(
|
| 69 |
-
self,
|
| 70 |
-
base_value,
|
| 71 |
-
final_value,
|
| 72 |
-
total_iters,
|
| 73 |
-
warmup_iters=0,
|
| 74 |
-
start_warmup_value=0,
|
| 75 |
-
freeze_iters=0,
|
| 76 |
-
):
|
| 77 |
-
super().__init__()
|
| 78 |
-
self.final_value = final_value
|
| 79 |
-
self.total_iters = total_iters
|
| 80 |
-
|
| 81 |
-
freeze_schedule = np.zeros((freeze_iters))
|
| 82 |
-
|
| 83 |
-
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
| 84 |
-
|
| 85 |
-
iters = np.arange(total_iters - warmup_iters - freeze_iters)
|
| 86 |
-
schedule = final_value + 0.5 * (base_value - final_value) * (
|
| 87 |
-
1 + np.cos(np.pi * iters / len(iters))
|
| 88 |
-
)
|
| 89 |
-
self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
|
| 90 |
-
|
| 91 |
-
assert len(self.schedule) == self.total_iters
|
| 92 |
-
|
| 93 |
-
def __getitem__(self, it):
|
| 94 |
-
if it >= self.total_iters:
|
| 95 |
-
return self.final_value
|
| 96 |
-
else:
|
| 97 |
-
return self.schedule[it]
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def has_batchnorms(model):
|
| 101 |
-
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
| 102 |
-
for name, module in model.named_modules():
|
| 103 |
-
if isinstance(module, bn_types):
|
| 104 |
-
return True
|
| 105 |
-
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/dust3r/__init__.py
DELETED
|
@@ -1,217 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Inference wrapper for DUSt3R
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import warnings
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
| 9 |
-
from dust3r.image_pairs import make_pairs
|
| 10 |
-
from dust3r.inference import inference
|
| 11 |
-
from dust3r.model import AsymmetricCroCo3DStereo # noqa
|
| 12 |
-
|
| 13 |
-
from mapanything.models.external.vggt.utils.rotation import mat_to_quat
|
| 14 |
-
from mapanything.utils.geometry import (
|
| 15 |
-
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
|
| 16 |
-
convert_z_depth_to_depth_along_ray,
|
| 17 |
-
depthmap_to_camera_frame,
|
| 18 |
-
get_rays_in_camera_frame,
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
inf = float("inf")
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def load_model(model_path, device, verbose=True):
|
| 25 |
-
if verbose:
|
| 26 |
-
print("Loading model from", model_path)
|
| 27 |
-
ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
|
| 28 |
-
args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
|
| 29 |
-
if "landscape_only" not in args:
|
| 30 |
-
args = args[:-1] + ", landscape_only=False)"
|
| 31 |
-
else:
|
| 32 |
-
args = args.replace(" ", "").replace(
|
| 33 |
-
"landscape_only=True", "landscape_only=False"
|
| 34 |
-
)
|
| 35 |
-
assert "landscape_only=False" in args
|
| 36 |
-
if verbose:
|
| 37 |
-
print(f"Instantiating: {args}")
|
| 38 |
-
try:
|
| 39 |
-
net = eval(args)
|
| 40 |
-
except NameError:
|
| 41 |
-
net = AsymmetricCroCo3DStereo(
|
| 42 |
-
enc_depth=24,
|
| 43 |
-
dec_depth=12,
|
| 44 |
-
enc_embed_dim=1024,
|
| 45 |
-
dec_embed_dim=768,
|
| 46 |
-
enc_num_heads=16,
|
| 47 |
-
dec_num_heads=12,
|
| 48 |
-
pos_embed="RoPE100",
|
| 49 |
-
patch_embed_cls="PatchEmbedDust3R",
|
| 50 |
-
img_size=(512, 512),
|
| 51 |
-
head_type="dpt",
|
| 52 |
-
output_mode="pts3d",
|
| 53 |
-
depth_mode=("exp", -inf, inf),
|
| 54 |
-
conf_mode=("exp", 1, inf),
|
| 55 |
-
landscape_only=False,
|
| 56 |
-
)
|
| 57 |
-
s = net.load_state_dict(ckpt["model"], strict=False)
|
| 58 |
-
if verbose:
|
| 59 |
-
print(s)
|
| 60 |
-
return net.to(device)
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class DUSt3RBAWrapper(torch.nn.Module):
|
| 64 |
-
def __init__(
|
| 65 |
-
self,
|
| 66 |
-
name,
|
| 67 |
-
ckpt_path,
|
| 68 |
-
scene_graph="complete",
|
| 69 |
-
inference_batch_size=32,
|
| 70 |
-
global_optim_schedule="cosine",
|
| 71 |
-
global_optim_lr=0.01,
|
| 72 |
-
global_optim_niter=300,
|
| 73 |
-
**kwargs,
|
| 74 |
-
):
|
| 75 |
-
super().__init__()
|
| 76 |
-
self.name = name
|
| 77 |
-
self.ckpt_path = ckpt_path
|
| 78 |
-
self.scene_graph = scene_graph
|
| 79 |
-
self.inference_batch_size = inference_batch_size
|
| 80 |
-
self.global_optim_schedule = global_optim_schedule
|
| 81 |
-
self.global_optim_lr = global_optim_lr
|
| 82 |
-
self.global_optim_niter = global_optim_niter
|
| 83 |
-
|
| 84 |
-
# Init the model and load the checkpoint
|
| 85 |
-
self.model = load_model(self.ckpt_path, device="cpu")
|
| 86 |
-
|
| 87 |
-
# Init the global aligner mode
|
| 88 |
-
self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer
|
| 89 |
-
|
| 90 |
-
def forward(self, views):
|
| 91 |
-
"""
|
| 92 |
-
Forward pass wrapper for DUSt3R using the global aligner.
|
| 93 |
-
|
| 94 |
-
Assumption:
|
| 95 |
-
- The batch size of input views is 1.
|
| 96 |
-
|
| 97 |
-
Args:
|
| 98 |
-
views (List[dict]): List of dictionaries containing the input views' images and instance information.
|
| 99 |
-
Each dictionary should contain the following keys, where B is the batch size and is 1:
|
| 100 |
-
"img" (tensor): Image tensor of shape (B, C, H, W).
|
| 101 |
-
"data_norm_type" (list): ["dust3r"]
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
List[dict]: A list containing the final outputs for the input views.
|
| 105 |
-
"""
|
| 106 |
-
# Check the batch size of input views
|
| 107 |
-
batch_size_per_view, _, height, width = views[0]["img"].shape
|
| 108 |
-
device = views[0]["img"].device
|
| 109 |
-
num_views = len(views)
|
| 110 |
-
assert batch_size_per_view == 1, (
|
| 111 |
-
f"Batch size of input views should be 1, but got {batch_size_per_view}."
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
# Check the data norm type
|
| 115 |
-
data_norm_type = views[0]["data_norm_type"][0]
|
| 116 |
-
assert data_norm_type == "dust3r", (
|
| 117 |
-
"DUSt3R expects a normalized image with the DUSt3R normalization scheme applied"
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
# Convert the input views to the expected input format
|
| 121 |
-
images = []
|
| 122 |
-
for view in views:
|
| 123 |
-
images.append(
|
| 124 |
-
dict(
|
| 125 |
-
img=view["img"],
|
| 126 |
-
idx=len(images),
|
| 127 |
-
instance=str(len(images)),
|
| 128 |
-
)
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
# Make image pairs and run inference pair-wise
|
| 132 |
-
pairs = make_pairs(
|
| 133 |
-
images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True
|
| 134 |
-
)
|
| 135 |
-
with warnings.catch_warnings():
|
| 136 |
-
warnings.simplefilter("ignore", category=FutureWarning)
|
| 137 |
-
output = inference(
|
| 138 |
-
pairs,
|
| 139 |
-
self.model,
|
| 140 |
-
device,
|
| 141 |
-
batch_size=self.inference_batch_size,
|
| 142 |
-
verbose=False,
|
| 143 |
-
)
|
| 144 |
-
|
| 145 |
-
# Global optimization
|
| 146 |
-
with torch.enable_grad():
|
| 147 |
-
scene = global_aligner(
|
| 148 |
-
output, device=device, mode=self.global_aligner_mode, verbose=False
|
| 149 |
-
)
|
| 150 |
-
_ = scene.compute_global_alignment(
|
| 151 |
-
init="mst",
|
| 152 |
-
niter=self.global_optim_niter,
|
| 153 |
-
schedule=self.global_optim_schedule,
|
| 154 |
-
lr=self.global_optim_lr,
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
# Make sure scene is not None
|
| 158 |
-
if scene is None:
|
| 159 |
-
raise RuntimeError("Global optimization failed.")
|
| 160 |
-
|
| 161 |
-
# Get the predictions
|
| 162 |
-
intrinsics = scene.get_intrinsics()
|
| 163 |
-
c2w_poses = scene.get_im_poses()
|
| 164 |
-
depths = scene.get_depthmaps()
|
| 165 |
-
|
| 166 |
-
# Convert the output to the MapAnything format
|
| 167 |
-
with torch.autocast("cuda", enabled=False):
|
| 168 |
-
res = []
|
| 169 |
-
for view_idx in range(num_views):
|
| 170 |
-
# Get the current view predictions
|
| 171 |
-
curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0)
|
| 172 |
-
curr_view_pose = c2w_poses[view_idx].unsqueeze(0)
|
| 173 |
-
curr_view_depth_z = depths[view_idx].unsqueeze(0)
|
| 174 |
-
|
| 175 |
-
# Convert the pose to quaternions and translation
|
| 176 |
-
curr_view_cam_translations = curr_view_pose[..., :3, 3]
|
| 177 |
-
curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3])
|
| 178 |
-
|
| 179 |
-
# Get the camera frame pointmaps
|
| 180 |
-
curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
|
| 181 |
-
curr_view_depth_z, curr_view_intrinsic
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
# Convert the z depth to depth along ray
|
| 185 |
-
curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
|
| 186 |
-
curr_view_depth_z, curr_view_intrinsic
|
| 187 |
-
)
|
| 188 |
-
curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
|
| 189 |
-
|
| 190 |
-
# Get the ray directions on the unit sphere in the camera frame
|
| 191 |
-
_, curr_view_ray_dirs = get_rays_in_camera_frame(
|
| 192 |
-
curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
|
| 193 |
-
)
|
| 194 |
-
|
| 195 |
-
# Get the pointmaps
|
| 196 |
-
curr_view_pts3d = (
|
| 197 |
-
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
|
| 198 |
-
curr_view_ray_dirs,
|
| 199 |
-
curr_view_depth_along_ray,
|
| 200 |
-
curr_view_cam_translations,
|
| 201 |
-
curr_view_cam_quats,
|
| 202 |
-
)
|
| 203 |
-
)
|
| 204 |
-
|
| 205 |
-
# Append the outputs to the result list
|
| 206 |
-
res.append(
|
| 207 |
-
{
|
| 208 |
-
"pts3d": curr_view_pts3d,
|
| 209 |
-
"pts3d_cam": curr_view_pts3d_cam,
|
| 210 |
-
"ray_directions": curr_view_ray_dirs,
|
| 211 |
-
"depth_along_ray": curr_view_depth_along_ray,
|
| 212 |
-
"cam_trans": curr_view_cam_translations,
|
| 213 |
-
"cam_quats": curr_view_cam_quats,
|
| 214 |
-
}
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/mast3r/__init__.py
DELETED
|
@@ -1,191 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Inference wrapper for MASt3R + Sparse GA
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import os
|
| 6 |
-
import tempfile
|
| 7 |
-
import warnings
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
from dust3r.image_pairs import make_pairs
|
| 11 |
-
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
|
| 12 |
-
from mast3r.model import load_model
|
| 13 |
-
|
| 14 |
-
from mapanything.models.external.vggt.utils.rotation import mat_to_quat
|
| 15 |
-
from mapanything.utils.geometry import (
|
| 16 |
-
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
|
| 17 |
-
convert_z_depth_to_depth_along_ray,
|
| 18 |
-
depthmap_to_camera_frame,
|
| 19 |
-
get_rays_in_camera_frame,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class MASt3RSGAWrapper(torch.nn.Module):
|
| 24 |
-
def __init__(
|
| 25 |
-
self,
|
| 26 |
-
name,
|
| 27 |
-
ckpt_path,
|
| 28 |
-
cache_dir,
|
| 29 |
-
scene_graph="complete",
|
| 30 |
-
sparse_ga_lr1=0.07,
|
| 31 |
-
sparse_ga_niter1=300,
|
| 32 |
-
sparse_ga_lr2=0.01,
|
| 33 |
-
sparse_ga_niter2=300,
|
| 34 |
-
sparse_ga_optim_level="refine+depth",
|
| 35 |
-
sparse_ga_shared_intrinsics=False,
|
| 36 |
-
sparse_ga_matching_conf_thr=5.0,
|
| 37 |
-
**kwargs,
|
| 38 |
-
):
|
| 39 |
-
super().__init__()
|
| 40 |
-
self.name = name
|
| 41 |
-
self.ckpt_path = ckpt_path
|
| 42 |
-
self.cache_dir = cache_dir
|
| 43 |
-
self.scene_graph = scene_graph
|
| 44 |
-
self.sparse_ga_lr1 = sparse_ga_lr1
|
| 45 |
-
self.sparse_ga_niter1 = sparse_ga_niter1
|
| 46 |
-
self.sparse_ga_lr2 = sparse_ga_lr2
|
| 47 |
-
self.sparse_ga_niter2 = sparse_ga_niter2
|
| 48 |
-
self.sparse_ga_optim_level = sparse_ga_optim_level
|
| 49 |
-
self.sparse_ga_shared_intrinsics = sparse_ga_shared_intrinsics
|
| 50 |
-
self.sparse_ga_matching_conf_thr = sparse_ga_matching_conf_thr
|
| 51 |
-
|
| 52 |
-
# Init the model and load the checkpoint
|
| 53 |
-
self.model = load_model(self.ckpt_path, device="cpu")
|
| 54 |
-
|
| 55 |
-
def forward(self, views):
|
| 56 |
-
"""
|
| 57 |
-
Forward pass wrapper for MASt3R using the sparse global aligner.
|
| 58 |
-
|
| 59 |
-
Assumption:
|
| 60 |
-
- The batch size of input views is 1.
|
| 61 |
-
|
| 62 |
-
Args:
|
| 63 |
-
views (List[dict]): List of dictionaries containing the input views' images and instance information.
|
| 64 |
-
Each dictionary should contain the following keys, where B is the batch size and is 1:
|
| 65 |
-
"img" (tensor): Image tensor of shape (B, C, H, W).
|
| 66 |
-
"data_norm_type" (list): ["dust3r"]
|
| 67 |
-
"label" (list): ["scene_name"]
|
| 68 |
-
"instance" (list): ["image_name"]
|
| 69 |
-
|
| 70 |
-
Returns:
|
| 71 |
-
List[dict]: A list containing the final outputs for the input views.
|
| 72 |
-
"""
|
| 73 |
-
# Check the batch size of input views
|
| 74 |
-
batch_size_per_view, _, height, width = views[0]["img"].shape
|
| 75 |
-
device = views[0]["img"].device
|
| 76 |
-
num_views = len(views)
|
| 77 |
-
assert batch_size_per_view == 1, (
|
| 78 |
-
f"Batch size of input views should be 1, but got {batch_size_per_view}."
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
# Check the data norm type
|
| 82 |
-
data_norm_type = views[0]["data_norm_type"][0]
|
| 83 |
-
assert data_norm_type == "dust3r", (
|
| 84 |
-
"MASt3R expects a normalized image with the DUSt3R normalization scheme applied"
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
# Convert the input views to the expected input format
|
| 88 |
-
images = []
|
| 89 |
-
image_paths = []
|
| 90 |
-
for view in views:
|
| 91 |
-
images.append(
|
| 92 |
-
dict(
|
| 93 |
-
img=view["img"].cpu(),
|
| 94 |
-
idx=len(images),
|
| 95 |
-
instance=str(len(images)),
|
| 96 |
-
true_shape=torch.tensor(view["img"].shape[-2:])[None]
|
| 97 |
-
.repeat(batch_size_per_view, 1)
|
| 98 |
-
.numpy(),
|
| 99 |
-
)
|
| 100 |
-
)
|
| 101 |
-
view_name = os.path.join(view["label"][0], view["instance"][0])
|
| 102 |
-
image_paths.append(view_name)
|
| 103 |
-
|
| 104 |
-
# Make image pairs and run inference
|
| 105 |
-
# Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
|
| 106 |
-
pairs = make_pairs(
|
| 107 |
-
images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True
|
| 108 |
-
)
|
| 109 |
-
with torch.enable_grad():
|
| 110 |
-
with warnings.catch_warnings():
|
| 111 |
-
warnings.simplefilter("ignore", category=FutureWarning)
|
| 112 |
-
tempfile.mkdtemp(dir=self.cache_dir)
|
| 113 |
-
scene = sparse_global_alignment(
|
| 114 |
-
image_paths,
|
| 115 |
-
pairs,
|
| 116 |
-
self.cache_dir,
|
| 117 |
-
self.model,
|
| 118 |
-
lr1=self.sparse_ga_lr1,
|
| 119 |
-
niter1=self.sparse_ga_niter1,
|
| 120 |
-
lr2=self.sparse_ga_lr2,
|
| 121 |
-
niter2=self.sparse_ga_niter2,
|
| 122 |
-
device=device,
|
| 123 |
-
opt_depth="depth" in self.sparse_ga_optim_level,
|
| 124 |
-
shared_intrinsics=self.sparse_ga_shared_intrinsics,
|
| 125 |
-
matching_conf_thr=self.sparse_ga_matching_conf_thr,
|
| 126 |
-
verbose=False,
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
# Make sure scene is not None
|
| 130 |
-
if scene is None:
|
| 131 |
-
raise RuntimeError("Global optimization failed.")
|
| 132 |
-
|
| 133 |
-
# Get the predictions
|
| 134 |
-
intrinsics = scene.intrinsics
|
| 135 |
-
c2w_poses = scene.get_im_poses()
|
| 136 |
-
_, depths, _ = scene.get_dense_pts3d()
|
| 137 |
-
|
| 138 |
-
# Convert the output to the MapAnything format
|
| 139 |
-
with torch.autocast("cuda", enabled=False):
|
| 140 |
-
res = []
|
| 141 |
-
for view_idx in range(num_views):
|
| 142 |
-
# Get the current view predictions
|
| 143 |
-
curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0)
|
| 144 |
-
curr_view_pose = c2w_poses[view_idx].unsqueeze(0)
|
| 145 |
-
curr_view_depth_z = (
|
| 146 |
-
depths[view_idx].reshape((height, width)).unsqueeze(0)
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
# Convert the pose to quaternions and translation
|
| 150 |
-
curr_view_cam_translations = curr_view_pose[..., :3, 3]
|
| 151 |
-
curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3])
|
| 152 |
-
|
| 153 |
-
# Get the camera frame pointmaps
|
| 154 |
-
curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
|
| 155 |
-
curr_view_depth_z, curr_view_intrinsic
|
| 156 |
-
)
|
| 157 |
-
|
| 158 |
-
# Convert the z depth to depth along ray
|
| 159 |
-
curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
|
| 160 |
-
curr_view_depth_z, curr_view_intrinsic
|
| 161 |
-
)
|
| 162 |
-
curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
|
| 163 |
-
|
| 164 |
-
# Get the ray directions on the unit sphere in the camera frame
|
| 165 |
-
_, curr_view_ray_dirs = get_rays_in_camera_frame(
|
| 166 |
-
curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
|
| 167 |
-
)
|
| 168 |
-
|
| 169 |
-
# Get the pointmaps
|
| 170 |
-
curr_view_pts3d = (
|
| 171 |
-
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
|
| 172 |
-
curr_view_ray_dirs,
|
| 173 |
-
curr_view_depth_along_ray,
|
| 174 |
-
curr_view_cam_translations,
|
| 175 |
-
curr_view_cam_quats,
|
| 176 |
-
)
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
# Append the outputs to the result list
|
| 180 |
-
res.append(
|
| 181 |
-
{
|
| 182 |
-
"pts3d": curr_view_pts3d,
|
| 183 |
-
"pts3d_cam": curr_view_pts3d_cam,
|
| 184 |
-
"ray_directions": curr_view_ray_dirs,
|
| 185 |
-
"depth_along_ray": curr_view_depth_along_ray,
|
| 186 |
-
"cam_trans": curr_view_cam_translations,
|
| 187 |
-
"cam_quats": curr_view_cam_quats,
|
| 188 |
-
}
|
| 189 |
-
)
|
| 190 |
-
|
| 191 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/moge/__init__.py
DELETED
|
@@ -1,114 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Inference wrapper for MoGe
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
from mapanything.models.external.moge.models.v1 import MoGeModel as MoGeModelV1
|
| 8 |
-
from mapanything.models.external.moge.models.v2 import MoGeModel as MoGeModelV2
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class MoGeWrapper(torch.nn.Module):
|
| 12 |
-
def __init__(
|
| 13 |
-
self,
|
| 14 |
-
name,
|
| 15 |
-
model_string="Ruicheng/moge-2-vitl",
|
| 16 |
-
torch_hub_force_reload=False,
|
| 17 |
-
load_custom_ckpt=False,
|
| 18 |
-
custom_ckpt_path=None,
|
| 19 |
-
):
|
| 20 |
-
super().__init__()
|
| 21 |
-
self.name = name
|
| 22 |
-
self.model_string = model_string
|
| 23 |
-
self.torch_hub_force_reload = torch_hub_force_reload
|
| 24 |
-
self.load_custom_ckpt = load_custom_ckpt
|
| 25 |
-
self.custom_ckpt_path = custom_ckpt_path
|
| 26 |
-
|
| 27 |
-
# Mapping of MoGe model version to checkpoint strings
|
| 28 |
-
self.moge_model_map = {
|
| 29 |
-
"v1": ["Ruicheng/moge-vitl"],
|
| 30 |
-
"v2": [
|
| 31 |
-
"Ruicheng/moge-2-vits-normal",
|
| 32 |
-
"Ruicheng/moge-2-vitb-normal",
|
| 33 |
-
"Ruicheng/moge-2-vitl-normal",
|
| 34 |
-
"Ruicheng/moge-2-vitl",
|
| 35 |
-
],
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
# Initialize the model
|
| 39 |
-
if self.model_string in self.moge_model_map["v1"]:
|
| 40 |
-
self.model = MoGeModelV1.from_pretrained(self.model_string)
|
| 41 |
-
elif self.model_string in self.moge_model_map["v2"]:
|
| 42 |
-
self.model = MoGeModelV2.from_pretrained(self.model_string)
|
| 43 |
-
else:
|
| 44 |
-
raise ValueError(
|
| 45 |
-
f"Invalid model string: {self.model_string}. Valid strings are: {self.moge_model_map}"
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
# Load custom checkpoint if requested
|
| 49 |
-
if self.load_custom_ckpt:
|
| 50 |
-
print(f"Loading checkpoint from {self.custom_ckpt_path} ...")
|
| 51 |
-
assert self.custom_ckpt_path is not None, (
|
| 52 |
-
"custom_ckpt_path must be provided if load_custom_ckpt is set to True"
|
| 53 |
-
)
|
| 54 |
-
custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False)
|
| 55 |
-
print(self.model.load_state_dict(custom_ckpt, strict=True))
|
| 56 |
-
del custom_ckpt # in case it occupies memory
|
| 57 |
-
|
| 58 |
-
def forward(self, views):
|
| 59 |
-
"""
|
| 60 |
-
Forward pass wrapper for MoGe-2.
|
| 61 |
-
The predicted MoGe-2 mask is not applied to the outputs.
|
| 62 |
-
The number of tokens for inference is determined by the image shape.
|
| 63 |
-
|
| 64 |
-
Assumption:
|
| 65 |
-
- The number of input views is 1.
|
| 66 |
-
|
| 67 |
-
Args:
|
| 68 |
-
views (List[dict]): List of dictionaries containing the input views' images and instance information.
|
| 69 |
-
Length of the list should be 1.
|
| 70 |
-
Each dictionary should contain the following keys:
|
| 71 |
-
"img" (tensor): Image tensor of shape (B, C, H, W).
|
| 72 |
-
"data_norm_type" (list): ["identity"]
|
| 73 |
-
|
| 74 |
-
Returns:
|
| 75 |
-
List[dict]: A list containing the final outputs for the single view. Length of the list will be 1.
|
| 76 |
-
"""
|
| 77 |
-
# Check that the number of input views is 1
|
| 78 |
-
assert len(views) == 1, "MoGe only supports 1 input view."
|
| 79 |
-
|
| 80 |
-
# Get input shape of the images, number of tokens for inference, and batch size per view
|
| 81 |
-
_, _, height, width = views[0]["img"].shape
|
| 82 |
-
num_tokens = int(height // 14) * int(width // 14)
|
| 83 |
-
|
| 84 |
-
# Check the data norm type
|
| 85 |
-
# MoGe expects a normalized image but without the DINOv2 mean and std applied ("identity")
|
| 86 |
-
data_norm_type = views[0]["data_norm_type"][0]
|
| 87 |
-
assert data_norm_type == "identity", (
|
| 88 |
-
"MoGe expects a normalized image but without the DINOv2 mean and std applied"
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
# Run MoGe inference
|
| 92 |
-
# Output dict contains: "points", "depth", "mask", "intrinsics", "normal" (based on model config)
|
| 93 |
-
model_outputs = self.model.infer(
|
| 94 |
-
image=views[0]["img"], num_tokens=num_tokens, apply_mask=False
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
# Get the ray directions and depth along ray
|
| 98 |
-
with torch.autocast("cuda", enabled=False):
|
| 99 |
-
depth_along_ray = torch.norm(model_outputs["points"], dim=-1, keepdim=True)
|
| 100 |
-
ray_directions = model_outputs["points"] / depth_along_ray
|
| 101 |
-
|
| 102 |
-
# Convert the output to MapAnything format
|
| 103 |
-
result_dict = {
|
| 104 |
-
"pts3d": model_outputs["points"],
|
| 105 |
-
"pts3d_cam": model_outputs["points"],
|
| 106 |
-
"depth_z": model_outputs["depth"].unsqueeze(-1),
|
| 107 |
-
"intrinsics": model_outputs["intrinsics"],
|
| 108 |
-
"non_ambiguous_mask": model_outputs["mask"],
|
| 109 |
-
"ray_directions": ray_directions,
|
| 110 |
-
"depth_along_ray": depth_along_ray,
|
| 111 |
-
}
|
| 112 |
-
res = [result_dict]
|
| 113 |
-
|
| 114 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/moge/models/modules.py
DELETED
|
@@ -1,467 +0,0 @@
|
|
| 1 |
-
import functools
|
| 2 |
-
import importlib
|
| 3 |
-
import itertools
|
| 4 |
-
from typing import List, Literal, Optional, Sequence, Tuple, Union
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
from mapanything.models.external.dinov2.models.vision_transformer import (
|
| 11 |
-
DinoVisionTransformer,
|
| 12 |
-
)
|
| 13 |
-
from mapanything.models.external.moge.models.utils import (
|
| 14 |
-
wrap_dinov2_attention_with_sdpa,
|
| 15 |
-
wrap_module_with_gradient_checkpointing,
|
| 16 |
-
)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class ResidualConvBlock(nn.Module):
|
| 20 |
-
def __init__(
|
| 21 |
-
self,
|
| 22 |
-
in_channels: int,
|
| 23 |
-
out_channels: int = None,
|
| 24 |
-
hidden_channels: int = None,
|
| 25 |
-
kernel_size: int = 3,
|
| 26 |
-
padding_mode: str = "replicate",
|
| 27 |
-
activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu",
|
| 28 |
-
in_norm: Literal[
|
| 29 |
-
"group_norm", "layer_norm", "instance_norm", "none"
|
| 30 |
-
] = "layer_norm",
|
| 31 |
-
hidden_norm: Literal[
|
| 32 |
-
"group_norm", "layer_norm", "instance_norm"
|
| 33 |
-
] = "group_norm",
|
| 34 |
-
):
|
| 35 |
-
super(ResidualConvBlock, self).__init__()
|
| 36 |
-
if out_channels is None:
|
| 37 |
-
out_channels = in_channels
|
| 38 |
-
if hidden_channels is None:
|
| 39 |
-
hidden_channels = in_channels
|
| 40 |
-
|
| 41 |
-
if activation == "relu":
|
| 42 |
-
activation_cls = nn.ReLU
|
| 43 |
-
elif activation == "leaky_relu":
|
| 44 |
-
activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
|
| 45 |
-
elif activation == "silu":
|
| 46 |
-
activation_cls = nn.SiLU
|
| 47 |
-
elif activation == "elu":
|
| 48 |
-
activation_cls = nn.ELU
|
| 49 |
-
else:
|
| 50 |
-
raise ValueError(f"Unsupported activation function: {activation}")
|
| 51 |
-
|
| 52 |
-
self.layers = nn.Sequential(
|
| 53 |
-
nn.GroupNorm(in_channels // 32, in_channels)
|
| 54 |
-
if in_norm == "group_norm"
|
| 55 |
-
else nn.GroupNorm(1, in_channels)
|
| 56 |
-
if in_norm == "layer_norm"
|
| 57 |
-
else nn.InstanceNorm2d(in_channels)
|
| 58 |
-
if in_norm == "instance_norm"
|
| 59 |
-
else nn.Identity(),
|
| 60 |
-
activation_cls(),
|
| 61 |
-
nn.Conv2d(
|
| 62 |
-
in_channels,
|
| 63 |
-
hidden_channels,
|
| 64 |
-
kernel_size=kernel_size,
|
| 65 |
-
padding=kernel_size // 2,
|
| 66 |
-
padding_mode=padding_mode,
|
| 67 |
-
),
|
| 68 |
-
nn.GroupNorm(hidden_channels // 32, hidden_channels)
|
| 69 |
-
if hidden_norm == "group_norm"
|
| 70 |
-
else nn.GroupNorm(1, hidden_channels)
|
| 71 |
-
if hidden_norm == "layer_norm"
|
| 72 |
-
else nn.InstanceNorm2d(hidden_channels)
|
| 73 |
-
if hidden_norm == "instance_norm"
|
| 74 |
-
else nn.Identity(),
|
| 75 |
-
activation_cls(),
|
| 76 |
-
nn.Conv2d(
|
| 77 |
-
hidden_channels,
|
| 78 |
-
out_channels,
|
| 79 |
-
kernel_size=kernel_size,
|
| 80 |
-
padding=kernel_size // 2,
|
| 81 |
-
padding_mode=padding_mode,
|
| 82 |
-
),
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
self.skip_connection = (
|
| 86 |
-
nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
| 87 |
-
if in_channels != out_channels
|
| 88 |
-
else nn.Identity()
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
def forward(self, x):
|
| 92 |
-
skip = self.skip_connection(x)
|
| 93 |
-
x = self.layers(x)
|
| 94 |
-
x = x + skip
|
| 95 |
-
return x
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
class DINOv2Encoder(nn.Module):
|
| 99 |
-
"Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]."
|
| 100 |
-
|
| 101 |
-
backbone: DinoVisionTransformer
|
| 102 |
-
image_mean: torch.Tensor
|
| 103 |
-
image_std: torch.Tensor
|
| 104 |
-
dim_features: int
|
| 105 |
-
|
| 106 |
-
def __init__(
|
| 107 |
-
self,
|
| 108 |
-
backbone: str,
|
| 109 |
-
intermediate_layers: Union[int, List[int]],
|
| 110 |
-
dim_out: int,
|
| 111 |
-
**deprecated_kwargs,
|
| 112 |
-
):
|
| 113 |
-
super(DINOv2Encoder, self).__init__()
|
| 114 |
-
|
| 115 |
-
self.intermediate_layers = intermediate_layers
|
| 116 |
-
|
| 117 |
-
# Load the backbone
|
| 118 |
-
self.hub_loader = getattr(
|
| 119 |
-
importlib.import_module(
|
| 120 |
-
"mapanything.models.external.dinov2.hub.backbones", __package__
|
| 121 |
-
),
|
| 122 |
-
backbone,
|
| 123 |
-
)
|
| 124 |
-
self.backbone_name = backbone
|
| 125 |
-
self.backbone = self.hub_loader(pretrained=False)
|
| 126 |
-
|
| 127 |
-
self.dim_features = self.backbone.blocks[0].attn.qkv.in_features
|
| 128 |
-
self.num_features = (
|
| 129 |
-
intermediate_layers
|
| 130 |
-
if isinstance(intermediate_layers, int)
|
| 131 |
-
else len(intermediate_layers)
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
self.output_projections = nn.ModuleList(
|
| 135 |
-
[
|
| 136 |
-
nn.Conv2d(
|
| 137 |
-
in_channels=self.dim_features,
|
| 138 |
-
out_channels=dim_out,
|
| 139 |
-
kernel_size=1,
|
| 140 |
-
stride=1,
|
| 141 |
-
padding=0,
|
| 142 |
-
)
|
| 143 |
-
for _ in range(self.num_features)
|
| 144 |
-
]
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
self.register_buffer(
|
| 148 |
-
"image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
| 149 |
-
)
|
| 150 |
-
self.register_buffer(
|
| 151 |
-
"image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
@property
|
| 155 |
-
def onnx_compatible_mode(self):
|
| 156 |
-
return getattr(self, "_onnx_compatible_mode", False)
|
| 157 |
-
|
| 158 |
-
@onnx_compatible_mode.setter
|
| 159 |
-
def onnx_compatible_mode(self, value: bool):
|
| 160 |
-
self._onnx_compatible_mode = value
|
| 161 |
-
self.backbone.onnx_compatible_mode = value
|
| 162 |
-
|
| 163 |
-
def init_weights(self):
|
| 164 |
-
pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict()
|
| 165 |
-
self.backbone.load_state_dict(pretrained_backbone_state_dict)
|
| 166 |
-
|
| 167 |
-
def enable_gradient_checkpointing(self):
|
| 168 |
-
for i in range(len(self.backbone.blocks)):
|
| 169 |
-
wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
|
| 170 |
-
|
| 171 |
-
def enable_pytorch_native_sdpa(self):
|
| 172 |
-
for i in range(len(self.backbone.blocks)):
|
| 173 |
-
wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
|
| 174 |
-
|
| 175 |
-
def forward(
|
| 176 |
-
self,
|
| 177 |
-
image: torch.Tensor,
|
| 178 |
-
token_rows: Union[int, torch.LongTensor],
|
| 179 |
-
token_cols: Union[int, torch.LongTensor],
|
| 180 |
-
return_class_token: bool = False,
|
| 181 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 182 |
-
image_14 = F.interpolate(
|
| 183 |
-
image,
|
| 184 |
-
(token_rows * 14, token_cols * 14),
|
| 185 |
-
mode="bilinear",
|
| 186 |
-
align_corners=False,
|
| 187 |
-
antialias=not self.onnx_compatible_mode,
|
| 188 |
-
)
|
| 189 |
-
image_14 = (image_14 - self.image_mean) / self.image_std
|
| 190 |
-
|
| 191 |
-
# Get intermediate layers from the backbone
|
| 192 |
-
features = self.backbone.get_intermediate_layers(
|
| 193 |
-
image_14, n=self.intermediate_layers, return_class_token=True
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
# Project features to the desired dimensionality
|
| 197 |
-
x = torch.stack(
|
| 198 |
-
[
|
| 199 |
-
proj(
|
| 200 |
-
feat.permute(0, 2, 1)
|
| 201 |
-
.unflatten(2, (token_rows, token_cols))
|
| 202 |
-
.contiguous()
|
| 203 |
-
)
|
| 204 |
-
for proj, (feat, clstoken) in zip(self.output_projections, features)
|
| 205 |
-
],
|
| 206 |
-
dim=1,
|
| 207 |
-
).sum(dim=1)
|
| 208 |
-
|
| 209 |
-
if return_class_token:
|
| 210 |
-
return x, features[-1][1]
|
| 211 |
-
else:
|
| 212 |
-
return x
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
class Resampler(nn.Sequential):
|
| 216 |
-
def __init__(
|
| 217 |
-
self,
|
| 218 |
-
in_channels: int,
|
| 219 |
-
out_channels: int,
|
| 220 |
-
type_: Literal[
|
| 221 |
-
"pixel_shuffle",
|
| 222 |
-
"nearest",
|
| 223 |
-
"bilinear",
|
| 224 |
-
"conv_transpose",
|
| 225 |
-
"pixel_unshuffle",
|
| 226 |
-
"avg_pool",
|
| 227 |
-
"max_pool",
|
| 228 |
-
],
|
| 229 |
-
scale_factor: int = 2,
|
| 230 |
-
):
|
| 231 |
-
if type_ == "pixel_shuffle":
|
| 232 |
-
nn.Sequential.__init__(
|
| 233 |
-
self,
|
| 234 |
-
nn.Conv2d(
|
| 235 |
-
in_channels,
|
| 236 |
-
out_channels * (scale_factor**2),
|
| 237 |
-
kernel_size=3,
|
| 238 |
-
stride=1,
|
| 239 |
-
padding=1,
|
| 240 |
-
padding_mode="replicate",
|
| 241 |
-
),
|
| 242 |
-
nn.PixelShuffle(scale_factor),
|
| 243 |
-
nn.Conv2d(
|
| 244 |
-
out_channels,
|
| 245 |
-
out_channels,
|
| 246 |
-
kernel_size=3,
|
| 247 |
-
stride=1,
|
| 248 |
-
padding=1,
|
| 249 |
-
padding_mode="replicate",
|
| 250 |
-
),
|
| 251 |
-
)
|
| 252 |
-
for i in range(1, scale_factor**2):
|
| 253 |
-
self[0].weight.data[i :: scale_factor**2] = self[0].weight.data[
|
| 254 |
-
0 :: scale_factor**2
|
| 255 |
-
]
|
| 256 |
-
self[0].bias.data[i :: scale_factor**2] = self[0].bias.data[
|
| 257 |
-
0 :: scale_factor**2
|
| 258 |
-
]
|
| 259 |
-
elif type_ in ["nearest", "bilinear"]:
|
| 260 |
-
nn.Sequential.__init__(
|
| 261 |
-
self,
|
| 262 |
-
nn.Upsample(
|
| 263 |
-
scale_factor=scale_factor,
|
| 264 |
-
mode=type_,
|
| 265 |
-
align_corners=False if type_ == "bilinear" else None,
|
| 266 |
-
),
|
| 267 |
-
nn.Conv2d(
|
| 268 |
-
in_channels,
|
| 269 |
-
out_channels,
|
| 270 |
-
kernel_size=3,
|
| 271 |
-
stride=1,
|
| 272 |
-
padding=1,
|
| 273 |
-
padding_mode="replicate",
|
| 274 |
-
),
|
| 275 |
-
)
|
| 276 |
-
elif type_ == "conv_transpose":
|
| 277 |
-
nn.Sequential.__init__(
|
| 278 |
-
self,
|
| 279 |
-
nn.ConvTranspose2d(
|
| 280 |
-
in_channels,
|
| 281 |
-
out_channels,
|
| 282 |
-
kernel_size=scale_factor,
|
| 283 |
-
stride=scale_factor,
|
| 284 |
-
),
|
| 285 |
-
nn.Conv2d(
|
| 286 |
-
out_channels,
|
| 287 |
-
out_channels,
|
| 288 |
-
kernel_size=3,
|
| 289 |
-
stride=1,
|
| 290 |
-
padding=1,
|
| 291 |
-
padding_mode="replicate",
|
| 292 |
-
),
|
| 293 |
-
)
|
| 294 |
-
self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
|
| 295 |
-
elif type_ == "pixel_unshuffle":
|
| 296 |
-
nn.Sequential.__init__(
|
| 297 |
-
self,
|
| 298 |
-
nn.PixelUnshuffle(scale_factor),
|
| 299 |
-
nn.Conv2d(
|
| 300 |
-
in_channels * (scale_factor**2),
|
| 301 |
-
out_channels,
|
| 302 |
-
kernel_size=3,
|
| 303 |
-
stride=1,
|
| 304 |
-
padding=1,
|
| 305 |
-
padding_mode="replicate",
|
| 306 |
-
),
|
| 307 |
-
)
|
| 308 |
-
elif type_ == "avg_pool":
|
| 309 |
-
nn.Sequential.__init__(
|
| 310 |
-
self,
|
| 311 |
-
nn.Conv2d(
|
| 312 |
-
in_channels,
|
| 313 |
-
out_channels,
|
| 314 |
-
kernel_size=3,
|
| 315 |
-
stride=1,
|
| 316 |
-
padding=1,
|
| 317 |
-
padding_mode="replicate",
|
| 318 |
-
),
|
| 319 |
-
nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
|
| 320 |
-
)
|
| 321 |
-
elif type_ == "max_pool":
|
| 322 |
-
nn.Sequential.__init__(
|
| 323 |
-
self,
|
| 324 |
-
nn.Conv2d(
|
| 325 |
-
in_channels,
|
| 326 |
-
out_channels,
|
| 327 |
-
kernel_size=3,
|
| 328 |
-
stride=1,
|
| 329 |
-
padding=1,
|
| 330 |
-
padding_mode="replicate",
|
| 331 |
-
),
|
| 332 |
-
nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
|
| 333 |
-
)
|
| 334 |
-
else:
|
| 335 |
-
raise ValueError(f"Unsupported resampler type: {type_}")
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
class MLP(nn.Sequential):
|
| 339 |
-
def __init__(self, dims: Sequence[int]):
|
| 340 |
-
nn.Sequential.__init__(
|
| 341 |
-
self,
|
| 342 |
-
*itertools.chain(
|
| 343 |
-
*[
|
| 344 |
-
(nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True))
|
| 345 |
-
for dim_in, dim_out in zip(dims[:-2], dims[1:-1])
|
| 346 |
-
]
|
| 347 |
-
),
|
| 348 |
-
nn.Linear(dims[-2], dims[-1]),
|
| 349 |
-
)
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
class ConvStack(nn.Module):
|
| 353 |
-
def __init__(
|
| 354 |
-
self,
|
| 355 |
-
dim_in: List[Optional[int]],
|
| 356 |
-
dim_res_blocks: List[int],
|
| 357 |
-
dim_out: List[Optional[int]],
|
| 358 |
-
resamplers: Union[
|
| 359 |
-
Literal[
|
| 360 |
-
"pixel_shuffle",
|
| 361 |
-
"nearest",
|
| 362 |
-
"bilinear",
|
| 363 |
-
"conv_transpose",
|
| 364 |
-
"pixel_unshuffle",
|
| 365 |
-
"avg_pool",
|
| 366 |
-
"max_pool",
|
| 367 |
-
],
|
| 368 |
-
List,
|
| 369 |
-
],
|
| 370 |
-
dim_times_res_block_hidden: int = 1,
|
| 371 |
-
num_res_blocks: int = 1,
|
| 372 |
-
res_block_in_norm: Literal[
|
| 373 |
-
"layer_norm", "group_norm", "instance_norm", "none"
|
| 374 |
-
] = "layer_norm",
|
| 375 |
-
res_block_hidden_norm: Literal[
|
| 376 |
-
"layer_norm", "group_norm", "instance_norm", "none"
|
| 377 |
-
] = "group_norm",
|
| 378 |
-
activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu",
|
| 379 |
-
):
|
| 380 |
-
super().__init__()
|
| 381 |
-
self.input_blocks = nn.ModuleList(
|
| 382 |
-
[
|
| 383 |
-
nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0)
|
| 384 |
-
if dim_in_ is not None
|
| 385 |
-
else nn.Identity()
|
| 386 |
-
for dim_in_, dim_res_block_ in zip(
|
| 387 |
-
dim_in
|
| 388 |
-
if isinstance(dim_in, Sequence)
|
| 389 |
-
else itertools.repeat(dim_in),
|
| 390 |
-
dim_res_blocks,
|
| 391 |
-
)
|
| 392 |
-
]
|
| 393 |
-
)
|
| 394 |
-
self.resamplers = nn.ModuleList(
|
| 395 |
-
[
|
| 396 |
-
Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
|
| 397 |
-
for i, (dim_prev, dim_succ, resampler) in enumerate(
|
| 398 |
-
zip(
|
| 399 |
-
dim_res_blocks[:-1],
|
| 400 |
-
dim_res_blocks[1:],
|
| 401 |
-
resamplers
|
| 402 |
-
if isinstance(resamplers, Sequence)
|
| 403 |
-
else itertools.repeat(resamplers),
|
| 404 |
-
)
|
| 405 |
-
)
|
| 406 |
-
]
|
| 407 |
-
)
|
| 408 |
-
self.res_blocks = nn.ModuleList(
|
| 409 |
-
[
|
| 410 |
-
nn.Sequential(
|
| 411 |
-
*(
|
| 412 |
-
ResidualConvBlock(
|
| 413 |
-
dim_res_block_,
|
| 414 |
-
dim_res_block_,
|
| 415 |
-
dim_times_res_block_hidden * dim_res_block_,
|
| 416 |
-
activation=activation,
|
| 417 |
-
in_norm=res_block_in_norm,
|
| 418 |
-
hidden_norm=res_block_hidden_norm,
|
| 419 |
-
)
|
| 420 |
-
for _ in range(
|
| 421 |
-
num_res_blocks[i]
|
| 422 |
-
if isinstance(num_res_blocks, list)
|
| 423 |
-
else num_res_blocks
|
| 424 |
-
)
|
| 425 |
-
)
|
| 426 |
-
)
|
| 427 |
-
for i, dim_res_block_ in enumerate(dim_res_blocks)
|
| 428 |
-
]
|
| 429 |
-
)
|
| 430 |
-
self.output_blocks = nn.ModuleList(
|
| 431 |
-
[
|
| 432 |
-
nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0)
|
| 433 |
-
if dim_out_ is not None
|
| 434 |
-
else nn.Identity()
|
| 435 |
-
for dim_out_, dim_res_block_ in zip(
|
| 436 |
-
dim_out
|
| 437 |
-
if isinstance(dim_out, Sequence)
|
| 438 |
-
else itertools.repeat(dim_out),
|
| 439 |
-
dim_res_blocks,
|
| 440 |
-
)
|
| 441 |
-
]
|
| 442 |
-
)
|
| 443 |
-
|
| 444 |
-
def enable_gradient_checkpointing(self):
|
| 445 |
-
for i in range(len(self.resamplers)):
|
| 446 |
-
self.resamplers[i] = wrap_module_with_gradient_checkpointing(
|
| 447 |
-
self.resamplers[i]
|
| 448 |
-
)
|
| 449 |
-
for i in range(len(self.res_blocks)):
|
| 450 |
-
for j in range(len(self.res_blocks[i])):
|
| 451 |
-
self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(
|
| 452 |
-
self.res_blocks[i][j]
|
| 453 |
-
)
|
| 454 |
-
|
| 455 |
-
def forward(self, in_features: List[torch.Tensor]):
|
| 456 |
-
out_features = []
|
| 457 |
-
for i in range(len(self.res_blocks)):
|
| 458 |
-
feature = self.input_blocks[i](in_features[i])
|
| 459 |
-
if i == 0:
|
| 460 |
-
x = feature
|
| 461 |
-
elif feature is not None:
|
| 462 |
-
x = x + feature
|
| 463 |
-
x = self.res_blocks[i](x)
|
| 464 |
-
out_features.append(self.output_blocks[i](x))
|
| 465 |
-
if i < len(self.res_blocks) - 1:
|
| 466 |
-
x = self.resamplers[i](x)
|
| 467 |
-
return out_features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/moge/models/utils.py
DELETED
|
@@ -1,477 +0,0 @@
|
|
| 1 |
-
import inspect
|
| 2 |
-
from functools import partial, wraps
|
| 3 |
-
from numbers import Number
|
| 4 |
-
from typing import Tuple, Union
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def wrap_module_with_gradient_checkpointing(module: nn.Module):
|
| 13 |
-
from torch.utils.checkpoint import checkpoint
|
| 14 |
-
|
| 15 |
-
class _CheckpointingWrapper(module.__class__):
|
| 16 |
-
_restore_cls = module.__class__
|
| 17 |
-
|
| 18 |
-
def forward(self, *args, **kwargs):
|
| 19 |
-
return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
|
| 20 |
-
|
| 21 |
-
module.__class__ = _CheckpointingWrapper
|
| 22 |
-
return module
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def unwrap_module_with_gradient_checkpointing(module: nn.Module):
|
| 26 |
-
module.__class__ = module.__class__._restore_cls
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def wrap_dinov2_attention_with_sdpa(module: nn.Module):
|
| 30 |
-
assert torch.__version__ >= "2.0", "SDPA requires PyTorch 2.0 or later"
|
| 31 |
-
|
| 32 |
-
class _AttentionWrapper(module.__class__):
|
| 33 |
-
def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 34 |
-
B, N, C = x.shape
|
| 35 |
-
qkv = (
|
| 36 |
-
self.qkv(x)
|
| 37 |
-
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 38 |
-
.permute(2, 0, 3, 1, 4)
|
| 39 |
-
) # (3, B, H, N, C // H)
|
| 40 |
-
|
| 41 |
-
q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
|
| 42 |
-
|
| 43 |
-
x = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
| 44 |
-
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
| 45 |
-
|
| 46 |
-
x = self.proj(x)
|
| 47 |
-
x = self.proj_drop(x)
|
| 48 |
-
return x
|
| 49 |
-
|
| 50 |
-
module.__class__ = _AttentionWrapper
|
| 51 |
-
return module
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def sync_ddp_hook(
|
| 55 |
-
state, bucket: torch.distributed.GradBucket
|
| 56 |
-
) -> torch.futures.Future[torch.Tensor]:
|
| 57 |
-
group_to_use = torch.distributed.group.WORLD
|
| 58 |
-
world_size = group_to_use.size()
|
| 59 |
-
grad = bucket.buffer()
|
| 60 |
-
grad.div_(world_size)
|
| 61 |
-
torch.distributed.all_reduce(grad, group=group_to_use)
|
| 62 |
-
fut = torch.futures.Future()
|
| 63 |
-
fut.set_result(grad)
|
| 64 |
-
return fut
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def normalized_view_plane_uv(
|
| 68 |
-
width: int,
|
| 69 |
-
height: int,
|
| 70 |
-
aspect_ratio: float = None,
|
| 71 |
-
dtype: torch.dtype = None,
|
| 72 |
-
device: torch.device = None,
|
| 73 |
-
) -> torch.Tensor:
|
| 74 |
-
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
| 75 |
-
if aspect_ratio is None:
|
| 76 |
-
aspect_ratio = width / height
|
| 77 |
-
|
| 78 |
-
span_x = aspect_ratio / (1 + aspect_ratio**2) ** 0.5
|
| 79 |
-
span_y = 1 / (1 + aspect_ratio**2) ** 0.5
|
| 80 |
-
|
| 81 |
-
u = torch.linspace(
|
| 82 |
-
-span_x * (width - 1) / width,
|
| 83 |
-
span_x * (width - 1) / width,
|
| 84 |
-
width,
|
| 85 |
-
dtype=dtype,
|
| 86 |
-
device=device,
|
| 87 |
-
)
|
| 88 |
-
v = torch.linspace(
|
| 89 |
-
-span_y * (height - 1) / height,
|
| 90 |
-
span_y * (height - 1) / height,
|
| 91 |
-
height,
|
| 92 |
-
dtype=dtype,
|
| 93 |
-
device=device,
|
| 94 |
-
)
|
| 95 |
-
u, v = torch.meshgrid(u, v, indexing="xy")
|
| 96 |
-
uv = torch.stack([u, v], dim=-1)
|
| 97 |
-
return uv
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
|
| 101 |
-
"Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
|
| 102 |
-
from scipy.optimize import least_squares
|
| 103 |
-
|
| 104 |
-
uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
|
| 105 |
-
|
| 106 |
-
def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
|
| 107 |
-
xy_proj = xy / (z + shift)[:, None]
|
| 108 |
-
f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
| 109 |
-
err = (f * xy_proj - uv).ravel()
|
| 110 |
-
return err
|
| 111 |
-
|
| 112 |
-
solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method="lm")
|
| 113 |
-
optim_shift = solution["x"].squeeze().astype(np.float32)
|
| 114 |
-
|
| 115 |
-
xy_proj = xy / (z + optim_shift)[:, None]
|
| 116 |
-
optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
| 117 |
-
|
| 118 |
-
return optim_shift, optim_focal
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
|
| 122 |
-
"Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
|
| 123 |
-
from scipy.optimize import least_squares
|
| 124 |
-
|
| 125 |
-
uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
|
| 126 |
-
|
| 127 |
-
def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
|
| 128 |
-
xy_proj = xy / (z + shift)[:, None]
|
| 129 |
-
err = (focal * xy_proj - uv).ravel()
|
| 130 |
-
return err
|
| 131 |
-
|
| 132 |
-
solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method="lm")
|
| 133 |
-
optim_shift = solution["x"].squeeze().astype(np.float32)
|
| 134 |
-
|
| 135 |
-
return optim_shift
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def recover_focal_shift(
|
| 139 |
-
points: torch.Tensor,
|
| 140 |
-
mask: torch.Tensor = None,
|
| 141 |
-
focal: torch.Tensor = None,
|
| 142 |
-
downsample_size: Tuple[int, int] = (64, 64),
|
| 143 |
-
):
|
| 144 |
-
"""
|
| 145 |
-
Recover the depth map and FoV from a point map with unknown z shift and focal.
|
| 146 |
-
|
| 147 |
-
Note that it assumes:
|
| 148 |
-
- the optical center is at the center of the map
|
| 149 |
-
- the map is undistorted
|
| 150 |
-
- the map is isometric in the x and y directions
|
| 151 |
-
|
| 152 |
-
### Parameters:
|
| 153 |
-
- `points: torch.Tensor` of shape (..., H, W, 3)
|
| 154 |
-
- `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
|
| 155 |
-
|
| 156 |
-
### Returns:
|
| 157 |
-
- `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
|
| 158 |
-
- `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
|
| 159 |
-
"""
|
| 160 |
-
shape = points.shape
|
| 161 |
-
height, width = points.shape[-3], points.shape[-2]
|
| 162 |
-
|
| 163 |
-
points = points.reshape(-1, *shape[-3:])
|
| 164 |
-
mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
|
| 165 |
-
focal = focal.reshape(-1) if focal is not None else None
|
| 166 |
-
uv = normalized_view_plane_uv(
|
| 167 |
-
width, height, dtype=points.dtype, device=points.device
|
| 168 |
-
) # (H, W, 2)
|
| 169 |
-
|
| 170 |
-
points_lr = F.interpolate(
|
| 171 |
-
points.permute(0, 3, 1, 2), downsample_size, mode="nearest"
|
| 172 |
-
).permute(0, 2, 3, 1)
|
| 173 |
-
uv_lr = (
|
| 174 |
-
F.interpolate(
|
| 175 |
-
uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode="nearest"
|
| 176 |
-
)
|
| 177 |
-
.squeeze(0)
|
| 178 |
-
.permute(1, 2, 0)
|
| 179 |
-
)
|
| 180 |
-
mask_lr = (
|
| 181 |
-
None
|
| 182 |
-
if mask is None
|
| 183 |
-
else F.interpolate(
|
| 184 |
-
mask.to(torch.float32).unsqueeze(1), downsample_size, mode="nearest"
|
| 185 |
-
).squeeze(1)
|
| 186 |
-
> 0
|
| 187 |
-
)
|
| 188 |
-
|
| 189 |
-
uv_lr_np = uv_lr.cpu().numpy()
|
| 190 |
-
points_lr_np = points_lr.detach().cpu().numpy()
|
| 191 |
-
focal_np = focal.cpu().numpy() if focal is not None else None
|
| 192 |
-
mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
|
| 193 |
-
optim_shift, optim_focal = [], []
|
| 194 |
-
for i in range(points.shape[0]):
|
| 195 |
-
points_lr_i_np = (
|
| 196 |
-
points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
|
| 197 |
-
)
|
| 198 |
-
uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
|
| 199 |
-
if uv_lr_i_np.shape[0] < 2:
|
| 200 |
-
optim_focal.append(1)
|
| 201 |
-
optim_shift.append(0)
|
| 202 |
-
continue
|
| 203 |
-
if focal is None:
|
| 204 |
-
optim_shift_i, optim_focal_i = solve_optimal_focal_shift(
|
| 205 |
-
uv_lr_i_np, points_lr_i_np
|
| 206 |
-
)
|
| 207 |
-
optim_focal.append(float(optim_focal_i))
|
| 208 |
-
else:
|
| 209 |
-
optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
|
| 210 |
-
optim_shift.append(float(optim_shift_i))
|
| 211 |
-
optim_shift = torch.tensor(
|
| 212 |
-
optim_shift, device=points.device, dtype=points.dtype
|
| 213 |
-
).reshape(shape[:-3])
|
| 214 |
-
|
| 215 |
-
if focal is None:
|
| 216 |
-
optim_focal = torch.tensor(
|
| 217 |
-
optim_focal, device=points.device, dtype=points.dtype
|
| 218 |
-
).reshape(shape[:-3])
|
| 219 |
-
else:
|
| 220 |
-
optim_focal = focal.reshape(shape[:-3])
|
| 221 |
-
|
| 222 |
-
return optim_focal, optim_shift
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
def suppress_traceback(fn):
|
| 226 |
-
@wraps(fn)
|
| 227 |
-
def wrapper(*args, **kwargs):
|
| 228 |
-
try:
|
| 229 |
-
return fn(*args, **kwargs)
|
| 230 |
-
except Exception as e:
|
| 231 |
-
e.__traceback__ = e.__traceback__.tb_next.tb_next
|
| 232 |
-
raise
|
| 233 |
-
|
| 234 |
-
return wrapper
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
def get_device(args, kwargs):
|
| 238 |
-
device = None
|
| 239 |
-
for arg in list(args) + list(kwargs.values()):
|
| 240 |
-
if isinstance(arg, torch.Tensor):
|
| 241 |
-
if device is None:
|
| 242 |
-
device = arg.device
|
| 243 |
-
elif device != arg.device:
|
| 244 |
-
raise ValueError("All tensors must be on the same device.")
|
| 245 |
-
return device
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
def get_args_order(func, args, kwargs):
|
| 249 |
-
"""
|
| 250 |
-
Get the order of the arguments of a function.
|
| 251 |
-
"""
|
| 252 |
-
names = inspect.getfullargspec(func).args
|
| 253 |
-
names_idx = {name: i for i, name in enumerate(names)}
|
| 254 |
-
args_order = []
|
| 255 |
-
kwargs_order = {}
|
| 256 |
-
for name, arg in kwargs.items():
|
| 257 |
-
if name in names:
|
| 258 |
-
kwargs_order[name] = names_idx[name]
|
| 259 |
-
names.remove(name)
|
| 260 |
-
for i, arg in enumerate(args):
|
| 261 |
-
if i < len(names):
|
| 262 |
-
args_order.append(names_idx[names[i]])
|
| 263 |
-
return args_order, kwargs_order
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
def broadcast_args(args, kwargs, args_dim, kwargs_dim):
|
| 267 |
-
spatial = []
|
| 268 |
-
for arg, arg_dim in zip(
|
| 269 |
-
args + list(kwargs.values()), args_dim + list(kwargs_dim.values())
|
| 270 |
-
):
|
| 271 |
-
if isinstance(arg, torch.Tensor) and arg_dim is not None:
|
| 272 |
-
arg_spatial = arg.shape[: arg.ndim - arg_dim]
|
| 273 |
-
if len(arg_spatial) > len(spatial):
|
| 274 |
-
spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial
|
| 275 |
-
for j in range(len(arg_spatial)):
|
| 276 |
-
if spatial[-j] < arg_spatial[-j]:
|
| 277 |
-
if spatial[-j] == 1:
|
| 278 |
-
spatial[-j] = arg_spatial[-j]
|
| 279 |
-
else:
|
| 280 |
-
raise ValueError("Cannot broadcast arguments.")
|
| 281 |
-
for i, arg in enumerate(args):
|
| 282 |
-
if isinstance(arg, torch.Tensor) and args_dim[i] is not None:
|
| 283 |
-
args[i] = torch.broadcast_to(
|
| 284 |
-
arg, [*spatial, *arg.shape[arg.ndim - args_dim[i] :]]
|
| 285 |
-
)
|
| 286 |
-
for key, arg in kwargs.items():
|
| 287 |
-
if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None:
|
| 288 |
-
kwargs[key] = torch.broadcast_to(
|
| 289 |
-
arg, [*spatial, *arg.shape[arg.ndim - kwargs_dim[key] :]]
|
| 290 |
-
)
|
| 291 |
-
return args, kwargs, spatial
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
@suppress_traceback
|
| 295 |
-
def batched(*dims):
|
| 296 |
-
"""
|
| 297 |
-
Decorator that allows a function to be called with batched arguments.
|
| 298 |
-
"""
|
| 299 |
-
|
| 300 |
-
def decorator(func):
|
| 301 |
-
@wraps(func)
|
| 302 |
-
def wrapper(*args, device=torch.device("cpu"), **kwargs):
|
| 303 |
-
args = list(args)
|
| 304 |
-
# get arguments dimensions
|
| 305 |
-
args_order, kwargs_order = get_args_order(func, args, kwargs)
|
| 306 |
-
args_dim = [dims[i] for i in args_order]
|
| 307 |
-
kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()}
|
| 308 |
-
# convert to torch tensor
|
| 309 |
-
device = get_device(args, kwargs) or device
|
| 310 |
-
for i, arg in enumerate(args):
|
| 311 |
-
if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None:
|
| 312 |
-
args[i] = torch.tensor(arg, device=device)
|
| 313 |
-
for key, arg in kwargs.items():
|
| 314 |
-
if (
|
| 315 |
-
isinstance(arg, (Number, list, tuple))
|
| 316 |
-
and kwargs_dim[key] is not None
|
| 317 |
-
):
|
| 318 |
-
kwargs[key] = torch.tensor(arg, device=device)
|
| 319 |
-
# broadcast arguments
|
| 320 |
-
args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim)
|
| 321 |
-
for i, (arg, arg_dim) in enumerate(zip(args, args_dim)):
|
| 322 |
-
if isinstance(arg, torch.Tensor) and arg_dim is not None:
|
| 323 |
-
args[i] = arg.reshape([-1, *arg.shape[arg.ndim - arg_dim :]])
|
| 324 |
-
for key, arg in kwargs.items():
|
| 325 |
-
if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None:
|
| 326 |
-
kwargs[key] = arg.reshape(
|
| 327 |
-
[-1, *arg.shape[arg.ndim - kwargs_dim[key] :]]
|
| 328 |
-
)
|
| 329 |
-
# call function
|
| 330 |
-
results = func(*args, **kwargs)
|
| 331 |
-
type_results = type(results)
|
| 332 |
-
results = list(results) if isinstance(results, (tuple, list)) else [results]
|
| 333 |
-
# restore spatial dimensions
|
| 334 |
-
for i, result in enumerate(results):
|
| 335 |
-
results[i] = result.reshape([*spatial, *result.shape[1:]])
|
| 336 |
-
if type_results is tuple:
|
| 337 |
-
results = tuple(results)
|
| 338 |
-
elif type_results is list:
|
| 339 |
-
results = list(results)
|
| 340 |
-
else:
|
| 341 |
-
results = results[0]
|
| 342 |
-
return results
|
| 343 |
-
|
| 344 |
-
return wrapper
|
| 345 |
-
|
| 346 |
-
return decorator
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
def image_uv(
|
| 350 |
-
height: int,
|
| 351 |
-
width: int,
|
| 352 |
-
left: int = None,
|
| 353 |
-
top: int = None,
|
| 354 |
-
right: int = None,
|
| 355 |
-
bottom: int = None,
|
| 356 |
-
device: torch.device = None,
|
| 357 |
-
dtype: torch.dtype = None,
|
| 358 |
-
) -> torch.Tensor:
|
| 359 |
-
"""
|
| 360 |
-
Get image space UV grid, ranging in [0, 1].
|
| 361 |
-
|
| 362 |
-
>>> image_uv(10, 10):
|
| 363 |
-
[[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]],
|
| 364 |
-
[[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]],
|
| 365 |
-
... ... ...
|
| 366 |
-
[[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]]
|
| 367 |
-
|
| 368 |
-
Args:
|
| 369 |
-
width (int): image width
|
| 370 |
-
height (int): image height
|
| 371 |
-
|
| 372 |
-
Returns:
|
| 373 |
-
torch.Tensor: shape (height, width, 2)
|
| 374 |
-
"""
|
| 375 |
-
if left is None:
|
| 376 |
-
left = 0
|
| 377 |
-
if top is None:
|
| 378 |
-
top = 0
|
| 379 |
-
if right is None:
|
| 380 |
-
right = width
|
| 381 |
-
if bottom is None:
|
| 382 |
-
bottom = height
|
| 383 |
-
u = torch.linspace(
|
| 384 |
-
(left + 0.5) / width,
|
| 385 |
-
(right - 0.5) / width,
|
| 386 |
-
right - left,
|
| 387 |
-
device=device,
|
| 388 |
-
dtype=dtype,
|
| 389 |
-
)
|
| 390 |
-
v = torch.linspace(
|
| 391 |
-
(top + 0.5) / height,
|
| 392 |
-
(bottom - 0.5) / height,
|
| 393 |
-
bottom - top,
|
| 394 |
-
device=device,
|
| 395 |
-
dtype=dtype,
|
| 396 |
-
)
|
| 397 |
-
u, v = torch.meshgrid(u, v, indexing="xy")
|
| 398 |
-
uv = torch.stack([u, v], dim=-1)
|
| 399 |
-
|
| 400 |
-
return uv
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
@batched(2, 1, 2, 2)
|
| 404 |
-
def unproject_cv(
|
| 405 |
-
uv_coord: torch.Tensor,
|
| 406 |
-
depth: torch.Tensor = None,
|
| 407 |
-
extrinsics: torch.Tensor = None,
|
| 408 |
-
intrinsics: torch.Tensor = None,
|
| 409 |
-
) -> torch.Tensor:
|
| 410 |
-
"""
|
| 411 |
-
Unproject uv coordinates to 3D view space following the OpenCV convention
|
| 412 |
-
|
| 413 |
-
Args:
|
| 414 |
-
uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
|
| 415 |
-
The origin (0., 0.) is corresponding to the left & top
|
| 416 |
-
depth (torch.Tensor): [..., N] depth value
|
| 417 |
-
extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
|
| 418 |
-
intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
|
| 419 |
-
|
| 420 |
-
Returns:
|
| 421 |
-
points (torch.Tensor): [..., N, 3] 3d points
|
| 422 |
-
"""
|
| 423 |
-
assert intrinsics is not None, "intrinsics matrix is required"
|
| 424 |
-
points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1)
|
| 425 |
-
points = points @ torch.inverse(intrinsics).transpose(-2, -1)
|
| 426 |
-
if depth is not None:
|
| 427 |
-
points = points * depth[..., None]
|
| 428 |
-
if extrinsics is not None:
|
| 429 |
-
points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
| 430 |
-
points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3]
|
| 431 |
-
return points
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
def depth_to_points(
|
| 435 |
-
depth: torch.Tensor, intrinsics: torch.Tensor, extrinsics: torch.Tensor = None
|
| 436 |
-
):
|
| 437 |
-
height, width = depth.shape[-2:]
|
| 438 |
-
uv = image_uv(width=width, height=height, dtype=depth.dtype, device=depth.device)
|
| 439 |
-
pts = unproject_cv(
|
| 440 |
-
uv,
|
| 441 |
-
depth,
|
| 442 |
-
intrinsics=intrinsics[..., None, :, :],
|
| 443 |
-
extrinsics=extrinsics[..., None, :, :] if extrinsics is not None else None,
|
| 444 |
-
)
|
| 445 |
-
|
| 446 |
-
return pts
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
@batched(0, 0, 0, 0, 0, 0)
|
| 450 |
-
def intrinsics_from_focal_center(
|
| 451 |
-
fx: Union[float, torch.Tensor],
|
| 452 |
-
fy: Union[float, torch.Tensor],
|
| 453 |
-
cx: Union[float, torch.Tensor],
|
| 454 |
-
cy: Union[float, torch.Tensor],
|
| 455 |
-
) -> torch.Tensor:
|
| 456 |
-
"""
|
| 457 |
-
Get OpenCV intrinsics matrix
|
| 458 |
-
|
| 459 |
-
Args:
|
| 460 |
-
focal_x (float | torch.Tensor): focal length in x axis
|
| 461 |
-
focal_y (float | torch.Tensor): focal length in y axis
|
| 462 |
-
cx (float | torch.Tensor): principal point in x axis
|
| 463 |
-
cy (float | torch.Tensor): principal point in y axis
|
| 464 |
-
|
| 465 |
-
Returns:
|
| 466 |
-
(torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
|
| 467 |
-
"""
|
| 468 |
-
N = fx.shape[0]
|
| 469 |
-
ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device)
|
| 470 |
-
zeros, ones = (
|
| 471 |
-
torch.zeros(N, dtype=fx.dtype, device=fx.device),
|
| 472 |
-
torch.ones(N, dtype=fx.dtype, device=fx.device),
|
| 473 |
-
)
|
| 474 |
-
ret = torch.stack(
|
| 475 |
-
[fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1
|
| 476 |
-
).unflatten(-1, (3, 3))
|
| 477 |
-
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/moge/models/v1.py
DELETED
|
@@ -1,595 +0,0 @@
|
|
| 1 |
-
import importlib
|
| 2 |
-
from numbers import Number
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Any, Dict, IO, List, Literal, Optional, Tuple, Union
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
import torch.utils
|
| 10 |
-
import torch.utils.checkpoint
|
| 11 |
-
import torch.version
|
| 12 |
-
from huggingface_hub import hf_hub_download
|
| 13 |
-
|
| 14 |
-
from mapanything.models.external.moge.models.utils import (
|
| 15 |
-
depth_to_points,
|
| 16 |
-
intrinsics_from_focal_center,
|
| 17 |
-
normalized_view_plane_uv,
|
| 18 |
-
recover_focal_shift,
|
| 19 |
-
wrap_module_with_gradient_checkpointing,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class ResidualConvBlock(nn.Module):
|
| 24 |
-
def __init__(
|
| 25 |
-
self,
|
| 26 |
-
in_channels: int,
|
| 27 |
-
out_channels: int = None,
|
| 28 |
-
hidden_channels: int = None,
|
| 29 |
-
padding_mode: str = "replicate",
|
| 30 |
-
activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu",
|
| 31 |
-
norm: Literal["group_norm", "layer_norm"] = "group_norm",
|
| 32 |
-
):
|
| 33 |
-
super(ResidualConvBlock, self).__init__()
|
| 34 |
-
if out_channels is None:
|
| 35 |
-
out_channels = in_channels
|
| 36 |
-
if hidden_channels is None:
|
| 37 |
-
hidden_channels = in_channels
|
| 38 |
-
|
| 39 |
-
if activation == "relu":
|
| 40 |
-
activation_cls = lambda: nn.ReLU(inplace=True) # noqa
|
| 41 |
-
elif activation == "leaky_relu":
|
| 42 |
-
activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) # noqa
|
| 43 |
-
elif activation == "silu":
|
| 44 |
-
activation_cls = lambda: nn.SiLU(inplace=True) # noqa
|
| 45 |
-
elif activation == "elu":
|
| 46 |
-
activation_cls = lambda: nn.ELU(inplace=True) # noqa
|
| 47 |
-
else:
|
| 48 |
-
raise ValueError(f"Unsupported activation function: {activation}")
|
| 49 |
-
|
| 50 |
-
self.layers = nn.Sequential(
|
| 51 |
-
nn.GroupNorm(1, in_channels),
|
| 52 |
-
activation_cls(),
|
| 53 |
-
nn.Conv2d(
|
| 54 |
-
in_channels,
|
| 55 |
-
hidden_channels,
|
| 56 |
-
kernel_size=3,
|
| 57 |
-
padding=1,
|
| 58 |
-
padding_mode=padding_mode,
|
| 59 |
-
),
|
| 60 |
-
nn.GroupNorm(
|
| 61 |
-
hidden_channels // 32 if norm == "group_norm" else 1, hidden_channels
|
| 62 |
-
),
|
| 63 |
-
activation_cls(),
|
| 64 |
-
nn.Conv2d(
|
| 65 |
-
hidden_channels,
|
| 66 |
-
out_channels,
|
| 67 |
-
kernel_size=3,
|
| 68 |
-
padding=1,
|
| 69 |
-
padding_mode=padding_mode,
|
| 70 |
-
),
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
self.skip_connection = (
|
| 74 |
-
nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
| 75 |
-
if in_channels != out_channels
|
| 76 |
-
else nn.Identity()
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
def forward(self, x):
|
| 80 |
-
skip = self.skip_connection(x)
|
| 81 |
-
x = self.layers(x)
|
| 82 |
-
x = x + skip
|
| 83 |
-
return x
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
class Head(nn.Module):
|
| 87 |
-
def __init__(
|
| 88 |
-
self,
|
| 89 |
-
num_features: int,
|
| 90 |
-
dim_in: int,
|
| 91 |
-
dim_out: List[int],
|
| 92 |
-
dim_proj: int = 512,
|
| 93 |
-
dim_upsample: List[int] = [256, 128, 128],
|
| 94 |
-
dim_times_res_block_hidden: int = 1,
|
| 95 |
-
num_res_blocks: int = 1,
|
| 96 |
-
res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm",
|
| 97 |
-
last_res_blocks: int = 0,
|
| 98 |
-
last_conv_channels: int = 32,
|
| 99 |
-
last_conv_size: int = 1,
|
| 100 |
-
):
|
| 101 |
-
super().__init__()
|
| 102 |
-
|
| 103 |
-
self.projects = nn.ModuleList(
|
| 104 |
-
[
|
| 105 |
-
nn.Conv2d(
|
| 106 |
-
in_channels=dim_in,
|
| 107 |
-
out_channels=dim_proj,
|
| 108 |
-
kernel_size=1,
|
| 109 |
-
stride=1,
|
| 110 |
-
padding=0,
|
| 111 |
-
)
|
| 112 |
-
for _ in range(num_features)
|
| 113 |
-
]
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
self.upsample_blocks = nn.ModuleList(
|
| 117 |
-
[
|
| 118 |
-
nn.Sequential(
|
| 119 |
-
self._make_upsampler(in_ch + 2, out_ch),
|
| 120 |
-
*(
|
| 121 |
-
ResidualConvBlock(
|
| 122 |
-
out_ch,
|
| 123 |
-
out_ch,
|
| 124 |
-
dim_times_res_block_hidden * out_ch,
|
| 125 |
-
activation="relu",
|
| 126 |
-
norm=res_block_norm,
|
| 127 |
-
)
|
| 128 |
-
for _ in range(num_res_blocks)
|
| 129 |
-
),
|
| 130 |
-
)
|
| 131 |
-
for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
|
| 132 |
-
]
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
self.output_block = nn.ModuleList(
|
| 136 |
-
[
|
| 137 |
-
self._make_output_block(
|
| 138 |
-
dim_upsample[-1] + 2,
|
| 139 |
-
dim_out_,
|
| 140 |
-
dim_times_res_block_hidden,
|
| 141 |
-
last_res_blocks,
|
| 142 |
-
last_conv_channels,
|
| 143 |
-
last_conv_size,
|
| 144 |
-
res_block_norm,
|
| 145 |
-
)
|
| 146 |
-
for dim_out_ in dim_out
|
| 147 |
-
]
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
def _make_upsampler(self, in_channels: int, out_channels: int):
|
| 151 |
-
upsampler = nn.Sequential(
|
| 152 |
-
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
|
| 153 |
-
nn.Conv2d(
|
| 154 |
-
out_channels,
|
| 155 |
-
out_channels,
|
| 156 |
-
kernel_size=3,
|
| 157 |
-
stride=1,
|
| 158 |
-
padding=1,
|
| 159 |
-
padding_mode="replicate",
|
| 160 |
-
),
|
| 161 |
-
)
|
| 162 |
-
upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
|
| 163 |
-
return upsampler
|
| 164 |
-
|
| 165 |
-
def _make_output_block(
|
| 166 |
-
self,
|
| 167 |
-
dim_in: int,
|
| 168 |
-
dim_out: int,
|
| 169 |
-
dim_times_res_block_hidden: int,
|
| 170 |
-
last_res_blocks: int,
|
| 171 |
-
last_conv_channels: int,
|
| 172 |
-
last_conv_size: int,
|
| 173 |
-
res_block_norm: Literal["group_norm", "layer_norm"],
|
| 174 |
-
):
|
| 175 |
-
return nn.Sequential(
|
| 176 |
-
nn.Conv2d(
|
| 177 |
-
dim_in,
|
| 178 |
-
last_conv_channels,
|
| 179 |
-
kernel_size=3,
|
| 180 |
-
stride=1,
|
| 181 |
-
padding=1,
|
| 182 |
-
padding_mode="replicate",
|
| 183 |
-
),
|
| 184 |
-
*(
|
| 185 |
-
ResidualConvBlock(
|
| 186 |
-
last_conv_channels,
|
| 187 |
-
last_conv_channels,
|
| 188 |
-
dim_times_res_block_hidden * last_conv_channels,
|
| 189 |
-
activation="relu",
|
| 190 |
-
norm=res_block_norm,
|
| 191 |
-
)
|
| 192 |
-
for _ in range(last_res_blocks)
|
| 193 |
-
),
|
| 194 |
-
nn.ReLU(inplace=True),
|
| 195 |
-
nn.Conv2d(
|
| 196 |
-
last_conv_channels,
|
| 197 |
-
dim_out,
|
| 198 |
-
kernel_size=last_conv_size,
|
| 199 |
-
stride=1,
|
| 200 |
-
padding=last_conv_size // 2,
|
| 201 |
-
padding_mode="replicate",
|
| 202 |
-
),
|
| 203 |
-
)
|
| 204 |
-
|
| 205 |
-
def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
|
| 206 |
-
img_h, img_w = image.shape[-2:]
|
| 207 |
-
patch_h, patch_w = img_h // 14, img_w // 14
|
| 208 |
-
|
| 209 |
-
# Process the hidden states
|
| 210 |
-
x = torch.stack(
|
| 211 |
-
[
|
| 212 |
-
proj(
|
| 213 |
-
feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()
|
| 214 |
-
)
|
| 215 |
-
for proj, (feat, clstoken) in zip(self.projects, hidden_states)
|
| 216 |
-
],
|
| 217 |
-
dim=1,
|
| 218 |
-
).sum(dim=1)
|
| 219 |
-
|
| 220 |
-
# Upsample stage
|
| 221 |
-
# (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
|
| 222 |
-
for i, block in enumerate(self.upsample_blocks):
|
| 223 |
-
# UV coordinates is for awareness of image aspect ratio
|
| 224 |
-
uv = normalized_view_plane_uv(
|
| 225 |
-
width=x.shape[-1],
|
| 226 |
-
height=x.shape[-2],
|
| 227 |
-
aspect_ratio=img_w / img_h,
|
| 228 |
-
dtype=x.dtype,
|
| 229 |
-
device=x.device,
|
| 230 |
-
)
|
| 231 |
-
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
|
| 232 |
-
x = torch.cat([x, uv], dim=1)
|
| 233 |
-
for layer in block:
|
| 234 |
-
x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
|
| 235 |
-
|
| 236 |
-
# (patch_h * 8, patch_w * 8) -> (img_h, img_w)
|
| 237 |
-
x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
|
| 238 |
-
uv = normalized_view_plane_uv(
|
| 239 |
-
width=x.shape[-1],
|
| 240 |
-
height=x.shape[-2],
|
| 241 |
-
aspect_ratio=img_w / img_h,
|
| 242 |
-
dtype=x.dtype,
|
| 243 |
-
device=x.device,
|
| 244 |
-
)
|
| 245 |
-
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
|
| 246 |
-
x = torch.cat([x, uv], dim=1)
|
| 247 |
-
|
| 248 |
-
if isinstance(self.output_block, nn.ModuleList):
|
| 249 |
-
output = [
|
| 250 |
-
torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
|
| 251 |
-
for block in self.output_block
|
| 252 |
-
]
|
| 253 |
-
else:
|
| 254 |
-
output = torch.utils.checkpoint.checkpoint(
|
| 255 |
-
self.output_block, x, use_reentrant=False
|
| 256 |
-
)
|
| 257 |
-
|
| 258 |
-
return output
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
class MoGeModel(nn.Module):
|
| 262 |
-
image_mean: torch.Tensor
|
| 263 |
-
image_std: torch.Tensor
|
| 264 |
-
|
| 265 |
-
def __init__(
|
| 266 |
-
self,
|
| 267 |
-
encoder: str = "dinov2_vitb14",
|
| 268 |
-
intermediate_layers: Union[int, List[int]] = 4,
|
| 269 |
-
dim_proj: int = 512,
|
| 270 |
-
dim_upsample: List[int] = [256, 128, 128],
|
| 271 |
-
dim_times_res_block_hidden: int = 1,
|
| 272 |
-
num_res_blocks: int = 1,
|
| 273 |
-
remap_output: Literal[
|
| 274 |
-
False, True, "linear", "sinh", "exp", "sinh_exp"
|
| 275 |
-
] = "linear",
|
| 276 |
-
res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm",
|
| 277 |
-
num_tokens_range: Tuple[Number, Number] = [1200, 2500],
|
| 278 |
-
last_res_blocks: int = 0,
|
| 279 |
-
last_conv_channels: int = 32,
|
| 280 |
-
last_conv_size: int = 1,
|
| 281 |
-
mask_threshold: float = 0.5,
|
| 282 |
-
**deprecated_kwargs,
|
| 283 |
-
):
|
| 284 |
-
super(MoGeModel, self).__init__()
|
| 285 |
-
|
| 286 |
-
if deprecated_kwargs:
|
| 287 |
-
# Process legacy arguments
|
| 288 |
-
if "trained_area_range" in deprecated_kwargs:
|
| 289 |
-
num_tokens_range = [
|
| 290 |
-
deprecated_kwargs["trained_area_range"][0] // 14**2,
|
| 291 |
-
deprecated_kwargs["trained_area_range"][1] // 14**2,
|
| 292 |
-
]
|
| 293 |
-
del deprecated_kwargs["trained_area_range"]
|
| 294 |
-
# warnings.warn(
|
| 295 |
-
# f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}"
|
| 296 |
-
# )
|
| 297 |
-
|
| 298 |
-
self.encoder = encoder
|
| 299 |
-
self.remap_output = remap_output
|
| 300 |
-
self.intermediate_layers = intermediate_layers
|
| 301 |
-
self.num_tokens_range = num_tokens_range
|
| 302 |
-
self.mask_threshold = mask_threshold
|
| 303 |
-
|
| 304 |
-
# NOTE: We have copied the DINOv2 code in torchhub to this repository.
|
| 305 |
-
# Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues.
|
| 306 |
-
hub_loader = getattr(
|
| 307 |
-
importlib.import_module(
|
| 308 |
-
"mapanything.models.external.dinov2.hub.backbones", __package__
|
| 309 |
-
),
|
| 310 |
-
encoder,
|
| 311 |
-
)
|
| 312 |
-
self.backbone = hub_loader(pretrained=False)
|
| 313 |
-
dim_feature = self.backbone.blocks[0].attn.qkv.in_features
|
| 314 |
-
|
| 315 |
-
self.head = Head(
|
| 316 |
-
num_features=intermediate_layers
|
| 317 |
-
if isinstance(intermediate_layers, int)
|
| 318 |
-
else len(intermediate_layers),
|
| 319 |
-
dim_in=dim_feature,
|
| 320 |
-
dim_out=[3, 1],
|
| 321 |
-
dim_proj=dim_proj,
|
| 322 |
-
dim_upsample=dim_upsample,
|
| 323 |
-
dim_times_res_block_hidden=dim_times_res_block_hidden,
|
| 324 |
-
num_res_blocks=num_res_blocks,
|
| 325 |
-
res_block_norm=res_block_norm,
|
| 326 |
-
last_res_blocks=last_res_blocks,
|
| 327 |
-
last_conv_channels=last_conv_channels,
|
| 328 |
-
last_conv_size=last_conv_size,
|
| 329 |
-
)
|
| 330 |
-
|
| 331 |
-
image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
| 332 |
-
image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
| 333 |
-
|
| 334 |
-
self.register_buffer("image_mean", image_mean)
|
| 335 |
-
self.register_buffer("image_std", image_std)
|
| 336 |
-
|
| 337 |
-
@property
|
| 338 |
-
def device(self) -> torch.device:
|
| 339 |
-
return next(self.parameters()).device
|
| 340 |
-
|
| 341 |
-
@property
|
| 342 |
-
def dtype(self) -> torch.dtype:
|
| 343 |
-
return next(self.parameters()).dtype
|
| 344 |
-
|
| 345 |
-
@classmethod
|
| 346 |
-
def from_pretrained(
|
| 347 |
-
cls,
|
| 348 |
-
pretrained_model_name_or_path: Union[str, Path, IO[bytes]],
|
| 349 |
-
model_kwargs: Optional[Dict[str, Any]] = None,
|
| 350 |
-
**hf_kwargs,
|
| 351 |
-
) -> "MoGeModel":
|
| 352 |
-
"""
|
| 353 |
-
Load a model from a checkpoint file.
|
| 354 |
-
|
| 355 |
-
### Parameters:
|
| 356 |
-
- `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
|
| 357 |
-
- `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
|
| 358 |
-
- `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
|
| 359 |
-
|
| 360 |
-
### Returns:
|
| 361 |
-
- A new instance of `MoGe` with the parameters loaded from the checkpoint.
|
| 362 |
-
"""
|
| 363 |
-
if Path(pretrained_model_name_or_path).exists():
|
| 364 |
-
checkpoint = torch.load(
|
| 365 |
-
pretrained_model_name_or_path, map_location="cpu", weights_only=True
|
| 366 |
-
)
|
| 367 |
-
else:
|
| 368 |
-
cached_checkpoint_path = hf_hub_download(
|
| 369 |
-
repo_id=pretrained_model_name_or_path,
|
| 370 |
-
repo_type="model",
|
| 371 |
-
filename="model.pt",
|
| 372 |
-
**hf_kwargs,
|
| 373 |
-
)
|
| 374 |
-
checkpoint = torch.load(
|
| 375 |
-
cached_checkpoint_path, map_location="cpu", weights_only=True
|
| 376 |
-
)
|
| 377 |
-
model_config = checkpoint["model_config"]
|
| 378 |
-
if model_kwargs is not None:
|
| 379 |
-
model_config.update(model_kwargs)
|
| 380 |
-
model = cls(**model_config)
|
| 381 |
-
model.load_state_dict(checkpoint["model"])
|
| 382 |
-
return model
|
| 383 |
-
|
| 384 |
-
def init_weights(self):
|
| 385 |
-
"Load the backbone with pretrained dinov2 weights from torch hub"
|
| 386 |
-
state_dict = torch.hub.load(
|
| 387 |
-
"facebookresearch/dinov2", self.encoder, pretrained=True
|
| 388 |
-
).state_dict()
|
| 389 |
-
self.backbone.load_state_dict(state_dict)
|
| 390 |
-
|
| 391 |
-
def enable_gradient_checkpointing(self):
|
| 392 |
-
for i in range(len(self.backbone.blocks)):
|
| 393 |
-
self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(
|
| 394 |
-
self.backbone.blocks[i]
|
| 395 |
-
)
|
| 396 |
-
|
| 397 |
-
def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
|
| 398 |
-
if self.remap_output == "linear":
|
| 399 |
-
pass
|
| 400 |
-
elif self.remap_output == "sinh":
|
| 401 |
-
points = torch.sinh(points)
|
| 402 |
-
elif self.remap_output == "exp":
|
| 403 |
-
xy, z = points.split([2, 1], dim=-1)
|
| 404 |
-
z = torch.exp(z)
|
| 405 |
-
points = torch.cat([xy * z, z], dim=-1)
|
| 406 |
-
elif self.remap_output == "sinh_exp":
|
| 407 |
-
xy, z = points.split([2, 1], dim=-1)
|
| 408 |
-
points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
|
| 409 |
-
else:
|
| 410 |
-
raise ValueError(f"Invalid remap output type: {self.remap_output}")
|
| 411 |
-
return points
|
| 412 |
-
|
| 413 |
-
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
|
| 414 |
-
original_height, original_width = image.shape[-2:]
|
| 415 |
-
|
| 416 |
-
# Resize to expected resolution defined by num_tokens
|
| 417 |
-
resize_factor = (
|
| 418 |
-
(num_tokens * 14**2) / (original_height * original_width)
|
| 419 |
-
) ** 0.5
|
| 420 |
-
resized_width, resized_height = (
|
| 421 |
-
int(original_width * resize_factor),
|
| 422 |
-
int(original_height * resize_factor),
|
| 423 |
-
)
|
| 424 |
-
image = F.interpolate(
|
| 425 |
-
image,
|
| 426 |
-
(resized_height, resized_width),
|
| 427 |
-
mode="bicubic",
|
| 428 |
-
align_corners=False,
|
| 429 |
-
antialias=True,
|
| 430 |
-
)
|
| 431 |
-
|
| 432 |
-
# Apply image transformation for DINOv2
|
| 433 |
-
image = (image - self.image_mean) / self.image_std
|
| 434 |
-
image_14 = F.interpolate(
|
| 435 |
-
image,
|
| 436 |
-
(resized_height // 14 * 14, resized_width // 14 * 14),
|
| 437 |
-
mode="bilinear",
|
| 438 |
-
align_corners=False,
|
| 439 |
-
antialias=True,
|
| 440 |
-
)
|
| 441 |
-
|
| 442 |
-
# Get intermediate layers from the backbone
|
| 443 |
-
features = self.backbone.get_intermediate_layers(
|
| 444 |
-
image_14, self.intermediate_layers, return_class_token=True
|
| 445 |
-
)
|
| 446 |
-
|
| 447 |
-
# Predict points (and mask)
|
| 448 |
-
output = self.head(features, image)
|
| 449 |
-
points, mask = output
|
| 450 |
-
|
| 451 |
-
# Make sure fp32 precision for output
|
| 452 |
-
with torch.autocast(device_type=image.device.type, dtype=torch.float32):
|
| 453 |
-
# Resize to original resolution
|
| 454 |
-
points = F.interpolate(
|
| 455 |
-
points,
|
| 456 |
-
(original_height, original_width),
|
| 457 |
-
mode="bilinear",
|
| 458 |
-
align_corners=False,
|
| 459 |
-
antialias=False,
|
| 460 |
-
)
|
| 461 |
-
mask = F.interpolate(
|
| 462 |
-
mask,
|
| 463 |
-
(original_height, original_width),
|
| 464 |
-
mode="bilinear",
|
| 465 |
-
align_corners=False,
|
| 466 |
-
antialias=False,
|
| 467 |
-
)
|
| 468 |
-
|
| 469 |
-
# Post-process points and mask
|
| 470 |
-
points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
|
| 471 |
-
points = self._remap_points(
|
| 472 |
-
points
|
| 473 |
-
) # slightly improves the performance in case of very large output values
|
| 474 |
-
|
| 475 |
-
return_dict = {"points": points, "mask": mask}
|
| 476 |
-
return return_dict
|
| 477 |
-
|
| 478 |
-
# @torch.inference_mode()
|
| 479 |
-
def infer(
|
| 480 |
-
self,
|
| 481 |
-
image: torch.Tensor,
|
| 482 |
-
fov_x: Union[Number, torch.Tensor] = None,
|
| 483 |
-
resolution_level: int = 9,
|
| 484 |
-
num_tokens: int = None,
|
| 485 |
-
apply_mask: bool = True,
|
| 486 |
-
force_projection: bool = True,
|
| 487 |
-
use_fp16: bool = True,
|
| 488 |
-
) -> Dict[str, torch.Tensor]:
|
| 489 |
-
"""
|
| 490 |
-
User-friendly inference function
|
| 491 |
-
|
| 492 |
-
### Parameters
|
| 493 |
-
- `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\
|
| 494 |
-
- `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
|
| 495 |
-
- `resolution_level`: An integer [0-9] for the resolution level for inference.
|
| 496 |
-
The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size.
|
| 497 |
-
`resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.
|
| 498 |
-
- `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`.
|
| 499 |
-
`resolution_level` will be ignored if `num_tokens` is provided. Default: None
|
| 500 |
-
- `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
|
| 501 |
-
- `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True
|
| 502 |
-
- `use_fp16`: if True, use mixed precision to speed up inference. Default: True
|
| 503 |
-
|
| 504 |
-
### Returns
|
| 505 |
-
|
| 506 |
-
A dictionary containing the following keys:
|
| 507 |
-
- `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
|
| 508 |
-
- `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
|
| 509 |
-
- `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
|
| 510 |
-
"""
|
| 511 |
-
if image.dim() == 3:
|
| 512 |
-
omit_batch_dim = True
|
| 513 |
-
image = image.unsqueeze(0)
|
| 514 |
-
else:
|
| 515 |
-
omit_batch_dim = False
|
| 516 |
-
image = image.to(dtype=self.dtype, device=self.device)
|
| 517 |
-
|
| 518 |
-
original_height, original_width = image.shape[-2:]
|
| 519 |
-
aspect_ratio = original_width / original_height
|
| 520 |
-
|
| 521 |
-
if num_tokens is None:
|
| 522 |
-
min_tokens, max_tokens = self.num_tokens_range
|
| 523 |
-
num_tokens = int(
|
| 524 |
-
min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)
|
| 525 |
-
)
|
| 526 |
-
|
| 527 |
-
with torch.autocast(
|
| 528 |
-
device_type=self.device.type,
|
| 529 |
-
dtype=torch.float16,
|
| 530 |
-
enabled=use_fp16 and self.dtype != torch.float16,
|
| 531 |
-
):
|
| 532 |
-
output = self.forward(image, num_tokens)
|
| 533 |
-
points, mask = output["points"], output["mask"]
|
| 534 |
-
|
| 535 |
-
# Always process the output in fp32 precision
|
| 536 |
-
with torch.autocast(device_type=self.device.type, dtype=torch.float32):
|
| 537 |
-
points, mask, fov_x = map(
|
| 538 |
-
lambda x: x.float() if isinstance(x, torch.Tensor) else x,
|
| 539 |
-
[points, mask, fov_x],
|
| 540 |
-
)
|
| 541 |
-
|
| 542 |
-
mask_binary = mask > self.mask_threshold
|
| 543 |
-
|
| 544 |
-
# Get camera-space point map. (Focal here is the focal length relative to half the image diagonal)
|
| 545 |
-
if fov_x is None:
|
| 546 |
-
focal, shift = recover_focal_shift(points, mask_binary)
|
| 547 |
-
else:
|
| 548 |
-
focal = (
|
| 549 |
-
aspect_ratio
|
| 550 |
-
/ (1 + aspect_ratio**2) ** 0.5
|
| 551 |
-
/ torch.tan(
|
| 552 |
-
torch.deg2rad(
|
| 553 |
-
torch.as_tensor(
|
| 554 |
-
fov_x, device=points.device, dtype=points.dtype
|
| 555 |
-
)
|
| 556 |
-
/ 2
|
| 557 |
-
)
|
| 558 |
-
)
|
| 559 |
-
)
|
| 560 |
-
if focal.ndim == 0:
|
| 561 |
-
focal = focal[None].expand(points.shape[0])
|
| 562 |
-
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
| 563 |
-
fx = focal / 2 * (1 + aspect_ratio**2) ** 0.5 / aspect_ratio
|
| 564 |
-
fy = focal / 2 * (1 + aspect_ratio**2) ** 0.5
|
| 565 |
-
intrinsics = intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
|
| 566 |
-
depth = points[..., 2] + shift[..., None, None]
|
| 567 |
-
|
| 568 |
-
# If projection constraint is forced, recompute the point map using the actual depth map
|
| 569 |
-
if force_projection:
|
| 570 |
-
points = depth_to_points(depth, intrinsics=intrinsics)
|
| 571 |
-
else:
|
| 572 |
-
points = (
|
| 573 |
-
points
|
| 574 |
-
+ torch.stack(
|
| 575 |
-
[torch.zeros_like(shift), torch.zeros_like(shift), shift],
|
| 576 |
-
dim=-1,
|
| 577 |
-
)[..., None, None, :]
|
| 578 |
-
)
|
| 579 |
-
|
| 580 |
-
# Apply mask if needed
|
| 581 |
-
if apply_mask:
|
| 582 |
-
points = torch.where(mask_binary[..., None], points, torch.inf)
|
| 583 |
-
depth = torch.where(mask_binary, depth, torch.inf)
|
| 584 |
-
|
| 585 |
-
return_dict = {
|
| 586 |
-
"points": points,
|
| 587 |
-
"intrinsics": intrinsics,
|
| 588 |
-
"depth": depth,
|
| 589 |
-
"mask": mask_binary,
|
| 590 |
-
}
|
| 591 |
-
|
| 592 |
-
if omit_batch_dim:
|
| 593 |
-
return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
|
| 594 |
-
|
| 595 |
-
return return_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/moge/models/v2.py
DELETED
|
@@ -1,379 +0,0 @@
|
|
| 1 |
-
import warnings
|
| 2 |
-
from numbers import Number
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Any, Dict, IO, List, Literal, Optional, Union
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.amp
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
import torch.utils
|
| 11 |
-
import torch.utils.checkpoint
|
| 12 |
-
import torch.version
|
| 13 |
-
from huggingface_hub import hf_hub_download
|
| 14 |
-
|
| 15 |
-
from mapanything.models.external.moge.models.modules import (
|
| 16 |
-
ConvStack,
|
| 17 |
-
DINOv2Encoder,
|
| 18 |
-
MLP,
|
| 19 |
-
)
|
| 20 |
-
from mapanything.models.external.moge.models.utils import (
|
| 21 |
-
depth_to_points,
|
| 22 |
-
intrinsics_from_focal_center,
|
| 23 |
-
normalized_view_plane_uv,
|
| 24 |
-
recover_focal_shift,
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class MoGeModel(nn.Module):
|
| 29 |
-
encoder: DINOv2Encoder
|
| 30 |
-
neck: ConvStack
|
| 31 |
-
points_head: ConvStack
|
| 32 |
-
mask_head: ConvStack
|
| 33 |
-
scale_head: MLP
|
| 34 |
-
onnx_compatible_mode: bool
|
| 35 |
-
|
| 36 |
-
def __init__(
|
| 37 |
-
self,
|
| 38 |
-
encoder: Dict[str, Any],
|
| 39 |
-
neck: Dict[str, Any],
|
| 40 |
-
points_head: Dict[str, Any] = None,
|
| 41 |
-
mask_head: Dict[str, Any] = None,
|
| 42 |
-
normal_head: Dict[str, Any] = None,
|
| 43 |
-
scale_head: Dict[str, Any] = None,
|
| 44 |
-
remap_output: Literal["linear", "sinh", "exp", "sinh_exp"] = "linear",
|
| 45 |
-
num_tokens_range: List[int] = [1200, 3600],
|
| 46 |
-
**deprecated_kwargs,
|
| 47 |
-
):
|
| 48 |
-
super(MoGeModel, self).__init__()
|
| 49 |
-
if deprecated_kwargs:
|
| 50 |
-
warnings.warn(
|
| 51 |
-
f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}"
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
self.remap_output = remap_output
|
| 55 |
-
self.num_tokens_range = num_tokens_range
|
| 56 |
-
|
| 57 |
-
self.encoder = DINOv2Encoder(**encoder)
|
| 58 |
-
self.neck = ConvStack(**neck)
|
| 59 |
-
if points_head is not None:
|
| 60 |
-
self.points_head = ConvStack(**points_head)
|
| 61 |
-
if mask_head is not None:
|
| 62 |
-
self.mask_head = ConvStack(**mask_head)
|
| 63 |
-
if normal_head is not None:
|
| 64 |
-
self.normal_head = ConvStack(**normal_head)
|
| 65 |
-
if scale_head is not None:
|
| 66 |
-
self.scale_head = MLP(**scale_head)
|
| 67 |
-
|
| 68 |
-
@property
|
| 69 |
-
def device(self) -> torch.device:
|
| 70 |
-
return next(self.parameters()).device
|
| 71 |
-
|
| 72 |
-
@property
|
| 73 |
-
def dtype(self) -> torch.dtype:
|
| 74 |
-
return next(self.parameters()).dtype
|
| 75 |
-
|
| 76 |
-
@property
|
| 77 |
-
def onnx_compatible_mode(self) -> bool:
|
| 78 |
-
return getattr(self, "_onnx_compatible_mode", False)
|
| 79 |
-
|
| 80 |
-
@onnx_compatible_mode.setter
|
| 81 |
-
def onnx_compatible_mode(self, value: bool):
|
| 82 |
-
self._onnx_compatible_mode = value
|
| 83 |
-
self.encoder.onnx_compatible_mode = value
|
| 84 |
-
|
| 85 |
-
@classmethod
|
| 86 |
-
def from_pretrained(
|
| 87 |
-
cls,
|
| 88 |
-
pretrained_model_name_or_path: Union[str, Path, IO[bytes]],
|
| 89 |
-
model_kwargs: Optional[Dict[str, Any]] = None,
|
| 90 |
-
**hf_kwargs,
|
| 91 |
-
) -> "MoGeModel":
|
| 92 |
-
"""
|
| 93 |
-
Load a model from a checkpoint file.
|
| 94 |
-
|
| 95 |
-
### Parameters:
|
| 96 |
-
- `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
|
| 97 |
-
- `compiled`
|
| 98 |
-
- `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
|
| 99 |
-
- `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
|
| 100 |
-
|
| 101 |
-
### Returns:
|
| 102 |
-
- A new instance of `MoGe` with the parameters loaded from the checkpoint.
|
| 103 |
-
"""
|
| 104 |
-
if Path(pretrained_model_name_or_path).exists():
|
| 105 |
-
checkpoint_path = pretrained_model_name_or_path
|
| 106 |
-
else:
|
| 107 |
-
checkpoint_path = hf_hub_download(
|
| 108 |
-
repo_id=pretrained_model_name_or_path,
|
| 109 |
-
repo_type="model",
|
| 110 |
-
filename="model.pt",
|
| 111 |
-
**hf_kwargs,
|
| 112 |
-
)
|
| 113 |
-
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
| 114 |
-
|
| 115 |
-
model_config = checkpoint["model_config"]
|
| 116 |
-
if model_kwargs is not None:
|
| 117 |
-
model_config.update(model_kwargs)
|
| 118 |
-
model = cls(**model_config)
|
| 119 |
-
model.load_state_dict(checkpoint["model"], strict=False)
|
| 120 |
-
|
| 121 |
-
return model
|
| 122 |
-
|
| 123 |
-
def init_weights(self):
|
| 124 |
-
self.encoder.init_weights()
|
| 125 |
-
|
| 126 |
-
def enable_gradient_checkpointing(self):
|
| 127 |
-
self.encoder.enable_gradient_checkpointing()
|
| 128 |
-
self.neck.enable_gradient_checkpointing()
|
| 129 |
-
for head in ["points_head", "normal_head", "mask_head"]:
|
| 130 |
-
if hasattr(self, head):
|
| 131 |
-
getattr(self, head).enable_gradient_checkpointing()
|
| 132 |
-
|
| 133 |
-
def enable_pytorch_native_sdpa(self):
|
| 134 |
-
self.encoder.enable_pytorch_native_sdpa()
|
| 135 |
-
|
| 136 |
-
def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
|
| 137 |
-
if self.remap_output == "linear":
|
| 138 |
-
pass
|
| 139 |
-
elif self.remap_output == "sinh":
|
| 140 |
-
points = torch.sinh(points)
|
| 141 |
-
elif self.remap_output == "exp":
|
| 142 |
-
xy, z = points.split([2, 1], dim=-1)
|
| 143 |
-
z = torch.exp(z)
|
| 144 |
-
points = torch.cat([xy * z, z], dim=-1)
|
| 145 |
-
elif self.remap_output == "sinh_exp":
|
| 146 |
-
xy, z = points.split([2, 1], dim=-1)
|
| 147 |
-
points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
|
| 148 |
-
else:
|
| 149 |
-
raise ValueError(f"Invalid remap output type: {self.remap_output}")
|
| 150 |
-
return points
|
| 151 |
-
|
| 152 |
-
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
|
| 153 |
-
batch_size, _, img_h, img_w = image.shape
|
| 154 |
-
device, dtype = image.device, image.dtype
|
| 155 |
-
|
| 156 |
-
aspect_ratio = img_w / img_h
|
| 157 |
-
base_h, base_w = (
|
| 158 |
-
int((num_tokens / aspect_ratio) ** 0.5),
|
| 159 |
-
int((num_tokens * aspect_ratio) ** 0.5),
|
| 160 |
-
)
|
| 161 |
-
num_tokens = base_h * base_w
|
| 162 |
-
|
| 163 |
-
# Backbones encoding
|
| 164 |
-
features, cls_token = self.encoder(
|
| 165 |
-
image, base_h, base_w, return_class_token=True
|
| 166 |
-
)
|
| 167 |
-
features = [features, None, None, None, None]
|
| 168 |
-
|
| 169 |
-
# Concat UVs for aspect ratio input
|
| 170 |
-
for level in range(5):
|
| 171 |
-
uv = normalized_view_plane_uv(
|
| 172 |
-
width=base_w * 2**level,
|
| 173 |
-
height=base_h * 2**level,
|
| 174 |
-
aspect_ratio=aspect_ratio,
|
| 175 |
-
dtype=dtype,
|
| 176 |
-
device=device,
|
| 177 |
-
)
|
| 178 |
-
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
|
| 179 |
-
if features[level] is None:
|
| 180 |
-
features[level] = uv
|
| 181 |
-
else:
|
| 182 |
-
features[level] = torch.concat([features[level], uv], dim=1)
|
| 183 |
-
|
| 184 |
-
# Shared neck
|
| 185 |
-
features = self.neck(features)
|
| 186 |
-
|
| 187 |
-
# Heads decoding
|
| 188 |
-
points, normal, mask = (
|
| 189 |
-
getattr(self, head)(features)[-1] if hasattr(self, head) else None
|
| 190 |
-
for head in ["points_head", "normal_head", "mask_head"]
|
| 191 |
-
)
|
| 192 |
-
metric_scale = (
|
| 193 |
-
self.scale_head(cls_token) if hasattr(self, "scale_head") else None
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
# Resize
|
| 197 |
-
points, normal, mask = (
|
| 198 |
-
F.interpolate(
|
| 199 |
-
v, (img_h, img_w), mode="bilinear", align_corners=False, antialias=False
|
| 200 |
-
)
|
| 201 |
-
if v is not None
|
| 202 |
-
else None
|
| 203 |
-
for v in [points, normal, mask]
|
| 204 |
-
)
|
| 205 |
-
|
| 206 |
-
# Remap output
|
| 207 |
-
if points is not None:
|
| 208 |
-
points = points.permute(0, 2, 3, 1)
|
| 209 |
-
points = self._remap_points(
|
| 210 |
-
points
|
| 211 |
-
) # slightly improves the performance in case of very large output values
|
| 212 |
-
if normal is not None:
|
| 213 |
-
normal = normal.permute(0, 2, 3, 1)
|
| 214 |
-
normal = F.normalize(normal, dim=-1)
|
| 215 |
-
if mask is not None:
|
| 216 |
-
mask = mask.squeeze(1).sigmoid()
|
| 217 |
-
if metric_scale is not None:
|
| 218 |
-
metric_scale = metric_scale.squeeze(1).exp()
|
| 219 |
-
|
| 220 |
-
return_dict = {
|
| 221 |
-
"points": points,
|
| 222 |
-
"normal": normal,
|
| 223 |
-
"mask": mask,
|
| 224 |
-
"metric_scale": metric_scale,
|
| 225 |
-
}
|
| 226 |
-
return_dict = {k: v for k, v in return_dict.items() if v is not None}
|
| 227 |
-
|
| 228 |
-
return return_dict
|
| 229 |
-
|
| 230 |
-
# @torch.inference_mode()
|
| 231 |
-
def infer(
|
| 232 |
-
self,
|
| 233 |
-
image: torch.Tensor,
|
| 234 |
-
num_tokens: int = None,
|
| 235 |
-
resolution_level: int = 9,
|
| 236 |
-
force_projection: bool = True,
|
| 237 |
-
apply_mask: Literal[False, True, "blend"] = True,
|
| 238 |
-
fov_x: Optional[Union[Number, torch.Tensor]] = None,
|
| 239 |
-
use_fp16: bool = True,
|
| 240 |
-
) -> Dict[str, torch.Tensor]:
|
| 241 |
-
"""
|
| 242 |
-
User-friendly inference function
|
| 243 |
-
|
| 244 |
-
### Parameters
|
| 245 |
-
- `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
|
| 246 |
-
- `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500.
|
| 247 |
-
More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`.
|
| 248 |
-
- `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True
|
| 249 |
-
- `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
|
| 250 |
-
- `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
|
| 251 |
-
- `use_fp16`: if True, use mixed precision to speed up inference. Default: True
|
| 252 |
-
|
| 253 |
-
### Returns
|
| 254 |
-
|
| 255 |
-
A dictionary containing the following keys:
|
| 256 |
-
- `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
|
| 257 |
-
- `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
|
| 258 |
-
- `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
|
| 259 |
-
"""
|
| 260 |
-
if image.dim() == 3:
|
| 261 |
-
omit_batch_dim = True
|
| 262 |
-
image = image.unsqueeze(0)
|
| 263 |
-
else:
|
| 264 |
-
omit_batch_dim = False
|
| 265 |
-
image = image.to(dtype=self.dtype, device=self.device)
|
| 266 |
-
|
| 267 |
-
original_height, original_width = image.shape[-2:]
|
| 268 |
-
aspect_ratio = original_width / original_height
|
| 269 |
-
|
| 270 |
-
# Determine the number of base tokens to use
|
| 271 |
-
if num_tokens is None:
|
| 272 |
-
min_tokens, max_tokens = self.num_tokens_range
|
| 273 |
-
num_tokens = int(
|
| 274 |
-
min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
-
# Forward pass
|
| 278 |
-
with torch.autocast(
|
| 279 |
-
device_type=self.device.type,
|
| 280 |
-
dtype=torch.float16,
|
| 281 |
-
enabled=use_fp16 and self.dtype != torch.float16,
|
| 282 |
-
):
|
| 283 |
-
output = self.forward(image, num_tokens=num_tokens)
|
| 284 |
-
points, normal, mask, metric_scale = (
|
| 285 |
-
output.get(k, None) for k in ["points", "normal", "mask", "metric_scale"]
|
| 286 |
-
)
|
| 287 |
-
|
| 288 |
-
# Always process the output in fp32 precision
|
| 289 |
-
points, normal, mask, metric_scale, fov_x = map(
|
| 290 |
-
lambda x: x.float() if isinstance(x, torch.Tensor) else x,
|
| 291 |
-
[points, normal, mask, metric_scale, fov_x],
|
| 292 |
-
)
|
| 293 |
-
with torch.autocast(device_type=self.device.type, dtype=torch.float32):
|
| 294 |
-
if mask is not None:
|
| 295 |
-
mask_binary = mask > 0.5
|
| 296 |
-
else:
|
| 297 |
-
mask_binary = None
|
| 298 |
-
|
| 299 |
-
if points is not None:
|
| 300 |
-
# Convert affine point map to camera-space. Recover depth and intrinsics from point map.
|
| 301 |
-
# NOTE: Focal here is the focal length relative to half the image diagonal
|
| 302 |
-
if fov_x is None:
|
| 303 |
-
# Recover focal and shift from predicted point map
|
| 304 |
-
focal, shift = recover_focal_shift(points, mask_binary)
|
| 305 |
-
else:
|
| 306 |
-
# Focal is known, recover shift only
|
| 307 |
-
focal = (
|
| 308 |
-
aspect_ratio
|
| 309 |
-
/ (1 + aspect_ratio**2) ** 0.5
|
| 310 |
-
/ torch.tan(
|
| 311 |
-
torch.deg2rad(
|
| 312 |
-
torch.as_tensor(
|
| 313 |
-
fov_x, device=points.device, dtype=points.dtype
|
| 314 |
-
)
|
| 315 |
-
/ 2
|
| 316 |
-
)
|
| 317 |
-
)
|
| 318 |
-
)
|
| 319 |
-
if focal.ndim == 0:
|
| 320 |
-
focal = focal[None].expand(points.shape[0])
|
| 321 |
-
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
| 322 |
-
fx, fy = (
|
| 323 |
-
focal / 2 * (1 + aspect_ratio**2) ** 0.5 / aspect_ratio,
|
| 324 |
-
focal / 2 * (1 + aspect_ratio**2) ** 0.5,
|
| 325 |
-
)
|
| 326 |
-
intrinsics = intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
|
| 327 |
-
points[..., 2] += shift[..., None, None]
|
| 328 |
-
if mask_binary is not None:
|
| 329 |
-
mask_binary &= (
|
| 330 |
-
points[..., 2] > 0
|
| 331 |
-
) # in case depth is contains negative values (which should never happen in practice)
|
| 332 |
-
depth = points[..., 2].clone()
|
| 333 |
-
else:
|
| 334 |
-
depth, intrinsics = None, None
|
| 335 |
-
|
| 336 |
-
# If projection constraint is forced, recompute the point map using the actual depth map & intrinsics
|
| 337 |
-
if force_projection and depth is not None:
|
| 338 |
-
points = depth_to_points(depth, intrinsics=intrinsics)
|
| 339 |
-
|
| 340 |
-
# Apply metric scale
|
| 341 |
-
if metric_scale is not None:
|
| 342 |
-
if points is not None:
|
| 343 |
-
points *= metric_scale[:, None, None, None]
|
| 344 |
-
if depth is not None:
|
| 345 |
-
depth *= metric_scale[:, None, None]
|
| 346 |
-
|
| 347 |
-
# Apply mask
|
| 348 |
-
if apply_mask and mask_binary is not None:
|
| 349 |
-
points = (
|
| 350 |
-
torch.where(mask_binary[..., None], points, torch.inf)
|
| 351 |
-
if points is not None
|
| 352 |
-
else None
|
| 353 |
-
)
|
| 354 |
-
depth = (
|
| 355 |
-
torch.where(mask_binary, depth, torch.inf)
|
| 356 |
-
if depth is not None
|
| 357 |
-
else None
|
| 358 |
-
)
|
| 359 |
-
normal = (
|
| 360 |
-
torch.where(
|
| 361 |
-
mask_binary[..., None], normal, torch.zeros_like(normal)
|
| 362 |
-
)
|
| 363 |
-
if normal is not None
|
| 364 |
-
else None
|
| 365 |
-
)
|
| 366 |
-
|
| 367 |
-
return_dict = {
|
| 368 |
-
"points": points,
|
| 369 |
-
"intrinsics": intrinsics,
|
| 370 |
-
"depth": depth,
|
| 371 |
-
"mask": mask_binary,
|
| 372 |
-
"normal": normal,
|
| 373 |
-
}
|
| 374 |
-
return_dict = {k: v for k, v in return_dict.items() if v is not None}
|
| 375 |
-
|
| 376 |
-
if omit_batch_dim:
|
| 377 |
-
return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
|
| 378 |
-
|
| 379 |
-
return return_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/must3r/__init__.py
DELETED
|
@@ -1,283 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Inference wrapper for MUSt3R
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import datetime
|
| 6 |
-
import os
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
-
from dust3r.viz import rgb
|
| 11 |
-
from must3r.demo.inference import SceneState
|
| 12 |
-
from must3r.engine.inference import inference_multi_ar, postprocess
|
| 13 |
-
from must3r.model import get_pointmaps_activation, load_model
|
| 14 |
-
|
| 15 |
-
from mapanything.models.external.vggt.utils.rotation import mat_to_quat
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def must3r_inference(
|
| 19 |
-
views,
|
| 20 |
-
filelist,
|
| 21 |
-
model,
|
| 22 |
-
retrieval,
|
| 23 |
-
device,
|
| 24 |
-
amp,
|
| 25 |
-
num_mem_images,
|
| 26 |
-
max_bs,
|
| 27 |
-
init_num_images=2,
|
| 28 |
-
batch_num_views=1,
|
| 29 |
-
render_once=False,
|
| 30 |
-
is_sequence=False,
|
| 31 |
-
viser_server=None,
|
| 32 |
-
num_refinements_iterations=2,
|
| 33 |
-
verbose=True,
|
| 34 |
-
):
|
| 35 |
-
if amp == "fp16":
|
| 36 |
-
dtype = torch.float16
|
| 37 |
-
elif amp == "bf16":
|
| 38 |
-
assert torch.cuda.is_bf16_supported()
|
| 39 |
-
dtype = torch.bfloat16
|
| 40 |
-
else:
|
| 41 |
-
assert not amp
|
| 42 |
-
dtype = torch.float32
|
| 43 |
-
|
| 44 |
-
max_bs = None if max_bs == 0 else max_bs
|
| 45 |
-
encoder, decoder = model
|
| 46 |
-
pointmaps_activation = get_pointmaps_activation(decoder, verbose=verbose)
|
| 47 |
-
|
| 48 |
-
def post_process_function(x):
|
| 49 |
-
return postprocess(
|
| 50 |
-
x, pointmaps_activation=pointmaps_activation, compute_cam=True
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
if verbose:
|
| 54 |
-
print("loading images")
|
| 55 |
-
time_start = datetime.datetime.now()
|
| 56 |
-
nimgs = len(views)
|
| 57 |
-
|
| 58 |
-
ellapsed = datetime.datetime.now() - time_start
|
| 59 |
-
if verbose:
|
| 60 |
-
print(f"loaded in {ellapsed}")
|
| 61 |
-
print("running inference")
|
| 62 |
-
time_start = datetime.datetime.now()
|
| 63 |
-
if viser_server is not None:
|
| 64 |
-
viser_server.reset(nimgs)
|
| 65 |
-
|
| 66 |
-
imgs = [b["img"].to("cpu") for b in views]
|
| 67 |
-
true_shape = [torch.from_numpy(b["true_shape"]).to("cpu") for b in views]
|
| 68 |
-
true_shape = torch.stack(true_shape, dim=0)
|
| 69 |
-
nimgs = true_shape.shape[0]
|
| 70 |
-
|
| 71 |
-
# Use all images as keyframes
|
| 72 |
-
keyframes = np.linspace(0, len(imgs) - 1, num_mem_images, dtype=int).tolist()
|
| 73 |
-
encoder_precomputed_features = None
|
| 74 |
-
|
| 75 |
-
not_keyframes = sorted(set(range(nimgs)).difference(set(keyframes)))
|
| 76 |
-
assert (len(keyframes) + len(not_keyframes)) == nimgs
|
| 77 |
-
# reorder images
|
| 78 |
-
views = [views[i] for i in keyframes] + [views[i] for i in not_keyframes]
|
| 79 |
-
imgs = [b["img"].to(device) for b in views]
|
| 80 |
-
true_shape = [torch.from_numpy(b["true_shape"]).to(device) for b in views]
|
| 81 |
-
filenames = [filelist[i] for i in keyframes + not_keyframes]
|
| 82 |
-
img_ids = [torch.tensor(v) for v in keyframes + not_keyframes]
|
| 83 |
-
|
| 84 |
-
if encoder_precomputed_features is not None:
|
| 85 |
-
x_start, pos_start = encoder_precomputed_features
|
| 86 |
-
x = [x_start[i] for i in keyframes] + [x_start[i] for i in not_keyframes]
|
| 87 |
-
pos = [pos_start[i] for i in keyframes] + [pos_start[i] for i in not_keyframes]
|
| 88 |
-
encoder_precomputed_features = (x, pos)
|
| 89 |
-
|
| 90 |
-
mem_batches = [init_num_images]
|
| 91 |
-
while (sum_b := sum(mem_batches)) != max(num_mem_images, init_num_images):
|
| 92 |
-
size_b = min(batch_num_views, num_mem_images - sum_b)
|
| 93 |
-
mem_batches.append(size_b)
|
| 94 |
-
|
| 95 |
-
if render_once:
|
| 96 |
-
to_render = list(range(num_mem_images, nimgs))
|
| 97 |
-
else:
|
| 98 |
-
to_render = None
|
| 99 |
-
|
| 100 |
-
with torch.autocast("cuda", dtype=dtype):
|
| 101 |
-
x_out_0, x_out = inference_multi_ar(
|
| 102 |
-
encoder,
|
| 103 |
-
decoder,
|
| 104 |
-
imgs,
|
| 105 |
-
img_ids,
|
| 106 |
-
true_shape,
|
| 107 |
-
mem_batches,
|
| 108 |
-
max_bs=max_bs,
|
| 109 |
-
verbose=verbose,
|
| 110 |
-
to_render=to_render,
|
| 111 |
-
encoder_precomputed_features=encoder_precomputed_features,
|
| 112 |
-
device=device,
|
| 113 |
-
preserve_gpu_mem=True,
|
| 114 |
-
post_process_function=post_process_function,
|
| 115 |
-
viser_server=viser_server,
|
| 116 |
-
num_refinements_iterations=num_refinements_iterations,
|
| 117 |
-
)
|
| 118 |
-
if to_render is not None:
|
| 119 |
-
x_out = x_out_0 + x_out
|
| 120 |
-
|
| 121 |
-
ellapsed = datetime.datetime.now() - time_start
|
| 122 |
-
if verbose:
|
| 123 |
-
print(f"inference in {ellapsed}")
|
| 124 |
-
try:
|
| 125 |
-
print(str(int(torch.cuda.max_memory_reserved(device) / (1024**2))) + " MB")
|
| 126 |
-
except Exception:
|
| 127 |
-
pass
|
| 128 |
-
|
| 129 |
-
if viser_server is not None:
|
| 130 |
-
viser_server.reset_cam_visility()
|
| 131 |
-
viser_server.send_message("Finished")
|
| 132 |
-
|
| 133 |
-
if verbose:
|
| 134 |
-
print("preparing pointcloud")
|
| 135 |
-
time_start = datetime.datetime.now()
|
| 136 |
-
focals = []
|
| 137 |
-
cams2world = []
|
| 138 |
-
for i in range(nimgs):
|
| 139 |
-
focals.append(float(x_out[i]["focal"].cpu()))
|
| 140 |
-
cams2world.append(x_out[i]["c2w"].cpu())
|
| 141 |
-
|
| 142 |
-
# x_out to cpu
|
| 143 |
-
for i in range(len(x_out)):
|
| 144 |
-
for k in x_out[i].keys():
|
| 145 |
-
x_out[i][k] = x_out[i][k].cpu()
|
| 146 |
-
|
| 147 |
-
rgbimg = [rgb(imgs[i], true_shape[i]) for i in range(nimgs)]
|
| 148 |
-
scene = SceneState(x_out, rgbimg, true_shape, focals, cams2world, filenames)
|
| 149 |
-
|
| 150 |
-
ellapsed = datetime.datetime.now() - time_start
|
| 151 |
-
if verbose:
|
| 152 |
-
print(f"pointcloud prepared in {ellapsed}")
|
| 153 |
-
|
| 154 |
-
return scene
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
class MUSt3RWrapper(torch.nn.Module):
|
| 158 |
-
def __init__(
|
| 159 |
-
self,
|
| 160 |
-
name,
|
| 161 |
-
ckpt_path,
|
| 162 |
-
retrieval_ckpt_path,
|
| 163 |
-
img_size=512,
|
| 164 |
-
amp="bf16",
|
| 165 |
-
max_bs=1,
|
| 166 |
-
**kwargs,
|
| 167 |
-
):
|
| 168 |
-
super().__init__()
|
| 169 |
-
self.name = name
|
| 170 |
-
self.ckpt_path = ckpt_path
|
| 171 |
-
self.retrieval_ckpt_path = retrieval_ckpt_path
|
| 172 |
-
self.amp = amp
|
| 173 |
-
self.max_bs = max_bs
|
| 174 |
-
|
| 175 |
-
# Init the model and load the checkpoint
|
| 176 |
-
self.model = load_model(self.ckpt_path, img_size=512)
|
| 177 |
-
|
| 178 |
-
def forward(self, views):
|
| 179 |
-
"""
|
| 180 |
-
Forward pass wrapper for MUSt3R.
|
| 181 |
-
|
| 182 |
-
Assumption:
|
| 183 |
-
- The batch size of input views is 1.
|
| 184 |
-
|
| 185 |
-
Args:
|
| 186 |
-
views (List[dict]): List of dictionaries containing the input views' images and instance information.
|
| 187 |
-
Each dictionary should contain the following keys, where B is the batch size and is 1:
|
| 188 |
-
"img" (tensor): Image tensor of shape (B, C, H, W).
|
| 189 |
-
"data_norm_type" (list): ["dust3r"]
|
| 190 |
-
"label" (list): ["scene_name"]
|
| 191 |
-
"instance" (list): ["image_name"]
|
| 192 |
-
|
| 193 |
-
Returns:
|
| 194 |
-
List[dict]: A list containing the final outputs for the input views.
|
| 195 |
-
"""
|
| 196 |
-
# Check the batch size of input views
|
| 197 |
-
batch_size_per_view, _, height, width = views[0]["img"].shape
|
| 198 |
-
device = views[0]["img"].device
|
| 199 |
-
num_views = len(views)
|
| 200 |
-
assert batch_size_per_view == 1, (
|
| 201 |
-
f"Batch size of input views should be 1, but got {batch_size_per_view}."
|
| 202 |
-
)
|
| 203 |
-
|
| 204 |
-
# Check the data norm type
|
| 205 |
-
data_norm_type = views[0]["data_norm_type"][0]
|
| 206 |
-
assert data_norm_type == "dust3r", (
|
| 207 |
-
"MUSt3R expects a normalized image with the DUSt3R normalization scheme applied"
|
| 208 |
-
)
|
| 209 |
-
|
| 210 |
-
# Convert the input views to the expected input format
|
| 211 |
-
images = []
|
| 212 |
-
image_paths = []
|
| 213 |
-
for view in views:
|
| 214 |
-
images.append(
|
| 215 |
-
dict(
|
| 216 |
-
img=view["img"][0].cpu(),
|
| 217 |
-
idx=len(images),
|
| 218 |
-
instance=str(len(images)),
|
| 219 |
-
true_shape=np.int32([view["img"].shape[-2], view["img"].shape[-1]]),
|
| 220 |
-
)
|
| 221 |
-
)
|
| 222 |
-
view_name = os.path.join(view["label"][0], view["instance"][0])
|
| 223 |
-
image_paths.append(view_name)
|
| 224 |
-
|
| 225 |
-
# Run MUSt3R inference
|
| 226 |
-
scene = must3r_inference(
|
| 227 |
-
images,
|
| 228 |
-
image_paths,
|
| 229 |
-
self.model,
|
| 230 |
-
self.retrieval_ckpt_path,
|
| 231 |
-
device,
|
| 232 |
-
self.amp,
|
| 233 |
-
num_views,
|
| 234 |
-
self.max_bs,
|
| 235 |
-
verbose=False,
|
| 236 |
-
)
|
| 237 |
-
|
| 238 |
-
# Make sure scene is not None
|
| 239 |
-
if scene is None:
|
| 240 |
-
raise RuntimeError("MUSt3R failed.")
|
| 241 |
-
|
| 242 |
-
# Get the predictions
|
| 243 |
-
predictions = scene.x_out
|
| 244 |
-
|
| 245 |
-
# Convert the output to the MapAnything format
|
| 246 |
-
with torch.autocast("cuda", enabled=False):
|
| 247 |
-
res = []
|
| 248 |
-
for view_idx in range(num_views):
|
| 249 |
-
# Get the current view predictions
|
| 250 |
-
curr_view_prediction = predictions[view_idx]
|
| 251 |
-
curr_view_conf = curr_view_prediction["conf"]
|
| 252 |
-
curr_view_pose = curr_view_prediction["c2w"].unsqueeze(0)
|
| 253 |
-
|
| 254 |
-
# Convert the pose to quaternions and translation
|
| 255 |
-
curr_view_cam_translations = curr_view_pose[..., :3, 3]
|
| 256 |
-
curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3])
|
| 257 |
-
|
| 258 |
-
# Get the camera frame pointmaps
|
| 259 |
-
curr_view_pts3d_cam = curr_view_prediction["pts3d_local"].unsqueeze(0)
|
| 260 |
-
|
| 261 |
-
# Get the depth along ray and ray directions
|
| 262 |
-
curr_view_depth_along_ray = torch.norm(
|
| 263 |
-
curr_view_pts3d_cam, dim=-1, keepdim=True
|
| 264 |
-
)
|
| 265 |
-
curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray
|
| 266 |
-
|
| 267 |
-
# Get the pointmaps
|
| 268 |
-
curr_view_pts3d = curr_view_prediction["pts3d"].unsqueeze(0)
|
| 269 |
-
|
| 270 |
-
# Append the outputs to the result list
|
| 271 |
-
res.append(
|
| 272 |
-
{
|
| 273 |
-
"pts3d": curr_view_pts3d.to(device),
|
| 274 |
-
"pts3d_cam": curr_view_pts3d_cam.to(device),
|
| 275 |
-
"ray_directions": curr_view_ray_dirs.to(device),
|
| 276 |
-
"depth_along_ray": curr_view_depth_along_ray.to(device),
|
| 277 |
-
"cam_trans": curr_view_cam_translations.to(device),
|
| 278 |
-
"cam_quats": curr_view_cam_quats.to(device),
|
| 279 |
-
"conf": curr_view_conf.to(device),
|
| 280 |
-
}
|
| 281 |
-
)
|
| 282 |
-
|
| 283 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/pi3/__init__.py
DELETED
|
@@ -1,119 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Inference wrapper for Pi3
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
from mapanything.models.external.pi3.models.pi3 import Pi3
|
| 8 |
-
from mapanything.models.external.vggt.utils.rotation import mat_to_quat
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class Pi3Wrapper(torch.nn.Module):
|
| 12 |
-
def __init__(
|
| 13 |
-
self,
|
| 14 |
-
name,
|
| 15 |
-
torch_hub_force_reload,
|
| 16 |
-
load_pretrained_weights=True,
|
| 17 |
-
pos_type="rope100",
|
| 18 |
-
decoder_size="large",
|
| 19 |
-
):
|
| 20 |
-
super().__init__()
|
| 21 |
-
self.name = name
|
| 22 |
-
self.torch_hub_force_reload = torch_hub_force_reload
|
| 23 |
-
|
| 24 |
-
if load_pretrained_weights:
|
| 25 |
-
# Load pre-trained weights
|
| 26 |
-
if not torch_hub_force_reload:
|
| 27 |
-
# Initialize the Pi3 model from huggingface hub cache
|
| 28 |
-
print("Loading Pi3 from huggingface cache ...")
|
| 29 |
-
self.model = Pi3.from_pretrained(
|
| 30 |
-
"yyfz233/Pi3",
|
| 31 |
-
)
|
| 32 |
-
else:
|
| 33 |
-
# Initialize the Pi3 model
|
| 34 |
-
self.model = Pi3.from_pretrained("yyfz233/Pi3", force_download=True)
|
| 35 |
-
else:
|
| 36 |
-
# Load the Pi3 class
|
| 37 |
-
self.model = Pi3(
|
| 38 |
-
pos_type=pos_type,
|
| 39 |
-
decoder_size=decoder_size,
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
# Get the dtype for Pi3 inference
|
| 43 |
-
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
|
| 44 |
-
self.dtype = (
|
| 45 |
-
torch.bfloat16
|
| 46 |
-
if torch.cuda.get_device_capability()[0] >= 8
|
| 47 |
-
else torch.float16
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
def forward(self, views):
|
| 51 |
-
"""
|
| 52 |
-
Forward pass wrapper for Pi3
|
| 53 |
-
|
| 54 |
-
Assumption:
|
| 55 |
-
- All the input views have the same image shape.
|
| 56 |
-
|
| 57 |
-
Args:
|
| 58 |
-
views (List[dict]): List of dictionaries containing the input views' images and instance information.
|
| 59 |
-
Each dictionary should contain the following keys:
|
| 60 |
-
"img" (tensor): Image tensor of shape (B, C, H, W).
|
| 61 |
-
"data_norm_type" (list): ["identity"]
|
| 62 |
-
|
| 63 |
-
Returns:
|
| 64 |
-
List[dict]: A list containing the final outputs for all N views.
|
| 65 |
-
"""
|
| 66 |
-
# Get input shape of the images, number of views, and batch size per view
|
| 67 |
-
batch_size_per_view, _, height, width = views[0]["img"].shape
|
| 68 |
-
num_views = len(views)
|
| 69 |
-
|
| 70 |
-
# Check the data norm type
|
| 71 |
-
# Pi3 expects a normalized image but without the DINOv2 mean and std applied ("identity")
|
| 72 |
-
data_norm_type = views[0]["data_norm_type"][0]
|
| 73 |
-
assert data_norm_type == "identity", (
|
| 74 |
-
"Pi3 expects a normalized image but without the DINOv2 mean and std applied"
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
# Concatenate the images to create a single (B, V, C, H, W) tensor
|
| 78 |
-
img_list = [view["img"] for view in views]
|
| 79 |
-
images = torch.stack(img_list, dim=1)
|
| 80 |
-
|
| 81 |
-
# Run the Pi3 aggregator
|
| 82 |
-
with torch.autocast("cuda", dtype=self.dtype):
|
| 83 |
-
results = self.model(images)
|
| 84 |
-
|
| 85 |
-
# Need high precision for transformations
|
| 86 |
-
with torch.autocast("cuda", enabled=False):
|
| 87 |
-
# Convert the output to MapAnything format
|
| 88 |
-
res = []
|
| 89 |
-
for view_idx in range(num_views):
|
| 90 |
-
# Get the extrinsics
|
| 91 |
-
curr_view_extrinsic = results["camera_poses"][:, view_idx, ...]
|
| 92 |
-
curr_view_cam_translations = curr_view_extrinsic[..., :3, 3]
|
| 93 |
-
curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3])
|
| 94 |
-
|
| 95 |
-
# Get the depth along ray, ray directions, local point cloud & global point cloud
|
| 96 |
-
curr_view_pts3d_cam = results["local_points"][:, view_idx, ...]
|
| 97 |
-
curr_view_depth_along_ray = torch.norm(
|
| 98 |
-
curr_view_pts3d_cam, dim=-1, keepdim=True
|
| 99 |
-
)
|
| 100 |
-
curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray
|
| 101 |
-
curr_view_pts3d = results["points"][:, view_idx, ...]
|
| 102 |
-
|
| 103 |
-
# Get the confidence
|
| 104 |
-
curr_view_confidence = results["conf"][:, view_idx, ...]
|
| 105 |
-
|
| 106 |
-
# Append the outputs to the result list
|
| 107 |
-
res.append(
|
| 108 |
-
{
|
| 109 |
-
"pts3d": curr_view_pts3d,
|
| 110 |
-
"pts3d_cam": curr_view_pts3d_cam,
|
| 111 |
-
"ray_directions": curr_view_ray_dirs,
|
| 112 |
-
"depth_along_ray": curr_view_depth_along_ray,
|
| 113 |
-
"cam_trans": curr_view_cam_translations,
|
| 114 |
-
"cam_quats": curr_view_cam_quats,
|
| 115 |
-
"conf": curr_view_confidence,
|
| 116 |
-
}
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/pi3/layers/__init__.py
DELETED
|
File without changes
|
mapanything/models/external/pi3/layers/attention.py
DELETED
|
@@ -1,429 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
# References:
|
| 7 |
-
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
-
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
import os
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
from torch import nn, Tensor
|
| 15 |
-
from torch.nn.attention import SDPBackend
|
| 16 |
-
from torch.nn.functional import scaled_dot_product_attention
|
| 17 |
-
|
| 18 |
-
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 19 |
-
try:
|
| 20 |
-
if XFORMERS_ENABLED:
|
| 21 |
-
from xformers.ops import memory_efficient_attention
|
| 22 |
-
|
| 23 |
-
XFORMERS_AVAILABLE = True
|
| 24 |
-
# warnings.warn("xFormers is available (Attention)")
|
| 25 |
-
else:
|
| 26 |
-
# warnings.warn("xFormers is disabled (Attention)")
|
| 27 |
-
raise ImportError
|
| 28 |
-
except ImportError:
|
| 29 |
-
XFORMERS_AVAILABLE = False
|
| 30 |
-
# warnings.warn("xFormers is not available (Attention)")
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class Attention(nn.Module):
|
| 34 |
-
def __init__(
|
| 35 |
-
self,
|
| 36 |
-
dim: int,
|
| 37 |
-
num_heads: int = 8,
|
| 38 |
-
qkv_bias: bool = False,
|
| 39 |
-
proj_bias: bool = True,
|
| 40 |
-
attn_drop: float = 0.0,
|
| 41 |
-
proj_drop: float = 0.0,
|
| 42 |
-
) -> None:
|
| 43 |
-
super().__init__()
|
| 44 |
-
self.num_heads = num_heads
|
| 45 |
-
head_dim = dim // num_heads
|
| 46 |
-
self.scale = head_dim**-0.5
|
| 47 |
-
|
| 48 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 49 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
| 50 |
-
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 51 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
| 52 |
-
|
| 53 |
-
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 54 |
-
B, N, C = x.shape
|
| 55 |
-
qkv = (
|
| 56 |
-
self.qkv(x)
|
| 57 |
-
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 58 |
-
.permute(2, 0, 3, 1, 4)
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 62 |
-
attn = q @ k.transpose(-2, -1)
|
| 63 |
-
|
| 64 |
-
attn = attn.softmax(dim=-1)
|
| 65 |
-
attn = self.attn_drop(attn)
|
| 66 |
-
|
| 67 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 68 |
-
x = self.proj(x)
|
| 69 |
-
x = self.proj_drop(x)
|
| 70 |
-
return x
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class MemEffAttention(Attention):
|
| 74 |
-
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 75 |
-
if not XFORMERS_AVAILABLE:
|
| 76 |
-
if attn_bias is not None:
|
| 77 |
-
raise AssertionError("xFormers is required for using nested tensors")
|
| 78 |
-
return super().forward(x)
|
| 79 |
-
|
| 80 |
-
B, N, C = x.shape
|
| 81 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 82 |
-
|
| 83 |
-
# q, k, v = unbind(qkv, 2)
|
| 84 |
-
q, k, v = [qkv[:, :, i] for i in range(3)]
|
| 85 |
-
|
| 86 |
-
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 87 |
-
x = x.reshape([B, N, C])
|
| 88 |
-
|
| 89 |
-
x = self.proj(x)
|
| 90 |
-
x = self.proj_drop(x)
|
| 91 |
-
return x
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
class FlashAttention(Attention):
|
| 95 |
-
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 96 |
-
B, N, C = x.shape
|
| 97 |
-
qkv = (
|
| 98 |
-
self.qkv(x)
|
| 99 |
-
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 100 |
-
.transpose(1, 3)
|
| 101 |
-
)
|
| 102 |
-
|
| 103 |
-
# q, k, v = unbind(qkv, 2)
|
| 104 |
-
q, k, v = [qkv[:, :, i] for i in range(3)]
|
| 105 |
-
|
| 106 |
-
if q.dtype == torch.bfloat16:
|
| 107 |
-
with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
| 108 |
-
x = scaled_dot_product_attention(q, k, v)
|
| 109 |
-
else:
|
| 110 |
-
with nn.attention.sdpa_kernel(
|
| 111 |
-
[SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]
|
| 112 |
-
):
|
| 113 |
-
x = scaled_dot_product_attention(q, k, v)
|
| 114 |
-
|
| 115 |
-
x = x.transpose(1, 2).reshape([B, N, C])
|
| 116 |
-
|
| 117 |
-
x = self.proj(x)
|
| 118 |
-
x = self.proj_drop(x)
|
| 119 |
-
return x
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
"""
|
| 123 |
-
Following is written by GPT-4o
|
| 124 |
-
"""
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
class CrossAttentionRope(nn.Module):
|
| 128 |
-
def __init__(
|
| 129 |
-
self,
|
| 130 |
-
dim: int,
|
| 131 |
-
num_heads: int = 8,
|
| 132 |
-
qkv_bias: bool = False,
|
| 133 |
-
proj_bias: bool = True,
|
| 134 |
-
attn_drop: float = 0.0,
|
| 135 |
-
proj_drop: float = 0.0,
|
| 136 |
-
qk_norm: bool = False,
|
| 137 |
-
norm_layer: nn.Module = nn.LayerNorm,
|
| 138 |
-
rope=None,
|
| 139 |
-
) -> None:
|
| 140 |
-
super().__init__()
|
| 141 |
-
self.num_heads = num_heads
|
| 142 |
-
head_dim = dim // num_heads
|
| 143 |
-
self.scale = head_dim**-0.5
|
| 144 |
-
|
| 145 |
-
# Separate projection layers for query, key, and value
|
| 146 |
-
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| 147 |
-
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| 148 |
-
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| 149 |
-
|
| 150 |
-
self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
|
| 151 |
-
self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
|
| 152 |
-
|
| 153 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
| 154 |
-
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 155 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
| 156 |
-
|
| 157 |
-
self.rope = rope
|
| 158 |
-
|
| 159 |
-
def forward(
|
| 160 |
-
self,
|
| 161 |
-
query: Tensor,
|
| 162 |
-
key: Tensor,
|
| 163 |
-
value: Tensor,
|
| 164 |
-
attn_bias=None,
|
| 165 |
-
qpos=None,
|
| 166 |
-
kpos=None,
|
| 167 |
-
) -> Tensor:
|
| 168 |
-
"""
|
| 169 |
-
Args:
|
| 170 |
-
query: Tensor of shape (B, N, C), input query
|
| 171 |
-
key: Tensor of shape (B, M, C), input key
|
| 172 |
-
value: Tensor of shape (B, M, C), input value
|
| 173 |
-
attn_bias: Optional tensor for attention bias
|
| 174 |
-
Returns:
|
| 175 |
-
Tensor of shape (B, N, C), output of cross-attention
|
| 176 |
-
"""
|
| 177 |
-
B, N, C = query.shape
|
| 178 |
-
_, M, _ = key.shape
|
| 179 |
-
|
| 180 |
-
# Project query, key, and value
|
| 181 |
-
q = (
|
| 182 |
-
self.q_proj(query)
|
| 183 |
-
.reshape(B, N, self.num_heads, C // self.num_heads)
|
| 184 |
-
.permute(0, 2, 1, 3)
|
| 185 |
-
)
|
| 186 |
-
k = (
|
| 187 |
-
self.k_proj(key)
|
| 188 |
-
.reshape(B, M, self.num_heads, C // self.num_heads)
|
| 189 |
-
.permute(0, 2, 1, 3)
|
| 190 |
-
)
|
| 191 |
-
v = (
|
| 192 |
-
self.v_proj(value)
|
| 193 |
-
.reshape(B, M, self.num_heads, C // self.num_heads)
|
| 194 |
-
.permute(0, 2, 1, 3)
|
| 195 |
-
)
|
| 196 |
-
q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
|
| 197 |
-
|
| 198 |
-
if self.rope is not None:
|
| 199 |
-
q = self.rope(q, qpos)
|
| 200 |
-
k = self.rope(k, kpos)
|
| 201 |
-
|
| 202 |
-
# Scale query
|
| 203 |
-
q = q * self.scale
|
| 204 |
-
|
| 205 |
-
# Compute attention scores
|
| 206 |
-
attn = q @ k.transpose(-2, -1) # (B, num_heads, N, M)
|
| 207 |
-
if attn_bias is not None:
|
| 208 |
-
attn = attn + attn_bias
|
| 209 |
-
|
| 210 |
-
attn = attn.softmax(dim=-1)
|
| 211 |
-
attn = self.attn_drop(attn)
|
| 212 |
-
|
| 213 |
-
# Compute attention output
|
| 214 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C)
|
| 215 |
-
|
| 216 |
-
# Final projection
|
| 217 |
-
x = self.proj(x)
|
| 218 |
-
x = self.proj_drop(x)
|
| 219 |
-
return x
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
class MemEffCrossAttentionRope(CrossAttentionRope):
|
| 223 |
-
def forward(
|
| 224 |
-
self,
|
| 225 |
-
query: Tensor,
|
| 226 |
-
key: Tensor,
|
| 227 |
-
value: Tensor,
|
| 228 |
-
attn_bias=None,
|
| 229 |
-
qpos=None,
|
| 230 |
-
kpos=None,
|
| 231 |
-
) -> Tensor:
|
| 232 |
-
"""
|
| 233 |
-
Args:
|
| 234 |
-
query: Tensor of shape (B, N, C), input query
|
| 235 |
-
key: Tensor of shape (B, M, C), input key
|
| 236 |
-
value: Tensor of shape (B, M, C), input value
|
| 237 |
-
attn_bias: Optional tensor for attention bias
|
| 238 |
-
Returns:
|
| 239 |
-
Tensor of shape (B, N, C), output of cross-attention
|
| 240 |
-
"""
|
| 241 |
-
if not XFORMERS_AVAILABLE:
|
| 242 |
-
if attn_bias is not None:
|
| 243 |
-
raise AssertionError("xFormers is required for using nested tensors")
|
| 244 |
-
return super().forward(query, key, value, attn_bias)
|
| 245 |
-
|
| 246 |
-
B, N, C = query.shape
|
| 247 |
-
_, M, _ = key.shape
|
| 248 |
-
|
| 249 |
-
# Project query, key, and value
|
| 250 |
-
q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads)
|
| 251 |
-
k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads)
|
| 252 |
-
v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads)
|
| 253 |
-
|
| 254 |
-
q = q.transpose(1, 2)
|
| 255 |
-
k = k.transpose(1, 2)
|
| 256 |
-
q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
|
| 257 |
-
|
| 258 |
-
if self.rope is not None:
|
| 259 |
-
q = self.rope(q, qpos)
|
| 260 |
-
k = self.rope(k, kpos)
|
| 261 |
-
|
| 262 |
-
q = q.transpose(1, 2)
|
| 263 |
-
k = k.transpose(1, 2)
|
| 264 |
-
|
| 265 |
-
# Compute memory-efficient attention
|
| 266 |
-
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 267 |
-
x = x.reshape(B, N, C)
|
| 268 |
-
|
| 269 |
-
# Final projection
|
| 270 |
-
x = self.proj(x)
|
| 271 |
-
x = self.proj_drop(x)
|
| 272 |
-
return x
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
class AttentionRope(nn.Module):
|
| 276 |
-
def __init__(
|
| 277 |
-
self,
|
| 278 |
-
dim: int,
|
| 279 |
-
num_heads: int = 8,
|
| 280 |
-
qkv_bias: bool = False,
|
| 281 |
-
proj_bias: bool = True,
|
| 282 |
-
attn_drop: float = 0.0,
|
| 283 |
-
proj_drop: float = 0.0,
|
| 284 |
-
qk_norm: bool = False,
|
| 285 |
-
norm_layer: nn.Module = nn.LayerNorm,
|
| 286 |
-
rope=None,
|
| 287 |
-
) -> None:
|
| 288 |
-
super().__init__()
|
| 289 |
-
self.num_heads = num_heads
|
| 290 |
-
head_dim = dim // num_heads
|
| 291 |
-
self.scale = head_dim**-0.5
|
| 292 |
-
|
| 293 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 294 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
| 295 |
-
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 296 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
| 297 |
-
|
| 298 |
-
self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
|
| 299 |
-
self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
|
| 300 |
-
|
| 301 |
-
self.rope = rope
|
| 302 |
-
|
| 303 |
-
def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
|
| 304 |
-
B, N, C = x.shape
|
| 305 |
-
qkv = (
|
| 306 |
-
self.qkv(x)
|
| 307 |
-
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 308 |
-
.permute(2, 0, 3, 1, 4)
|
| 309 |
-
)
|
| 310 |
-
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 311 |
-
q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
|
| 312 |
-
|
| 313 |
-
if self.rope is not None:
|
| 314 |
-
q = self.rope(q, xpos)
|
| 315 |
-
k = self.rope(k, xpos)
|
| 316 |
-
|
| 317 |
-
q = q * self.scale
|
| 318 |
-
attn = q @ k.transpose(-2, -1)
|
| 319 |
-
|
| 320 |
-
attn = attn.softmax(dim=-1)
|
| 321 |
-
attn = self.attn_drop(attn)
|
| 322 |
-
|
| 323 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 324 |
-
x = self.proj(x)
|
| 325 |
-
x = self.proj_drop(x)
|
| 326 |
-
return x
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
class MemEffAttentionRope(AttentionRope):
|
| 330 |
-
def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
|
| 331 |
-
if not XFORMERS_AVAILABLE:
|
| 332 |
-
if attn_bias is not None:
|
| 333 |
-
raise AssertionError("xFormers is required for using nested tensors")
|
| 334 |
-
return super().forward(x)
|
| 335 |
-
|
| 336 |
-
B, N, C = x.shape
|
| 337 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 338 |
-
|
| 339 |
-
qkv = qkv.transpose(1, 3)
|
| 340 |
-
# q, k, v = unbind(qkv, 2)
|
| 341 |
-
q, k, v = [qkv[:, :, i] for i in range(3)]
|
| 342 |
-
q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
|
| 343 |
-
|
| 344 |
-
if self.rope is not None:
|
| 345 |
-
q = self.rope(q, xpos)
|
| 346 |
-
k = self.rope(k, xpos)
|
| 347 |
-
|
| 348 |
-
q = q.transpose(1, 2)
|
| 349 |
-
k = k.transpose(1, 2)
|
| 350 |
-
v = v.transpose(1, 2)
|
| 351 |
-
|
| 352 |
-
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 353 |
-
x = x.reshape([B, N, C])
|
| 354 |
-
|
| 355 |
-
# score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1) # for frame attention matrix
|
| 356 |
-
# global_valid_id = torch.where(score_matrix > 0)
|
| 357 |
-
# score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1)
|
| 358 |
-
|
| 359 |
-
x = self.proj(x)
|
| 360 |
-
x = self.proj_drop(x)
|
| 361 |
-
return x
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
class FlashAttentionRope(AttentionRope):
|
| 365 |
-
def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
|
| 366 |
-
B, N, C = x.shape
|
| 367 |
-
qkv = (
|
| 368 |
-
self.qkv(x)
|
| 369 |
-
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 370 |
-
.transpose(1, 3)
|
| 371 |
-
)
|
| 372 |
-
|
| 373 |
-
# q, k, v = unbind(qkv, 2)
|
| 374 |
-
q, k, v = [qkv[:, :, i] for i in range(3)]
|
| 375 |
-
q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
|
| 376 |
-
|
| 377 |
-
if self.rope is not None:
|
| 378 |
-
q = self.rope(q, xpos)
|
| 379 |
-
k = self.rope(k, xpos)
|
| 380 |
-
|
| 381 |
-
if q.dtype == torch.bfloat16:
|
| 382 |
-
with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
| 383 |
-
x = scaled_dot_product_attention(q, k, v)
|
| 384 |
-
else:
|
| 385 |
-
with nn.attention.sdpa_kernel(
|
| 386 |
-
[SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]
|
| 387 |
-
):
|
| 388 |
-
x = scaled_dot_product_attention(q, k, v)
|
| 389 |
-
|
| 390 |
-
x = x.transpose(1, 2).reshape([B, N, C])
|
| 391 |
-
|
| 392 |
-
x = self.proj(x)
|
| 393 |
-
x = self.proj_drop(x)
|
| 394 |
-
return x
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
def get_attn_score(blk_class, x, frame_num, token_length, xpos=None):
|
| 398 |
-
x = blk_class.norm1(x)
|
| 399 |
-
|
| 400 |
-
B, N, C = x.shape
|
| 401 |
-
qkv = blk_class.attn.qkv(x).reshape(
|
| 402 |
-
B, N, 3, blk_class.attn.num_heads, C // blk_class.attn.num_heads
|
| 403 |
-
)
|
| 404 |
-
|
| 405 |
-
qkv = qkv.transpose(1, 3)
|
| 406 |
-
# q, k, v = unbind(qkv, 2)
|
| 407 |
-
q, k, v = [qkv[:, :, i] for i in range(3)]
|
| 408 |
-
q, k = blk_class.attn.q_norm(q).to(v.dtype), blk_class.attn.k_norm(k).to(v.dtype)
|
| 409 |
-
|
| 410 |
-
if blk_class.attn.rope is not None:
|
| 411 |
-
q = blk_class.attn.rope(q, xpos)
|
| 412 |
-
k = blk_class.attn.rope(k, xpos)
|
| 413 |
-
|
| 414 |
-
q = q.transpose(1, 2)
|
| 415 |
-
k = k.transpose(1, 2)
|
| 416 |
-
|
| 417 |
-
score = (
|
| 418 |
-
(
|
| 419 |
-
q.permute(0, 2, 1, 3)
|
| 420 |
-
* blk_class.attn.scale
|
| 421 |
-
@ k.permute(0, 2, 1, 3).transpose(-2, -1)
|
| 422 |
-
)
|
| 423 |
-
.sum(dim=1)
|
| 424 |
-
.reshape(B, frame_num, token_length, frame_num, token_length)
|
| 425 |
-
.mean(dim=[2, 4])
|
| 426 |
-
.sum(-1)
|
| 427 |
-
)
|
| 428 |
-
|
| 429 |
-
return score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/pi3/layers/block.py
DELETED
|
@@ -1,448 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
-
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
# References:
|
| 7 |
-
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
-
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
-
|
| 10 |
-
import os
|
| 11 |
-
from typing import Any, Callable, Dict, List, Tuple
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
from torch import nn, Tensor
|
| 15 |
-
|
| 16 |
-
from mapanything.models.external.dinov2.layers.drop_path import DropPath
|
| 17 |
-
from mapanything.models.external.dinov2.layers.layer_scale import LayerScale
|
| 18 |
-
from mapanything.models.external.dinov2.layers.mlp import Mlp
|
| 19 |
-
from mapanything.models.external.pi3.layers.attention import (
|
| 20 |
-
Attention,
|
| 21 |
-
CrossAttentionRope,
|
| 22 |
-
MemEffAttention,
|
| 23 |
-
)
|
| 24 |
-
|
| 25 |
-
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 26 |
-
try:
|
| 27 |
-
if XFORMERS_ENABLED:
|
| 28 |
-
from xformers.ops import fmha, index_select_cat, scaled_index_add
|
| 29 |
-
|
| 30 |
-
XFORMERS_AVAILABLE = True
|
| 31 |
-
# warnings.warn("xFormers is available (Block)")
|
| 32 |
-
else:
|
| 33 |
-
# warnings.warn("xFormers is disabled (Block)")
|
| 34 |
-
raise ImportError
|
| 35 |
-
except ImportError:
|
| 36 |
-
XFORMERS_AVAILABLE = False
|
| 37 |
-
# warnings.warn("xFormers is not available (Block)")
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class Block(nn.Module):
|
| 41 |
-
def __init__(
|
| 42 |
-
self,
|
| 43 |
-
dim: int,
|
| 44 |
-
num_heads: int,
|
| 45 |
-
mlp_ratio: float = 4.0,
|
| 46 |
-
qkv_bias: bool = False,
|
| 47 |
-
proj_bias: bool = True,
|
| 48 |
-
ffn_bias: bool = True,
|
| 49 |
-
drop: float = 0.0,
|
| 50 |
-
attn_drop: float = 0.0,
|
| 51 |
-
init_values=None,
|
| 52 |
-
drop_path: float = 0.0,
|
| 53 |
-
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 54 |
-
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 55 |
-
attn_class: Callable[..., nn.Module] = Attention,
|
| 56 |
-
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 57 |
-
) -> None:
|
| 58 |
-
super().__init__()
|
| 59 |
-
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 60 |
-
self.norm1 = norm_layer(dim)
|
| 61 |
-
self.attn = attn_class(
|
| 62 |
-
dim,
|
| 63 |
-
num_heads=num_heads,
|
| 64 |
-
qkv_bias=qkv_bias,
|
| 65 |
-
proj_bias=proj_bias,
|
| 66 |
-
attn_drop=attn_drop,
|
| 67 |
-
proj_drop=drop,
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
self.ls1 = (
|
| 71 |
-
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 72 |
-
)
|
| 73 |
-
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 74 |
-
|
| 75 |
-
self.norm2 = norm_layer(dim)
|
| 76 |
-
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 77 |
-
self.mlp = ffn_layer(
|
| 78 |
-
in_features=dim,
|
| 79 |
-
hidden_features=mlp_hidden_dim,
|
| 80 |
-
act_layer=act_layer,
|
| 81 |
-
drop=drop,
|
| 82 |
-
bias=ffn_bias,
|
| 83 |
-
)
|
| 84 |
-
self.ls2 = (
|
| 85 |
-
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 86 |
-
)
|
| 87 |
-
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 88 |
-
|
| 89 |
-
self.sample_drop_ratio = drop_path
|
| 90 |
-
|
| 91 |
-
def forward(self, x: Tensor) -> Tensor:
|
| 92 |
-
def attn_residual_func(x: Tensor) -> Tensor:
|
| 93 |
-
return self.ls1(self.attn(self.norm1(x)))
|
| 94 |
-
|
| 95 |
-
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 96 |
-
return self.ls2(self.mlp(self.norm2(x)))
|
| 97 |
-
|
| 98 |
-
if self.training and self.sample_drop_ratio > 0.1:
|
| 99 |
-
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 100 |
-
x = drop_add_residual_stochastic_depth(
|
| 101 |
-
x,
|
| 102 |
-
residual_func=attn_residual_func,
|
| 103 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
| 104 |
-
)
|
| 105 |
-
x = drop_add_residual_stochastic_depth(
|
| 106 |
-
x,
|
| 107 |
-
residual_func=ffn_residual_func,
|
| 108 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
| 109 |
-
)
|
| 110 |
-
elif self.training and self.sample_drop_ratio > 0.0:
|
| 111 |
-
x = x + self.drop_path1(attn_residual_func(x))
|
| 112 |
-
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 113 |
-
else:
|
| 114 |
-
x = x + attn_residual_func(x)
|
| 115 |
-
x = x + ffn_residual_func(x)
|
| 116 |
-
return x
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def drop_add_residual_stochastic_depth(
|
| 120 |
-
x: Tensor,
|
| 121 |
-
residual_func: Callable[[Tensor], Tensor],
|
| 122 |
-
sample_drop_ratio: float = 0.0,
|
| 123 |
-
) -> Tensor:
|
| 124 |
-
# 1) extract subset using permutation
|
| 125 |
-
b, n, d = x.shape
|
| 126 |
-
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 127 |
-
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 128 |
-
x_subset = x[brange]
|
| 129 |
-
|
| 130 |
-
# 2) apply residual_func to get residual
|
| 131 |
-
residual = residual_func(x_subset)
|
| 132 |
-
|
| 133 |
-
x_flat = x.flatten(1)
|
| 134 |
-
residual = residual.flatten(1)
|
| 135 |
-
|
| 136 |
-
residual_scale_factor = b / sample_subset_size
|
| 137 |
-
|
| 138 |
-
# 3) add the residual
|
| 139 |
-
x_plus_residual = torch.index_add(
|
| 140 |
-
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| 141 |
-
)
|
| 142 |
-
return x_plus_residual.view_as(x)
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 146 |
-
b, n, d = x.shape
|
| 147 |
-
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 148 |
-
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 149 |
-
residual_scale_factor = b / sample_subset_size
|
| 150 |
-
return brange, residual_scale_factor
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 154 |
-
if scaling_vector is None:
|
| 155 |
-
x_flat = x.flatten(1)
|
| 156 |
-
residual = residual.flatten(1)
|
| 157 |
-
x_plus_residual = torch.index_add(
|
| 158 |
-
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| 159 |
-
)
|
| 160 |
-
else:
|
| 161 |
-
x_plus_residual = scaled_index_add(
|
| 162 |
-
x,
|
| 163 |
-
brange,
|
| 164 |
-
residual.to(dtype=x.dtype),
|
| 165 |
-
scaling=scaling_vector,
|
| 166 |
-
alpha=residual_scale_factor,
|
| 167 |
-
)
|
| 168 |
-
return x_plus_residual
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
def get_attn_bias_and_cat(x_list, branges=None):
|
| 175 |
-
"""
|
| 176 |
-
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 177 |
-
"""
|
| 178 |
-
batch_sizes = (
|
| 179 |
-
[b.shape[0] for b in branges]
|
| 180 |
-
if branges is not None
|
| 181 |
-
else [x.shape[0] for x in x_list]
|
| 182 |
-
)
|
| 183 |
-
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 184 |
-
if all_shapes not in attn_bias_cache.keys():
|
| 185 |
-
seqlens = []
|
| 186 |
-
for b, x in zip(batch_sizes, x_list):
|
| 187 |
-
for _ in range(b):
|
| 188 |
-
seqlens.append(x.shape[1])
|
| 189 |
-
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 190 |
-
attn_bias._batch_sizes = batch_sizes
|
| 191 |
-
attn_bias_cache[all_shapes] = attn_bias
|
| 192 |
-
|
| 193 |
-
if branges is not None:
|
| 194 |
-
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
|
| 195 |
-
1, -1, x_list[0].shape[-1]
|
| 196 |
-
)
|
| 197 |
-
else:
|
| 198 |
-
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 199 |
-
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 200 |
-
|
| 201 |
-
return attn_bias_cache[all_shapes], cat_tensors
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
def drop_add_residual_stochastic_depth_list(
|
| 205 |
-
x_list: List[Tensor],
|
| 206 |
-
residual_func: Callable[[Tensor, Any], Tensor],
|
| 207 |
-
sample_drop_ratio: float = 0.0,
|
| 208 |
-
scaling_vector=None,
|
| 209 |
-
) -> Tensor:
|
| 210 |
-
# 1) generate random set of indices for dropping samples in the batch
|
| 211 |
-
branges_scales = [
|
| 212 |
-
get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
|
| 213 |
-
]
|
| 214 |
-
branges = [s[0] for s in branges_scales]
|
| 215 |
-
residual_scale_factors = [s[1] for s in branges_scales]
|
| 216 |
-
|
| 217 |
-
# 2) get attention bias and index+concat the tensors
|
| 218 |
-
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 219 |
-
|
| 220 |
-
# 3) apply residual_func to get residual, and split the result
|
| 221 |
-
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 222 |
-
|
| 223 |
-
outputs = []
|
| 224 |
-
for x, brange, residual, residual_scale_factor in zip(
|
| 225 |
-
x_list, branges, residual_list, residual_scale_factors
|
| 226 |
-
):
|
| 227 |
-
outputs.append(
|
| 228 |
-
add_residual(
|
| 229 |
-
x, brange, residual, residual_scale_factor, scaling_vector
|
| 230 |
-
).view_as(x)
|
| 231 |
-
)
|
| 232 |
-
return outputs
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
class NestedTensorBlock(Block):
|
| 236 |
-
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 237 |
-
"""
|
| 238 |
-
x_list contains a list of tensors to nest together and run
|
| 239 |
-
"""
|
| 240 |
-
assert isinstance(self.attn, MemEffAttention)
|
| 241 |
-
|
| 242 |
-
if self.training and self.sample_drop_ratio > 0.0:
|
| 243 |
-
|
| 244 |
-
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 245 |
-
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 246 |
-
|
| 247 |
-
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 248 |
-
return self.mlp(self.norm2(x))
|
| 249 |
-
|
| 250 |
-
x_list = drop_add_residual_stochastic_depth_list(
|
| 251 |
-
x_list,
|
| 252 |
-
residual_func=attn_residual_func,
|
| 253 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
| 254 |
-
scaling_vector=self.ls1.gamma
|
| 255 |
-
if isinstance(self.ls1, LayerScale)
|
| 256 |
-
else None,
|
| 257 |
-
)
|
| 258 |
-
x_list = drop_add_residual_stochastic_depth_list(
|
| 259 |
-
x_list,
|
| 260 |
-
residual_func=ffn_residual_func,
|
| 261 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
| 262 |
-
scaling_vector=self.ls2.gamma
|
| 263 |
-
if isinstance(self.ls1, LayerScale)
|
| 264 |
-
else None,
|
| 265 |
-
)
|
| 266 |
-
return x_list
|
| 267 |
-
else:
|
| 268 |
-
|
| 269 |
-
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 270 |
-
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 271 |
-
|
| 272 |
-
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 273 |
-
return self.ls2(self.mlp(self.norm2(x)))
|
| 274 |
-
|
| 275 |
-
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 276 |
-
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 277 |
-
x = x + ffn_residual_func(x)
|
| 278 |
-
return attn_bias.split(x)
|
| 279 |
-
|
| 280 |
-
def forward(self, x_or_x_list):
|
| 281 |
-
if isinstance(x_or_x_list, Tensor):
|
| 282 |
-
return super().forward(x_or_x_list)
|
| 283 |
-
elif isinstance(x_or_x_list, list):
|
| 284 |
-
if not XFORMERS_AVAILABLE:
|
| 285 |
-
raise AssertionError("xFormers is required for using nested tensors")
|
| 286 |
-
return self.forward_nested(x_or_x_list)
|
| 287 |
-
else:
|
| 288 |
-
raise AssertionError
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
class BlockRope(nn.Module):
|
| 292 |
-
def __init__(
|
| 293 |
-
self,
|
| 294 |
-
dim: int,
|
| 295 |
-
num_heads: int,
|
| 296 |
-
mlp_ratio: float = 4.0,
|
| 297 |
-
qkv_bias: bool = False,
|
| 298 |
-
proj_bias: bool = True,
|
| 299 |
-
ffn_bias: bool = True,
|
| 300 |
-
drop: float = 0.0,
|
| 301 |
-
attn_drop: float = 0.0,
|
| 302 |
-
init_values=None,
|
| 303 |
-
drop_path: float = 0.0,
|
| 304 |
-
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 305 |
-
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 306 |
-
attn_class: Callable[..., nn.Module] = Attention,
|
| 307 |
-
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 308 |
-
qk_norm: bool = False,
|
| 309 |
-
rope=None,
|
| 310 |
-
) -> None:
|
| 311 |
-
super().__init__()
|
| 312 |
-
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 313 |
-
self.norm1 = norm_layer(dim)
|
| 314 |
-
self.attn = attn_class(
|
| 315 |
-
dim,
|
| 316 |
-
num_heads=num_heads,
|
| 317 |
-
qkv_bias=qkv_bias,
|
| 318 |
-
proj_bias=proj_bias,
|
| 319 |
-
attn_drop=attn_drop,
|
| 320 |
-
proj_drop=drop,
|
| 321 |
-
qk_norm=qk_norm,
|
| 322 |
-
rope=rope,
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
-
self.ls1 = (
|
| 326 |
-
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 327 |
-
)
|
| 328 |
-
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 329 |
-
|
| 330 |
-
self.norm2 = norm_layer(dim)
|
| 331 |
-
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 332 |
-
self.mlp = ffn_layer(
|
| 333 |
-
in_features=dim,
|
| 334 |
-
hidden_features=mlp_hidden_dim,
|
| 335 |
-
act_layer=act_layer,
|
| 336 |
-
drop=drop,
|
| 337 |
-
bias=ffn_bias,
|
| 338 |
-
)
|
| 339 |
-
self.ls2 = (
|
| 340 |
-
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 341 |
-
)
|
| 342 |
-
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 343 |
-
|
| 344 |
-
self.sample_drop_ratio = drop_path
|
| 345 |
-
|
| 346 |
-
def forward(self, x: Tensor, xpos=None) -> Tensor:
|
| 347 |
-
def attn_residual_func(x: Tensor) -> Tensor:
|
| 348 |
-
return self.ls1(self.attn(self.norm1(x), xpos=xpos))
|
| 349 |
-
|
| 350 |
-
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 351 |
-
return self.ls2(self.mlp(self.norm2(x)))
|
| 352 |
-
|
| 353 |
-
if self.training and self.sample_drop_ratio > 0.1:
|
| 354 |
-
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 355 |
-
x = drop_add_residual_stochastic_depth(
|
| 356 |
-
x,
|
| 357 |
-
residual_func=attn_residual_func,
|
| 358 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
| 359 |
-
)
|
| 360 |
-
x = drop_add_residual_stochastic_depth(
|
| 361 |
-
x,
|
| 362 |
-
residual_func=ffn_residual_func,
|
| 363 |
-
sample_drop_ratio=self.sample_drop_ratio,
|
| 364 |
-
)
|
| 365 |
-
elif self.training and self.sample_drop_ratio > 0.0:
|
| 366 |
-
x = x + self.drop_path1(attn_residual_func(x))
|
| 367 |
-
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 368 |
-
else:
|
| 369 |
-
x = x + attn_residual_func(x)
|
| 370 |
-
x = x + ffn_residual_func(x)
|
| 371 |
-
return x
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
class CrossBlockRope(nn.Module):
|
| 375 |
-
def __init__(
|
| 376 |
-
self,
|
| 377 |
-
dim: int,
|
| 378 |
-
num_heads: int,
|
| 379 |
-
mlp_ratio: float = 4.0,
|
| 380 |
-
qkv_bias: bool = False,
|
| 381 |
-
proj_bias: bool = True,
|
| 382 |
-
ffn_bias: bool = True,
|
| 383 |
-
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 384 |
-
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 385 |
-
attn_class: Callable[..., nn.Module] = Attention,
|
| 386 |
-
cross_attn_class: Callable[..., nn.Module] = CrossAttentionRope,
|
| 387 |
-
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 388 |
-
init_values=None,
|
| 389 |
-
qk_norm: bool = False,
|
| 390 |
-
rope=None,
|
| 391 |
-
) -> None:
|
| 392 |
-
super().__init__()
|
| 393 |
-
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 394 |
-
self.ls1 = (
|
| 395 |
-
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 396 |
-
)
|
| 397 |
-
self.norm1 = norm_layer(dim)
|
| 398 |
-
self.attn = attn_class(
|
| 399 |
-
dim,
|
| 400 |
-
num_heads=num_heads,
|
| 401 |
-
qkv_bias=qkv_bias,
|
| 402 |
-
proj_bias=proj_bias,
|
| 403 |
-
rope=rope,
|
| 404 |
-
qk_norm=qk_norm,
|
| 405 |
-
)
|
| 406 |
-
|
| 407 |
-
self.ls2 = (
|
| 408 |
-
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 409 |
-
)
|
| 410 |
-
self.ls_y = (
|
| 411 |
-
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 412 |
-
)
|
| 413 |
-
self.norm2 = norm_layer(dim)
|
| 414 |
-
self.norm_y = norm_layer(dim)
|
| 415 |
-
self.cross_attn = cross_attn_class(
|
| 416 |
-
dim,
|
| 417 |
-
num_heads=num_heads,
|
| 418 |
-
qkv_bias=qkv_bias,
|
| 419 |
-
proj_bias=proj_bias,
|
| 420 |
-
rope=rope,
|
| 421 |
-
qk_norm=qk_norm,
|
| 422 |
-
)
|
| 423 |
-
|
| 424 |
-
self.norm3 = norm_layer(dim)
|
| 425 |
-
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 426 |
-
self.mlp = ffn_layer(
|
| 427 |
-
in_features=dim,
|
| 428 |
-
hidden_features=mlp_hidden_dim,
|
| 429 |
-
act_layer=act_layer,
|
| 430 |
-
bias=ffn_bias,
|
| 431 |
-
)
|
| 432 |
-
|
| 433 |
-
def forward(self, x: Tensor, y: Tensor, xpos=None, ypos=None) -> Tensor:
|
| 434 |
-
def attn_residual_func(x: Tensor) -> Tensor:
|
| 435 |
-
return self.ls1(self.attn(self.norm1(x), xpos=xpos))
|
| 436 |
-
|
| 437 |
-
def cross_attn_residual_func(x: Tensor, y: Tensor) -> Tensor:
|
| 438 |
-
return self.ls_y(self.cross_attn(self.norm2(x), y, y, qpos=xpos, kpos=ypos))
|
| 439 |
-
|
| 440 |
-
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 441 |
-
return self.ls2(self.mlp(self.norm3(x)))
|
| 442 |
-
|
| 443 |
-
x = x + attn_residual_func(x)
|
| 444 |
-
y_ = self.norm_y(y)
|
| 445 |
-
x = x + cross_attn_residual_func(x, y_)
|
| 446 |
-
x = x + ffn_residual_func(x)
|
| 447 |
-
|
| 448 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/pi3/layers/camera_head.py
DELETED
|
@@ -1,106 +0,0 @@
|
|
| 1 |
-
from copy import deepcopy
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
# code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172'
|
| 9 |
-
class ResConvBlock(nn.Module):
|
| 10 |
-
"""
|
| 11 |
-
1x1 convolution residual block
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
def __init__(self, in_channels, out_channels):
|
| 15 |
-
super().__init__()
|
| 16 |
-
self.in_channels = in_channels
|
| 17 |
-
self.out_channels = out_channels
|
| 18 |
-
self.head_skip = (
|
| 19 |
-
nn.Identity()
|
| 20 |
-
if self.in_channels == self.out_channels
|
| 21 |
-
else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
|
| 22 |
-
)
|
| 23 |
-
# self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
|
| 24 |
-
# self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
|
| 25 |
-
# self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
|
| 26 |
-
|
| 27 |
-
# change 1x1 convolution to linear
|
| 28 |
-
self.res_conv1 = nn.Linear(self.in_channels, self.out_channels)
|
| 29 |
-
self.res_conv2 = nn.Linear(self.out_channels, self.out_channels)
|
| 30 |
-
self.res_conv3 = nn.Linear(self.out_channels, self.out_channels)
|
| 31 |
-
|
| 32 |
-
def forward(self, res):
|
| 33 |
-
x = F.relu(self.res_conv1(res))
|
| 34 |
-
x = F.relu(self.res_conv2(x))
|
| 35 |
-
x = F.relu(self.res_conv3(x))
|
| 36 |
-
res = self.head_skip(res) + x
|
| 37 |
-
return res
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class CameraHead(nn.Module):
|
| 41 |
-
def __init__(self, dim=512):
|
| 42 |
-
super().__init__()
|
| 43 |
-
output_dim = dim
|
| 44 |
-
self.res_conv = nn.ModuleList(
|
| 45 |
-
[deepcopy(ResConvBlock(output_dim, output_dim)) for _ in range(2)]
|
| 46 |
-
)
|
| 47 |
-
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 48 |
-
self.more_mlps = nn.Sequential(
|
| 49 |
-
nn.Linear(output_dim, output_dim),
|
| 50 |
-
nn.ReLU(),
|
| 51 |
-
nn.Linear(output_dim, output_dim),
|
| 52 |
-
nn.ReLU(),
|
| 53 |
-
)
|
| 54 |
-
self.fc_t = nn.Linear(output_dim, 3)
|
| 55 |
-
self.fc_rot = nn.Linear(output_dim, 9)
|
| 56 |
-
|
| 57 |
-
def forward(self, feat, patch_h, patch_w):
|
| 58 |
-
BN, hw, c = feat.shape
|
| 59 |
-
|
| 60 |
-
for i in range(2):
|
| 61 |
-
feat = self.res_conv[i](feat)
|
| 62 |
-
|
| 63 |
-
# feat = self.avgpool(feat)
|
| 64 |
-
feat = self.avgpool(
|
| 65 |
-
feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous()
|
| 66 |
-
) ##########
|
| 67 |
-
feat = feat.view(feat.size(0), -1)
|
| 68 |
-
|
| 69 |
-
feat = self.more_mlps(feat) # [B, D_]
|
| 70 |
-
with torch.amp.autocast(device_type="cuda", enabled=False):
|
| 71 |
-
out_t = self.fc_t(feat.float()) # [B,3]
|
| 72 |
-
out_r = self.fc_rot(feat.float()) # [B,9]
|
| 73 |
-
pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device)
|
| 74 |
-
|
| 75 |
-
return pose
|
| 76 |
-
|
| 77 |
-
def convert_pose_to_4x4(self, B, out_r, out_t, device):
|
| 78 |
-
out_r = self.svd_orthogonalize(out_r) # [N,3,3]
|
| 79 |
-
pose = torch.zeros((B, 4, 4), device=device)
|
| 80 |
-
pose[:, :3, :3] = out_r
|
| 81 |
-
pose[:, :3, 3] = out_t
|
| 82 |
-
pose[:, 3, 3] = 1.0
|
| 83 |
-
return pose
|
| 84 |
-
|
| 85 |
-
def svd_orthogonalize(self, m):
|
| 86 |
-
"""Convert 9D representation to SO(3) using SVD orthogonalization.
|
| 87 |
-
|
| 88 |
-
Args:
|
| 89 |
-
m: [BATCH, 3, 3] 3x3 matrices.
|
| 90 |
-
|
| 91 |
-
Returns:
|
| 92 |
-
[BATCH, 3, 3] SO(3) rotation matrices.
|
| 93 |
-
"""
|
| 94 |
-
if m.dim() < 3:
|
| 95 |
-
m = m.reshape((-1, 3, 3))
|
| 96 |
-
m_transpose = torch.transpose(
|
| 97 |
-
torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2
|
| 98 |
-
)
|
| 99 |
-
u, s, v = torch.svd(m_transpose)
|
| 100 |
-
det = torch.det(torch.matmul(v, u.transpose(-2, -1)))
|
| 101 |
-
# Check orientation reflection.
|
| 102 |
-
r = torch.matmul(
|
| 103 |
-
torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2),
|
| 104 |
-
u.transpose(-2, -1),
|
| 105 |
-
)
|
| 106 |
-
return r
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/pi3/layers/pos_embed.py
DELETED
|
@@ -1,190 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
# --------------------------------------------------------
|
| 6 |
-
# Position embedding utils
|
| 7 |
-
# --------------------------------------------------------
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
import numpy as np
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
# --------------------------------------------------------
|
| 15 |
-
# 2D sine-cosine position embedding
|
| 16 |
-
# References:
|
| 17 |
-
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 18 |
-
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
| 19 |
-
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
| 20 |
-
# --------------------------------------------------------
|
| 21 |
-
def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
|
| 22 |
-
"""
|
| 23 |
-
grid_size: int of the grid height and width
|
| 24 |
-
return:
|
| 25 |
-
pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 26 |
-
"""
|
| 27 |
-
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 28 |
-
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 29 |
-
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 30 |
-
grid = np.stack(grid, axis=0)
|
| 31 |
-
|
| 32 |
-
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 33 |
-
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 34 |
-
if n_cls_token > 0:
|
| 35 |
-
pos_embed = np.concatenate(
|
| 36 |
-
[np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0
|
| 37 |
-
)
|
| 38 |
-
return pos_embed
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 42 |
-
assert embed_dim % 2 == 0
|
| 43 |
-
|
| 44 |
-
# use half of dimensions to encode grid_h
|
| 45 |
-
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 46 |
-
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 47 |
-
|
| 48 |
-
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 49 |
-
return emb
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 53 |
-
"""
|
| 54 |
-
embed_dim: output dimension for each position
|
| 55 |
-
pos: a list of positions to be encoded: size (M,)
|
| 56 |
-
out: (M, D)
|
| 57 |
-
"""
|
| 58 |
-
assert embed_dim % 2 == 0
|
| 59 |
-
omega = np.arange(embed_dim // 2, dtype=float)
|
| 60 |
-
omega /= embed_dim / 2.0
|
| 61 |
-
omega = 1.0 / 10000**omega # (D/2,)
|
| 62 |
-
|
| 63 |
-
pos = pos.reshape(-1) # (M,)
|
| 64 |
-
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 65 |
-
|
| 66 |
-
emb_sin = np.sin(out) # (M, D/2)
|
| 67 |
-
emb_cos = np.cos(out) # (M, D/2)
|
| 68 |
-
|
| 69 |
-
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 70 |
-
return emb
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
# --------------------------------------------------------
|
| 74 |
-
# Interpolate position embeddings for high-resolution
|
| 75 |
-
# References:
|
| 76 |
-
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 77 |
-
# DeiT: https://github.com/facebookresearch/deit
|
| 78 |
-
# --------------------------------------------------------
|
| 79 |
-
def interpolate_pos_embed(model, checkpoint_model):
|
| 80 |
-
if "pos_embed" in checkpoint_model:
|
| 81 |
-
pos_embed_checkpoint = checkpoint_model["pos_embed"]
|
| 82 |
-
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 83 |
-
num_patches = model.patch_embed.num_patches
|
| 84 |
-
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 85 |
-
# height (== width) for the checkpoint position embedding
|
| 86 |
-
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 87 |
-
# height (== width) for the new position embedding
|
| 88 |
-
new_size = int(num_patches**0.5)
|
| 89 |
-
# class_token and dist_token are kept unchanged
|
| 90 |
-
if orig_size != new_size:
|
| 91 |
-
print(
|
| 92 |
-
"Position interpolate from %dx%d to %dx%d"
|
| 93 |
-
% (orig_size, orig_size, new_size, new_size)
|
| 94 |
-
)
|
| 95 |
-
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 96 |
-
# only the position tokens are interpolated
|
| 97 |
-
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 98 |
-
pos_tokens = pos_tokens.reshape(
|
| 99 |
-
-1, orig_size, orig_size, embedding_size
|
| 100 |
-
).permute(0, 3, 1, 2)
|
| 101 |
-
pos_tokens = torch.nn.functional.interpolate(
|
| 102 |
-
pos_tokens,
|
| 103 |
-
size=(new_size, new_size),
|
| 104 |
-
mode="bicubic",
|
| 105 |
-
align_corners=False,
|
| 106 |
-
)
|
| 107 |
-
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 108 |
-
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 109 |
-
checkpoint_model["pos_embed"] = new_pos_embed
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
# ----------------------------------------------------------
|
| 113 |
-
# RoPE2D: RoPE implementation in 2D
|
| 114 |
-
# ----------------------------------------------------------
|
| 115 |
-
|
| 116 |
-
try:
|
| 117 |
-
from models.curope import cuRoPE2D
|
| 118 |
-
|
| 119 |
-
RoPE2D = cuRoPE2D
|
| 120 |
-
except ImportError:
|
| 121 |
-
|
| 122 |
-
class RoPE2D(torch.nn.Module):
|
| 123 |
-
def __init__(self, freq=100.0, F0=1.0):
|
| 124 |
-
super().__init__()
|
| 125 |
-
self.base = freq
|
| 126 |
-
self.F0 = F0
|
| 127 |
-
self.cache = {}
|
| 128 |
-
|
| 129 |
-
def get_cos_sin(self, D, seq_len, device, dtype):
|
| 130 |
-
if (D, seq_len, device, dtype) not in self.cache:
|
| 131 |
-
inv_freq = 1.0 / (
|
| 132 |
-
self.base ** (torch.arange(0, D, 2).float().to(device) / D)
|
| 133 |
-
)
|
| 134 |
-
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 135 |
-
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
| 136 |
-
freqs = torch.cat((freqs, freqs), dim=-1)
|
| 137 |
-
cos = freqs.cos() # (Seq, Dim)
|
| 138 |
-
sin = freqs.sin()
|
| 139 |
-
self.cache[D, seq_len, device, dtype] = (cos, sin)
|
| 140 |
-
return self.cache[D, seq_len, device, dtype]
|
| 141 |
-
|
| 142 |
-
@staticmethod
|
| 143 |
-
def rotate_half(x):
|
| 144 |
-
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 145 |
-
return torch.cat((-x2, x1), dim=-1)
|
| 146 |
-
|
| 147 |
-
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
| 148 |
-
assert pos1d.ndim == 2
|
| 149 |
-
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
| 150 |
-
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
| 151 |
-
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
| 152 |
-
|
| 153 |
-
def forward(self, tokens, positions):
|
| 154 |
-
"""
|
| 155 |
-
input:
|
| 156 |
-
* tokens: batch_size x nheads x ntokens x dim
|
| 157 |
-
* positions: batch_size x ntokens x 2 (y and x position of each token)
|
| 158 |
-
output:
|
| 159 |
-
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
|
| 160 |
-
"""
|
| 161 |
-
assert tokens.size(3) % 2 == 0, (
|
| 162 |
-
"number of dimensions should be a multiple of two"
|
| 163 |
-
)
|
| 164 |
-
D = tokens.size(3) // 2
|
| 165 |
-
assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2
|
| 166 |
-
cos, sin = self.get_cos_sin(
|
| 167 |
-
D, int(positions.max()) + 1, tokens.device, tokens.dtype
|
| 168 |
-
)
|
| 169 |
-
# split features into two along the feature dimension, and apply rope1d on each half
|
| 170 |
-
y, x = tokens.chunk(2, dim=-1)
|
| 171 |
-
y = self.apply_rope1d(y, positions[:, :, 0], cos, sin)
|
| 172 |
-
x = self.apply_rope1d(x, positions[:, :, 1], cos, sin)
|
| 173 |
-
tokens = torch.cat((y, x), dim=-1)
|
| 174 |
-
return tokens
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
# patch embedding
|
| 178 |
-
class PositionGetter(object):
|
| 179 |
-
"""return positions of patches"""
|
| 180 |
-
|
| 181 |
-
def __init__(self):
|
| 182 |
-
self.cache_positions = {}
|
| 183 |
-
|
| 184 |
-
def __call__(self, b, h, w, device):
|
| 185 |
-
if (h, w) not in self.cache_positions:
|
| 186 |
-
x = torch.arange(w, device=device)
|
| 187 |
-
y = torch.arange(h, device=device)
|
| 188 |
-
self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2)
|
| 189 |
-
pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone()
|
| 190 |
-
return pos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/pi3/layers/transformer_head.py
DELETED
|
@@ -1,98 +0,0 @@
|
|
| 1 |
-
from functools import partial
|
| 2 |
-
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
from torch.utils.checkpoint import checkpoint
|
| 6 |
-
|
| 7 |
-
from mapanything.models.external.dinov2.layers import Mlp
|
| 8 |
-
from mapanything.models.external.pi3.layers.attention import FlashAttentionRope
|
| 9 |
-
from mapanything.models.external.pi3.layers.block import BlockRope
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class TransformerDecoder(nn.Module):
|
| 13 |
-
def __init__(
|
| 14 |
-
self,
|
| 15 |
-
in_dim,
|
| 16 |
-
out_dim,
|
| 17 |
-
dec_embed_dim=512,
|
| 18 |
-
depth=5,
|
| 19 |
-
dec_num_heads=8,
|
| 20 |
-
mlp_ratio=4,
|
| 21 |
-
rope=None,
|
| 22 |
-
need_project=True,
|
| 23 |
-
use_checkpoint=False,
|
| 24 |
-
):
|
| 25 |
-
super().__init__()
|
| 26 |
-
|
| 27 |
-
self.projects = (
|
| 28 |
-
nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
|
| 29 |
-
)
|
| 30 |
-
self.use_checkpoint = use_checkpoint
|
| 31 |
-
|
| 32 |
-
self.blocks = nn.ModuleList(
|
| 33 |
-
[
|
| 34 |
-
BlockRope(
|
| 35 |
-
dim=dec_embed_dim,
|
| 36 |
-
num_heads=dec_num_heads,
|
| 37 |
-
mlp_ratio=mlp_ratio,
|
| 38 |
-
qkv_bias=True,
|
| 39 |
-
proj_bias=True,
|
| 40 |
-
ffn_bias=True,
|
| 41 |
-
drop_path=0.0,
|
| 42 |
-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 43 |
-
act_layer=nn.GELU,
|
| 44 |
-
ffn_layer=Mlp,
|
| 45 |
-
init_values=None,
|
| 46 |
-
qk_norm=False,
|
| 47 |
-
# attn_class=MemEffAttentionRope,
|
| 48 |
-
attn_class=FlashAttentionRope,
|
| 49 |
-
rope=rope,
|
| 50 |
-
)
|
| 51 |
-
for _ in range(depth)
|
| 52 |
-
]
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
self.linear_out = nn.Linear(dec_embed_dim, out_dim)
|
| 56 |
-
|
| 57 |
-
def forward(self, hidden, xpos=None):
|
| 58 |
-
hidden = self.projects(hidden)
|
| 59 |
-
for i, blk in enumerate(self.blocks):
|
| 60 |
-
if self.use_checkpoint and self.training:
|
| 61 |
-
hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False)
|
| 62 |
-
else:
|
| 63 |
-
hidden = blk(hidden, xpos=xpos)
|
| 64 |
-
out = self.linear_out(hidden)
|
| 65 |
-
return out
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
class LinearPts3d(nn.Module):
|
| 69 |
-
"""
|
| 70 |
-
Linear head for dust3r
|
| 71 |
-
Each token outputs: - 16x16 3D points (+ confidence)
|
| 72 |
-
"""
|
| 73 |
-
|
| 74 |
-
def __init__(
|
| 75 |
-
self,
|
| 76 |
-
patch_size,
|
| 77 |
-
dec_embed_dim,
|
| 78 |
-
output_dim=3,
|
| 79 |
-
):
|
| 80 |
-
super().__init__()
|
| 81 |
-
self.patch_size = patch_size
|
| 82 |
-
|
| 83 |
-
self.proj = nn.Linear(dec_embed_dim, (output_dim) * self.patch_size**2)
|
| 84 |
-
|
| 85 |
-
def forward(self, decout, img_shape):
|
| 86 |
-
H, W = img_shape
|
| 87 |
-
tokens = decout[-1]
|
| 88 |
-
B, S, D = tokens.shape
|
| 89 |
-
|
| 90 |
-
# extract 3D points
|
| 91 |
-
feat = self.proj(tokens) # B,S,D
|
| 92 |
-
feat = feat.transpose(-1, -2).view(
|
| 93 |
-
B, -1, H // self.patch_size, W // self.patch_size
|
| 94 |
-
)
|
| 95 |
-
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
|
| 96 |
-
|
| 97 |
-
# permute + norm depth
|
| 98 |
-
return feat.permute(0, 2, 3, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/pi3/models/__init__.py
DELETED
|
File without changes
|
mapanything/models/external/pi3/models/pi3.py
DELETED
|
@@ -1,251 +0,0 @@
|
|
| 1 |
-
from copy import deepcopy
|
| 2 |
-
from functools import partial
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
from huggingface_hub import PyTorchModelHubMixin
|
| 7 |
-
|
| 8 |
-
from mapanything.models.external.dinov2.hub.backbones import dinov2_vitl14_reg
|
| 9 |
-
from mapanything.models.external.dinov2.layers import Mlp
|
| 10 |
-
from mapanything.models.external.pi3.layers.attention import FlashAttentionRope
|
| 11 |
-
from mapanything.models.external.pi3.layers.block import BlockRope
|
| 12 |
-
from mapanything.models.external.pi3.layers.camera_head import CameraHead
|
| 13 |
-
from mapanything.models.external.pi3.layers.pos_embed import PositionGetter, RoPE2D
|
| 14 |
-
from mapanything.models.external.pi3.layers.transformer_head import (
|
| 15 |
-
LinearPts3d,
|
| 16 |
-
TransformerDecoder,
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def homogenize_points(
|
| 21 |
-
points,
|
| 22 |
-
):
|
| 23 |
-
"""Convert batched points (xyz) to (xyz1)."""
|
| 24 |
-
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class Pi3(nn.Module, PyTorchModelHubMixin):
|
| 28 |
-
def __init__(
|
| 29 |
-
self,
|
| 30 |
-
pos_type="rope100",
|
| 31 |
-
decoder_size="large",
|
| 32 |
-
):
|
| 33 |
-
super().__init__()
|
| 34 |
-
|
| 35 |
-
# ----------------------
|
| 36 |
-
# Encoder
|
| 37 |
-
# ----------------------
|
| 38 |
-
self.encoder = dinov2_vitl14_reg(pretrained=False)
|
| 39 |
-
self.patch_size = 14
|
| 40 |
-
del self.encoder.mask_token
|
| 41 |
-
|
| 42 |
-
# ----------------------
|
| 43 |
-
# Positonal Encoding
|
| 44 |
-
# ----------------------
|
| 45 |
-
self.pos_type = pos_type if pos_type is not None else "none"
|
| 46 |
-
self.rope = None
|
| 47 |
-
if self.pos_type.startswith("rope"): # eg rope100
|
| 48 |
-
if RoPE2D is None:
|
| 49 |
-
raise ImportError(
|
| 50 |
-
"Cannot find cuRoPE2D, please install it following the README instructions"
|
| 51 |
-
)
|
| 52 |
-
freq = float(self.pos_type[len("rope") :])
|
| 53 |
-
self.rope = RoPE2D(freq=freq)
|
| 54 |
-
self.position_getter = PositionGetter()
|
| 55 |
-
else:
|
| 56 |
-
raise NotImplementedError
|
| 57 |
-
|
| 58 |
-
# ----------------------
|
| 59 |
-
# Decoder
|
| 60 |
-
# ----------------------
|
| 61 |
-
if decoder_size == "small":
|
| 62 |
-
dec_embed_dim = 384
|
| 63 |
-
dec_num_heads = 6
|
| 64 |
-
mlp_ratio = 4
|
| 65 |
-
dec_depth = 24
|
| 66 |
-
elif decoder_size == "base":
|
| 67 |
-
dec_embed_dim = 768
|
| 68 |
-
dec_num_heads = 12
|
| 69 |
-
mlp_ratio = 4
|
| 70 |
-
dec_depth = 24
|
| 71 |
-
elif decoder_size == "large":
|
| 72 |
-
dec_embed_dim = 1024
|
| 73 |
-
dec_num_heads = 16
|
| 74 |
-
mlp_ratio = 4
|
| 75 |
-
dec_depth = 36
|
| 76 |
-
else:
|
| 77 |
-
raise NotImplementedError
|
| 78 |
-
self.decoder = nn.ModuleList(
|
| 79 |
-
[
|
| 80 |
-
BlockRope(
|
| 81 |
-
dim=dec_embed_dim,
|
| 82 |
-
num_heads=dec_num_heads,
|
| 83 |
-
mlp_ratio=mlp_ratio,
|
| 84 |
-
qkv_bias=True,
|
| 85 |
-
proj_bias=True,
|
| 86 |
-
ffn_bias=True,
|
| 87 |
-
drop_path=0.0,
|
| 88 |
-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 89 |
-
act_layer=nn.GELU,
|
| 90 |
-
ffn_layer=Mlp,
|
| 91 |
-
init_values=0.01,
|
| 92 |
-
qk_norm=True,
|
| 93 |
-
attn_class=FlashAttentionRope,
|
| 94 |
-
rope=self.rope,
|
| 95 |
-
)
|
| 96 |
-
for _ in range(dec_depth)
|
| 97 |
-
]
|
| 98 |
-
)
|
| 99 |
-
self.dec_embed_dim = dec_embed_dim
|
| 100 |
-
|
| 101 |
-
# ----------------------
|
| 102 |
-
# Register_token
|
| 103 |
-
# ----------------------
|
| 104 |
-
num_register_tokens = 5
|
| 105 |
-
self.patch_start_idx = num_register_tokens
|
| 106 |
-
self.register_token = nn.Parameter(
|
| 107 |
-
torch.randn(1, 1, num_register_tokens, self.dec_embed_dim)
|
| 108 |
-
)
|
| 109 |
-
nn.init.normal_(self.register_token, std=1e-6)
|
| 110 |
-
|
| 111 |
-
# ----------------------
|
| 112 |
-
# Local Points Decoder
|
| 113 |
-
# ----------------------
|
| 114 |
-
self.point_decoder = TransformerDecoder(
|
| 115 |
-
in_dim=2 * self.dec_embed_dim,
|
| 116 |
-
dec_embed_dim=1024,
|
| 117 |
-
dec_num_heads=16,
|
| 118 |
-
out_dim=1024,
|
| 119 |
-
rope=self.rope,
|
| 120 |
-
)
|
| 121 |
-
self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3)
|
| 122 |
-
|
| 123 |
-
# ----------------------
|
| 124 |
-
# Conf Decoder
|
| 125 |
-
# ----------------------
|
| 126 |
-
self.conf_decoder = deepcopy(self.point_decoder)
|
| 127 |
-
self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1)
|
| 128 |
-
|
| 129 |
-
# ----------------------
|
| 130 |
-
# Camera Pose Decoder
|
| 131 |
-
# ----------------------
|
| 132 |
-
self.camera_decoder = TransformerDecoder(
|
| 133 |
-
in_dim=2 * self.dec_embed_dim,
|
| 134 |
-
dec_embed_dim=1024,
|
| 135 |
-
dec_num_heads=16, # 8
|
| 136 |
-
out_dim=512,
|
| 137 |
-
rope=self.rope,
|
| 138 |
-
use_checkpoint=False,
|
| 139 |
-
)
|
| 140 |
-
self.camera_head = CameraHead(dim=512)
|
| 141 |
-
|
| 142 |
-
# For ImageNet Normalize
|
| 143 |
-
image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
| 144 |
-
image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
| 145 |
-
|
| 146 |
-
self.register_buffer("image_mean", image_mean)
|
| 147 |
-
self.register_buffer("image_std", image_std)
|
| 148 |
-
|
| 149 |
-
def decode(self, hidden, N, H, W):
|
| 150 |
-
BN, hw, _ = hidden.shape
|
| 151 |
-
B = BN // N
|
| 152 |
-
|
| 153 |
-
final_output = []
|
| 154 |
-
|
| 155 |
-
hidden = hidden.reshape(B * N, hw, -1)
|
| 156 |
-
|
| 157 |
-
register_token = self.register_token.repeat(B, N, 1, 1).reshape(
|
| 158 |
-
B * N, *self.register_token.shape[-2:]
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
-
# Concatenate special tokens with patch tokens
|
| 162 |
-
hidden = torch.cat([register_token, hidden], dim=1)
|
| 163 |
-
hw = hidden.shape[1]
|
| 164 |
-
|
| 165 |
-
if self.pos_type.startswith("rope"):
|
| 166 |
-
pos = self.position_getter(
|
| 167 |
-
B * N, H // self.patch_size, W // self.patch_size, hidden.device
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
if self.patch_start_idx > 0:
|
| 171 |
-
# do not use position embedding for special tokens (camera and register tokens)
|
| 172 |
-
# so set pos to 0 for the special tokens
|
| 173 |
-
pos = pos + 1
|
| 174 |
-
pos_special = (
|
| 175 |
-
torch.zeros(B * N, self.patch_start_idx, 2)
|
| 176 |
-
.to(hidden.device)
|
| 177 |
-
.to(pos.dtype)
|
| 178 |
-
)
|
| 179 |
-
pos = torch.cat([pos_special, pos], dim=1)
|
| 180 |
-
|
| 181 |
-
for i in range(len(self.decoder)):
|
| 182 |
-
blk = self.decoder[i]
|
| 183 |
-
|
| 184 |
-
if i % 2 == 0:
|
| 185 |
-
pos = pos.reshape(B * N, hw, -1)
|
| 186 |
-
hidden = hidden.reshape(B * N, hw, -1)
|
| 187 |
-
else:
|
| 188 |
-
pos = pos.reshape(B, N * hw, -1)
|
| 189 |
-
hidden = hidden.reshape(B, N * hw, -1)
|
| 190 |
-
|
| 191 |
-
hidden = blk(hidden, xpos=pos)
|
| 192 |
-
|
| 193 |
-
if i + 1 in [len(self.decoder) - 1, len(self.decoder)]:
|
| 194 |
-
final_output.append(hidden.reshape(B * N, hw, -1))
|
| 195 |
-
|
| 196 |
-
return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(
|
| 197 |
-
B * N, hw, -1
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
def forward(self, imgs):
|
| 201 |
-
imgs = (imgs - self.image_mean) / self.image_std
|
| 202 |
-
|
| 203 |
-
B, N, _, H, W = imgs.shape
|
| 204 |
-
patch_h, patch_w = H // 14, W // 14
|
| 205 |
-
|
| 206 |
-
# encode by dinov2
|
| 207 |
-
imgs = imgs.reshape(B * N, _, H, W)
|
| 208 |
-
hidden = self.encoder(imgs, is_training=True)
|
| 209 |
-
|
| 210 |
-
if isinstance(hidden, dict):
|
| 211 |
-
hidden = hidden["x_norm_patchtokens"]
|
| 212 |
-
|
| 213 |
-
hidden, pos = self.decode(hidden, N, H, W)
|
| 214 |
-
|
| 215 |
-
point_hidden = self.point_decoder(hidden, xpos=pos)
|
| 216 |
-
conf_hidden = self.conf_decoder(hidden, xpos=pos)
|
| 217 |
-
camera_hidden = self.camera_decoder(hidden, xpos=pos)
|
| 218 |
-
|
| 219 |
-
with torch.amp.autocast(device_type="cuda", enabled=False):
|
| 220 |
-
# local points
|
| 221 |
-
point_hidden = point_hidden.float()
|
| 222 |
-
ret = self.point_head(
|
| 223 |
-
[point_hidden[:, self.patch_start_idx :]], (H, W)
|
| 224 |
-
).reshape(B, N, H, W, -1)
|
| 225 |
-
xy, z = ret.split([2, 1], dim=-1)
|
| 226 |
-
z = torch.exp(z)
|
| 227 |
-
local_points = torch.cat([xy * z, z], dim=-1)
|
| 228 |
-
|
| 229 |
-
# confidence
|
| 230 |
-
conf_hidden = conf_hidden.float()
|
| 231 |
-
conf = self.conf_head(
|
| 232 |
-
[conf_hidden[:, self.patch_start_idx :]], (H, W)
|
| 233 |
-
).reshape(B, N, H, W, -1)
|
| 234 |
-
|
| 235 |
-
# camera
|
| 236 |
-
camera_hidden = camera_hidden.float()
|
| 237 |
-
camera_poses = self.camera_head(
|
| 238 |
-
camera_hidden[:, self.patch_start_idx :], patch_h, patch_w
|
| 239 |
-
).reshape(B, N, 4, 4)
|
| 240 |
-
|
| 241 |
-
# unproject local points using camera poses
|
| 242 |
-
points = torch.einsum(
|
| 243 |
-
"bnij, bnhwj -> bnhwi", camera_poses, homogenize_points(local_points)
|
| 244 |
-
)[..., :3]
|
| 245 |
-
|
| 246 |
-
return dict(
|
| 247 |
-
points=points,
|
| 248 |
-
local_points=local_points,
|
| 249 |
-
conf=conf,
|
| 250 |
-
camera_poses=camera_poses,
|
| 251 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/pow3r/__init__.py
DELETED
|
@@ -1,860 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Inference wrapper for Pow3R
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import warnings
|
| 6 |
-
from copy import deepcopy
|
| 7 |
-
|
| 8 |
-
import pow3r.model.blocks # noqa
|
| 9 |
-
import roma
|
| 10 |
-
import torch
|
| 11 |
-
import torch.nn as nn
|
| 12 |
-
import tqdm
|
| 13 |
-
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
| 14 |
-
from dust3r.image_pairs import make_pairs
|
| 15 |
-
from dust3r.inference import check_if_same_size
|
| 16 |
-
from dust3r.model import CroCoNet
|
| 17 |
-
from dust3r.patch_embed import get_patch_embed as dust3r_patch_embed
|
| 18 |
-
from dust3r.utils.device import collate_with_cat, to_cpu
|
| 19 |
-
from dust3r.utils.misc import (
|
| 20 |
-
fill_default_args,
|
| 21 |
-
freeze_all_params,
|
| 22 |
-
interleave,
|
| 23 |
-
is_symmetrized,
|
| 24 |
-
transpose_to_landscape,
|
| 25 |
-
)
|
| 26 |
-
from pow3r.model.blocks import Block, BlockInject, DecoderBlock, DecoderBlockInject, Mlp
|
| 27 |
-
from pow3r.model.heads import head_factory
|
| 28 |
-
from pow3r.model.inference import (
|
| 29 |
-
add_depth,
|
| 30 |
-
add_intrinsics,
|
| 31 |
-
add_relpose,
|
| 32 |
-
)
|
| 33 |
-
from pow3r.model.patch_embed import get_patch_embed
|
| 34 |
-
|
| 35 |
-
from mapanything.models.external.vggt.utils.rotation import mat_to_quat
|
| 36 |
-
from mapanything.utils.geometry import (
|
| 37 |
-
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
|
| 38 |
-
convert_z_depth_to_depth_along_ray,
|
| 39 |
-
depthmap_to_camera_frame,
|
| 40 |
-
get_rays_in_camera_frame,
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
class Pow3R(CroCoNet):
|
| 45 |
-
"""Two siamese encoders, followed by two decoders.
|
| 46 |
-
The goal is to output 3d points directly, both images in view1's frame
|
| 47 |
-
(hence the asymmetry).
|
| 48 |
-
"""
|
| 49 |
-
|
| 50 |
-
def __init__(
|
| 51 |
-
self,
|
| 52 |
-
mode="embed",
|
| 53 |
-
head_type="linear",
|
| 54 |
-
patch_embed_cls="PatchEmbedDust3R",
|
| 55 |
-
freeze="none",
|
| 56 |
-
landscape_only=True,
|
| 57 |
-
**croco_kwargs,
|
| 58 |
-
):
|
| 59 |
-
# retrieve all default arguments using python magic
|
| 60 |
-
self.croco_args = fill_default_args(croco_kwargs, super().__init__)
|
| 61 |
-
super().__init__(**croco_kwargs)
|
| 62 |
-
del self.mask_token # useless
|
| 63 |
-
del self.prediction_head
|
| 64 |
-
|
| 65 |
-
dec_dim, enc_dim = self.decoder_embed.weight.shape
|
| 66 |
-
self.enc_embed_dim = enc_dim
|
| 67 |
-
self.dec_embed_dim = dec_dim
|
| 68 |
-
|
| 69 |
-
self.mode = mode
|
| 70 |
-
# additional parameters in the encoder
|
| 71 |
-
img_size = self.patch_embed.img_size
|
| 72 |
-
patch_size = self.patch_embed.patch_size[0]
|
| 73 |
-
self.patch_embed = dust3r_patch_embed(
|
| 74 |
-
patch_embed_cls, img_size, patch_size, self.enc_embed_dim
|
| 75 |
-
)
|
| 76 |
-
self.patch_embed_rays = get_patch_embed(
|
| 77 |
-
patch_embed_cls + "_Mlp",
|
| 78 |
-
img_size,
|
| 79 |
-
patch_size,
|
| 80 |
-
self.enc_embed_dim,
|
| 81 |
-
in_chans=3,
|
| 82 |
-
)
|
| 83 |
-
self.patch_embed_depth = get_patch_embed(
|
| 84 |
-
patch_embed_cls + "_Mlp",
|
| 85 |
-
img_size,
|
| 86 |
-
patch_size,
|
| 87 |
-
self.enc_embed_dim,
|
| 88 |
-
in_chans=2,
|
| 89 |
-
)
|
| 90 |
-
self.pose_embed = Mlp(12, 4 * dec_dim, dec_dim)
|
| 91 |
-
|
| 92 |
-
# additional parameters in the decoder
|
| 93 |
-
self.dec_cls = "_cls" in self.mode
|
| 94 |
-
self.dec_num_cls = 0
|
| 95 |
-
if self.dec_cls:
|
| 96 |
-
# use a CLS token in the decoder only
|
| 97 |
-
self.mode = self.mode.replace("_cls", "")
|
| 98 |
-
self.cls_token1 = nn.Parameter(torch.zeros((dec_dim,)))
|
| 99 |
-
self.cls_token2 = nn.Parameter(torch.zeros((dec_dim,)))
|
| 100 |
-
self.dec_num_cls = 1 # affects all blocks
|
| 101 |
-
|
| 102 |
-
use_ln = "_ln" in self.mode # TODO remove?
|
| 103 |
-
self.patch_ln = nn.LayerNorm(enc_dim) if use_ln else nn.Identity()
|
| 104 |
-
self.dec1_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity()
|
| 105 |
-
self.dec2_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity()
|
| 106 |
-
|
| 107 |
-
self.dec_blocks2 = deepcopy(self.dec_blocks)
|
| 108 |
-
|
| 109 |
-
# here we modify some of the blocks
|
| 110 |
-
self.replace_some_blocks()
|
| 111 |
-
|
| 112 |
-
self.set_downstream_head(head_type, landscape_only, **croco_kwargs)
|
| 113 |
-
self.set_freeze(freeze)
|
| 114 |
-
|
| 115 |
-
def replace_some_blocks(self):
|
| 116 |
-
assert self.mode.startswith("inject") # inject[0,0.5]
|
| 117 |
-
NewBlock = BlockInject
|
| 118 |
-
DecoderNewBlock = DecoderBlockInject
|
| 119 |
-
|
| 120 |
-
all_layers = {
|
| 121 |
-
i / n
|
| 122 |
-
for i in range(len(self.enc_blocks))
|
| 123 |
-
for n in [len(self.enc_blocks), len(self.dec_blocks)]
|
| 124 |
-
}
|
| 125 |
-
which_layers = eval(self.mode[self.mode.find("[") :]) or all_layers
|
| 126 |
-
assert isinstance(which_layers, (set, list))
|
| 127 |
-
|
| 128 |
-
n = 0
|
| 129 |
-
for i, block in enumerate(self.enc_blocks):
|
| 130 |
-
if i / len(self.enc_blocks) in which_layers:
|
| 131 |
-
block.__class__ = NewBlock
|
| 132 |
-
block.init(self.enc_embed_dim)
|
| 133 |
-
n += 1
|
| 134 |
-
else:
|
| 135 |
-
block.__class__ = Block
|
| 136 |
-
assert n == len(which_layers), breakpoint()
|
| 137 |
-
|
| 138 |
-
n = 0
|
| 139 |
-
for i in range(len(self.dec_blocks)):
|
| 140 |
-
for blocks in [self.dec_blocks, self.dec_blocks2]:
|
| 141 |
-
block = blocks[i]
|
| 142 |
-
if i / len(self.dec_blocks) in which_layers:
|
| 143 |
-
block.__class__ = DecoderNewBlock
|
| 144 |
-
block.init(self.dec_embed_dim)
|
| 145 |
-
n += 1
|
| 146 |
-
else:
|
| 147 |
-
block.__class__ = DecoderBlock
|
| 148 |
-
assert n == 2 * len(which_layers), breakpoint()
|
| 149 |
-
|
| 150 |
-
@classmethod
|
| 151 |
-
def from_pretrained(cls, pretrained_model_path, **kw):
|
| 152 |
-
return _load_model(pretrained_model_path, device="cpu")
|
| 153 |
-
|
| 154 |
-
def load_state_dict(self, ckpt, **kw):
|
| 155 |
-
# duplicate all weights for the second decoder if not present
|
| 156 |
-
new_ckpt = dict(ckpt)
|
| 157 |
-
if not any(k.startswith("dec_blocks2") for k in ckpt):
|
| 158 |
-
for key, value in ckpt.items():
|
| 159 |
-
if key.startswith("dec_blocks"):
|
| 160 |
-
new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value
|
| 161 |
-
# remove layers that have different shapes
|
| 162 |
-
cur_ckpt = self.state_dict()
|
| 163 |
-
for key, val in ckpt.items():
|
| 164 |
-
if key.startswith("downstream_head2.proj"):
|
| 165 |
-
if key in cur_ckpt and cur_ckpt[key].shape != val.shape:
|
| 166 |
-
print(f" (removing ckpt[{key}] because wrong shape)")
|
| 167 |
-
del new_ckpt[key]
|
| 168 |
-
return super().load_state_dict(new_ckpt, **kw)
|
| 169 |
-
|
| 170 |
-
def set_freeze(self, freeze): # this is for use by downstream models
|
| 171 |
-
self.freeze = freeze
|
| 172 |
-
to_be_frozen = {
|
| 173 |
-
"none": [],
|
| 174 |
-
"encoder": [self.patch_embed, self.enc_blocks],
|
| 175 |
-
}
|
| 176 |
-
freeze_all_params(to_be_frozen[freeze])
|
| 177 |
-
|
| 178 |
-
def set_prediction_head(self, *args, **kwargs):
|
| 179 |
-
"""No prediction head"""
|
| 180 |
-
return
|
| 181 |
-
|
| 182 |
-
def set_downstream_head(
|
| 183 |
-
self,
|
| 184 |
-
head_type,
|
| 185 |
-
landscape_only,
|
| 186 |
-
patch_size,
|
| 187 |
-
img_size,
|
| 188 |
-
mlp_ratio,
|
| 189 |
-
dec_depth,
|
| 190 |
-
**kw,
|
| 191 |
-
):
|
| 192 |
-
assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, (
|
| 193 |
-
f"{img_size=} must be multiple of {patch_size=}"
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
# split heads if different
|
| 197 |
-
heads = head_type.split(";")
|
| 198 |
-
assert len(heads) in (1, 2)
|
| 199 |
-
head1_type, head2_type = (heads + heads)[:2]
|
| 200 |
-
|
| 201 |
-
# allocate heads
|
| 202 |
-
self.downstream_head1 = head_factory(head1_type, self)
|
| 203 |
-
self.downstream_head2 = head_factory(head2_type, self)
|
| 204 |
-
|
| 205 |
-
# magic wrapper
|
| 206 |
-
self.head1 = transpose_to_landscape(
|
| 207 |
-
self.downstream_head1, activate=landscape_only
|
| 208 |
-
)
|
| 209 |
-
self.head2 = transpose_to_landscape(
|
| 210 |
-
self.downstream_head2, activate=landscape_only
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
def _encode_image(self, image, true_shape, rays=None, depth=None):
|
| 214 |
-
# embed the image into patches (x has size B x Npatches x C)
|
| 215 |
-
x, pos = self.patch_embed(image, true_shape=true_shape)
|
| 216 |
-
|
| 217 |
-
if rays is not None: # B,3,H,W
|
| 218 |
-
rays_emb, pos2 = self.patch_embed_rays(rays, true_shape=true_shape)
|
| 219 |
-
assert (pos == pos2).all()
|
| 220 |
-
if self.mode.startswith("embed"):
|
| 221 |
-
x = x + rays_emb
|
| 222 |
-
else:
|
| 223 |
-
rays_emb = None
|
| 224 |
-
|
| 225 |
-
if depth is not None: # B,2,H,W
|
| 226 |
-
depth_emb, pos2 = self.patch_embed_depth(depth, true_shape=true_shape)
|
| 227 |
-
assert (pos == pos2).all()
|
| 228 |
-
if self.mode.startswith("embed"):
|
| 229 |
-
x = x + depth_emb
|
| 230 |
-
else:
|
| 231 |
-
depth_emb = None
|
| 232 |
-
|
| 233 |
-
x = self.patch_ln(x)
|
| 234 |
-
|
| 235 |
-
# add positional embedding without cls token
|
| 236 |
-
assert self.enc_pos_embed is None
|
| 237 |
-
|
| 238 |
-
# now apply the transformer encoder and normalization
|
| 239 |
-
for blk in self.enc_blocks:
|
| 240 |
-
x = blk(x, pos, rays=rays_emb, depth=depth_emb)
|
| 241 |
-
|
| 242 |
-
x = self.enc_norm(x)
|
| 243 |
-
return x, pos
|
| 244 |
-
|
| 245 |
-
def encode_symmetrized(self, view1, view2):
|
| 246 |
-
img1 = view1["img"]
|
| 247 |
-
img2 = view2["img"]
|
| 248 |
-
B = img1.shape[0]
|
| 249 |
-
# Recover true_shape when available, otherwise assume that the img shape is the true one
|
| 250 |
-
shape1 = view1.get(
|
| 251 |
-
"true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1)
|
| 252 |
-
)
|
| 253 |
-
shape2 = view2.get(
|
| 254 |
-
"true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1)
|
| 255 |
-
)
|
| 256 |
-
# warning! maybe the images have different portrait/landscape orientations
|
| 257 |
-
|
| 258 |
-
# privileged information
|
| 259 |
-
rays1 = view1.get("known_rays", None)
|
| 260 |
-
rays2 = view2.get("known_rays", None)
|
| 261 |
-
depth1 = view1.get("known_depth", None)
|
| 262 |
-
depth2 = view2.get("known_depth", None)
|
| 263 |
-
|
| 264 |
-
if is_symmetrized(view1, view2):
|
| 265 |
-
# computing half of forward pass!'
|
| 266 |
-
def hsub(x):
|
| 267 |
-
return None if x is None else x[::2]
|
| 268 |
-
|
| 269 |
-
feat1, pos1 = self._encode_image(
|
| 270 |
-
img1[::2], shape1[::2], rays=hsub(rays1), depth=hsub(depth1)
|
| 271 |
-
)
|
| 272 |
-
feat2, pos2 = self._encode_image(
|
| 273 |
-
img2[::2], shape2[::2], rays=hsub(rays2), depth=hsub(depth2)
|
| 274 |
-
)
|
| 275 |
-
|
| 276 |
-
feat1, feat2 = interleave(feat1, feat2)
|
| 277 |
-
pos1, pos2 = interleave(pos1, pos2)
|
| 278 |
-
else:
|
| 279 |
-
feat1, pos1 = self._encode_image(img1, shape1, rays=rays1, depth=depth1)
|
| 280 |
-
feat2, pos2 = self._encode_image(img2, shape2, rays=rays2, depth=depth2)
|
| 281 |
-
|
| 282 |
-
return (shape1, shape2), (feat1, feat2), (pos1, pos2)
|
| 283 |
-
|
| 284 |
-
def _decoder(self, f1, pos1, f2, pos2, relpose1=None, relpose2=None):
|
| 285 |
-
final_output = [(f1, f2)] # before projection
|
| 286 |
-
|
| 287 |
-
# project to decoder dim
|
| 288 |
-
f1 = self.decoder_embed(f1)
|
| 289 |
-
f2 = self.decoder_embed(f2)
|
| 290 |
-
|
| 291 |
-
# add CLS token for the decoder
|
| 292 |
-
if self.dec_cls:
|
| 293 |
-
cls1 = self.cls_token1[None, None].expand(len(f1), 1, -1).clone()
|
| 294 |
-
cls2 = self.cls_token2[None, None].expand(len(f2), 1, -1).clone()
|
| 295 |
-
|
| 296 |
-
if relpose1 is not None: # shape = (B, 4, 4)
|
| 297 |
-
pose_emb1 = self.pose_embed(relpose1[:, :3].flatten(1)).unsqueeze(1)
|
| 298 |
-
if self.mode.startswith("embed"):
|
| 299 |
-
if self.dec_cls:
|
| 300 |
-
cls1 = cls1 + pose_emb1
|
| 301 |
-
else:
|
| 302 |
-
f1 = f1 + pose_emb1
|
| 303 |
-
else:
|
| 304 |
-
pose_emb1 = None
|
| 305 |
-
|
| 306 |
-
if relpose2 is not None: # shape = (B, 4, 4)
|
| 307 |
-
pose_emb2 = self.pose_embed(relpose2[:, :3].flatten(1)).unsqueeze(1)
|
| 308 |
-
if self.mode.startswith("embed"):
|
| 309 |
-
if self.dec_cls:
|
| 310 |
-
cls2 = cls2 + pose_emb2
|
| 311 |
-
else:
|
| 312 |
-
f2 = f2 + pose_emb2
|
| 313 |
-
else:
|
| 314 |
-
pose_emb2 = None
|
| 315 |
-
|
| 316 |
-
if self.dec_cls:
|
| 317 |
-
f1, pos1 = cat_cls(cls1, f1, pos1)
|
| 318 |
-
f2, pos2 = cat_cls(cls2, f2, pos2)
|
| 319 |
-
|
| 320 |
-
f1 = self.dec1_pre_ln(f1)
|
| 321 |
-
f2 = self.dec2_pre_ln(f2)
|
| 322 |
-
|
| 323 |
-
final_output.append((f1, f2)) # to be removed later
|
| 324 |
-
for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):
|
| 325 |
-
# img1 side
|
| 326 |
-
f1, _ = blk1(
|
| 327 |
-
*final_output[-1][::+1],
|
| 328 |
-
pos1,
|
| 329 |
-
pos2,
|
| 330 |
-
relpose=pose_emb1,
|
| 331 |
-
num_cls=self.dec_num_cls,
|
| 332 |
-
)
|
| 333 |
-
# img2 side
|
| 334 |
-
f2, _ = blk2(
|
| 335 |
-
*final_output[-1][::-1],
|
| 336 |
-
pos2,
|
| 337 |
-
pos1,
|
| 338 |
-
relpose=pose_emb2,
|
| 339 |
-
num_cls=self.dec_num_cls,
|
| 340 |
-
)
|
| 341 |
-
# store the result
|
| 342 |
-
final_output.append((f1, f2))
|
| 343 |
-
|
| 344 |
-
del final_output[1] # duplicate with final_output[0] (after decoder proj)
|
| 345 |
-
if self.dec_cls: # remove cls token for decoder layers
|
| 346 |
-
final_output[1:] = [(f1[:, 1:], f2[:, 1:]) for f1, f2 in final_output[1:]]
|
| 347 |
-
# normalize last output
|
| 348 |
-
final_output[-1] = tuple(map(self.dec_norm, final_output[-1]))
|
| 349 |
-
return zip(*final_output)
|
| 350 |
-
|
| 351 |
-
def _downstream_head(self, head_num, decout, img_shape):
|
| 352 |
-
B, S, D = decout[-1].shape
|
| 353 |
-
head = getattr(self, f"head{head_num}")
|
| 354 |
-
return head(decout, img_shape)
|
| 355 |
-
|
| 356 |
-
def forward(self, view1, view2):
|
| 357 |
-
# encode the two images --> B,S,D
|
| 358 |
-
(shape1, shape2), (feat1, feat2), (pos1, pos2) = self.encode_symmetrized(
|
| 359 |
-
view1, view2
|
| 360 |
-
)
|
| 361 |
-
|
| 362 |
-
# combine all ref images into object-centric representation
|
| 363 |
-
dec1, dec2 = self._decoder(
|
| 364 |
-
feat1,
|
| 365 |
-
pos1,
|
| 366 |
-
feat2,
|
| 367 |
-
pos2,
|
| 368 |
-
relpose1=view1.get("known_pose"),
|
| 369 |
-
relpose2=view2.get("known_pose"),
|
| 370 |
-
)
|
| 371 |
-
with torch.autocast("cuda", enabled=False):
|
| 372 |
-
res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1)
|
| 373 |
-
res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2)
|
| 374 |
-
|
| 375 |
-
res2["pts3d_in_other_view"] = res2.pop(
|
| 376 |
-
"pts3d"
|
| 377 |
-
) # predict view2's pts3d in view1's frame
|
| 378 |
-
return res1, res2
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
def convert_release_dust3r_args(args):
|
| 382 |
-
args.model = (
|
| 383 |
-
args.model.replace("patch_embed_cls", "patch_embed")
|
| 384 |
-
.replace("AsymmetricMASt3R", "AsymmetricCroCo3DStereo")
|
| 385 |
-
.replace("PatchEmbedDust3R", "convManyAR")
|
| 386 |
-
.replace(
|
| 387 |
-
"pos_embed='RoPE100'",
|
| 388 |
-
"enc_pos_embed='cuRoPE100', dec_pos_embed='cuRoPE100'",
|
| 389 |
-
)
|
| 390 |
-
)
|
| 391 |
-
return args
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
def _load_model(model_path, device):
|
| 395 |
-
print("... loading model from", model_path)
|
| 396 |
-
ckpt = torch.load(model_path, map_location="cpu")
|
| 397 |
-
try:
|
| 398 |
-
net = eval(
|
| 399 |
-
ckpt["args"].model[:-1].replace("convManyAR", "convP")
|
| 400 |
-
+ ", landscape_only=False)"
|
| 401 |
-
)
|
| 402 |
-
except Exception:
|
| 403 |
-
args = convert_release_dust3r_args(ckpt["args"])
|
| 404 |
-
net = eval(
|
| 405 |
-
args.model[:-1].replace("convManyAR", "convP") + ", landscape_only=False)"
|
| 406 |
-
)
|
| 407 |
-
ckpt["model"] = {
|
| 408 |
-
k.replace("_downstream_head", "downstream_head"): v
|
| 409 |
-
for k, v in ckpt["model"].items()
|
| 410 |
-
}
|
| 411 |
-
print(net.load_state_dict(ckpt["model"], strict=False))
|
| 412 |
-
return net.to(device)
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
def cat_cls(cls, tokens, pos):
|
| 416 |
-
tokens = torch.cat((cls, tokens), dim=1)
|
| 417 |
-
pos = torch.cat((-pos.new_ones(len(cls), 1, 2), pos), dim=1)
|
| 418 |
-
return tokens, pos
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
class Pow3RWrapper(torch.nn.Module):
|
| 422 |
-
def __init__(
|
| 423 |
-
self,
|
| 424 |
-
name,
|
| 425 |
-
ckpt_path,
|
| 426 |
-
geometric_input_config,
|
| 427 |
-
**kwargs,
|
| 428 |
-
):
|
| 429 |
-
super().__init__()
|
| 430 |
-
self.name = name
|
| 431 |
-
self.ckpt_path = ckpt_path
|
| 432 |
-
self.geometric_input_config = geometric_input_config
|
| 433 |
-
|
| 434 |
-
# Init the model and load the checkpoint
|
| 435 |
-
print(f"Loading checkpoint from {self.ckpt_path} ...")
|
| 436 |
-
ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False)
|
| 437 |
-
model = ckpt["definition"]
|
| 438 |
-
print(f"Creating model = {model}")
|
| 439 |
-
self.model = eval(model)
|
| 440 |
-
print(self.model.load_state_dict(ckpt["weights"]))
|
| 441 |
-
|
| 442 |
-
def forward(self, views):
|
| 443 |
-
"""
|
| 444 |
-
Forward pass wrapper for Pow3R.
|
| 445 |
-
|
| 446 |
-
Assumption:
|
| 447 |
-
- The number of input views is 2.
|
| 448 |
-
|
| 449 |
-
Args:
|
| 450 |
-
views (List[dict]): List of dictionaries containing the input views' images and instance information.
|
| 451 |
-
Length of the list should be 2.
|
| 452 |
-
Each dictionary should contain the following keys:
|
| 453 |
-
"img" (tensor): Image tensor of shape (B, C, H, W).
|
| 454 |
-
"data_norm_type" (list): ["dust3r"]
|
| 455 |
-
Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs:
|
| 456 |
-
"camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3).
|
| 457 |
-
"camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation.
|
| 458 |
-
"depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1).
|
| 459 |
-
|
| 460 |
-
Returns:
|
| 461 |
-
List[dict]: A list containing the final outputs for the two views. Length of the list will be 2.
|
| 462 |
-
"""
|
| 463 |
-
# Check that the number of input views is 2
|
| 464 |
-
assert len(views) == 2, "Pow3R requires 2 input views."
|
| 465 |
-
|
| 466 |
-
# Check the data norm type
|
| 467 |
-
data_norm_type = views[0]["data_norm_type"][0]
|
| 468 |
-
assert data_norm_type == "dust3r", (
|
| 469 |
-
"Pow3R expects a normalized image with the DUSt3R normalization scheme applied"
|
| 470 |
-
)
|
| 471 |
-
|
| 472 |
-
# Get the batch size per view, device and two views
|
| 473 |
-
batch_size_per_view = views[0]["img"].shape[0]
|
| 474 |
-
device = views[0]["img"].device
|
| 475 |
-
view1, view2 = views
|
| 476 |
-
|
| 477 |
-
# Decide if we need to use the geometric inputs
|
| 478 |
-
if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]:
|
| 479 |
-
# Decide if we need to use the camera intrinsics
|
| 480 |
-
if (
|
| 481 |
-
torch.rand(1, device=device)
|
| 482 |
-
< self.geometric_input_config["ray_dirs_prob"]
|
| 483 |
-
):
|
| 484 |
-
add_intrinsics(view1, view1.get("camera_intrinsics"))
|
| 485 |
-
add_intrinsics(view2, view2.get("camera_intrinsics"))
|
| 486 |
-
|
| 487 |
-
# Decide if we need to use the depth map
|
| 488 |
-
if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]:
|
| 489 |
-
depthmap1 = view1.get("depthmap")
|
| 490 |
-
depthmap2 = view2.get("depthmap")
|
| 491 |
-
if depthmap1 is not None:
|
| 492 |
-
depthmap1 = depthmap1.squeeze(-1).to(device)
|
| 493 |
-
if depthmap2 is not None:
|
| 494 |
-
depthmap2 = depthmap2.squeeze(-1).to(device)
|
| 495 |
-
add_depth(view1, depthmap1)
|
| 496 |
-
add_depth(view2, depthmap2)
|
| 497 |
-
|
| 498 |
-
# Decide if we need to use the camera pose
|
| 499 |
-
if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]:
|
| 500 |
-
cam1 = view1.get("camera_pose")
|
| 501 |
-
cam2 = view2.get("camera_pose")
|
| 502 |
-
add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1)
|
| 503 |
-
add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1)
|
| 504 |
-
|
| 505 |
-
# Get the model predictions
|
| 506 |
-
preds = self.model(view1, view2)
|
| 507 |
-
|
| 508 |
-
# Convert the output to MapAnything format
|
| 509 |
-
with torch.autocast("cuda", enabled=False):
|
| 510 |
-
res = []
|
| 511 |
-
for view_idx in range(2):
|
| 512 |
-
# Get the model predictions for the current view
|
| 513 |
-
curr_view_pred = preds[view_idx]
|
| 514 |
-
|
| 515 |
-
# For the first view
|
| 516 |
-
if view_idx == 0:
|
| 517 |
-
# Get the global frame and camera frame pointmaps
|
| 518 |
-
global_pts = curr_view_pred["pts3d"]
|
| 519 |
-
cam_pts = curr_view_pred["pts3d"]
|
| 520 |
-
conf = curr_view_pred["conf"]
|
| 521 |
-
|
| 522 |
-
# Get the ray directions and depth along ray
|
| 523 |
-
depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True)
|
| 524 |
-
ray_directions = cam_pts / depth_along_ray
|
| 525 |
-
|
| 526 |
-
# Initalize identity camera pose
|
| 527 |
-
cam_rot = torch.eye(3, device=device)
|
| 528 |
-
cam_quat = mat_to_quat(cam_rot)
|
| 529 |
-
cam_trans = torch.zeros(3, device=device)
|
| 530 |
-
cam_quat = cam_quat.unsqueeze(0).repeat(batch_size_per_view, 1)
|
| 531 |
-
cam_trans = cam_trans.unsqueeze(0).repeat(batch_size_per_view, 1)
|
| 532 |
-
# For the second view
|
| 533 |
-
elif view_idx == 1:
|
| 534 |
-
# Get the global frame and camera frame pointmaps
|
| 535 |
-
pred_global_pts = curr_view_pred["pts3d_in_other_view"]
|
| 536 |
-
cam_pts = curr_view_pred["pts3d2"]
|
| 537 |
-
conf = (curr_view_pred["conf"] * curr_view_pred["conf2"]).sqrt()
|
| 538 |
-
|
| 539 |
-
# Get the ray directions and depth along ray
|
| 540 |
-
depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True)
|
| 541 |
-
ray_directions = cam_pts / depth_along_ray
|
| 542 |
-
|
| 543 |
-
# Compute the camera pose using the pointmaps
|
| 544 |
-
cam_rot, cam_trans, scale = roma.rigid_points_registration(
|
| 545 |
-
cam_pts.reshape(batch_size_per_view, -1, 3),
|
| 546 |
-
pred_global_pts.reshape(batch_size_per_view, -1, 3),
|
| 547 |
-
weights=conf.reshape(batch_size_per_view, -1),
|
| 548 |
-
compute_scaling=True,
|
| 549 |
-
)
|
| 550 |
-
cam_quat = mat_to_quat(cam_rot)
|
| 551 |
-
|
| 552 |
-
# Scale the predicted camera frame pointmap and compute the new global frame pointmap
|
| 553 |
-
cam_pts = scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * cam_pts
|
| 554 |
-
global_pts = cam_pts.reshape(
|
| 555 |
-
batch_size_per_view, -1, 3
|
| 556 |
-
) @ cam_rot.permute(0, 2, 1) + cam_trans.unsqueeze(1)
|
| 557 |
-
global_pts = global_pts.view(pred_global_pts.shape)
|
| 558 |
-
|
| 559 |
-
# Append the result in MapAnything format
|
| 560 |
-
res.append(
|
| 561 |
-
{
|
| 562 |
-
"pts3d": global_pts,
|
| 563 |
-
"pts3d_cam": cam_pts,
|
| 564 |
-
"ray_directions": ray_directions,
|
| 565 |
-
"depth_along_ray": depth_along_ray,
|
| 566 |
-
"cam_trans": cam_trans,
|
| 567 |
-
"cam_quats": cam_quat,
|
| 568 |
-
"conf": conf,
|
| 569 |
-
}
|
| 570 |
-
)
|
| 571 |
-
|
| 572 |
-
return res
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
class Pow3RBAWrapper(torch.nn.Module):
|
| 576 |
-
def __init__(
|
| 577 |
-
self,
|
| 578 |
-
name,
|
| 579 |
-
ckpt_path,
|
| 580 |
-
geometric_input_config,
|
| 581 |
-
scene_graph="complete",
|
| 582 |
-
inference_batch_size=32,
|
| 583 |
-
global_optim_schedule="cosine",
|
| 584 |
-
global_optim_lr=0.01,
|
| 585 |
-
global_optim_niter=300,
|
| 586 |
-
**kwargs,
|
| 587 |
-
):
|
| 588 |
-
super().__init__()
|
| 589 |
-
self.name = name
|
| 590 |
-
self.ckpt_path = ckpt_path
|
| 591 |
-
self.geometric_input_config = geometric_input_config
|
| 592 |
-
self.scene_graph = scene_graph
|
| 593 |
-
self.inference_batch_size = inference_batch_size
|
| 594 |
-
self.global_optim_schedule = global_optim_schedule
|
| 595 |
-
self.global_optim_lr = global_optim_lr
|
| 596 |
-
self.global_optim_niter = global_optim_niter
|
| 597 |
-
|
| 598 |
-
# Init the model and load the checkpoint
|
| 599 |
-
print(f"Loading checkpoint from {self.ckpt_path} ...")
|
| 600 |
-
ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False)
|
| 601 |
-
model = ckpt["definition"]
|
| 602 |
-
print(f"Creating model = {model}")
|
| 603 |
-
self.model = eval(model)
|
| 604 |
-
print(self.model.load_state_dict(ckpt["weights"]))
|
| 605 |
-
|
| 606 |
-
# Init the global aligner mode
|
| 607 |
-
self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer
|
| 608 |
-
|
| 609 |
-
def infer_two_views(self, views):
|
| 610 |
-
"""
|
| 611 |
-
Wrapper for Pow3R 2-View inference.
|
| 612 |
-
|
| 613 |
-
Assumption:
|
| 614 |
-
- The number of input views is 2.
|
| 615 |
-
|
| 616 |
-
Args:
|
| 617 |
-
views (List[dict]): List of dictionaries containing the input views' images and instance information.
|
| 618 |
-
Length of the list should be 2.
|
| 619 |
-
Each dictionary should contain the following keys:
|
| 620 |
-
"img" (tensor): Image tensor of shape (B, C, H, W).
|
| 621 |
-
"data_norm_type" (list): ["dust3r"]
|
| 622 |
-
Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs:
|
| 623 |
-
"camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3).
|
| 624 |
-
"camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation.
|
| 625 |
-
"depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1).
|
| 626 |
-
|
| 627 |
-
Returns:
|
| 628 |
-
List[dict]: A list containing the final outputs for the two views. Length of the list will be 2.
|
| 629 |
-
"""
|
| 630 |
-
# Check that the number of input views is 2
|
| 631 |
-
assert len(views) == 2, "Pow3R requires 2 input views."
|
| 632 |
-
|
| 633 |
-
# Check the data norm type
|
| 634 |
-
data_norm_type = views[0]["data_norm_type"][0]
|
| 635 |
-
assert data_norm_type == "dust3r", (
|
| 636 |
-
"Pow3R expects a normalized image with the DUSt3R normalization scheme applied"
|
| 637 |
-
)
|
| 638 |
-
|
| 639 |
-
# Get the device and two views
|
| 640 |
-
device = views[0]["img"].device
|
| 641 |
-
view1, view2 = views
|
| 642 |
-
|
| 643 |
-
# Decide if we need to use the geometric inputs
|
| 644 |
-
if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]:
|
| 645 |
-
# Decide if we need to use the camera intrinsics
|
| 646 |
-
if (
|
| 647 |
-
torch.rand(1, device=device)
|
| 648 |
-
< self.geometric_input_config["ray_dirs_prob"]
|
| 649 |
-
):
|
| 650 |
-
add_intrinsics(view1, view1.get("camera_intrinsics"))
|
| 651 |
-
add_intrinsics(view2, view2.get("camera_intrinsics"))
|
| 652 |
-
|
| 653 |
-
# Decide if we need to use the depth map
|
| 654 |
-
if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]:
|
| 655 |
-
depthmap1 = view1.get("depthmap")
|
| 656 |
-
depthmap2 = view2.get("depthmap")
|
| 657 |
-
if depthmap1 is not None:
|
| 658 |
-
depthmap1 = depthmap1.squeeze(-1).to(device)
|
| 659 |
-
if depthmap2 is not None:
|
| 660 |
-
depthmap2 = depthmap2.squeeze(-1).to(device)
|
| 661 |
-
add_depth(view1, depthmap1)
|
| 662 |
-
add_depth(view2, depthmap2)
|
| 663 |
-
|
| 664 |
-
# Decide if we need to use the camera pose
|
| 665 |
-
if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]:
|
| 666 |
-
cam1 = view1.get("camera_pose")
|
| 667 |
-
cam2 = view2.get("camera_pose")
|
| 668 |
-
add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1)
|
| 669 |
-
add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1)
|
| 670 |
-
|
| 671 |
-
# Get the model predictions
|
| 672 |
-
preds = self.model(view1, view2)
|
| 673 |
-
|
| 674 |
-
return preds
|
| 675 |
-
|
| 676 |
-
def loss_of_one_batch(self, batch, device):
|
| 677 |
-
"""
|
| 678 |
-
Compute prediction for two views.
|
| 679 |
-
"""
|
| 680 |
-
view1, view2 = batch
|
| 681 |
-
ignore_keys = set(
|
| 682 |
-
[
|
| 683 |
-
"dataset",
|
| 684 |
-
"label",
|
| 685 |
-
"instance",
|
| 686 |
-
"idx",
|
| 687 |
-
"true_shape",
|
| 688 |
-
"rng",
|
| 689 |
-
"name",
|
| 690 |
-
"data_norm_type",
|
| 691 |
-
]
|
| 692 |
-
)
|
| 693 |
-
for view in batch:
|
| 694 |
-
for name in view.keys(): # pseudo_focal
|
| 695 |
-
if name in ignore_keys:
|
| 696 |
-
continue
|
| 697 |
-
view[name] = view[name].to(device, non_blocking=True)
|
| 698 |
-
|
| 699 |
-
pred1, pred2 = self.infer_two_views([view1, view2])
|
| 700 |
-
|
| 701 |
-
result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2)
|
| 702 |
-
|
| 703 |
-
return result
|
| 704 |
-
|
| 705 |
-
@torch.no_grad()
|
| 706 |
-
def inference(self, pairs, device, verbose=False):
|
| 707 |
-
"""
|
| 708 |
-
Wrapper for multi-pair inference using Pow3R.
|
| 709 |
-
"""
|
| 710 |
-
if verbose:
|
| 711 |
-
print(f">> Inference with model on {len(pairs)} image pairs")
|
| 712 |
-
result = []
|
| 713 |
-
|
| 714 |
-
multiple_shapes = not (check_if_same_size(pairs))
|
| 715 |
-
if multiple_shapes:
|
| 716 |
-
self.inference_batch_size = 1
|
| 717 |
-
|
| 718 |
-
for i in tqdm.trange(
|
| 719 |
-
0, len(pairs), self.inference_batch_size, disable=not verbose
|
| 720 |
-
):
|
| 721 |
-
res = self.loss_of_one_batch(
|
| 722 |
-
collate_with_cat(pairs[i : i + self.inference_batch_size]), device
|
| 723 |
-
)
|
| 724 |
-
result.append(to_cpu(res))
|
| 725 |
-
|
| 726 |
-
result = collate_with_cat(result, lists=multiple_shapes)
|
| 727 |
-
|
| 728 |
-
return result
|
| 729 |
-
|
| 730 |
-
def forward(self, views):
|
| 731 |
-
"""
|
| 732 |
-
Forward pass wrapper for Pow3R using the global aligner.
|
| 733 |
-
|
| 734 |
-
Assumption:
|
| 735 |
-
- The batch size of input views is 1.
|
| 736 |
-
|
| 737 |
-
Args:
|
| 738 |
-
views (List[dict]): List of dictionaries containing the input views' images and instance information.
|
| 739 |
-
Each dictionary should contain the following keys, where B is the batch size and is 1:
|
| 740 |
-
"img" (tensor): Image tensor of shape (B, C, H, W).
|
| 741 |
-
"data_norm_type" (list): ["dust3r"]
|
| 742 |
-
|
| 743 |
-
Returns:
|
| 744 |
-
List[dict]: A list containing the final outputs for the input views.
|
| 745 |
-
"""
|
| 746 |
-
# Check the batch size of input views
|
| 747 |
-
batch_size_per_view, _, height, width = views[0]["img"].shape
|
| 748 |
-
device = views[0]["img"].device
|
| 749 |
-
num_views = len(views)
|
| 750 |
-
assert batch_size_per_view == 1, (
|
| 751 |
-
f"Batch size of input views should be 1, but got {batch_size_per_view}."
|
| 752 |
-
)
|
| 753 |
-
|
| 754 |
-
# Check the data norm type
|
| 755 |
-
data_norm_type = views[0]["data_norm_type"][0]
|
| 756 |
-
assert data_norm_type == "dust3r", (
|
| 757 |
-
"Pow3R-BA expects a normalized image with the DUSt3R normalization scheme applied"
|
| 758 |
-
)
|
| 759 |
-
|
| 760 |
-
# Convert the input views to the expected input format
|
| 761 |
-
images = []
|
| 762 |
-
for view in views:
|
| 763 |
-
images.append(
|
| 764 |
-
dict(
|
| 765 |
-
img=view["img"],
|
| 766 |
-
camera_intrinsics=view["camera_intrinsics"],
|
| 767 |
-
depthmap=view["depthmap"],
|
| 768 |
-
camera_pose=view["camera_pose"],
|
| 769 |
-
data_norm_type=view["data_norm_type"],
|
| 770 |
-
true_shape=view["true_shape"],
|
| 771 |
-
idx=len(images),
|
| 772 |
-
instance=str(len(images)),
|
| 773 |
-
)
|
| 774 |
-
)
|
| 775 |
-
|
| 776 |
-
# Make image pairs and run inference pair-wise
|
| 777 |
-
pairs = make_pairs(
|
| 778 |
-
images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True
|
| 779 |
-
)
|
| 780 |
-
with warnings.catch_warnings():
|
| 781 |
-
warnings.simplefilter("ignore", category=FutureWarning)
|
| 782 |
-
output = self.inference(
|
| 783 |
-
pairs,
|
| 784 |
-
device,
|
| 785 |
-
verbose=False,
|
| 786 |
-
)
|
| 787 |
-
|
| 788 |
-
# Global optimization
|
| 789 |
-
with torch.enable_grad():
|
| 790 |
-
scene = global_aligner(
|
| 791 |
-
output, device=device, mode=self.global_aligner_mode, verbose=False
|
| 792 |
-
)
|
| 793 |
-
_ = scene.compute_global_alignment(
|
| 794 |
-
init="mst",
|
| 795 |
-
niter=self.global_optim_niter,
|
| 796 |
-
schedule=self.global_optim_schedule,
|
| 797 |
-
lr=self.global_optim_lr,
|
| 798 |
-
)
|
| 799 |
-
|
| 800 |
-
# Make sure scene is not None
|
| 801 |
-
if scene is None:
|
| 802 |
-
raise RuntimeError("Global optimization failed.")
|
| 803 |
-
|
| 804 |
-
# Get the predictions
|
| 805 |
-
intrinsics = scene.get_intrinsics()
|
| 806 |
-
c2w_poses = scene.get_im_poses()
|
| 807 |
-
depths = scene.get_depthmaps()
|
| 808 |
-
|
| 809 |
-
# Convert the output to the MapAnything format
|
| 810 |
-
with torch.autocast("cuda", enabled=False):
|
| 811 |
-
res = []
|
| 812 |
-
for view_idx in range(num_views):
|
| 813 |
-
# Get the current view predictions
|
| 814 |
-
curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0)
|
| 815 |
-
curr_view_pose = c2w_poses[view_idx].unsqueeze(0)
|
| 816 |
-
curr_view_depth_z = depths[view_idx].unsqueeze(0)
|
| 817 |
-
|
| 818 |
-
# Convert the pose to quaternions and translation
|
| 819 |
-
curr_view_cam_translations = curr_view_pose[..., :3, 3]
|
| 820 |
-
curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3])
|
| 821 |
-
|
| 822 |
-
# Get the camera frame pointmaps
|
| 823 |
-
curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
|
| 824 |
-
curr_view_depth_z, curr_view_intrinsic
|
| 825 |
-
)
|
| 826 |
-
|
| 827 |
-
# Convert the z depth to depth along ray
|
| 828 |
-
curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
|
| 829 |
-
curr_view_depth_z, curr_view_intrinsic
|
| 830 |
-
)
|
| 831 |
-
curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
|
| 832 |
-
|
| 833 |
-
# Get the ray directions on the unit sphere in the camera frame
|
| 834 |
-
_, curr_view_ray_dirs = get_rays_in_camera_frame(
|
| 835 |
-
curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
|
| 836 |
-
)
|
| 837 |
-
|
| 838 |
-
# Get the pointmaps
|
| 839 |
-
curr_view_pts3d = (
|
| 840 |
-
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
|
| 841 |
-
curr_view_ray_dirs,
|
| 842 |
-
curr_view_depth_along_ray,
|
| 843 |
-
curr_view_cam_translations,
|
| 844 |
-
curr_view_cam_quats,
|
| 845 |
-
)
|
| 846 |
-
)
|
| 847 |
-
|
| 848 |
-
# Append the outputs to the result list
|
| 849 |
-
res.append(
|
| 850 |
-
{
|
| 851 |
-
"pts3d": curr_view_pts3d,
|
| 852 |
-
"pts3d_cam": curr_view_pts3d_cam,
|
| 853 |
-
"ray_directions": curr_view_ray_dirs,
|
| 854 |
-
"depth_along_ray": curr_view_depth_along_ray,
|
| 855 |
-
"cam_trans": curr_view_cam_translations,
|
| 856 |
-
"cam_quats": curr_view_cam_quats,
|
| 857 |
-
}
|
| 858 |
-
)
|
| 859 |
-
|
| 860 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/vggt/__init__.py
DELETED
|
@@ -1,186 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Inference wrapper for VGGT
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
from mapanything.models.external.vggt.models.vggt import VGGT
|
| 8 |
-
from mapanything.models.external.vggt.utils.geometry import closed_form_inverse_se3
|
| 9 |
-
from mapanything.models.external.vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
| 10 |
-
from mapanything.models.external.vggt.utils.rotation import mat_to_quat
|
| 11 |
-
from mapanything.utils.geometry import (
|
| 12 |
-
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
|
| 13 |
-
convert_z_depth_to_depth_along_ray,
|
| 14 |
-
depthmap_to_camera_frame,
|
| 15 |
-
get_rays_in_camera_frame,
|
| 16 |
-
)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class VGGTWrapper(torch.nn.Module):
|
| 20 |
-
def __init__(
|
| 21 |
-
self,
|
| 22 |
-
name,
|
| 23 |
-
torch_hub_force_reload,
|
| 24 |
-
load_pretrained_weights=True,
|
| 25 |
-
depth=24,
|
| 26 |
-
num_heads=16,
|
| 27 |
-
intermediate_layer_idx=[4, 11, 17, 23],
|
| 28 |
-
load_custom_ckpt=False,
|
| 29 |
-
custom_ckpt_path=None,
|
| 30 |
-
):
|
| 31 |
-
super().__init__()
|
| 32 |
-
self.name = name
|
| 33 |
-
self.torch_hub_force_reload = torch_hub_force_reload
|
| 34 |
-
self.load_custom_ckpt = load_custom_ckpt
|
| 35 |
-
self.custom_ckpt_path = custom_ckpt_path
|
| 36 |
-
|
| 37 |
-
if load_pretrained_weights:
|
| 38 |
-
# Load pre-trained weights
|
| 39 |
-
if not torch_hub_force_reload:
|
| 40 |
-
# Initialize the 1B VGGT model from huggingface hub cache
|
| 41 |
-
print("Loading facebook/VGGT-1B from huggingface cache ...")
|
| 42 |
-
self.model = VGGT.from_pretrained(
|
| 43 |
-
"facebook/VGGT-1B",
|
| 44 |
-
)
|
| 45 |
-
else:
|
| 46 |
-
# Initialize the 1B VGGT model
|
| 47 |
-
print("Re-downloading facebook/VGGT-1B ...")
|
| 48 |
-
self.model = VGGT.from_pretrained(
|
| 49 |
-
"facebook/VGGT-1B", force_download=True
|
| 50 |
-
)
|
| 51 |
-
else:
|
| 52 |
-
# Load the VGGT class
|
| 53 |
-
self.model = VGGT(
|
| 54 |
-
depth=depth,
|
| 55 |
-
num_heads=num_heads,
|
| 56 |
-
intermediate_layer_idx=intermediate_layer_idx,
|
| 57 |
-
)
|
| 58 |
-
|
| 59 |
-
# Get the dtype for VGGT inference
|
| 60 |
-
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
|
| 61 |
-
self.dtype = (
|
| 62 |
-
torch.bfloat16
|
| 63 |
-
if torch.cuda.get_device_capability()[0] >= 8
|
| 64 |
-
else torch.float16
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
# Load custom checkpoint if requested
|
| 68 |
-
if self.load_custom_ckpt:
|
| 69 |
-
print(f"Loading checkpoint from {self.custom_ckpt_path} ...")
|
| 70 |
-
assert self.custom_ckpt_path is not None, (
|
| 71 |
-
"custom_ckpt_path must be provided if load_custom_ckpt is set to True"
|
| 72 |
-
)
|
| 73 |
-
custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False)
|
| 74 |
-
print(self.model.load_state_dict(custom_ckpt, strict=True))
|
| 75 |
-
del custom_ckpt # in case it occupies memory
|
| 76 |
-
|
| 77 |
-
def forward(self, views):
|
| 78 |
-
"""
|
| 79 |
-
Forward pass wrapper for VGGT
|
| 80 |
-
|
| 81 |
-
Assumption:
|
| 82 |
-
- All the input views have the same image shape.
|
| 83 |
-
|
| 84 |
-
Args:
|
| 85 |
-
views (List[dict]): List of dictionaries containing the input views' images and instance information.
|
| 86 |
-
Each dictionary should contain the following keys:
|
| 87 |
-
"img" (tensor): Image tensor of shape (B, C, H, W).
|
| 88 |
-
"data_norm_type" (list): ["identity"]
|
| 89 |
-
|
| 90 |
-
Returns:
|
| 91 |
-
List[dict]: A list containing the final outputs for all N views.
|
| 92 |
-
"""
|
| 93 |
-
# Get input shape of the images, number of views, and batch size per view
|
| 94 |
-
batch_size_per_view, _, height, width = views[0]["img"].shape
|
| 95 |
-
num_views = len(views)
|
| 96 |
-
|
| 97 |
-
# Check the data norm type
|
| 98 |
-
# VGGT expects a normalized image but without the DINOv2 mean and std applied ("identity")
|
| 99 |
-
data_norm_type = views[0]["data_norm_type"][0]
|
| 100 |
-
assert data_norm_type == "identity", (
|
| 101 |
-
"VGGT expects a normalized image but without the DINOv2 mean and std applied"
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
# Concatenate the images to create a single (B, V, C, H, W) tensor
|
| 105 |
-
img_list = [view["img"] for view in views]
|
| 106 |
-
images = torch.stack(img_list, dim=1)
|
| 107 |
-
|
| 108 |
-
# Run the VGGT aggregator
|
| 109 |
-
with torch.autocast("cuda", dtype=self.dtype):
|
| 110 |
-
aggregated_tokens_list, ps_idx = self.model.aggregator(images)
|
| 111 |
-
|
| 112 |
-
# Run the Camera + Pose Branch of VGGT
|
| 113 |
-
with torch.autocast("cuda", enabled=False):
|
| 114 |
-
# Predict Cameras
|
| 115 |
-
pose_enc = self.model.camera_head(aggregated_tokens_list)[-1]
|
| 116 |
-
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
|
| 117 |
-
# Extrinsics Shape: (B, V, 3, 4)
|
| 118 |
-
# Intrinsics Shape: (B, V, 3, 3)
|
| 119 |
-
extrinsic, intrinsic = pose_encoding_to_extri_intri(
|
| 120 |
-
pose_enc, images.shape[-2:]
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
# Predict Depth Maps
|
| 124 |
-
# Depth Shape: (B, V, H, W, 1)
|
| 125 |
-
# Depth Confidence Shape: (B, V, H, W)
|
| 126 |
-
depth_map, depth_conf = self.model.depth_head(
|
| 127 |
-
aggregated_tokens_list, images, ps_idx
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
# Convert the output to MapAnything format
|
| 131 |
-
res = []
|
| 132 |
-
for view_idx in range(num_views):
|
| 133 |
-
# Get the extrinsics, intrinsics, depth map for the current view
|
| 134 |
-
curr_view_extrinsic = extrinsic[:, view_idx, ...]
|
| 135 |
-
curr_view_extrinsic = closed_form_inverse_se3(
|
| 136 |
-
curr_view_extrinsic
|
| 137 |
-
) # Convert to cam2world
|
| 138 |
-
curr_view_intrinsic = intrinsic[:, view_idx, ...]
|
| 139 |
-
curr_view_depth_z = depth_map[:, view_idx, ...]
|
| 140 |
-
curr_view_depth_z = curr_view_depth_z.squeeze(-1)
|
| 141 |
-
curr_view_confidence = depth_conf[:, view_idx, ...]
|
| 142 |
-
|
| 143 |
-
# Get the camera frame pointmaps
|
| 144 |
-
curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
|
| 145 |
-
curr_view_depth_z, curr_view_intrinsic
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
# Convert the extrinsics to quaternions and translations
|
| 149 |
-
curr_view_cam_translations = curr_view_extrinsic[..., :3, 3]
|
| 150 |
-
curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3])
|
| 151 |
-
|
| 152 |
-
# Convert the z depth to depth along ray
|
| 153 |
-
curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
|
| 154 |
-
curr_view_depth_z, curr_view_intrinsic
|
| 155 |
-
)
|
| 156 |
-
curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
|
| 157 |
-
|
| 158 |
-
# Get the ray directions on the unit sphere in the camera frame
|
| 159 |
-
_, curr_view_ray_dirs = get_rays_in_camera_frame(
|
| 160 |
-
curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
# Get the pointmaps
|
| 164 |
-
curr_view_pts3d = (
|
| 165 |
-
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
|
| 166 |
-
curr_view_ray_dirs,
|
| 167 |
-
curr_view_depth_along_ray,
|
| 168 |
-
curr_view_cam_translations,
|
| 169 |
-
curr_view_cam_quats,
|
| 170 |
-
)
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
# Append the outputs to the result list
|
| 174 |
-
res.append(
|
| 175 |
-
{
|
| 176 |
-
"pts3d": curr_view_pts3d,
|
| 177 |
-
"pts3d_cam": curr_view_pts3d_cam,
|
| 178 |
-
"ray_directions": curr_view_ray_dirs,
|
| 179 |
-
"depth_along_ray": curr_view_depth_along_ray,
|
| 180 |
-
"cam_trans": curr_view_cam_translations,
|
| 181 |
-
"cam_quats": curr_view_cam_quats,
|
| 182 |
-
"conf": curr_view_confidence,
|
| 183 |
-
}
|
| 184 |
-
)
|
| 185 |
-
|
| 186 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/vggt/heads/__init__.py
DELETED
|
File without changes
|
mapanything/models/external/vggt/heads/camera_head.py
DELETED
|
@@ -1,167 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
|
| 11 |
-
from mapanything.models.external.vggt.heads.head_act import activate_pose
|
| 12 |
-
from mapanything.models.external.vggt.layers import Mlp
|
| 13 |
-
from mapanything.models.external.vggt.layers.block import Block
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class CameraHead(nn.Module):
|
| 17 |
-
"""
|
| 18 |
-
CameraHead predicts camera parameters from token representations using iterative refinement.
|
| 19 |
-
|
| 20 |
-
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
def __init__(
|
| 24 |
-
self,
|
| 25 |
-
dim_in: int = 2048,
|
| 26 |
-
trunk_depth: int = 4,
|
| 27 |
-
pose_encoding_type: str = "absT_quaR_FoV",
|
| 28 |
-
num_heads: int = 16,
|
| 29 |
-
mlp_ratio: int = 4,
|
| 30 |
-
init_values: float = 0.01,
|
| 31 |
-
trans_act: str = "linear",
|
| 32 |
-
quat_act: str = "linear",
|
| 33 |
-
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
| 34 |
-
):
|
| 35 |
-
super().__init__()
|
| 36 |
-
|
| 37 |
-
if pose_encoding_type == "absT_quaR_FoV":
|
| 38 |
-
self.target_dim = 9
|
| 39 |
-
else:
|
| 40 |
-
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
| 41 |
-
|
| 42 |
-
self.trans_act = trans_act
|
| 43 |
-
self.quat_act = quat_act
|
| 44 |
-
self.fl_act = fl_act
|
| 45 |
-
self.trunk_depth = trunk_depth
|
| 46 |
-
|
| 47 |
-
# Build the trunk using a sequence of transformer blocks.
|
| 48 |
-
self.trunk = nn.Sequential(
|
| 49 |
-
*[
|
| 50 |
-
Block(
|
| 51 |
-
dim=dim_in,
|
| 52 |
-
num_heads=num_heads,
|
| 53 |
-
mlp_ratio=mlp_ratio,
|
| 54 |
-
init_values=init_values,
|
| 55 |
-
)
|
| 56 |
-
for _ in range(trunk_depth)
|
| 57 |
-
]
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
# Normalizations for camera token and trunk output.
|
| 61 |
-
self.token_norm = nn.LayerNorm(dim_in)
|
| 62 |
-
self.trunk_norm = nn.LayerNorm(dim_in)
|
| 63 |
-
|
| 64 |
-
# Learnable empty camera pose token.
|
| 65 |
-
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
| 66 |
-
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
| 67 |
-
|
| 68 |
-
# Module for producing modulation parameters: shift, scale, and a gate.
|
| 69 |
-
self.poseLN_modulation = nn.Sequential(
|
| 70 |
-
nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
# Adaptive layer normalization without affine parameters.
|
| 74 |
-
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
| 75 |
-
self.pose_branch = Mlp(
|
| 76 |
-
in_features=dim_in,
|
| 77 |
-
hidden_features=dim_in // 2,
|
| 78 |
-
out_features=self.target_dim,
|
| 79 |
-
drop=0,
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
|
| 83 |
-
"""
|
| 84 |
-
Forward pass to predict camera parameters.
|
| 85 |
-
|
| 86 |
-
Args:
|
| 87 |
-
aggregated_tokens_list (list): List of token tensors from the network;
|
| 88 |
-
the last tensor is used for prediction.
|
| 89 |
-
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
| 90 |
-
|
| 91 |
-
Returns:
|
| 92 |
-
list: A list of predicted camera encodings (post-activation) from each iteration.
|
| 93 |
-
"""
|
| 94 |
-
# Use tokens from the last block for camera prediction.
|
| 95 |
-
tokens = aggregated_tokens_list[-1]
|
| 96 |
-
|
| 97 |
-
# Extract the camera tokens
|
| 98 |
-
pose_tokens = tokens[:, :, 0]
|
| 99 |
-
pose_tokens = self.token_norm(pose_tokens)
|
| 100 |
-
|
| 101 |
-
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
| 102 |
-
return pred_pose_enc_list
|
| 103 |
-
|
| 104 |
-
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
| 105 |
-
"""
|
| 106 |
-
Iteratively refine camera pose predictions.
|
| 107 |
-
|
| 108 |
-
Args:
|
| 109 |
-
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
| 110 |
-
num_iterations (int): Number of refinement iterations.
|
| 111 |
-
|
| 112 |
-
Returns:
|
| 113 |
-
list: List of activated camera encodings from each iteration.
|
| 114 |
-
"""
|
| 115 |
-
B, S, C = pose_tokens.shape # S is expected to be 1.
|
| 116 |
-
pred_pose_enc = None
|
| 117 |
-
pred_pose_enc_list = []
|
| 118 |
-
|
| 119 |
-
for _ in range(num_iterations):
|
| 120 |
-
# Use a learned empty pose for the first iteration.
|
| 121 |
-
if pred_pose_enc is None:
|
| 122 |
-
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
| 123 |
-
else:
|
| 124 |
-
# Detach the previous prediction to avoid backprop through time.
|
| 125 |
-
pred_pose_enc = pred_pose_enc.detach()
|
| 126 |
-
module_input = self.embed_pose(pred_pose_enc)
|
| 127 |
-
|
| 128 |
-
# Generate modulation parameters and split them into shift, scale, and gate components.
|
| 129 |
-
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(
|
| 130 |
-
3, dim=-1
|
| 131 |
-
)
|
| 132 |
-
|
| 133 |
-
# Adaptive layer normalization and modulation.
|
| 134 |
-
pose_tokens_modulated = gate_msa * modulate(
|
| 135 |
-
self.adaln_norm(pose_tokens), shift_msa, scale_msa
|
| 136 |
-
)
|
| 137 |
-
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
| 138 |
-
|
| 139 |
-
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
|
| 140 |
-
# Compute the delta update for the pose encoding.
|
| 141 |
-
pred_pose_enc_delta = self.pose_branch(
|
| 142 |
-
self.trunk_norm(pose_tokens_modulated)
|
| 143 |
-
)
|
| 144 |
-
|
| 145 |
-
if pred_pose_enc is None:
|
| 146 |
-
pred_pose_enc = pred_pose_enc_delta
|
| 147 |
-
else:
|
| 148 |
-
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
| 149 |
-
|
| 150 |
-
# Apply final activation functions for translation, quaternion, and field-of-view.
|
| 151 |
-
activated_pose = activate_pose(
|
| 152 |
-
pred_pose_enc,
|
| 153 |
-
trans_act=self.trans_act,
|
| 154 |
-
quat_act=self.quat_act,
|
| 155 |
-
fl_act=self.fl_act,
|
| 156 |
-
)
|
| 157 |
-
pred_pose_enc_list.append(activated_pose)
|
| 158 |
-
|
| 159 |
-
return pred_pose_enc_list
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 163 |
-
"""
|
| 164 |
-
Modulate the input tensor using scaling and shifting parameters.
|
| 165 |
-
"""
|
| 166 |
-
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 167 |
-
return x * (1 + scale) + shift
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/vggt/heads/dpt_head.py
DELETED
|
@@ -1,600 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
from typing import List, Tuple, Union
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
import torch.nn as nn
|
| 15 |
-
|
| 16 |
-
from .head_act import activate_head
|
| 17 |
-
from .utils import create_uv_grid, position_grid_to_embed
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class DPTHead(nn.Module):
|
| 21 |
-
"""
|
| 22 |
-
DPT Head for dense prediction tasks.
|
| 23 |
-
|
| 24 |
-
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
| 25 |
-
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
| 26 |
-
backbone and produces dense predictions by fusing multi-scale features.
|
| 27 |
-
|
| 28 |
-
Args:
|
| 29 |
-
dim_in (int): Input dimension (channels).
|
| 30 |
-
patch_size (int, optional): Patch size. Default is 14.
|
| 31 |
-
output_dim (int, optional): Number of output channels. Default is 4.
|
| 32 |
-
activation (str, optional): Activation type. Default is "inv_log".
|
| 33 |
-
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
| 34 |
-
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
| 35 |
-
out_channels (List[int], optional): Output channels for each intermediate layer.
|
| 36 |
-
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
| 37 |
-
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
| 38 |
-
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
| 39 |
-
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
| 40 |
-
"""
|
| 41 |
-
|
| 42 |
-
def __init__(
|
| 43 |
-
self,
|
| 44 |
-
dim_in: int,
|
| 45 |
-
patch_size: int = 14,
|
| 46 |
-
output_dim: int = 4,
|
| 47 |
-
activation: str = "inv_log",
|
| 48 |
-
conf_activation: str = "expp1",
|
| 49 |
-
features: int = 256,
|
| 50 |
-
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 51 |
-
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
|
| 52 |
-
pos_embed: bool = True,
|
| 53 |
-
feature_only: bool = False,
|
| 54 |
-
down_ratio: int = 1,
|
| 55 |
-
) -> None:
|
| 56 |
-
super(DPTHead, self).__init__()
|
| 57 |
-
self.patch_size = patch_size
|
| 58 |
-
self.activation = activation
|
| 59 |
-
self.conf_activation = conf_activation
|
| 60 |
-
self.pos_embed = pos_embed
|
| 61 |
-
self.feature_only = feature_only
|
| 62 |
-
self.down_ratio = down_ratio
|
| 63 |
-
self.intermediate_layer_idx = intermediate_layer_idx
|
| 64 |
-
|
| 65 |
-
self.norm = nn.LayerNorm(dim_in)
|
| 66 |
-
|
| 67 |
-
# Projection layers for each output channel from tokens.
|
| 68 |
-
self.projects = nn.ModuleList(
|
| 69 |
-
[
|
| 70 |
-
nn.Conv2d(
|
| 71 |
-
in_channels=dim_in,
|
| 72 |
-
out_channels=oc,
|
| 73 |
-
kernel_size=1,
|
| 74 |
-
stride=1,
|
| 75 |
-
padding=0,
|
| 76 |
-
)
|
| 77 |
-
for oc in out_channels
|
| 78 |
-
]
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
# Resize layers for upsampling feature maps.
|
| 82 |
-
self.resize_layers = nn.ModuleList(
|
| 83 |
-
[
|
| 84 |
-
nn.ConvTranspose2d(
|
| 85 |
-
in_channels=out_channels[0],
|
| 86 |
-
out_channels=out_channels[0],
|
| 87 |
-
kernel_size=4,
|
| 88 |
-
stride=4,
|
| 89 |
-
padding=0,
|
| 90 |
-
),
|
| 91 |
-
nn.ConvTranspose2d(
|
| 92 |
-
in_channels=out_channels[1],
|
| 93 |
-
out_channels=out_channels[1],
|
| 94 |
-
kernel_size=2,
|
| 95 |
-
stride=2,
|
| 96 |
-
padding=0,
|
| 97 |
-
),
|
| 98 |
-
nn.Identity(),
|
| 99 |
-
nn.Conv2d(
|
| 100 |
-
in_channels=out_channels[3],
|
| 101 |
-
out_channels=out_channels[3],
|
| 102 |
-
kernel_size=3,
|
| 103 |
-
stride=2,
|
| 104 |
-
padding=1,
|
| 105 |
-
),
|
| 106 |
-
]
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
self.scratch = _make_scratch(
|
| 110 |
-
out_channels,
|
| 111 |
-
features,
|
| 112 |
-
expand=False,
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
# Attach additional modules to scratch.
|
| 116 |
-
self.scratch.stem_transpose = None
|
| 117 |
-
self.scratch.refinenet1 = _make_fusion_block(features)
|
| 118 |
-
self.scratch.refinenet2 = _make_fusion_block(features)
|
| 119 |
-
self.scratch.refinenet3 = _make_fusion_block(features)
|
| 120 |
-
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
| 121 |
-
|
| 122 |
-
head_features_1 = features
|
| 123 |
-
head_features_2 = 32
|
| 124 |
-
|
| 125 |
-
if feature_only:
|
| 126 |
-
self.scratch.output_conv1 = nn.Conv2d(
|
| 127 |
-
head_features_1, head_features_1, kernel_size=3, stride=1, padding=1
|
| 128 |
-
)
|
| 129 |
-
else:
|
| 130 |
-
self.scratch.output_conv1 = nn.Conv2d(
|
| 131 |
-
head_features_1,
|
| 132 |
-
head_features_1 // 2,
|
| 133 |
-
kernel_size=3,
|
| 134 |
-
stride=1,
|
| 135 |
-
padding=1,
|
| 136 |
-
)
|
| 137 |
-
conv2_in_channels = head_features_1 // 2
|
| 138 |
-
|
| 139 |
-
self.scratch.output_conv2 = nn.Sequential(
|
| 140 |
-
nn.Conv2d(
|
| 141 |
-
conv2_in_channels,
|
| 142 |
-
head_features_2,
|
| 143 |
-
kernel_size=3,
|
| 144 |
-
stride=1,
|
| 145 |
-
padding=1,
|
| 146 |
-
),
|
| 147 |
-
nn.ReLU(inplace=True),
|
| 148 |
-
nn.Conv2d(
|
| 149 |
-
head_features_2, output_dim, kernel_size=1, stride=1, padding=0
|
| 150 |
-
),
|
| 151 |
-
)
|
| 152 |
-
|
| 153 |
-
def forward(
|
| 154 |
-
self,
|
| 155 |
-
aggregated_tokens_list: List[torch.Tensor],
|
| 156 |
-
images: torch.Tensor,
|
| 157 |
-
patch_start_idx: int,
|
| 158 |
-
frames_chunk_size: int = 8,
|
| 159 |
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 160 |
-
"""
|
| 161 |
-
Forward pass through the DPT head, supports processing by chunking frames.
|
| 162 |
-
Args:
|
| 163 |
-
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 164 |
-
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 165 |
-
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
| 166 |
-
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
| 167 |
-
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
| 168 |
-
If None or larger than S, all frames are processed at once. Default: 8.
|
| 169 |
-
|
| 170 |
-
Returns:
|
| 171 |
-
Tensor or Tuple[Tensor, Tensor]:
|
| 172 |
-
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
| 173 |
-
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
| 174 |
-
"""
|
| 175 |
-
B, S, _, H, W = images.shape
|
| 176 |
-
|
| 177 |
-
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
| 178 |
-
if frames_chunk_size is None or frames_chunk_size >= S:
|
| 179 |
-
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
| 180 |
-
|
| 181 |
-
# Otherwise, process frames in chunks to manage memory usage
|
| 182 |
-
assert frames_chunk_size > 0
|
| 183 |
-
|
| 184 |
-
# Process frames in batches
|
| 185 |
-
all_preds = []
|
| 186 |
-
all_conf = []
|
| 187 |
-
|
| 188 |
-
for frames_start_idx in range(0, S, frames_chunk_size):
|
| 189 |
-
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
| 190 |
-
|
| 191 |
-
# Process batch of frames
|
| 192 |
-
if self.feature_only:
|
| 193 |
-
chunk_output = self._forward_impl(
|
| 194 |
-
aggregated_tokens_list,
|
| 195 |
-
images,
|
| 196 |
-
patch_start_idx,
|
| 197 |
-
frames_start_idx,
|
| 198 |
-
frames_end_idx,
|
| 199 |
-
)
|
| 200 |
-
all_preds.append(chunk_output)
|
| 201 |
-
else:
|
| 202 |
-
chunk_preds, chunk_conf = self._forward_impl(
|
| 203 |
-
aggregated_tokens_list,
|
| 204 |
-
images,
|
| 205 |
-
patch_start_idx,
|
| 206 |
-
frames_start_idx,
|
| 207 |
-
frames_end_idx,
|
| 208 |
-
)
|
| 209 |
-
all_preds.append(chunk_preds)
|
| 210 |
-
all_conf.append(chunk_conf)
|
| 211 |
-
|
| 212 |
-
# Concatenate results along the sequence dimension
|
| 213 |
-
if self.feature_only:
|
| 214 |
-
return torch.cat(all_preds, dim=1)
|
| 215 |
-
else:
|
| 216 |
-
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
| 217 |
-
|
| 218 |
-
def _forward_impl(
|
| 219 |
-
self,
|
| 220 |
-
aggregated_tokens_list: List[torch.Tensor],
|
| 221 |
-
images: torch.Tensor,
|
| 222 |
-
patch_start_idx: int,
|
| 223 |
-
frames_start_idx: int = None,
|
| 224 |
-
frames_end_idx: int = None,
|
| 225 |
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 226 |
-
"""
|
| 227 |
-
Implementation of the forward pass through the DPT head.
|
| 228 |
-
|
| 229 |
-
This method processes a specific chunk of frames from the sequence.
|
| 230 |
-
|
| 231 |
-
Args:
|
| 232 |
-
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 233 |
-
images (Tensor): Input images with shape [B, S, 3, H, W].
|
| 234 |
-
patch_start_idx (int): Starting index for patch tokens.
|
| 235 |
-
frames_start_idx (int, optional): Starting index for frames to process.
|
| 236 |
-
frames_end_idx (int, optional): Ending index for frames to process.
|
| 237 |
-
|
| 238 |
-
Returns:
|
| 239 |
-
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
| 240 |
-
"""
|
| 241 |
-
if frames_start_idx is not None and frames_end_idx is not None:
|
| 242 |
-
images = images[:, frames_start_idx:frames_end_idx].contiguous()
|
| 243 |
-
|
| 244 |
-
B, S, _, H, W = images.shape
|
| 245 |
-
|
| 246 |
-
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 247 |
-
|
| 248 |
-
out = []
|
| 249 |
-
dpt_idx = 0
|
| 250 |
-
|
| 251 |
-
for layer_idx in self.intermediate_layer_idx:
|
| 252 |
-
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
| 253 |
-
|
| 254 |
-
# Select frames if processing a chunk
|
| 255 |
-
if frames_start_idx is not None and frames_end_idx is not None:
|
| 256 |
-
x = x[:, frames_start_idx:frames_end_idx]
|
| 257 |
-
|
| 258 |
-
x = x.reshape(B * S, -1, x.shape[-1])
|
| 259 |
-
|
| 260 |
-
x = self.norm(x)
|
| 261 |
-
|
| 262 |
-
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 263 |
-
|
| 264 |
-
x = self.projects[dpt_idx](x)
|
| 265 |
-
if self.pos_embed:
|
| 266 |
-
x = self._apply_pos_embed(x, W, H)
|
| 267 |
-
x = self.resize_layers[dpt_idx](x)
|
| 268 |
-
|
| 269 |
-
out.append(x)
|
| 270 |
-
dpt_idx += 1
|
| 271 |
-
|
| 272 |
-
# Fuse features from multiple layers.
|
| 273 |
-
out = self.scratch_forward(out)
|
| 274 |
-
# Interpolate fused output to match target image resolution.
|
| 275 |
-
out = custom_interpolate(
|
| 276 |
-
out,
|
| 277 |
-
(
|
| 278 |
-
int(patch_h * self.patch_size / self.down_ratio),
|
| 279 |
-
int(patch_w * self.patch_size / self.down_ratio),
|
| 280 |
-
),
|
| 281 |
-
mode="bilinear",
|
| 282 |
-
align_corners=True,
|
| 283 |
-
)
|
| 284 |
-
|
| 285 |
-
if self.pos_embed:
|
| 286 |
-
out = self._apply_pos_embed(out, W, H)
|
| 287 |
-
|
| 288 |
-
if self.feature_only:
|
| 289 |
-
return out.view(B, S, *out.shape[1:])
|
| 290 |
-
|
| 291 |
-
out = self.scratch.output_conv2(out)
|
| 292 |
-
preds, conf = activate_head(
|
| 293 |
-
out, activation=self.activation, conf_activation=self.conf_activation
|
| 294 |
-
)
|
| 295 |
-
|
| 296 |
-
preds = preds.view(B, S, *preds.shape[1:])
|
| 297 |
-
conf = conf.view(B, S, *conf.shape[1:])
|
| 298 |
-
return preds, conf
|
| 299 |
-
|
| 300 |
-
def _apply_pos_embed(
|
| 301 |
-
self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1
|
| 302 |
-
) -> torch.Tensor:
|
| 303 |
-
"""
|
| 304 |
-
Apply positional embedding to tensor x.
|
| 305 |
-
"""
|
| 306 |
-
patch_w = x.shape[-1]
|
| 307 |
-
patch_h = x.shape[-2]
|
| 308 |
-
pos_embed = create_uv_grid(
|
| 309 |
-
patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device
|
| 310 |
-
)
|
| 311 |
-
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 312 |
-
pos_embed = pos_embed * ratio
|
| 313 |
-
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 314 |
-
return x + pos_embed
|
| 315 |
-
|
| 316 |
-
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
| 317 |
-
"""
|
| 318 |
-
Forward pass through the fusion blocks.
|
| 319 |
-
|
| 320 |
-
Args:
|
| 321 |
-
features (List[Tensor]): List of feature maps from different layers.
|
| 322 |
-
|
| 323 |
-
Returns:
|
| 324 |
-
Tensor: Fused feature map.
|
| 325 |
-
"""
|
| 326 |
-
layer_1, layer_2, layer_3, layer_4 = features
|
| 327 |
-
|
| 328 |
-
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 329 |
-
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 330 |
-
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 331 |
-
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 332 |
-
|
| 333 |
-
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 334 |
-
del layer_4_rn, layer_4
|
| 335 |
-
|
| 336 |
-
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 337 |
-
del layer_3_rn, layer_3
|
| 338 |
-
|
| 339 |
-
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 340 |
-
del layer_2_rn, layer_2
|
| 341 |
-
|
| 342 |
-
out = self.scratch.refinenet1(out, layer_1_rn)
|
| 343 |
-
del layer_1_rn, layer_1
|
| 344 |
-
|
| 345 |
-
out = self.scratch.output_conv1(out)
|
| 346 |
-
return out
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
################################################################################
|
| 350 |
-
# Modules
|
| 351 |
-
################################################################################
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
def _make_fusion_block(
|
| 355 |
-
features: int, size: int = None, has_residual: bool = True, groups: int = 1
|
| 356 |
-
) -> nn.Module:
|
| 357 |
-
return FeatureFusionBlock(
|
| 358 |
-
features,
|
| 359 |
-
nn.ReLU(inplace=True),
|
| 360 |
-
deconv=False,
|
| 361 |
-
bn=False,
|
| 362 |
-
expand=False,
|
| 363 |
-
align_corners=True,
|
| 364 |
-
size=size,
|
| 365 |
-
has_residual=has_residual,
|
| 366 |
-
groups=groups,
|
| 367 |
-
)
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
def _make_scratch(
|
| 371 |
-
in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
|
| 372 |
-
) -> nn.Module:
|
| 373 |
-
scratch = nn.Module()
|
| 374 |
-
out_shape1 = out_shape
|
| 375 |
-
out_shape2 = out_shape
|
| 376 |
-
out_shape3 = out_shape
|
| 377 |
-
if len(in_shape) >= 4:
|
| 378 |
-
out_shape4 = out_shape
|
| 379 |
-
|
| 380 |
-
if expand:
|
| 381 |
-
out_shape1 = out_shape
|
| 382 |
-
out_shape2 = out_shape * 2
|
| 383 |
-
out_shape3 = out_shape * 4
|
| 384 |
-
if len(in_shape) >= 4:
|
| 385 |
-
out_shape4 = out_shape * 8
|
| 386 |
-
|
| 387 |
-
scratch.layer1_rn = nn.Conv2d(
|
| 388 |
-
in_shape[0],
|
| 389 |
-
out_shape1,
|
| 390 |
-
kernel_size=3,
|
| 391 |
-
stride=1,
|
| 392 |
-
padding=1,
|
| 393 |
-
bias=False,
|
| 394 |
-
groups=groups,
|
| 395 |
-
)
|
| 396 |
-
scratch.layer2_rn = nn.Conv2d(
|
| 397 |
-
in_shape[1],
|
| 398 |
-
out_shape2,
|
| 399 |
-
kernel_size=3,
|
| 400 |
-
stride=1,
|
| 401 |
-
padding=1,
|
| 402 |
-
bias=False,
|
| 403 |
-
groups=groups,
|
| 404 |
-
)
|
| 405 |
-
scratch.layer3_rn = nn.Conv2d(
|
| 406 |
-
in_shape[2],
|
| 407 |
-
out_shape3,
|
| 408 |
-
kernel_size=3,
|
| 409 |
-
stride=1,
|
| 410 |
-
padding=1,
|
| 411 |
-
bias=False,
|
| 412 |
-
groups=groups,
|
| 413 |
-
)
|
| 414 |
-
if len(in_shape) >= 4:
|
| 415 |
-
scratch.layer4_rn = nn.Conv2d(
|
| 416 |
-
in_shape[3],
|
| 417 |
-
out_shape4,
|
| 418 |
-
kernel_size=3,
|
| 419 |
-
stride=1,
|
| 420 |
-
padding=1,
|
| 421 |
-
bias=False,
|
| 422 |
-
groups=groups,
|
| 423 |
-
)
|
| 424 |
-
return scratch
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
class ResidualConvUnit(nn.Module):
|
| 428 |
-
"""Residual convolution module."""
|
| 429 |
-
|
| 430 |
-
def __init__(self, features, activation, bn, groups=1):
|
| 431 |
-
"""Init.
|
| 432 |
-
|
| 433 |
-
Args:
|
| 434 |
-
features (int): number of features
|
| 435 |
-
"""
|
| 436 |
-
super().__init__()
|
| 437 |
-
|
| 438 |
-
self.bn = bn
|
| 439 |
-
self.groups = groups
|
| 440 |
-
self.conv1 = nn.Conv2d(
|
| 441 |
-
features,
|
| 442 |
-
features,
|
| 443 |
-
kernel_size=3,
|
| 444 |
-
stride=1,
|
| 445 |
-
padding=1,
|
| 446 |
-
bias=True,
|
| 447 |
-
groups=self.groups,
|
| 448 |
-
)
|
| 449 |
-
self.conv2 = nn.Conv2d(
|
| 450 |
-
features,
|
| 451 |
-
features,
|
| 452 |
-
kernel_size=3,
|
| 453 |
-
stride=1,
|
| 454 |
-
padding=1,
|
| 455 |
-
bias=True,
|
| 456 |
-
groups=self.groups,
|
| 457 |
-
)
|
| 458 |
-
|
| 459 |
-
self.norm1 = None
|
| 460 |
-
self.norm2 = None
|
| 461 |
-
|
| 462 |
-
self.activation = activation
|
| 463 |
-
self.skip_add = nn.quantized.FloatFunctional()
|
| 464 |
-
|
| 465 |
-
def forward(self, x):
|
| 466 |
-
"""Forward pass.
|
| 467 |
-
|
| 468 |
-
Args:
|
| 469 |
-
x (tensor): input
|
| 470 |
-
|
| 471 |
-
Returns:
|
| 472 |
-
tensor: output
|
| 473 |
-
"""
|
| 474 |
-
|
| 475 |
-
out = self.activation(x)
|
| 476 |
-
out = self.conv1(out)
|
| 477 |
-
if self.norm1 is not None:
|
| 478 |
-
out = self.norm1(out)
|
| 479 |
-
|
| 480 |
-
out = self.activation(out)
|
| 481 |
-
out = self.conv2(out)
|
| 482 |
-
if self.norm2 is not None:
|
| 483 |
-
out = self.norm2(out)
|
| 484 |
-
|
| 485 |
-
return self.skip_add.add(out, x)
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
class FeatureFusionBlock(nn.Module):
|
| 489 |
-
"""Feature fusion block."""
|
| 490 |
-
|
| 491 |
-
def __init__(
|
| 492 |
-
self,
|
| 493 |
-
features,
|
| 494 |
-
activation,
|
| 495 |
-
deconv=False,
|
| 496 |
-
bn=False,
|
| 497 |
-
expand=False,
|
| 498 |
-
align_corners=True,
|
| 499 |
-
size=None,
|
| 500 |
-
has_residual=True,
|
| 501 |
-
groups=1,
|
| 502 |
-
):
|
| 503 |
-
"""Init.
|
| 504 |
-
|
| 505 |
-
Args:
|
| 506 |
-
features (int): number of features
|
| 507 |
-
"""
|
| 508 |
-
super(FeatureFusionBlock, self).__init__()
|
| 509 |
-
|
| 510 |
-
self.deconv = deconv
|
| 511 |
-
self.align_corners = align_corners
|
| 512 |
-
self.groups = groups
|
| 513 |
-
self.expand = expand
|
| 514 |
-
out_features = features
|
| 515 |
-
if self.expand:
|
| 516 |
-
out_features = features // 2
|
| 517 |
-
|
| 518 |
-
self.out_conv = nn.Conv2d(
|
| 519 |
-
features,
|
| 520 |
-
out_features,
|
| 521 |
-
kernel_size=1,
|
| 522 |
-
stride=1,
|
| 523 |
-
padding=0,
|
| 524 |
-
bias=True,
|
| 525 |
-
groups=self.groups,
|
| 526 |
-
)
|
| 527 |
-
|
| 528 |
-
if has_residual:
|
| 529 |
-
self.resConfUnit1 = ResidualConvUnit(
|
| 530 |
-
features, activation, bn, groups=self.groups
|
| 531 |
-
)
|
| 532 |
-
|
| 533 |
-
self.has_residual = has_residual
|
| 534 |
-
self.resConfUnit2 = ResidualConvUnit(
|
| 535 |
-
features, activation, bn, groups=self.groups
|
| 536 |
-
)
|
| 537 |
-
|
| 538 |
-
self.skip_add = nn.quantized.FloatFunctional()
|
| 539 |
-
self.size = size
|
| 540 |
-
|
| 541 |
-
def forward(self, *xs, size=None):
|
| 542 |
-
"""Forward pass.
|
| 543 |
-
|
| 544 |
-
Returns:
|
| 545 |
-
tensor: output
|
| 546 |
-
"""
|
| 547 |
-
output = xs[0]
|
| 548 |
-
|
| 549 |
-
if self.has_residual:
|
| 550 |
-
res = self.resConfUnit1(xs[1])
|
| 551 |
-
output = self.skip_add.add(output, res)
|
| 552 |
-
|
| 553 |
-
output = self.resConfUnit2(output)
|
| 554 |
-
|
| 555 |
-
if (size is None) and (self.size is None):
|
| 556 |
-
modifier = {"scale_factor": 2}
|
| 557 |
-
elif size is None:
|
| 558 |
-
modifier = {"size": self.size}
|
| 559 |
-
else:
|
| 560 |
-
modifier = {"size": size}
|
| 561 |
-
|
| 562 |
-
output = custom_interpolate(
|
| 563 |
-
output, **modifier, mode="bilinear", align_corners=self.align_corners
|
| 564 |
-
)
|
| 565 |
-
output = self.out_conv(output)
|
| 566 |
-
|
| 567 |
-
return output
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
def custom_interpolate(
|
| 571 |
-
x: torch.Tensor,
|
| 572 |
-
size: Tuple[int, int] = None,
|
| 573 |
-
scale_factor: float = None,
|
| 574 |
-
mode: str = "bilinear",
|
| 575 |
-
align_corners: bool = True,
|
| 576 |
-
) -> torch.Tensor:
|
| 577 |
-
"""
|
| 578 |
-
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
| 579 |
-
"""
|
| 580 |
-
if size is None:
|
| 581 |
-
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 582 |
-
|
| 583 |
-
INT_MAX = 1610612736
|
| 584 |
-
|
| 585 |
-
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 586 |
-
|
| 587 |
-
if input_elements > INT_MAX:
|
| 588 |
-
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
| 589 |
-
interpolated_chunks = [
|
| 590 |
-
nn.functional.interpolate(
|
| 591 |
-
chunk, size=size, mode=mode, align_corners=align_corners
|
| 592 |
-
)
|
| 593 |
-
for chunk in chunks
|
| 594 |
-
]
|
| 595 |
-
x = torch.cat(interpolated_chunks, dim=0)
|
| 596 |
-
return x.contiguous()
|
| 597 |
-
else:
|
| 598 |
-
return nn.functional.interpolate(
|
| 599 |
-
x, size=size, mode=mode, align_corners=align_corners
|
| 600 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/vggt/heads/head_act.py
DELETED
|
@@ -1,127 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def activate_pose(
|
| 13 |
-
pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"
|
| 14 |
-
):
|
| 15 |
-
"""
|
| 16 |
-
Activate pose parameters with specified activation functions.
|
| 17 |
-
|
| 18 |
-
Args:
|
| 19 |
-
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
|
| 20 |
-
trans_act: Activation type for translation component
|
| 21 |
-
quat_act: Activation type for quaternion component
|
| 22 |
-
fl_act: Activation type for focal length component
|
| 23 |
-
|
| 24 |
-
Returns:
|
| 25 |
-
Activated pose parameters tensor
|
| 26 |
-
"""
|
| 27 |
-
T = pred_pose_enc[..., :3]
|
| 28 |
-
quat = pred_pose_enc[..., 3:7]
|
| 29 |
-
fl = pred_pose_enc[..., 7:] # or fov
|
| 30 |
-
|
| 31 |
-
T = base_pose_act(T, trans_act)
|
| 32 |
-
quat = base_pose_act(quat, quat_act)
|
| 33 |
-
fl = base_pose_act(fl, fl_act) # or fov
|
| 34 |
-
|
| 35 |
-
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
| 36 |
-
|
| 37 |
-
return pred_pose_enc
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def base_pose_act(pose_enc, act_type="linear"):
|
| 41 |
-
"""
|
| 42 |
-
Apply basic activation function to pose parameters.
|
| 43 |
-
|
| 44 |
-
Args:
|
| 45 |
-
pose_enc: Tensor containing encoded pose parameters
|
| 46 |
-
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
| 47 |
-
|
| 48 |
-
Returns:
|
| 49 |
-
Activated pose parameters
|
| 50 |
-
"""
|
| 51 |
-
if act_type == "linear":
|
| 52 |
-
return pose_enc
|
| 53 |
-
elif act_type == "inv_log":
|
| 54 |
-
return inverse_log_transform(pose_enc)
|
| 55 |
-
elif act_type == "exp":
|
| 56 |
-
return torch.exp(pose_enc)
|
| 57 |
-
elif act_type == "relu":
|
| 58 |
-
return F.relu(pose_enc)
|
| 59 |
-
else:
|
| 60 |
-
raise ValueError(f"Unknown act_type: {act_type}")
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
| 64 |
-
"""
|
| 65 |
-
Process network output to extract 3D points and confidence values.
|
| 66 |
-
|
| 67 |
-
Args:
|
| 68 |
-
out: Network output tensor (B, C, H, W)
|
| 69 |
-
activation: Activation type for 3D points
|
| 70 |
-
conf_activation: Activation type for confidence values
|
| 71 |
-
|
| 72 |
-
Returns:
|
| 73 |
-
Tuple of (3D points tensor, confidence tensor)
|
| 74 |
-
"""
|
| 75 |
-
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
| 76 |
-
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
| 77 |
-
|
| 78 |
-
# Split into xyz (first C-1 channels) and confidence (last channel)
|
| 79 |
-
xyz = fmap[:, :, :, :-1]
|
| 80 |
-
conf = fmap[:, :, :, -1]
|
| 81 |
-
|
| 82 |
-
if activation == "norm_exp":
|
| 83 |
-
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 84 |
-
xyz_normed = xyz / d
|
| 85 |
-
pts3d = xyz_normed * torch.expm1(d)
|
| 86 |
-
elif activation == "norm":
|
| 87 |
-
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
| 88 |
-
elif activation == "exp":
|
| 89 |
-
pts3d = torch.exp(xyz)
|
| 90 |
-
elif activation == "relu":
|
| 91 |
-
pts3d = F.relu(xyz)
|
| 92 |
-
elif activation == "inv_log":
|
| 93 |
-
pts3d = inverse_log_transform(xyz)
|
| 94 |
-
elif activation == "xy_inv_log":
|
| 95 |
-
xy, z = xyz.split([2, 1], dim=-1)
|
| 96 |
-
z = inverse_log_transform(z)
|
| 97 |
-
pts3d = torch.cat([xy * z, z], dim=-1)
|
| 98 |
-
elif activation == "sigmoid":
|
| 99 |
-
pts3d = torch.sigmoid(xyz)
|
| 100 |
-
elif activation == "linear":
|
| 101 |
-
pts3d = xyz
|
| 102 |
-
else:
|
| 103 |
-
raise ValueError(f"Unknown activation: {activation}")
|
| 104 |
-
|
| 105 |
-
if conf_activation == "expp1":
|
| 106 |
-
conf_out = 1 + conf.exp()
|
| 107 |
-
elif conf_activation == "expp0":
|
| 108 |
-
conf_out = conf.exp()
|
| 109 |
-
elif conf_activation == "sigmoid":
|
| 110 |
-
conf_out = torch.sigmoid(conf)
|
| 111 |
-
else:
|
| 112 |
-
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
| 113 |
-
|
| 114 |
-
return pts3d, conf_out
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def inverse_log_transform(y):
|
| 118 |
-
"""
|
| 119 |
-
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
| 120 |
-
|
| 121 |
-
Args:
|
| 122 |
-
y: Input tensor
|
| 123 |
-
|
| 124 |
-
Returns:
|
| 125 |
-
Transformed tensor
|
| 126 |
-
"""
|
| 127 |
-
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/vggt/heads/track_head.py
DELETED
|
@@ -1,118 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
|
| 9 |
-
from .dpt_head import DPTHead
|
| 10 |
-
from .track_modules.base_track_predictor import BaseTrackerPredictor
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class TrackHead(nn.Module):
|
| 14 |
-
"""
|
| 15 |
-
Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
|
| 16 |
-
The tracking is performed iteratively, refining predictions over multiple iterations.
|
| 17 |
-
"""
|
| 18 |
-
|
| 19 |
-
def __init__(
|
| 20 |
-
self,
|
| 21 |
-
dim_in,
|
| 22 |
-
patch_size=14,
|
| 23 |
-
features=128,
|
| 24 |
-
iters=4,
|
| 25 |
-
predict_conf=True,
|
| 26 |
-
stride=2,
|
| 27 |
-
corr_levels=7,
|
| 28 |
-
corr_radius=4,
|
| 29 |
-
hidden_size=384,
|
| 30 |
-
):
|
| 31 |
-
"""
|
| 32 |
-
Initialize the TrackHead module.
|
| 33 |
-
|
| 34 |
-
Args:
|
| 35 |
-
dim_in (int): Input dimension of tokens from the backbone.
|
| 36 |
-
patch_size (int): Size of image patches used in the vision transformer.
|
| 37 |
-
features (int): Number of feature channels in the feature extractor output.
|
| 38 |
-
iters (int): Number of refinement iterations for tracking predictions.
|
| 39 |
-
predict_conf (bool): Whether to predict confidence scores for tracked points.
|
| 40 |
-
stride (int): Stride value for the tracker predictor.
|
| 41 |
-
corr_levels (int): Number of correlation pyramid levels
|
| 42 |
-
corr_radius (int): Radius for correlation computation, controlling the search area.
|
| 43 |
-
hidden_size (int): Size of hidden layers in the tracker network.
|
| 44 |
-
"""
|
| 45 |
-
super().__init__()
|
| 46 |
-
|
| 47 |
-
self.patch_size = patch_size
|
| 48 |
-
|
| 49 |
-
# Feature extractor based on DPT architecture
|
| 50 |
-
# Processes tokens into feature maps for tracking
|
| 51 |
-
self.feature_extractor = DPTHead(
|
| 52 |
-
dim_in=dim_in,
|
| 53 |
-
patch_size=patch_size,
|
| 54 |
-
features=features,
|
| 55 |
-
feature_only=True, # Only output features, no activation
|
| 56 |
-
down_ratio=2, # Reduces spatial dimensions by factor of 2
|
| 57 |
-
pos_embed=False,
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
# Tracker module that predicts point trajectories
|
| 61 |
-
# Takes feature maps and predicts coordinates and visibility
|
| 62 |
-
self.tracker = BaseTrackerPredictor(
|
| 63 |
-
latent_dim=features, # Match the output_dim of feature extractor
|
| 64 |
-
predict_conf=predict_conf,
|
| 65 |
-
stride=stride,
|
| 66 |
-
corr_levels=corr_levels,
|
| 67 |
-
corr_radius=corr_radius,
|
| 68 |
-
hidden_size=hidden_size,
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
self.iters = iters
|
| 72 |
-
|
| 73 |
-
def forward(
|
| 74 |
-
self,
|
| 75 |
-
aggregated_tokens_list,
|
| 76 |
-
images,
|
| 77 |
-
patch_start_idx,
|
| 78 |
-
query_points=None,
|
| 79 |
-
iters=None,
|
| 80 |
-
):
|
| 81 |
-
"""
|
| 82 |
-
Forward pass of the TrackHead.
|
| 83 |
-
|
| 84 |
-
Args:
|
| 85 |
-
aggregated_tokens_list (list): List of aggregated tokens from the backbone.
|
| 86 |
-
images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
|
| 87 |
-
B = batch size, S = sequence length.
|
| 88 |
-
patch_start_idx (int): Starting index for patch tokens.
|
| 89 |
-
query_points (torch.Tensor, optional): Initial query points to track.
|
| 90 |
-
If None, points are initialized by the tracker.
|
| 91 |
-
iters (int, optional): Number of refinement iterations. If None, uses self.iters.
|
| 92 |
-
|
| 93 |
-
Returns:
|
| 94 |
-
tuple:
|
| 95 |
-
- coord_preds (torch.Tensor): Predicted coordinates for tracked points.
|
| 96 |
-
- vis_scores (torch.Tensor): Visibility scores for tracked points.
|
| 97 |
-
- conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
|
| 98 |
-
"""
|
| 99 |
-
B, S, _, H, W = images.shape
|
| 100 |
-
|
| 101 |
-
# Extract features from tokens
|
| 102 |
-
# feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
|
| 103 |
-
feature_maps = self.feature_extractor(
|
| 104 |
-
aggregated_tokens_list, images, patch_start_idx
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
# Use default iterations if not specified
|
| 108 |
-
if iters is None:
|
| 109 |
-
iters = self.iters
|
| 110 |
-
|
| 111 |
-
# Perform tracking using the extracted features
|
| 112 |
-
coord_preds, vis_scores, conf_scores = self.tracker(
|
| 113 |
-
query_points=query_points,
|
| 114 |
-
fmaps=feature_maps,
|
| 115 |
-
iters=iters,
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
return coord_preds, vis_scores, conf_scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/vggt/heads/track_modules/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py
DELETED
|
@@ -1,242 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
from einops import rearrange
|
| 10 |
-
|
| 11 |
-
from .blocks import CorrBlock, EfficientUpdateFormer
|
| 12 |
-
from .modules import Mlp
|
| 13 |
-
from .utils import get_2d_embedding, get_2d_sincos_pos_embed, sample_features4d
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class BaseTrackerPredictor(nn.Module):
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
stride=1,
|
| 20 |
-
corr_levels=5,
|
| 21 |
-
corr_radius=4,
|
| 22 |
-
latent_dim=128,
|
| 23 |
-
hidden_size=384,
|
| 24 |
-
use_spaceatt=True,
|
| 25 |
-
depth=6,
|
| 26 |
-
max_scale=518,
|
| 27 |
-
predict_conf=True,
|
| 28 |
-
):
|
| 29 |
-
super(BaseTrackerPredictor, self).__init__()
|
| 30 |
-
"""
|
| 31 |
-
The base template to create a track predictor
|
| 32 |
-
|
| 33 |
-
Modified from https://github.com/facebookresearch/co-tracker/
|
| 34 |
-
and https://github.com/facebookresearch/vggsfm
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
self.stride = stride
|
| 38 |
-
self.latent_dim = latent_dim
|
| 39 |
-
self.corr_levels = corr_levels
|
| 40 |
-
self.corr_radius = corr_radius
|
| 41 |
-
self.hidden_size = hidden_size
|
| 42 |
-
self.max_scale = max_scale
|
| 43 |
-
self.predict_conf = predict_conf
|
| 44 |
-
|
| 45 |
-
self.flows_emb_dim = latent_dim // 2
|
| 46 |
-
|
| 47 |
-
self.corr_mlp = Mlp(
|
| 48 |
-
in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
|
| 49 |
-
hidden_features=self.hidden_size,
|
| 50 |
-
out_features=self.latent_dim,
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
|
| 54 |
-
|
| 55 |
-
self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
|
| 56 |
-
|
| 57 |
-
space_depth = depth if use_spaceatt else 0
|
| 58 |
-
time_depth = depth
|
| 59 |
-
|
| 60 |
-
self.updateformer = EfficientUpdateFormer(
|
| 61 |
-
space_depth=space_depth,
|
| 62 |
-
time_depth=time_depth,
|
| 63 |
-
input_dim=self.transformer_dim,
|
| 64 |
-
hidden_size=self.hidden_size,
|
| 65 |
-
output_dim=self.latent_dim + 2,
|
| 66 |
-
mlp_ratio=4.0,
|
| 67 |
-
add_space_attn=use_spaceatt,
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
self.fmap_norm = nn.LayerNorm(self.latent_dim)
|
| 71 |
-
self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
|
| 72 |
-
|
| 73 |
-
# A linear layer to update track feats at each iteration
|
| 74 |
-
self.ffeat_updater = nn.Sequential(
|
| 75 |
-
nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
| 79 |
-
|
| 80 |
-
if predict_conf:
|
| 81 |
-
self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
| 82 |
-
|
| 83 |
-
def forward(
|
| 84 |
-
self,
|
| 85 |
-
query_points,
|
| 86 |
-
fmaps=None,
|
| 87 |
-
iters=6,
|
| 88 |
-
return_feat=False,
|
| 89 |
-
down_ratio=1,
|
| 90 |
-
apply_sigmoid=True,
|
| 91 |
-
):
|
| 92 |
-
"""
|
| 93 |
-
query_points: B x N x 2, the number of batches, tracks, and xy
|
| 94 |
-
fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
|
| 95 |
-
note HH and WW is the size of feature maps instead of original images
|
| 96 |
-
"""
|
| 97 |
-
B, N, D = query_points.shape
|
| 98 |
-
B, S, C, HH, WW = fmaps.shape
|
| 99 |
-
|
| 100 |
-
assert D == 2, "Input points must be 2D coordinates"
|
| 101 |
-
|
| 102 |
-
# apply a layernorm to fmaps here
|
| 103 |
-
fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
|
| 104 |
-
fmaps = fmaps.permute(0, 1, 4, 2, 3)
|
| 105 |
-
|
| 106 |
-
# Scale the input query_points because we may downsample the images
|
| 107 |
-
# by down_ratio or self.stride
|
| 108 |
-
# e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
|
| 109 |
-
# its query_points should be query_points/4
|
| 110 |
-
if down_ratio > 1:
|
| 111 |
-
query_points = query_points / float(down_ratio)
|
| 112 |
-
|
| 113 |
-
query_points = query_points / float(self.stride)
|
| 114 |
-
|
| 115 |
-
# Init with coords as the query points
|
| 116 |
-
# It means the search will start from the position of query points at the reference frames
|
| 117 |
-
coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
|
| 118 |
-
|
| 119 |
-
# Sample/extract the features of the query points in the query frame
|
| 120 |
-
query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
|
| 121 |
-
|
| 122 |
-
# init track feats by query feats
|
| 123 |
-
track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
|
| 124 |
-
# back up the init coords
|
| 125 |
-
coords_backup = coords.clone()
|
| 126 |
-
|
| 127 |
-
fcorr_fn = CorrBlock(
|
| 128 |
-
fmaps, num_levels=self.corr_levels, radius=self.corr_radius
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
coord_preds = []
|
| 132 |
-
|
| 133 |
-
# Iterative Refinement
|
| 134 |
-
for _ in range(iters):
|
| 135 |
-
# Detach the gradients from the last iteration
|
| 136 |
-
# (in my experience, not very important for performance)
|
| 137 |
-
coords = coords.detach()
|
| 138 |
-
|
| 139 |
-
fcorrs = fcorr_fn.corr_sample(track_feats, coords)
|
| 140 |
-
|
| 141 |
-
corr_dim = fcorrs.shape[3]
|
| 142 |
-
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
|
| 143 |
-
fcorrs_ = self.corr_mlp(fcorrs_)
|
| 144 |
-
|
| 145 |
-
# Movement of current coords relative to query points
|
| 146 |
-
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
| 147 |
-
|
| 148 |
-
flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
|
| 149 |
-
|
| 150 |
-
# (In my trials, it is also okay to just add the flows_emb instead of concat)
|
| 151 |
-
flows_emb = torch.cat(
|
| 152 |
-
[flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(
|
| 156 |
-
B * N, S, self.latent_dim
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
# Concatenate them as the input for the transformers
|
| 160 |
-
transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
|
| 161 |
-
|
| 162 |
-
# 2D positional embed
|
| 163 |
-
# TODO: this can be much simplified
|
| 164 |
-
pos_embed = get_2d_sincos_pos_embed(
|
| 165 |
-
self.transformer_dim, grid_size=(HH, WW)
|
| 166 |
-
).to(query_points.device)
|
| 167 |
-
sampled_pos_emb = sample_features4d(
|
| 168 |
-
pos_embed.expand(B, -1, -1, -1), coords[:, 0]
|
| 169 |
-
)
|
| 170 |
-
|
| 171 |
-
sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(
|
| 172 |
-
1
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
x = transformer_input + sampled_pos_emb
|
| 176 |
-
|
| 177 |
-
# Add the query ref token to the track feats
|
| 178 |
-
query_ref_token = torch.cat(
|
| 179 |
-
[
|
| 180 |
-
self.query_ref_token[:, 0:1],
|
| 181 |
-
self.query_ref_token[:, 1:2].expand(-1, S - 1, -1),
|
| 182 |
-
],
|
| 183 |
-
dim=1,
|
| 184 |
-
)
|
| 185 |
-
x = x + query_ref_token.to(x.device).to(x.dtype)
|
| 186 |
-
|
| 187 |
-
# B, N, S, C
|
| 188 |
-
x = rearrange(x, "(b n) s d -> b n s d", b=B)
|
| 189 |
-
|
| 190 |
-
# Compute the delta coordinates and delta track features
|
| 191 |
-
delta, _ = self.updateformer(x)
|
| 192 |
-
|
| 193 |
-
# BN, S, C
|
| 194 |
-
delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
|
| 195 |
-
delta_coords_ = delta[:, :, :2]
|
| 196 |
-
delta_feats_ = delta[:, :, 2:]
|
| 197 |
-
|
| 198 |
-
track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
|
| 199 |
-
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
| 200 |
-
|
| 201 |
-
# Update the track features
|
| 202 |
-
track_feats_ = (
|
| 203 |
-
self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
|
| 204 |
-
)
|
| 205 |
-
|
| 206 |
-
track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(
|
| 207 |
-
0, 2, 1, 3
|
| 208 |
-
) # BxSxNxC
|
| 209 |
-
|
| 210 |
-
# B x S x N x 2
|
| 211 |
-
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
| 212 |
-
|
| 213 |
-
# Force coord0 as query
|
| 214 |
-
# because we assume the query points should not be changed
|
| 215 |
-
coords[:, 0] = coords_backup[:, 0]
|
| 216 |
-
|
| 217 |
-
# The predicted tracks are in the original image scale
|
| 218 |
-
if down_ratio > 1:
|
| 219 |
-
coord_preds.append(coords * self.stride * down_ratio)
|
| 220 |
-
else:
|
| 221 |
-
coord_preds.append(coords * self.stride)
|
| 222 |
-
|
| 223 |
-
# B, S, N
|
| 224 |
-
vis_e = self.vis_predictor(
|
| 225 |
-
track_feats.reshape(B * S * N, self.latent_dim)
|
| 226 |
-
).reshape(B, S, N)
|
| 227 |
-
if apply_sigmoid:
|
| 228 |
-
vis_e = torch.sigmoid(vis_e)
|
| 229 |
-
|
| 230 |
-
if self.predict_conf:
|
| 231 |
-
conf_e = self.conf_predictor(
|
| 232 |
-
track_feats.reshape(B * S * N, self.latent_dim)
|
| 233 |
-
).reshape(B, S, N)
|
| 234 |
-
if apply_sigmoid:
|
| 235 |
-
conf_e = torch.sigmoid(conf_e)
|
| 236 |
-
else:
|
| 237 |
-
conf_e = None
|
| 238 |
-
|
| 239 |
-
if return_feat:
|
| 240 |
-
return coord_preds, vis_e, track_feats, query_track_feat, conf_e
|
| 241 |
-
else:
|
| 242 |
-
return coord_preds, vis_e, conf_e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|