NikV09 commited on
Commit
3f62b54
·
verified ·
1 Parent(s): f3026b8

Delete External Models to prevent HF tags

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. mapanything/models/external/README.md +0 -5
  2. mapanything/models/external/__init__.py +0 -0
  3. mapanything/models/external/anycalib/__init__.py +0 -95
  4. mapanything/models/external/dinov2/__init__.py +0 -6
  5. mapanything/models/external/dinov2/hub/__init__.py +0 -4
  6. mapanything/models/external/dinov2/hub/backbones.py +0 -183
  7. mapanything/models/external/dinov2/hub/utils.py +0 -42
  8. mapanything/models/external/dinov2/layers/__init__.py +0 -14
  9. mapanything/models/external/dinov2/layers/attention.py +0 -90
  10. mapanything/models/external/dinov2/layers/block.py +0 -290
  11. mapanything/models/external/dinov2/layers/dino_head.py +0 -67
  12. mapanything/models/external/dinov2/layers/drop_path.py +0 -36
  13. mapanything/models/external/dinov2/layers/layer_scale.py +0 -26
  14. mapanything/models/external/dinov2/layers/mlp.py +0 -40
  15. mapanything/models/external/dinov2/layers/patch_embed.py +0 -100
  16. mapanything/models/external/dinov2/layers/swiglu_ffn.py +0 -71
  17. mapanything/models/external/dinov2/models/__init__.py +0 -44
  18. mapanything/models/external/dinov2/models/vision_transformer.py +0 -448
  19. mapanything/models/external/dinov2/utils/__init__.py +0 -4
  20. mapanything/models/external/dinov2/utils/cluster.py +0 -102
  21. mapanything/models/external/dinov2/utils/config.py +0 -74
  22. mapanything/models/external/dinov2/utils/dtype.py +0 -38
  23. mapanything/models/external/dinov2/utils/param_groups.py +0 -122
  24. mapanything/models/external/dinov2/utils/utils.py +0 -105
  25. mapanything/models/external/dust3r/__init__.py +0 -217
  26. mapanything/models/external/mast3r/__init__.py +0 -191
  27. mapanything/models/external/moge/__init__.py +0 -114
  28. mapanything/models/external/moge/models/modules.py +0 -467
  29. mapanything/models/external/moge/models/utils.py +0 -477
  30. mapanything/models/external/moge/models/v1.py +0 -595
  31. mapanything/models/external/moge/models/v2.py +0 -379
  32. mapanything/models/external/must3r/__init__.py +0 -283
  33. mapanything/models/external/pi3/__init__.py +0 -119
  34. mapanything/models/external/pi3/layers/__init__.py +0 -0
  35. mapanything/models/external/pi3/layers/attention.py +0 -429
  36. mapanything/models/external/pi3/layers/block.py +0 -448
  37. mapanything/models/external/pi3/layers/camera_head.py +0 -106
  38. mapanything/models/external/pi3/layers/pos_embed.py +0 -190
  39. mapanything/models/external/pi3/layers/transformer_head.py +0 -98
  40. mapanything/models/external/pi3/models/__init__.py +0 -0
  41. mapanything/models/external/pi3/models/pi3.py +0 -251
  42. mapanything/models/external/pow3r/__init__.py +0 -860
  43. mapanything/models/external/vggt/__init__.py +0 -186
  44. mapanything/models/external/vggt/heads/__init__.py +0 -0
  45. mapanything/models/external/vggt/heads/camera_head.py +0 -167
  46. mapanything/models/external/vggt/heads/dpt_head.py +0 -600
  47. mapanything/models/external/vggt/heads/head_act.py +0 -127
  48. mapanything/models/external/vggt/heads/track_head.py +0 -118
  49. mapanything/models/external/vggt/heads/track_modules/__init__.py +0 -5
  50. mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py +0 -242
mapanything/models/external/README.md DELETED
@@ -1,5 +0,0 @@
1
- # External Model Code for Benchmarking & Re-Training
2
-
3
- This directory contains external model code that we use to train and benchmark external models fairly. These libraries are not part of the core MapAnything codebase and are included for only benchmarking purposes. The code in this directory is licensed under the same license as the source code from which it was derived, unless otherwise specified.
4
-
5
- The open-source Apache 2.0 License of MapAnything does not apply to these libraries.
 
 
 
 
 
 
mapanything/models/external/__init__.py DELETED
File without changes
mapanything/models/external/anycalib/__init__.py DELETED
@@ -1,95 +0,0 @@
1
- """
2
- Inference wrapper for AnyCalib
3
- """
4
-
5
- import torch
6
- from anycalib import AnyCalib
7
-
8
- from mapanything.utils.geometry import get_rays_in_camera_frame
9
-
10
-
11
- class AnyCalibWrapper(torch.nn.Module):
12
- def __init__(
13
- self,
14
- name,
15
- model_id="anycalib_pinhole",
16
- **kwargs,
17
- ):
18
- super().__init__()
19
- self.name = name
20
- self.model_id = model_id
21
-
22
- # Initialize the model
23
- self.model = AnyCalib(model_id=self.model_id)
24
-
25
- def forward(self, views):
26
- """
27
- Forward pass wrapper for AnyCalib.
28
-
29
- Assumption:
30
- - The number of input views is 1.
31
- - The output camera model is pinhole (fx, fy, cx, cy).
32
- This can be relaxed by not hardcoding the cam_id.
33
-
34
- Args:
35
- views (List[dict]): List of dictionaries containing the input views' images and instance information.
36
- Length of the list should be 1.
37
- Each dictionary should contain the following keys:
38
- "img" (tensor): Image tensor of shape (B, C, H, W).
39
- "data_norm_type" (list): ["identity"]
40
-
41
- Returns:
42
- List[dict]: A list containing the final outputs for the single view. Length of the list will be 1.
43
- """
44
- # Check that the number of input views is 1
45
- assert len(views) == 1, "AnyCalib only supports 1 input view."
46
-
47
- # Get input shape of the images and batch size per view
48
- _, _, height, width = views[0]["img"].shape
49
-
50
- # Check the data norm type
51
- # AnyCalib expects a normalized image but without the DINOv2 mean and std applied ("identity")
52
- data_norm_type = views[0]["data_norm_type"][0]
53
- assert data_norm_type == "identity", (
54
- "AnyCalib expects a normalized image but without the DINOv2 mean and std applied"
55
- )
56
-
57
- # Run AnyCalib inference
58
- # Corresponding batched output dictionary:
59
- # {
60
- # "intrinsics": List[(D_i,) tensors] for each camera model "i" at the original input resolution,
61
- # "fov_field": (B, N, 2) tensor with the regressed FoV field by the network. N≈320^2 (resolution close to the one seen during training),
62
- # "tangent_coords": alias for "fov_field",
63
- # "rays": (B, N, 3) tensor with the corresponding (via the exponential map) ray directions in the camera frame (x right, y down, z forward),
64
- # "pred_size": (H, W) tuple with the image size used by the network. It can be used e.g. for resizing the FoV/ray fields to the original image size.
65
- # }
66
- # For "pinhole" camera model, the intrinsics are (fx, fy, cx, cy).
67
- model_outputs = self.model.predict(views[0]["img"], cam_id="pinhole")
68
-
69
- # Convert the list of intrinsics to a tensor
70
- intrinsics = []
71
- for intrinsics_per_sample in model_outputs["intrinsics"]:
72
- pred_fx, pred_fy, pred_cx, pred_cy = intrinsics_per_sample
73
- intrinsics_per_sample = torch.tensor(
74
- [
75
- [pred_fx, 0, pred_cx],
76
- [0, pred_fy, pred_cy],
77
- [0, 0, 1],
78
- ],
79
- device=views[0]["img"].device,
80
- )
81
- intrinsics.append(intrinsics_per_sample)
82
-
83
- # Convert the list of intrinsics to a tensor of size (batch_size_per_view, 3, 3)
84
- intrinsics = torch.stack(intrinsics)
85
-
86
- # Get the ray directions
87
- with torch.autocast("cuda", enabled=False):
88
- _, ray_directions = get_rays_in_camera_frame(
89
- intrinsics, height, width, normalize_to_unit_sphere=True
90
- )
91
-
92
- # Return the output in MapAnything format
93
- res = [{"ray_directions": ray_directions, "intrinsics": intrinsics}]
94
-
95
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- __version__ = "0.0.1"
 
 
 
 
 
 
 
mapanything/models/external/dinov2/hub/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
 
 
 
 
 
mapanything/models/external/dinov2/hub/backbones.py DELETED
@@ -1,183 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from enum import Enum
7
- from typing import Union
8
-
9
- import torch
10
-
11
- from mapanything.models.external.dinov2.hub.utils import (
12
- _DINOV2_BASE_URL,
13
- _make_dinov2_model_name,
14
- )
15
-
16
-
17
- class Weights(Enum):
18
- LVD142M = "LVD142M"
19
-
20
-
21
- def _make_dinov2_model(
22
- *,
23
- arch_name: str = "vit_large",
24
- img_size: int = 518,
25
- patch_size: int = 14,
26
- init_values: float = 1.0,
27
- ffn_layer: str = "mlp",
28
- block_chunks: int = 0,
29
- num_register_tokens: int = 0,
30
- interpolate_antialias: bool = False,
31
- interpolate_offset: float = 0.1,
32
- pretrained: bool = True,
33
- weights: Union[Weights, str] = Weights.LVD142M,
34
- **kwargs,
35
- ):
36
- from ..models import vision_transformer as vits
37
-
38
- if isinstance(weights, str):
39
- try:
40
- weights = Weights[weights]
41
- except KeyError:
42
- raise AssertionError(f"Unsupported weights: {weights}")
43
-
44
- model_base_name = _make_dinov2_model_name(arch_name, patch_size)
45
- vit_kwargs = dict(
46
- img_size=img_size,
47
- patch_size=patch_size,
48
- init_values=init_values,
49
- ffn_layer=ffn_layer,
50
- block_chunks=block_chunks,
51
- num_register_tokens=num_register_tokens,
52
- interpolate_antialias=interpolate_antialias,
53
- interpolate_offset=interpolate_offset,
54
- )
55
- vit_kwargs.update(**kwargs)
56
- model = vits.__dict__[arch_name](**vit_kwargs)
57
-
58
- if pretrained:
59
- model_full_name = _make_dinov2_model_name(
60
- arch_name, patch_size, num_register_tokens
61
- )
62
- url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
63
- state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
64
- model.load_state_dict(state_dict, strict=True)
65
-
66
- return model
67
-
68
-
69
- def dinov2_vits14(
70
- *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
71
- ):
72
- """
73
- DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
74
- """
75
- return _make_dinov2_model(
76
- arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs
77
- )
78
-
79
-
80
- def dinov2_vitb14(
81
- *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
82
- ):
83
- """
84
- DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
85
- """
86
- return _make_dinov2_model(
87
- arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs
88
- )
89
-
90
-
91
- def dinov2_vitl14(
92
- *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
93
- ):
94
- """
95
- DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
96
- """
97
- return _make_dinov2_model(
98
- arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs
99
- )
100
-
101
-
102
- def dinov2_vitg14(
103
- *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
104
- ):
105
- """
106
- DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
107
- """
108
- return _make_dinov2_model(
109
- arch_name="vit_giant2",
110
- ffn_layer="swiglufused",
111
- weights=weights,
112
- pretrained=pretrained,
113
- **kwargs,
114
- )
115
-
116
-
117
- def dinov2_vits14_reg(
118
- *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
119
- ):
120
- """
121
- DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
122
- """
123
- return _make_dinov2_model(
124
- arch_name="vit_small",
125
- pretrained=pretrained,
126
- weights=weights,
127
- num_register_tokens=4,
128
- interpolate_antialias=True,
129
- interpolate_offset=0.0,
130
- **kwargs,
131
- )
132
-
133
-
134
- def dinov2_vitb14_reg(
135
- *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
136
- ):
137
- """
138
- DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
139
- """
140
- return _make_dinov2_model(
141
- arch_name="vit_base",
142
- pretrained=pretrained,
143
- weights=weights,
144
- num_register_tokens=4,
145
- interpolate_antialias=True,
146
- interpolate_offset=0.0,
147
- **kwargs,
148
- )
149
-
150
-
151
- def dinov2_vitl14_reg(
152
- *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
153
- ):
154
- """
155
- DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
156
- """
157
- return _make_dinov2_model(
158
- arch_name="vit_large",
159
- pretrained=pretrained,
160
- weights=weights,
161
- num_register_tokens=4,
162
- interpolate_antialias=True,
163
- interpolate_offset=0.0,
164
- **kwargs,
165
- )
166
-
167
-
168
- def dinov2_vitg14_reg(
169
- *, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
170
- ):
171
- """
172
- DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
173
- """
174
- return _make_dinov2_model(
175
- arch_name="vit_giant2",
176
- ffn_layer="swiglufused",
177
- weights=weights,
178
- pretrained=pretrained,
179
- num_register_tokens=4,
180
- interpolate_antialias=True,
181
- interpolate_offset=0.0,
182
- **kwargs,
183
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/hub/utils.py DELETED
@@ -1,42 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import itertools
7
- import math
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
-
13
- _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
14
-
15
-
16
- def _make_dinov2_model_name(
17
- arch_name: str, patch_size: int, num_register_tokens: int = 0
18
- ) -> str:
19
- compact_arch_name = arch_name.replace("_", "")[:4]
20
- registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
21
- return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
22
-
23
-
24
- class CenterPadding(nn.Module):
25
- def __init__(self, multiple):
26
- super().__init__()
27
- self.multiple = multiple
28
-
29
- def _get_pad(self, size):
30
- new_size = math.ceil(size / self.multiple) * self.multiple
31
- pad_size = new_size - size
32
- pad_size_left = pad_size // 2
33
- pad_size_right = pad_size - pad_size_left
34
- return pad_size_left, pad_size_right
35
-
36
- @torch.inference_mode()
37
- def forward(self, x):
38
- pads = list(
39
- itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])
40
- )
41
- output = F.pad(x, pads)
42
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/layers/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- from mapanything.models.external.dinov2.layers.dino_head import DINOHead # noqa
7
- from mapanything.models.external.dinov2.layers.mlp import Mlp # noqa
8
- from mapanything.models.external.dinov2.layers.patch_embed import PatchEmbed # noqa
9
- from mapanything.models.external.dinov2.layers.swiglu_ffn import (
10
- SwiGLUFFN, # noqa
11
- SwiGLUFFNFused, # noqa
12
- )
13
- from mapanything.models.external.dinov2.layers.block import NestedTensorBlock # noqa
14
- from mapanything.models.external.dinov2.layers.attention import MemEffAttention # noqa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/layers/attention.py DELETED
@@ -1,90 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- # References:
7
- # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
-
10
- import logging
11
- import os
12
-
13
- from torch import nn, Tensor
14
-
15
- logger = logging.getLogger("dinov2")
16
-
17
-
18
- XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
19
- try:
20
- if XFORMERS_ENABLED:
21
- from xformers.ops import memory_efficient_attention, unbind
22
-
23
- XFORMERS_AVAILABLE = True
24
- # warnings.warn("xFormers is available (Attention)")
25
- else:
26
- # warnings.warn("xFormers is disabled (Attention)")
27
- raise ImportError
28
- except ImportError:
29
- XFORMERS_AVAILABLE = False
30
- # warnings.warn("xFormers is not available (Attention)")
31
-
32
-
33
- class Attention(nn.Module):
34
- def __init__(
35
- self,
36
- dim: int,
37
- num_heads: int = 8,
38
- qkv_bias: bool = False,
39
- proj_bias: bool = True,
40
- attn_drop: float = 0.0,
41
- proj_drop: float = 0.0,
42
- ) -> None:
43
- super().__init__()
44
- self.num_heads = num_heads
45
- head_dim = dim // num_heads
46
- self.scale = head_dim**-0.5
47
-
48
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
49
- self.attn_drop = nn.Dropout(attn_drop)
50
- self.proj = nn.Linear(dim, dim, bias=proj_bias)
51
- self.proj_drop = nn.Dropout(proj_drop)
52
-
53
- def forward(self, x: Tensor, attn_bias=None) -> Tensor:
54
- B, N, C = x.shape
55
- qkv = (
56
- self.qkv(x)
57
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
58
- .permute(2, 0, 3, 1, 4)
59
- )
60
-
61
- q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
62
- attn = q @ k.transpose(-2, -1)
63
-
64
- attn = attn.softmax(dim=-1)
65
- attn = self.attn_drop(attn)
66
-
67
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
68
- x = self.proj(x)
69
- x = self.proj_drop(x)
70
- return x
71
-
72
-
73
- class MemEffAttention(Attention):
74
- def forward(self, x: Tensor, attn_bias=None) -> Tensor:
75
- if not XFORMERS_AVAILABLE:
76
- if attn_bias is not None:
77
- raise AssertionError("xFormers is required for using nested tensors")
78
- return super().forward(x)
79
-
80
- B, N, C = x.shape
81
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
82
-
83
- q, k, v = unbind(qkv, 2)
84
-
85
- x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
86
- x = x.reshape([B, N, C])
87
-
88
- x = self.proj(x)
89
- x = self.proj_drop(x)
90
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/layers/block.py DELETED
@@ -1,290 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- # References:
7
- # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
-
10
- import logging
11
- import os
12
- from typing import Any, Callable, Dict, List, Tuple
13
-
14
- import torch
15
- from torch import nn, Tensor
16
-
17
- from mapanything.models.external.dinov2.layers.attention import (
18
- Attention,
19
- MemEffAttention,
20
- )
21
- from mapanything.models.external.dinov2.layers.drop_path import DropPath
22
- from mapanything.models.external.dinov2.layers.layer_scale import LayerScale
23
- from mapanything.models.external.dinov2.layers.mlp import Mlp
24
-
25
- logger = logging.getLogger("dinov2")
26
-
27
-
28
- XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
29
- try:
30
- if XFORMERS_ENABLED:
31
- from xformers.ops import fmha, index_select_cat, scaled_index_add
32
-
33
- XFORMERS_AVAILABLE = True
34
- # warnings.warn("xFormers is available (Block)")
35
- else:
36
- # warnings.warn("xFormers is disabled (Block)")
37
- raise ImportError
38
- except ImportError:
39
- XFORMERS_AVAILABLE = False
40
- # warnings.warn("xFormers is not available (Block)")
41
-
42
-
43
- class Block(nn.Module):
44
- def __init__(
45
- self,
46
- dim: int,
47
- num_heads: int,
48
- mlp_ratio: float = 4.0,
49
- qkv_bias: bool = False,
50
- proj_bias: bool = True,
51
- ffn_bias: bool = True,
52
- drop: float = 0.0,
53
- attn_drop: float = 0.0,
54
- init_values=None,
55
- drop_path: float = 0.0,
56
- act_layer: Callable[..., nn.Module] = nn.GELU,
57
- norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
58
- attn_class: Callable[..., nn.Module] = Attention,
59
- ffn_layer: Callable[..., nn.Module] = Mlp,
60
- ) -> None:
61
- super().__init__()
62
- # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
63
- self.norm1 = norm_layer(dim)
64
- self.attn = attn_class(
65
- dim,
66
- num_heads=num_heads,
67
- qkv_bias=qkv_bias,
68
- proj_bias=proj_bias,
69
- attn_drop=attn_drop,
70
- proj_drop=drop,
71
- )
72
- self.ls1 = (
73
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
74
- )
75
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
76
-
77
- self.norm2 = norm_layer(dim)
78
- mlp_hidden_dim = int(dim * mlp_ratio)
79
- self.mlp = ffn_layer(
80
- in_features=dim,
81
- hidden_features=mlp_hidden_dim,
82
- act_layer=act_layer,
83
- drop=drop,
84
- bias=ffn_bias,
85
- )
86
- self.ls2 = (
87
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
88
- )
89
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
90
-
91
- self.sample_drop_ratio = drop_path
92
-
93
- def forward(self, x: Tensor) -> Tensor:
94
- def attn_residual_func(x: Tensor) -> Tensor:
95
- return self.ls1(self.attn(self.norm1(x)))
96
-
97
- def ffn_residual_func(x: Tensor) -> Tensor:
98
- return self.ls2(self.mlp(self.norm2(x)))
99
-
100
- if self.training and self.sample_drop_ratio > 0.1:
101
- # the overhead is compensated only for a drop path rate larger than 0.1
102
- x = drop_add_residual_stochastic_depth(
103
- x,
104
- residual_func=attn_residual_func,
105
- sample_drop_ratio=self.sample_drop_ratio,
106
- )
107
- x = drop_add_residual_stochastic_depth(
108
- x,
109
- residual_func=ffn_residual_func,
110
- sample_drop_ratio=self.sample_drop_ratio,
111
- )
112
- elif self.training and self.sample_drop_ratio > 0.0:
113
- x = x + self.drop_path1(attn_residual_func(x))
114
- x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
115
- else:
116
- x = x + attn_residual_func(x)
117
- x = x + ffn_residual_func(x)
118
- return x
119
-
120
-
121
- def drop_add_residual_stochastic_depth(
122
- x: Tensor,
123
- residual_func: Callable[[Tensor], Tensor],
124
- sample_drop_ratio: float = 0.0,
125
- ) -> Tensor:
126
- # 1) extract subset using permutation
127
- b, n, d = x.shape
128
- sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
129
- brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
130
- x_subset = x[brange]
131
-
132
- # 2) apply residual_func to get residual
133
- residual = residual_func(x_subset)
134
-
135
- x_flat = x.flatten(1)
136
- residual = residual.flatten(1)
137
-
138
- residual_scale_factor = b / sample_subset_size
139
-
140
- # 3) add the residual
141
- x_plus_residual = torch.index_add(
142
- x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
143
- )
144
- return x_plus_residual.view_as(x)
145
-
146
-
147
- def get_branges_scales(x, sample_drop_ratio=0.0):
148
- b, n, d = x.shape
149
- sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
150
- brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
151
- residual_scale_factor = b / sample_subset_size
152
- return brange, residual_scale_factor
153
-
154
-
155
- def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
156
- if scaling_vector is None:
157
- x_flat = x.flatten(1)
158
- residual = residual.flatten(1)
159
- x_plus_residual = torch.index_add(
160
- x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
161
- )
162
- else:
163
- x_plus_residual = scaled_index_add(
164
- x,
165
- brange,
166
- residual.to(dtype=x.dtype),
167
- scaling=scaling_vector,
168
- alpha=residual_scale_factor,
169
- )
170
- return x_plus_residual
171
-
172
-
173
- attn_bias_cache: Dict[Tuple, Any] = {}
174
-
175
-
176
- def get_attn_bias_and_cat(x_list, branges=None):
177
- """
178
- this will perform the index select, cat the tensors, and provide the attn_bias from cache
179
- """
180
- batch_sizes = (
181
- [b.shape[0] for b in branges]
182
- if branges is not None
183
- else [x.shape[0] for x in x_list]
184
- )
185
- all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
186
- if all_shapes not in attn_bias_cache.keys():
187
- seqlens = []
188
- for b, x in zip(batch_sizes, x_list):
189
- for _ in range(b):
190
- seqlens.append(x.shape[1])
191
- attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
192
- attn_bias._batch_sizes = batch_sizes
193
- attn_bias_cache[all_shapes] = attn_bias
194
-
195
- if branges is not None:
196
- cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
197
- 1, -1, x_list[0].shape[-1]
198
- )
199
- else:
200
- tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
201
- cat_tensors = torch.cat(tensors_bs1, dim=1)
202
-
203
- return attn_bias_cache[all_shapes], cat_tensors
204
-
205
-
206
- def drop_add_residual_stochastic_depth_list(
207
- x_list: List[Tensor],
208
- residual_func: Callable[[Tensor, Any], Tensor],
209
- sample_drop_ratio: float = 0.0,
210
- scaling_vector=None,
211
- ) -> Tensor:
212
- # 1) generate random set of indices for dropping samples in the batch
213
- branges_scales = [
214
- get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
215
- ]
216
- branges = [s[0] for s in branges_scales]
217
- residual_scale_factors = [s[1] for s in branges_scales]
218
-
219
- # 2) get attention bias and index+concat the tensors
220
- attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
221
-
222
- # 3) apply residual_func to get residual, and split the result
223
- residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
224
-
225
- outputs = []
226
- for x, brange, residual, residual_scale_factor in zip(
227
- x_list, branges, residual_list, residual_scale_factors
228
- ):
229
- outputs.append(
230
- add_residual(
231
- x, brange, residual, residual_scale_factor, scaling_vector
232
- ).view_as(x)
233
- )
234
- return outputs
235
-
236
-
237
- class NestedTensorBlock(Block):
238
- def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
239
- """
240
- x_list contains a list of tensors to nest together and run
241
- """
242
- assert isinstance(self.attn, MemEffAttention)
243
-
244
- if self.training and self.sample_drop_ratio > 0.0:
245
-
246
- def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
247
- return self.attn(self.norm1(x), attn_bias=attn_bias)
248
-
249
- def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
250
- return self.mlp(self.norm2(x))
251
-
252
- x_list = drop_add_residual_stochastic_depth_list(
253
- x_list,
254
- residual_func=attn_residual_func,
255
- sample_drop_ratio=self.sample_drop_ratio,
256
- scaling_vector=self.ls1.gamma
257
- if isinstance(self.ls1, LayerScale)
258
- else None,
259
- )
260
- x_list = drop_add_residual_stochastic_depth_list(
261
- x_list,
262
- residual_func=ffn_residual_func,
263
- sample_drop_ratio=self.sample_drop_ratio,
264
- scaling_vector=self.ls2.gamma
265
- if isinstance(self.ls1, LayerScale)
266
- else None,
267
- )
268
- return x_list
269
- else:
270
-
271
- def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
272
- return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
273
-
274
- def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
275
- return self.ls2(self.mlp(self.norm2(x)))
276
-
277
- attn_bias, x = get_attn_bias_and_cat(x_list)
278
- x = x + attn_residual_func(x, attn_bias=attn_bias)
279
- x = x + ffn_residual_func(x)
280
- return attn_bias.split(x)
281
-
282
- def forward(self, x_or_x_list):
283
- if isinstance(x_or_x_list, Tensor):
284
- return super().forward(x_or_x_list)
285
- elif isinstance(x_or_x_list, list):
286
- if not XFORMERS_AVAILABLE:
287
- raise AssertionError("xFormers is required for using nested tensors")
288
- return self.forward_nested(x_or_x_list)
289
- else:
290
- raise AssertionError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/layers/dino_head.py DELETED
@@ -1,67 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import torch
7
- import torch.nn as nn
8
- from torch.nn.init import trunc_normal_
9
- from torch.nn.utils import weight_norm
10
-
11
-
12
- class DINOHead(nn.Module):
13
- def __init__(
14
- self,
15
- in_dim,
16
- out_dim,
17
- use_bn=False,
18
- nlayers=3,
19
- hidden_dim=2048,
20
- bottleneck_dim=256,
21
- mlp_bias=True,
22
- ):
23
- super().__init__()
24
- nlayers = max(nlayers, 1)
25
- self.mlp = _build_mlp(
26
- nlayers,
27
- in_dim,
28
- bottleneck_dim,
29
- hidden_dim=hidden_dim,
30
- use_bn=use_bn,
31
- bias=mlp_bias,
32
- )
33
- self.apply(self._init_weights)
34
- self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
35
- self.last_layer.weight_g.data.fill_(1)
36
-
37
- def _init_weights(self, m):
38
- if isinstance(m, nn.Linear):
39
- trunc_normal_(m.weight, std=0.02)
40
- if isinstance(m, nn.Linear) and m.bias is not None:
41
- nn.init.constant_(m.bias, 0)
42
-
43
- def forward(self, x):
44
- x = self.mlp(x)
45
- eps = 1e-6 if x.dtype == torch.float16 else 1e-12
46
- x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
47
- x = self.last_layer(x)
48
- return x
49
-
50
-
51
- def _build_mlp(
52
- nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
53
- ):
54
- if nlayers == 1:
55
- return nn.Linear(in_dim, bottleneck_dim, bias=bias)
56
- else:
57
- layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
58
- if use_bn:
59
- layers.append(nn.BatchNorm1d(hidden_dim))
60
- layers.append(nn.GELU())
61
- for _ in range(nlayers - 2):
62
- layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
63
- if use_bn:
64
- layers.append(nn.BatchNorm1d(hidden_dim))
65
- layers.append(nn.GELU())
66
- layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
67
- return nn.Sequential(*layers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/layers/drop_path.py DELETED
@@ -1,36 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- # References:
7
- # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
-
10
-
11
- from torch import nn
12
-
13
-
14
- def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
- if drop_prob == 0.0 or not training:
16
- return x
17
- keep_prob = 1 - drop_prob
18
- shape = (x.shape[0],) + (1,) * (
19
- x.ndim - 1
20
- ) # work with diff dim tensors, not just 2D ConvNets
21
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
22
- if keep_prob > 0.0:
23
- random_tensor.div_(keep_prob)
24
- output = x * random_tensor
25
- return output
26
-
27
-
28
- class DropPath(nn.Module):
29
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
30
-
31
- def __init__(self, drop_prob=None):
32
- super(DropPath, self).__init__()
33
- self.drop_prob = drop_prob
34
-
35
- def forward(self, x):
36
- return drop_path(x, self.drop_prob, self.training)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/layers/layer_scale.py DELETED
@@ -1,26 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
-
8
- from typing import Union
9
-
10
- import torch
11
- from torch import nn, Tensor
12
-
13
-
14
- class LayerScale(nn.Module):
15
- def __init__(
16
- self,
17
- dim: int,
18
- init_values: Union[float, Tensor] = 1e-5,
19
- inplace: bool = False,
20
- ) -> None:
21
- super().__init__()
22
- self.inplace = inplace
23
- self.gamma = nn.Parameter(init_values * torch.ones(dim))
24
-
25
- def forward(self, x: Tensor) -> Tensor:
26
- return x.mul_(self.gamma) if self.inplace else x * self.gamma
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/layers/mlp.py DELETED
@@ -1,40 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- # References:
7
- # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
-
10
-
11
- from typing import Callable, Optional
12
-
13
- from torch import nn, Tensor
14
-
15
-
16
- class Mlp(nn.Module):
17
- def __init__(
18
- self,
19
- in_features: int,
20
- hidden_features: Optional[int] = None,
21
- out_features: Optional[int] = None,
22
- act_layer: Callable[..., nn.Module] = nn.GELU,
23
- drop: float = 0.0,
24
- bias: bool = True,
25
- ) -> None:
26
- super().__init__()
27
- out_features = out_features or in_features
28
- hidden_features = hidden_features or in_features
29
- self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
- self.act = act_layer()
31
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
- self.drop = nn.Dropout(drop)
33
-
34
- def forward(self, x: Tensor) -> Tensor:
35
- x = self.fc1(x)
36
- x = self.act(x)
37
- x = self.drop(x)
38
- x = self.fc2(x)
39
- x = self.drop(x)
40
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/layers/patch_embed.py DELETED
@@ -1,100 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- # References:
7
- # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
-
10
- from typing import Callable, Optional, Tuple, Union
11
-
12
- import torch.nn as nn
13
- from torch import Tensor
14
-
15
-
16
- def make_2tuple(x):
17
- if isinstance(x, tuple):
18
- assert len(x) == 2
19
- return x
20
-
21
- assert isinstance(x, int)
22
- return (x, x)
23
-
24
-
25
- class PatchEmbed(nn.Module):
26
- """
27
- 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
-
29
- Args:
30
- img_size: Image size.
31
- patch_size: Patch token size.
32
- in_chans: Number of input image channels.
33
- embed_dim: Number of linear projection output channels.
34
- norm_layer: Normalization layer.
35
- """
36
-
37
- def __init__(
38
- self,
39
- img_size: Union[int, Tuple[int, int]] = 224,
40
- patch_size: Union[int, Tuple[int, int]] = 16,
41
- in_chans: int = 3,
42
- embed_dim: int = 768,
43
- norm_layer: Optional[Callable] = None,
44
- flatten_embedding: bool = True,
45
- ) -> None:
46
- super().__init__()
47
-
48
- image_HW = make_2tuple(img_size)
49
- patch_HW = make_2tuple(patch_size)
50
- patch_grid_size = (
51
- image_HW[0] // patch_HW[0],
52
- image_HW[1] // patch_HW[1],
53
- )
54
-
55
- self.img_size = image_HW
56
- self.patch_size = patch_HW
57
- self.patches_resolution = patch_grid_size
58
- self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
-
60
- self.in_chans = in_chans
61
- self.embed_dim = embed_dim
62
-
63
- self.flatten_embedding = flatten_embedding
64
-
65
- self.proj = nn.Conv2d(
66
- in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
67
- )
68
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
69
-
70
- def forward(self, x: Tensor) -> Tensor:
71
- _, _, H, W = x.shape
72
- patch_H, patch_W = self.patch_size
73
-
74
- assert H % patch_H == 0, (
75
- f"Input image height {H} is not a multiple of patch height {patch_H}"
76
- )
77
- assert W % patch_W == 0, (
78
- f"Input image width {W} is not a multiple of patch width: {patch_W}"
79
- )
80
-
81
- x = self.proj(x) # B C H W
82
- H, W = x.size(2), x.size(3)
83
- x = x.flatten(2).transpose(1, 2) # B HW C
84
- x = self.norm(x)
85
- if not self.flatten_embedding:
86
- x = x.reshape(-1, H, W, self.embed_dim) # B H W C
87
- return x
88
-
89
- def flops(self) -> float:
90
- Ho, Wo = self.patches_resolution
91
- flops = (
92
- Ho
93
- * Wo
94
- * self.embed_dim
95
- * self.in_chans
96
- * (self.patch_size[0] * self.patch_size[1])
97
- )
98
- if self.norm is not None:
99
- flops += Ho * Wo * self.embed_dim
100
- return flops
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/layers/swiglu_ffn.py DELETED
@@ -1,71 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import os
7
- from typing import Callable, Optional
8
-
9
- import torch.nn.functional as F
10
- from torch import nn, Tensor
11
-
12
-
13
- class SwiGLUFFN(nn.Module):
14
- def __init__(
15
- self,
16
- in_features: int,
17
- hidden_features: Optional[int] = None,
18
- out_features: Optional[int] = None,
19
- act_layer: Callable[..., nn.Module] = None,
20
- drop: float = 0.0,
21
- bias: bool = True,
22
- ) -> None:
23
- super().__init__()
24
- out_features = out_features or in_features
25
- hidden_features = hidden_features or in_features
26
- self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
- self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
-
29
- def forward(self, x: Tensor) -> Tensor:
30
- x12 = self.w12(x)
31
- x1, x2 = x12.chunk(2, dim=-1)
32
- hidden = F.silu(x1) * x2
33
- return self.w3(hidden)
34
-
35
-
36
- XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
37
- try:
38
- if XFORMERS_ENABLED:
39
- from xformers.ops import SwiGLU
40
-
41
- XFORMERS_AVAILABLE = True
42
- # warnings.warn("xFormers is available (SwiGLU)")
43
- else:
44
- # warnings.warn("xFormers is disabled (SwiGLU)")
45
- raise ImportError
46
- except ImportError:
47
- SwiGLU = SwiGLUFFN
48
- XFORMERS_AVAILABLE = False
49
-
50
- # warnings.warn("xFormers is not available (SwiGLU)")
51
-
52
-
53
- class SwiGLUFFNFused(SwiGLU):
54
- def __init__(
55
- self,
56
- in_features: int,
57
- hidden_features: Optional[int] = None,
58
- out_features: Optional[int] = None,
59
- act_layer: Callable[..., nn.Module] = None,
60
- drop: float = 0.0,
61
- bias: bool = True,
62
- ) -> None:
63
- out_features = out_features or in_features
64
- hidden_features = hidden_features or in_features
65
- hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
66
- super().__init__(
67
- in_features=in_features,
68
- hidden_features=hidden_features,
69
- out_features=out_features,
70
- bias=bias,
71
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/models/__init__.py DELETED
@@ -1,44 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import logging
7
-
8
- import mapanything.models.external.dinov2.models.vision_transformer as vits
9
-
10
- logger = logging.getLogger("dinov2")
11
-
12
-
13
- def build_model(args, only_teacher=False, img_size=224):
14
- args.arch = args.arch.removesuffix("_memeff")
15
- if "vit" in args.arch:
16
- vit_kwargs = dict(
17
- img_size=img_size,
18
- patch_size=args.patch_size,
19
- init_values=args.layerscale,
20
- ffn_layer=args.ffn_layer,
21
- block_chunks=args.block_chunks,
22
- qkv_bias=args.qkv_bias,
23
- proj_bias=args.proj_bias,
24
- ffn_bias=args.ffn_bias,
25
- num_register_tokens=args.num_register_tokens,
26
- interpolate_offset=args.interpolate_offset,
27
- interpolate_antialias=args.interpolate_antialias,
28
- )
29
- teacher = vits.__dict__[args.arch](**vit_kwargs)
30
- if only_teacher:
31
- return teacher, teacher.embed_dim
32
- student = vits.__dict__[args.arch](
33
- **vit_kwargs,
34
- drop_path_rate=args.drop_path_rate,
35
- drop_path_uniform=args.drop_path_uniform,
36
- )
37
- embed_dim = student.embed_dim
38
- return student, teacher, embed_dim
39
-
40
-
41
- def build_model_from_cfg(cfg, only_teacher=False):
42
- return build_model(
43
- cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size
44
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/models/vision_transformer.py DELETED
@@ -1,448 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- # References:
7
- # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
-
10
- import math
11
- from functools import partial
12
- from typing import Callable, Sequence, Tuple, Union
13
-
14
- import torch
15
- import torch.nn as nn
16
- from torch.nn.init import trunc_normal_
17
- from torch.utils.checkpoint import checkpoint
18
-
19
- from mapanything.models.external.dinov2.layers import (
20
- MemEffAttention,
21
- Mlp,
22
- NestedTensorBlock as Block,
23
- PatchEmbed,
24
- SwiGLUFFNFused,
25
- )
26
- from mapanything.models.external.pi3.layers.attention import FlashAttention
27
-
28
- # logger = logging.getLogger("dinov2")
29
-
30
-
31
- def named_apply(
32
- fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
33
- ) -> nn.Module:
34
- if not depth_first and include_root:
35
- fn(module=module, name=name)
36
- for child_name, child_module in module.named_children():
37
- child_name = ".".join((name, child_name)) if name else child_name
38
- named_apply(
39
- fn=fn,
40
- module=child_module,
41
- name=child_name,
42
- depth_first=depth_first,
43
- include_root=True,
44
- )
45
- if depth_first and include_root:
46
- fn(module=module, name=name)
47
- return module
48
-
49
-
50
- class BlockChunk(nn.ModuleList):
51
- def forward(self, x):
52
- for b in self:
53
- x = b(x)
54
- return x
55
-
56
-
57
- class DinoVisionTransformer(nn.Module):
58
- def __init__(
59
- self,
60
- img_size=224,
61
- patch_size=16,
62
- in_chans=3,
63
- embed_dim=768,
64
- depth=12,
65
- num_heads=12,
66
- mlp_ratio=4.0,
67
- qkv_bias=True,
68
- ffn_bias=True,
69
- proj_bias=True,
70
- drop_path_rate=0.0,
71
- drop_path_uniform=False,
72
- init_values=None, # for layerscale: None or 0 => no layerscale
73
- embed_layer=PatchEmbed,
74
- act_layer=nn.GELU,
75
- block_fn=Block,
76
- ffn_layer="mlp",
77
- block_chunks=1,
78
- num_register_tokens=0,
79
- interpolate_antialias=False,
80
- interpolate_offset=0.1,
81
- ):
82
- """
83
- Args:
84
- img_size (int, tuple): input image size
85
- patch_size (int, tuple): patch size
86
- in_chans (int): number of input channels
87
- embed_dim (int): embedding dimension
88
- depth (int): depth of transformer
89
- num_heads (int): number of attention heads
90
- mlp_ratio (int): ratio of mlp hidden dim to embedding dim
91
- qkv_bias (bool): enable bias for qkv if True
92
- proj_bias (bool): enable bias for proj in attn if True
93
- ffn_bias (bool): enable bias for ffn if True
94
- drop_path_rate (float): stochastic depth rate
95
- drop_path_uniform (bool): apply uniform drop rate across blocks
96
- weight_init (str): weight init scheme
97
- init_values (float): layer-scale init values
98
- embed_layer (nn.Module): patch embedding layer
99
- act_layer (nn.Module): MLP activation layer
100
- block_fn (nn.Module): transformer block class
101
- ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
102
- block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
103
- num_register_tokens: (int) number of extra cls tokens (so-called "registers")
104
- interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
105
- interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
106
- """
107
- super().__init__()
108
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
109
-
110
- self.num_features = self.embed_dim = (
111
- embed_dim # num_features for consistency with other models
112
- )
113
- self.num_tokens = 1
114
- self.n_blocks = depth
115
- self.num_heads = num_heads
116
- self.patch_size = patch_size
117
- self.num_register_tokens = num_register_tokens
118
- self.interpolate_antialias = interpolate_antialias
119
- self.interpolate_offset = interpolate_offset
120
-
121
- self.patch_embed = embed_layer(
122
- img_size=img_size,
123
- patch_size=patch_size,
124
- in_chans=in_chans,
125
- embed_dim=embed_dim,
126
- )
127
- num_patches = self.patch_embed.num_patches
128
-
129
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
130
- self.pos_embed = nn.Parameter(
131
- torch.zeros(1, num_patches + self.num_tokens, embed_dim)
132
- )
133
- assert num_register_tokens >= 0
134
- self.register_tokens = (
135
- nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
136
- if num_register_tokens
137
- else None
138
- )
139
-
140
- if drop_path_uniform is True:
141
- dpr = [drop_path_rate] * depth
142
- else:
143
- dpr = [
144
- x.item() for x in torch.linspace(0, drop_path_rate, depth)
145
- ] # stochastic depth decay rule
146
-
147
- if ffn_layer == "mlp":
148
- # logger.info("using MLP layer as FFN")
149
- ffn_layer = Mlp
150
- elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
151
- # logger.info("using SwiGLU layer as FFN")
152
- ffn_layer = SwiGLUFFNFused
153
- elif ffn_layer == "identity":
154
- # logger.info("using Identity layer as FFN")
155
-
156
- def f(*args, **kwargs):
157
- return nn.Identity()
158
-
159
- ffn_layer = f
160
- else:
161
- raise NotImplementedError
162
-
163
- blocks_list = [
164
- block_fn(
165
- dim=embed_dim,
166
- num_heads=num_heads,
167
- mlp_ratio=mlp_ratio,
168
- qkv_bias=qkv_bias,
169
- proj_bias=proj_bias,
170
- ffn_bias=ffn_bias,
171
- drop_path=dpr[i],
172
- norm_layer=norm_layer,
173
- act_layer=act_layer,
174
- ffn_layer=ffn_layer,
175
- init_values=init_values,
176
- attn_class=FlashAttention,
177
- )
178
- for i in range(depth)
179
- ]
180
- if block_chunks > 0:
181
- self.chunked_blocks = True
182
- chunked_blocks = []
183
- chunksize = depth // block_chunks
184
- for i in range(0, depth, chunksize):
185
- # this is to keep the block index consistent if we chunk the block list
186
- chunked_blocks.append(
187
- [nn.Identity()] * i + blocks_list[i : i + chunksize]
188
- )
189
- self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
190
- else:
191
- self.chunked_blocks = False
192
- self.blocks = nn.ModuleList(blocks_list)
193
-
194
- self.norm = norm_layer(embed_dim)
195
- self.head = nn.Identity()
196
-
197
- self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
198
-
199
- self.init_weights()
200
-
201
- def init_weights(self):
202
- trunc_normal_(self.pos_embed, std=0.02)
203
- nn.init.normal_(self.cls_token, std=1e-6)
204
- if self.register_tokens is not None:
205
- nn.init.normal_(self.register_tokens, std=1e-6)
206
- named_apply(init_weights_vit_timm, self)
207
-
208
- def interpolate_pos_encoding(self, x, w, h):
209
- previous_dtype = x.dtype
210
- npatch = x.shape[1] - 1
211
- N = self.pos_embed.shape[1] - 1
212
- if npatch == N and w == h:
213
- return self.pos_embed
214
- pos_embed = self.pos_embed.float()
215
- class_pos_embed = pos_embed[:, 0]
216
- patch_pos_embed = pos_embed[:, 1:]
217
- dim = x.shape[-1]
218
- w0 = w // self.patch_size
219
- h0 = h // self.patch_size
220
- M = int(math.sqrt(N)) # Recover the number of patches in each dimension
221
- assert N == M * M
222
- kwargs = {}
223
- if self.interpolate_offset:
224
- # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
225
- # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
226
- sx = float(w0 + self.interpolate_offset) / M
227
- sy = float(h0 + self.interpolate_offset) / M
228
- kwargs["scale_factor"] = (sx, sy)
229
- else:
230
- # Simply specify an output size instead of a scale factor
231
- kwargs["size"] = (w0, h0)
232
- patch_pos_embed = nn.functional.interpolate(
233
- patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
234
- mode="bicubic",
235
- antialias=self.interpolate_antialias,
236
- **kwargs,
237
- )
238
- assert (w0, h0) == patch_pos_embed.shape[-2:]
239
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
240
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
241
- previous_dtype
242
- )
243
-
244
- def prepare_tokens_with_masks(self, x, masks=None):
245
- B, nc, w, h = x.shape
246
- x = self.patch_embed(x)
247
- if masks is not None:
248
- x = torch.where(
249
- masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
250
- )
251
-
252
- x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
253
- x = x + self.interpolate_pos_encoding(x, w, h)
254
-
255
- if self.register_tokens is not None:
256
- x = torch.cat(
257
- (
258
- x[:, :1],
259
- self.register_tokens.expand(x.shape[0], -1, -1),
260
- x[:, 1:],
261
- ),
262
- dim=1,
263
- )
264
-
265
- return x
266
-
267
- def forward_features_list(self, x_list, masks_list):
268
- x = [
269
- self.prepare_tokens_with_masks(x, masks)
270
- for x, masks in zip(x_list, masks_list)
271
- ]
272
- for blk in self.blocks:
273
- if self.training:
274
- x = checkpoint(blk, x, use_reentrant=False)
275
- else:
276
- x = blk(x)
277
-
278
- all_x = x
279
- output = []
280
- for x, masks in zip(all_x, masks_list):
281
- x_norm = self.norm(x)
282
- output.append(
283
- {
284
- "x_norm_clstoken": x_norm[:, 0],
285
- "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
286
- "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
287
- "x_prenorm": x,
288
- "masks": masks,
289
- }
290
- )
291
- return output
292
-
293
- def forward_features(self, x, masks=None):
294
- if isinstance(x, list):
295
- return self.forward_features_list(x, masks)
296
-
297
- x = self.prepare_tokens_with_masks(x, masks)
298
-
299
- for blk in self.blocks:
300
- if self.training:
301
- x = checkpoint(blk, x, use_reentrant=False)
302
- else:
303
- x = blk(x)
304
-
305
- x_norm = self.norm(x)
306
- return {
307
- "x_norm_clstoken": x_norm[:, 0],
308
- "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
309
- "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
310
- "x_prenorm": x,
311
- "masks": masks,
312
- }
313
-
314
- def _get_intermediate_layers_not_chunked(self, x, n=1):
315
- x = self.prepare_tokens_with_masks(x)
316
- # If n is an int, take the n last blocks. If it's a list, take them
317
- output, total_block_len = [], len(self.blocks)
318
- blocks_to_take = (
319
- range(total_block_len - n, total_block_len) if isinstance(n, int) else n
320
- )
321
- for i, blk in enumerate(self.blocks):
322
- x = blk(x)
323
- if i in blocks_to_take:
324
- output.append(x)
325
- assert len(output) == len(blocks_to_take), (
326
- f"only {len(output)} / {len(blocks_to_take)} blocks found"
327
- )
328
- return output
329
-
330
- def _get_intermediate_layers_chunked(self, x, n=1):
331
- x = self.prepare_tokens_with_masks(x)
332
- output, i, total_block_len = [], 0, len(self.blocks[-1])
333
- # If n is an int, take the n last blocks. If it's a list, take them
334
- blocks_to_take = (
335
- range(total_block_len - n, total_block_len) if isinstance(n, int) else n
336
- )
337
- for block_chunk in self.blocks:
338
- for blk in block_chunk[i:]: # Passing the nn.Identity()
339
- x = blk(x)
340
- if i in blocks_to_take:
341
- output.append(x)
342
- i += 1
343
- assert len(output) == len(blocks_to_take), (
344
- f"only {len(output)} / {len(blocks_to_take)} blocks found"
345
- )
346
- return output
347
-
348
- def get_intermediate_layers(
349
- self,
350
- x: torch.Tensor,
351
- n: Union[int, Sequence] = 1, # Layers or n last layers to take
352
- reshape: bool = False,
353
- return_class_token: bool = False,
354
- norm=True,
355
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
356
- if self.chunked_blocks:
357
- outputs = self._get_intermediate_layers_chunked(x, n)
358
- else:
359
- outputs = self._get_intermediate_layers_not_chunked(x, n)
360
- if norm:
361
- outputs = [self.norm(out) for out in outputs]
362
- class_tokens = [out[:, 0] for out in outputs]
363
- outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
364
- if reshape:
365
- B, _, w, h = x.shape
366
- outputs = [
367
- out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
368
- .permute(0, 3, 1, 2)
369
- .contiguous()
370
- for out in outputs
371
- ]
372
- if return_class_token:
373
- return tuple(zip(outputs, class_tokens))
374
- return tuple(outputs)
375
-
376
- def forward(self, *args, is_training=False, **kwargs):
377
- ret = self.forward_features(*args, **kwargs)
378
- if is_training:
379
- return ret
380
- else:
381
- return self.head(ret["x_norm_clstoken"])
382
-
383
-
384
- def init_weights_vit_timm(module: nn.Module, name: str = ""):
385
- """ViT weight initialization, original timm impl (for reproducibility)"""
386
- if isinstance(module, nn.Linear):
387
- trunc_normal_(module.weight, std=0.02)
388
- if module.bias is not None:
389
- nn.init.zeros_(module.bias)
390
-
391
-
392
- def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
393
- model = DinoVisionTransformer(
394
- patch_size=patch_size,
395
- embed_dim=384,
396
- depth=12,
397
- num_heads=6,
398
- mlp_ratio=4,
399
- block_fn=partial(Block, attn_class=MemEffAttention),
400
- num_register_tokens=num_register_tokens,
401
- **kwargs,
402
- )
403
- return model
404
-
405
-
406
- def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
407
- model = DinoVisionTransformer(
408
- patch_size=patch_size,
409
- embed_dim=768,
410
- depth=12,
411
- num_heads=12,
412
- mlp_ratio=4,
413
- block_fn=partial(Block, attn_class=MemEffAttention),
414
- num_register_tokens=num_register_tokens,
415
- **kwargs,
416
- )
417
- return model
418
-
419
-
420
- def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
421
- model = DinoVisionTransformer(
422
- patch_size=patch_size,
423
- embed_dim=1024,
424
- depth=24,
425
- num_heads=16,
426
- mlp_ratio=4,
427
- block_fn=partial(Block, attn_class=MemEffAttention),
428
- num_register_tokens=num_register_tokens,
429
- **kwargs,
430
- )
431
- return model
432
-
433
-
434
- def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
435
- """
436
- Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
437
- """
438
- model = DinoVisionTransformer(
439
- patch_size=patch_size,
440
- embed_dim=1536,
441
- depth=40,
442
- num_heads=24,
443
- mlp_ratio=4,
444
- block_fn=partial(Block, attn_class=MemEffAttention),
445
- num_register_tokens=num_register_tokens,
446
- **kwargs,
447
- )
448
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/utils/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
 
 
 
 
 
mapanything/models/external/dinov2/utils/cluster.py DELETED
@@ -1,102 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import os
7
- from enum import Enum
8
- from pathlib import Path
9
- from typing import Any, Dict, Optional
10
-
11
-
12
- class ClusterType(Enum):
13
- AWS = "aws"
14
- FAIR = "fair"
15
- RSC = "rsc"
16
-
17
-
18
- def _guess_cluster_type() -> ClusterType:
19
- uname = os.uname()
20
- if uname.sysname == "Linux":
21
- if uname.release.endswith("-aws"):
22
- # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
23
- return ClusterType.AWS
24
- elif uname.nodename.startswith("rsc"):
25
- # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
26
- return ClusterType.RSC
27
-
28
- return ClusterType.FAIR
29
-
30
-
31
- def get_cluster_type(
32
- cluster_type: Optional[ClusterType] = None,
33
- ) -> Optional[ClusterType]:
34
- if cluster_type is None:
35
- return _guess_cluster_type()
36
-
37
- return cluster_type
38
-
39
-
40
- def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
41
- cluster_type = get_cluster_type(cluster_type)
42
- if cluster_type is None:
43
- return None
44
-
45
- CHECKPOINT_DIRNAMES = {
46
- ClusterType.AWS: "checkpoints",
47
- ClusterType.FAIR: "checkpoint",
48
- ClusterType.RSC: "checkpoint/dino",
49
- }
50
- return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
51
-
52
-
53
- def get_user_checkpoint_path(
54
- cluster_type: Optional[ClusterType] = None,
55
- ) -> Optional[Path]:
56
- checkpoint_path = get_checkpoint_path(cluster_type)
57
- if checkpoint_path is None:
58
- return None
59
-
60
- username = os.environ.get("USER")
61
- assert username is not None
62
- return checkpoint_path / username
63
-
64
-
65
- def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
66
- cluster_type = get_cluster_type(cluster_type)
67
- if cluster_type is None:
68
- return None
69
-
70
- SLURM_PARTITIONS = {
71
- ClusterType.AWS: "learnlab",
72
- ClusterType.FAIR: "learnlab",
73
- ClusterType.RSC: "learn",
74
- }
75
- return SLURM_PARTITIONS[cluster_type]
76
-
77
-
78
- def get_slurm_executor_parameters(
79
- nodes: int,
80
- num_gpus_per_node: int,
81
- cluster_type: Optional[ClusterType] = None,
82
- **kwargs,
83
- ) -> Dict[str, Any]:
84
- # create default parameters
85
- params = {
86
- "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
87
- "gpus_per_node": num_gpus_per_node,
88
- "tasks_per_node": num_gpus_per_node, # one task per GPU
89
- "cpus_per_task": 10,
90
- "nodes": nodes,
91
- "slurm_partition": get_slurm_partition(cluster_type),
92
- }
93
- # apply cluster-specific adjustments
94
- cluster_type = get_cluster_type(cluster_type)
95
- if cluster_type == ClusterType.AWS:
96
- params["cpus_per_task"] = 12
97
- del params["mem_gb"]
98
- elif cluster_type == ClusterType.RSC:
99
- params["cpus_per_task"] = 12
100
- # set additional parameters / apply overrides
101
- params.update(kwargs)
102
- return params
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/utils/config.py DELETED
@@ -1,74 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import logging
7
- import math
8
- import os
9
-
10
- import dinov2.distributed as distributed
11
- from dinov2.configs import dinov2_default_config
12
- from dinov2.logging import setup_logging
13
- from dinov2.utils import utils
14
- from omegaconf import OmegaConf
15
-
16
- logger = logging.getLogger("dinov2")
17
-
18
-
19
- def apply_scaling_rules_to_cfg(cfg): # to fix
20
- if cfg.optim.scaling_rule == "sqrt_wrt_1024":
21
- base_lr = cfg.optim.base_lr
22
- cfg.optim.lr = base_lr
23
- cfg.optim.lr *= math.sqrt(
24
- cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0
25
- )
26
- logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
27
- else:
28
- raise NotImplementedError
29
- return cfg
30
-
31
-
32
- def write_config(cfg, output_dir, name="config.yaml"):
33
- logger.info(OmegaConf.to_yaml(cfg))
34
- saved_cfg_path = os.path.join(output_dir, name)
35
- with open(saved_cfg_path, "w") as f:
36
- OmegaConf.save(config=cfg, f=f)
37
- return saved_cfg_path
38
-
39
-
40
- def get_cfg_from_args(args):
41
- args.output_dir = os.path.abspath(args.output_dir)
42
- args.opts += [f"train.output_dir={args.output_dir}"]
43
- default_cfg = OmegaConf.create(dinov2_default_config)
44
- cfg = OmegaConf.load(args.config_file)
45
- cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
46
- return cfg
47
-
48
-
49
- def default_setup(args):
50
- distributed.enable(overwrite=True)
51
- seed = getattr(args, "seed", 0)
52
- rank = distributed.get_global_rank()
53
-
54
- global logger
55
- setup_logging(output=args.output_dir, level=logging.INFO)
56
- logger = logging.getLogger("dinov2")
57
-
58
- utils.fix_random_seeds(seed + rank)
59
- logger.info("git:\n {}\n".format(utils.get_sha()))
60
- logger.info(
61
- "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
62
- )
63
-
64
-
65
- def setup(args):
66
- """
67
- Create configs and perform basic setups.
68
- """
69
- cfg = get_cfg_from_args(args)
70
- os.makedirs(args.output_dir, exist_ok=True)
71
- default_setup(args)
72
- apply_scaling_rules_to_cfg(cfg)
73
- write_config(cfg, args.output_dir)
74
- return cfg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/utils/dtype.py DELETED
@@ -1,38 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
-
7
- from typing import Dict, Union
8
-
9
- import numpy as np
10
- import torch
11
-
12
- TypeSpec = Union[str, np.dtype, torch.dtype]
13
-
14
-
15
- _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
16
- np.dtype("bool"): torch.bool,
17
- np.dtype("uint8"): torch.uint8,
18
- np.dtype("int8"): torch.int8,
19
- np.dtype("int16"): torch.int16,
20
- np.dtype("int32"): torch.int32,
21
- np.dtype("int64"): torch.int64,
22
- np.dtype("float16"): torch.float16,
23
- np.dtype("float32"): torch.float32,
24
- np.dtype("float64"): torch.float64,
25
- np.dtype("complex64"): torch.complex64,
26
- np.dtype("complex128"): torch.complex128,
27
- }
28
-
29
-
30
- def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
31
- if isinstance(dtype, torch.dtype):
32
- return dtype
33
- if isinstance(dtype, str):
34
- dtype = np.dtype(dtype)
35
- assert isinstance(dtype, np.dtype), (
36
- f"Expected an instance of nunpy dtype, got {type(dtype)}"
37
- )
38
- return _NUMPY_TO_TORCH_DTYPE[dtype]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/utils/param_groups.py DELETED
@@ -1,122 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import logging
7
- from collections import defaultdict
8
-
9
- logger = logging.getLogger("dinov2")
10
-
11
-
12
- def get_vit_lr_decay_rate(
13
- name,
14
- lr_decay_rate=1.0,
15
- num_layers=12,
16
- force_is_backbone=False,
17
- chunked_blocks=False,
18
- ):
19
- """
20
- Calculate lr decay rate for different ViT blocks.
21
- Args:
22
- name (string): parameter name.
23
- lr_decay_rate (float): base lr decay rate.
24
- num_layers (int): number of ViT blocks.
25
- Returns:
26
- lr decay rate for the given parameter.
27
- """
28
- layer_id = num_layers + 1
29
- if name.startswith("backbone") or force_is_backbone:
30
- if (
31
- ".pos_embed" in name
32
- or ".patch_embed" in name
33
- or ".mask_token" in name
34
- or ".cls_token" in name
35
- or ".register_tokens" in name
36
- ):
37
- layer_id = 0
38
- elif force_is_backbone and (
39
- "pos_embed" in name
40
- or "patch_embed" in name
41
- or "mask_token" in name
42
- or "cls_token" in name
43
- or "register_tokens" in name
44
- ):
45
- layer_id = 0
46
- elif ".blocks." in name and ".residual." not in name:
47
- layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
48
- elif chunked_blocks and "blocks." in name and "residual." not in name:
49
- layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
50
- elif "blocks." in name and "residual." not in name:
51
- layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
52
-
53
- return lr_decay_rate ** (num_layers + 1 - layer_id)
54
-
55
-
56
- def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
57
- chunked_blocks = False
58
- if hasattr(model, "n_blocks"):
59
- logger.info("chunked fsdp")
60
- n_blocks = model.n_blocks
61
- chunked_blocks = model.chunked_blocks
62
- elif hasattr(model, "blocks"):
63
- logger.info("first code branch")
64
- n_blocks = len(model.blocks)
65
- elif hasattr(model, "backbone"):
66
- logger.info("second code branch")
67
- n_blocks = len(model.backbone.blocks)
68
- else:
69
- logger.info("else code branch")
70
- n_blocks = 0
71
- all_param_groups = []
72
-
73
- for name, param in model.named_parameters():
74
- name = name.replace("_fsdp_wrapped_module.", "")
75
- if not param.requires_grad:
76
- continue
77
- decay_rate = get_vit_lr_decay_rate(
78
- name,
79
- lr_decay_rate,
80
- num_layers=n_blocks,
81
- force_is_backbone=n_blocks > 0,
82
- chunked_blocks=chunked_blocks,
83
- )
84
- d = {
85
- "params": param,
86
- "is_last_layer": False,
87
- "lr_multiplier": decay_rate,
88
- "wd_multiplier": 1.0,
89
- "name": name,
90
- }
91
-
92
- if "last_layer" in name:
93
- d.update({"is_last_layer": True})
94
-
95
- if name.endswith(".bias") or "norm" in name or "gamma" in name:
96
- d.update({"wd_multiplier": 0.0})
97
-
98
- if "patch_embed" in name:
99
- d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
100
-
101
- all_param_groups.append(d)
102
- logger.info(
103
- f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}"""
104
- )
105
-
106
- return all_param_groups
107
-
108
-
109
- def fuse_params_groups(
110
- all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")
111
- ):
112
- fused_params_groups = defaultdict(lambda: {"params": []})
113
- for d in all_params_groups:
114
- identifier = ""
115
- for k in keys:
116
- identifier += k + str(d[k]) + "_"
117
-
118
- for k in keys:
119
- fused_params_groups[identifier][k] = d[k]
120
- fused_params_groups[identifier]["params"].append(d["params"])
121
-
122
- return fused_params_groups.values()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dinov2/utils/utils.py DELETED
@@ -1,105 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- import os
7
- import random
8
- import subprocess
9
- from urllib.parse import urlparse
10
-
11
- import numpy as np
12
- import torch
13
- from torch import nn
14
-
15
- # logger = logging.getLogger("dinov2")
16
-
17
-
18
- def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
19
- if urlparse(pretrained_weights).scheme: # If it looks like an URL
20
- state_dict = torch.hub.load_state_dict_from_url(
21
- pretrained_weights, map_location="cpu"
22
- )
23
- else:
24
- state_dict = torch.load(pretrained_weights, map_location="cpu")
25
- if checkpoint_key is not None and checkpoint_key in state_dict:
26
- # logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
27
- state_dict = state_dict[checkpoint_key]
28
- # remove `module.` prefix
29
- state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
30
- # remove `backbone.` prefix induced by multicrop wrapper
31
- state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
32
- _ = model.load_state_dict(state_dict, strict=False)
33
- # logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
34
-
35
-
36
- def fix_random_seeds(seed=31):
37
- """
38
- Fix random seeds.
39
- """
40
- torch.manual_seed(seed)
41
- torch.cuda.manual_seed_all(seed)
42
- np.random.seed(seed)
43
- random.seed(seed)
44
-
45
-
46
- def get_sha():
47
- cwd = os.path.dirname(os.path.abspath(__file__))
48
-
49
- def _run(command):
50
- return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
51
-
52
- sha = "N/A"
53
- diff = "clean"
54
- branch = "N/A"
55
- try:
56
- sha = _run(["git", "rev-parse", "HEAD"])
57
- subprocess.check_output(["git", "diff"], cwd=cwd)
58
- diff = _run(["git", "diff-index", "HEAD"])
59
- diff = "has uncommitted changes" if diff else "clean"
60
- branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
61
- except Exception:
62
- pass
63
- message = f"sha: {sha}, status: {diff}, branch: {branch}"
64
- return message
65
-
66
-
67
- class CosineScheduler(object):
68
- def __init__(
69
- self,
70
- base_value,
71
- final_value,
72
- total_iters,
73
- warmup_iters=0,
74
- start_warmup_value=0,
75
- freeze_iters=0,
76
- ):
77
- super().__init__()
78
- self.final_value = final_value
79
- self.total_iters = total_iters
80
-
81
- freeze_schedule = np.zeros((freeze_iters))
82
-
83
- warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
84
-
85
- iters = np.arange(total_iters - warmup_iters - freeze_iters)
86
- schedule = final_value + 0.5 * (base_value - final_value) * (
87
- 1 + np.cos(np.pi * iters / len(iters))
88
- )
89
- self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
90
-
91
- assert len(self.schedule) == self.total_iters
92
-
93
- def __getitem__(self, it):
94
- if it >= self.total_iters:
95
- return self.final_value
96
- else:
97
- return self.schedule[it]
98
-
99
-
100
- def has_batchnorms(model):
101
- bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
102
- for name, module in model.named_modules():
103
- if isinstance(module, bn_types):
104
- return True
105
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/dust3r/__init__.py DELETED
@@ -1,217 +0,0 @@
1
- """
2
- Inference wrapper for DUSt3R
3
- """
4
-
5
- import warnings
6
-
7
- import torch
8
- from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
9
- from dust3r.image_pairs import make_pairs
10
- from dust3r.inference import inference
11
- from dust3r.model import AsymmetricCroCo3DStereo # noqa
12
-
13
- from mapanything.models.external.vggt.utils.rotation import mat_to_quat
14
- from mapanything.utils.geometry import (
15
- convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
16
- convert_z_depth_to_depth_along_ray,
17
- depthmap_to_camera_frame,
18
- get_rays_in_camera_frame,
19
- )
20
-
21
- inf = float("inf")
22
-
23
-
24
- def load_model(model_path, device, verbose=True):
25
- if verbose:
26
- print("Loading model from", model_path)
27
- ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
28
- args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
29
- if "landscape_only" not in args:
30
- args = args[:-1] + ", landscape_only=False)"
31
- else:
32
- args = args.replace(" ", "").replace(
33
- "landscape_only=True", "landscape_only=False"
34
- )
35
- assert "landscape_only=False" in args
36
- if verbose:
37
- print(f"Instantiating: {args}")
38
- try:
39
- net = eval(args)
40
- except NameError:
41
- net = AsymmetricCroCo3DStereo(
42
- enc_depth=24,
43
- dec_depth=12,
44
- enc_embed_dim=1024,
45
- dec_embed_dim=768,
46
- enc_num_heads=16,
47
- dec_num_heads=12,
48
- pos_embed="RoPE100",
49
- patch_embed_cls="PatchEmbedDust3R",
50
- img_size=(512, 512),
51
- head_type="dpt",
52
- output_mode="pts3d",
53
- depth_mode=("exp", -inf, inf),
54
- conf_mode=("exp", 1, inf),
55
- landscape_only=False,
56
- )
57
- s = net.load_state_dict(ckpt["model"], strict=False)
58
- if verbose:
59
- print(s)
60
- return net.to(device)
61
-
62
-
63
- class DUSt3RBAWrapper(torch.nn.Module):
64
- def __init__(
65
- self,
66
- name,
67
- ckpt_path,
68
- scene_graph="complete",
69
- inference_batch_size=32,
70
- global_optim_schedule="cosine",
71
- global_optim_lr=0.01,
72
- global_optim_niter=300,
73
- **kwargs,
74
- ):
75
- super().__init__()
76
- self.name = name
77
- self.ckpt_path = ckpt_path
78
- self.scene_graph = scene_graph
79
- self.inference_batch_size = inference_batch_size
80
- self.global_optim_schedule = global_optim_schedule
81
- self.global_optim_lr = global_optim_lr
82
- self.global_optim_niter = global_optim_niter
83
-
84
- # Init the model and load the checkpoint
85
- self.model = load_model(self.ckpt_path, device="cpu")
86
-
87
- # Init the global aligner mode
88
- self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer
89
-
90
- def forward(self, views):
91
- """
92
- Forward pass wrapper for DUSt3R using the global aligner.
93
-
94
- Assumption:
95
- - The batch size of input views is 1.
96
-
97
- Args:
98
- views (List[dict]): List of dictionaries containing the input views' images and instance information.
99
- Each dictionary should contain the following keys, where B is the batch size and is 1:
100
- "img" (tensor): Image tensor of shape (B, C, H, W).
101
- "data_norm_type" (list): ["dust3r"]
102
-
103
- Returns:
104
- List[dict]: A list containing the final outputs for the input views.
105
- """
106
- # Check the batch size of input views
107
- batch_size_per_view, _, height, width = views[0]["img"].shape
108
- device = views[0]["img"].device
109
- num_views = len(views)
110
- assert batch_size_per_view == 1, (
111
- f"Batch size of input views should be 1, but got {batch_size_per_view}."
112
- )
113
-
114
- # Check the data norm type
115
- data_norm_type = views[0]["data_norm_type"][0]
116
- assert data_norm_type == "dust3r", (
117
- "DUSt3R expects a normalized image with the DUSt3R normalization scheme applied"
118
- )
119
-
120
- # Convert the input views to the expected input format
121
- images = []
122
- for view in views:
123
- images.append(
124
- dict(
125
- img=view["img"],
126
- idx=len(images),
127
- instance=str(len(images)),
128
- )
129
- )
130
-
131
- # Make image pairs and run inference pair-wise
132
- pairs = make_pairs(
133
- images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True
134
- )
135
- with warnings.catch_warnings():
136
- warnings.simplefilter("ignore", category=FutureWarning)
137
- output = inference(
138
- pairs,
139
- self.model,
140
- device,
141
- batch_size=self.inference_batch_size,
142
- verbose=False,
143
- )
144
-
145
- # Global optimization
146
- with torch.enable_grad():
147
- scene = global_aligner(
148
- output, device=device, mode=self.global_aligner_mode, verbose=False
149
- )
150
- _ = scene.compute_global_alignment(
151
- init="mst",
152
- niter=self.global_optim_niter,
153
- schedule=self.global_optim_schedule,
154
- lr=self.global_optim_lr,
155
- )
156
-
157
- # Make sure scene is not None
158
- if scene is None:
159
- raise RuntimeError("Global optimization failed.")
160
-
161
- # Get the predictions
162
- intrinsics = scene.get_intrinsics()
163
- c2w_poses = scene.get_im_poses()
164
- depths = scene.get_depthmaps()
165
-
166
- # Convert the output to the MapAnything format
167
- with torch.autocast("cuda", enabled=False):
168
- res = []
169
- for view_idx in range(num_views):
170
- # Get the current view predictions
171
- curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0)
172
- curr_view_pose = c2w_poses[view_idx].unsqueeze(0)
173
- curr_view_depth_z = depths[view_idx].unsqueeze(0)
174
-
175
- # Convert the pose to quaternions and translation
176
- curr_view_cam_translations = curr_view_pose[..., :3, 3]
177
- curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3])
178
-
179
- # Get the camera frame pointmaps
180
- curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
181
- curr_view_depth_z, curr_view_intrinsic
182
- )
183
-
184
- # Convert the z depth to depth along ray
185
- curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
186
- curr_view_depth_z, curr_view_intrinsic
187
- )
188
- curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
189
-
190
- # Get the ray directions on the unit sphere in the camera frame
191
- _, curr_view_ray_dirs = get_rays_in_camera_frame(
192
- curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
193
- )
194
-
195
- # Get the pointmaps
196
- curr_view_pts3d = (
197
- convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
198
- curr_view_ray_dirs,
199
- curr_view_depth_along_ray,
200
- curr_view_cam_translations,
201
- curr_view_cam_quats,
202
- )
203
- )
204
-
205
- # Append the outputs to the result list
206
- res.append(
207
- {
208
- "pts3d": curr_view_pts3d,
209
- "pts3d_cam": curr_view_pts3d_cam,
210
- "ray_directions": curr_view_ray_dirs,
211
- "depth_along_ray": curr_view_depth_along_ray,
212
- "cam_trans": curr_view_cam_translations,
213
- "cam_quats": curr_view_cam_quats,
214
- }
215
- )
216
-
217
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/mast3r/__init__.py DELETED
@@ -1,191 +0,0 @@
1
- """
2
- Inference wrapper for MASt3R + Sparse GA
3
- """
4
-
5
- import os
6
- import tempfile
7
- import warnings
8
-
9
- import torch
10
- from dust3r.image_pairs import make_pairs
11
- from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
12
- from mast3r.model import load_model
13
-
14
- from mapanything.models.external.vggt.utils.rotation import mat_to_quat
15
- from mapanything.utils.geometry import (
16
- convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
17
- convert_z_depth_to_depth_along_ray,
18
- depthmap_to_camera_frame,
19
- get_rays_in_camera_frame,
20
- )
21
-
22
-
23
- class MASt3RSGAWrapper(torch.nn.Module):
24
- def __init__(
25
- self,
26
- name,
27
- ckpt_path,
28
- cache_dir,
29
- scene_graph="complete",
30
- sparse_ga_lr1=0.07,
31
- sparse_ga_niter1=300,
32
- sparse_ga_lr2=0.01,
33
- sparse_ga_niter2=300,
34
- sparse_ga_optim_level="refine+depth",
35
- sparse_ga_shared_intrinsics=False,
36
- sparse_ga_matching_conf_thr=5.0,
37
- **kwargs,
38
- ):
39
- super().__init__()
40
- self.name = name
41
- self.ckpt_path = ckpt_path
42
- self.cache_dir = cache_dir
43
- self.scene_graph = scene_graph
44
- self.sparse_ga_lr1 = sparse_ga_lr1
45
- self.sparse_ga_niter1 = sparse_ga_niter1
46
- self.sparse_ga_lr2 = sparse_ga_lr2
47
- self.sparse_ga_niter2 = sparse_ga_niter2
48
- self.sparse_ga_optim_level = sparse_ga_optim_level
49
- self.sparse_ga_shared_intrinsics = sparse_ga_shared_intrinsics
50
- self.sparse_ga_matching_conf_thr = sparse_ga_matching_conf_thr
51
-
52
- # Init the model and load the checkpoint
53
- self.model = load_model(self.ckpt_path, device="cpu")
54
-
55
- def forward(self, views):
56
- """
57
- Forward pass wrapper for MASt3R using the sparse global aligner.
58
-
59
- Assumption:
60
- - The batch size of input views is 1.
61
-
62
- Args:
63
- views (List[dict]): List of dictionaries containing the input views' images and instance information.
64
- Each dictionary should contain the following keys, where B is the batch size and is 1:
65
- "img" (tensor): Image tensor of shape (B, C, H, W).
66
- "data_norm_type" (list): ["dust3r"]
67
- "label" (list): ["scene_name"]
68
- "instance" (list): ["image_name"]
69
-
70
- Returns:
71
- List[dict]: A list containing the final outputs for the input views.
72
- """
73
- # Check the batch size of input views
74
- batch_size_per_view, _, height, width = views[0]["img"].shape
75
- device = views[0]["img"].device
76
- num_views = len(views)
77
- assert batch_size_per_view == 1, (
78
- f"Batch size of input views should be 1, but got {batch_size_per_view}."
79
- )
80
-
81
- # Check the data norm type
82
- data_norm_type = views[0]["data_norm_type"][0]
83
- assert data_norm_type == "dust3r", (
84
- "MASt3R expects a normalized image with the DUSt3R normalization scheme applied"
85
- )
86
-
87
- # Convert the input views to the expected input format
88
- images = []
89
- image_paths = []
90
- for view in views:
91
- images.append(
92
- dict(
93
- img=view["img"].cpu(),
94
- idx=len(images),
95
- instance=str(len(images)),
96
- true_shape=torch.tensor(view["img"].shape[-2:])[None]
97
- .repeat(batch_size_per_view, 1)
98
- .numpy(),
99
- )
100
- )
101
- view_name = os.path.join(view["label"][0], view["instance"][0])
102
- image_paths.append(view_name)
103
-
104
- # Make image pairs and run inference
105
- # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
106
- pairs = make_pairs(
107
- images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True
108
- )
109
- with torch.enable_grad():
110
- with warnings.catch_warnings():
111
- warnings.simplefilter("ignore", category=FutureWarning)
112
- tempfile.mkdtemp(dir=self.cache_dir)
113
- scene = sparse_global_alignment(
114
- image_paths,
115
- pairs,
116
- self.cache_dir,
117
- self.model,
118
- lr1=self.sparse_ga_lr1,
119
- niter1=self.sparse_ga_niter1,
120
- lr2=self.sparse_ga_lr2,
121
- niter2=self.sparse_ga_niter2,
122
- device=device,
123
- opt_depth="depth" in self.sparse_ga_optim_level,
124
- shared_intrinsics=self.sparse_ga_shared_intrinsics,
125
- matching_conf_thr=self.sparse_ga_matching_conf_thr,
126
- verbose=False,
127
- )
128
-
129
- # Make sure scene is not None
130
- if scene is None:
131
- raise RuntimeError("Global optimization failed.")
132
-
133
- # Get the predictions
134
- intrinsics = scene.intrinsics
135
- c2w_poses = scene.get_im_poses()
136
- _, depths, _ = scene.get_dense_pts3d()
137
-
138
- # Convert the output to the MapAnything format
139
- with torch.autocast("cuda", enabled=False):
140
- res = []
141
- for view_idx in range(num_views):
142
- # Get the current view predictions
143
- curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0)
144
- curr_view_pose = c2w_poses[view_idx].unsqueeze(0)
145
- curr_view_depth_z = (
146
- depths[view_idx].reshape((height, width)).unsqueeze(0)
147
- )
148
-
149
- # Convert the pose to quaternions and translation
150
- curr_view_cam_translations = curr_view_pose[..., :3, 3]
151
- curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3])
152
-
153
- # Get the camera frame pointmaps
154
- curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
155
- curr_view_depth_z, curr_view_intrinsic
156
- )
157
-
158
- # Convert the z depth to depth along ray
159
- curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
160
- curr_view_depth_z, curr_view_intrinsic
161
- )
162
- curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
163
-
164
- # Get the ray directions on the unit sphere in the camera frame
165
- _, curr_view_ray_dirs = get_rays_in_camera_frame(
166
- curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
167
- )
168
-
169
- # Get the pointmaps
170
- curr_view_pts3d = (
171
- convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
172
- curr_view_ray_dirs,
173
- curr_view_depth_along_ray,
174
- curr_view_cam_translations,
175
- curr_view_cam_quats,
176
- )
177
- )
178
-
179
- # Append the outputs to the result list
180
- res.append(
181
- {
182
- "pts3d": curr_view_pts3d,
183
- "pts3d_cam": curr_view_pts3d_cam,
184
- "ray_directions": curr_view_ray_dirs,
185
- "depth_along_ray": curr_view_depth_along_ray,
186
- "cam_trans": curr_view_cam_translations,
187
- "cam_quats": curr_view_cam_quats,
188
- }
189
- )
190
-
191
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/moge/__init__.py DELETED
@@ -1,114 +0,0 @@
1
- """
2
- Inference wrapper for MoGe
3
- """
4
-
5
- import torch
6
-
7
- from mapanything.models.external.moge.models.v1 import MoGeModel as MoGeModelV1
8
- from mapanything.models.external.moge.models.v2 import MoGeModel as MoGeModelV2
9
-
10
-
11
- class MoGeWrapper(torch.nn.Module):
12
- def __init__(
13
- self,
14
- name,
15
- model_string="Ruicheng/moge-2-vitl",
16
- torch_hub_force_reload=False,
17
- load_custom_ckpt=False,
18
- custom_ckpt_path=None,
19
- ):
20
- super().__init__()
21
- self.name = name
22
- self.model_string = model_string
23
- self.torch_hub_force_reload = torch_hub_force_reload
24
- self.load_custom_ckpt = load_custom_ckpt
25
- self.custom_ckpt_path = custom_ckpt_path
26
-
27
- # Mapping of MoGe model version to checkpoint strings
28
- self.moge_model_map = {
29
- "v1": ["Ruicheng/moge-vitl"],
30
- "v2": [
31
- "Ruicheng/moge-2-vits-normal",
32
- "Ruicheng/moge-2-vitb-normal",
33
- "Ruicheng/moge-2-vitl-normal",
34
- "Ruicheng/moge-2-vitl",
35
- ],
36
- }
37
-
38
- # Initialize the model
39
- if self.model_string in self.moge_model_map["v1"]:
40
- self.model = MoGeModelV1.from_pretrained(self.model_string)
41
- elif self.model_string in self.moge_model_map["v2"]:
42
- self.model = MoGeModelV2.from_pretrained(self.model_string)
43
- else:
44
- raise ValueError(
45
- f"Invalid model string: {self.model_string}. Valid strings are: {self.moge_model_map}"
46
- )
47
-
48
- # Load custom checkpoint if requested
49
- if self.load_custom_ckpt:
50
- print(f"Loading checkpoint from {self.custom_ckpt_path} ...")
51
- assert self.custom_ckpt_path is not None, (
52
- "custom_ckpt_path must be provided if load_custom_ckpt is set to True"
53
- )
54
- custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False)
55
- print(self.model.load_state_dict(custom_ckpt, strict=True))
56
- del custom_ckpt # in case it occupies memory
57
-
58
- def forward(self, views):
59
- """
60
- Forward pass wrapper for MoGe-2.
61
- The predicted MoGe-2 mask is not applied to the outputs.
62
- The number of tokens for inference is determined by the image shape.
63
-
64
- Assumption:
65
- - The number of input views is 1.
66
-
67
- Args:
68
- views (List[dict]): List of dictionaries containing the input views' images and instance information.
69
- Length of the list should be 1.
70
- Each dictionary should contain the following keys:
71
- "img" (tensor): Image tensor of shape (B, C, H, W).
72
- "data_norm_type" (list): ["identity"]
73
-
74
- Returns:
75
- List[dict]: A list containing the final outputs for the single view. Length of the list will be 1.
76
- """
77
- # Check that the number of input views is 1
78
- assert len(views) == 1, "MoGe only supports 1 input view."
79
-
80
- # Get input shape of the images, number of tokens for inference, and batch size per view
81
- _, _, height, width = views[0]["img"].shape
82
- num_tokens = int(height // 14) * int(width // 14)
83
-
84
- # Check the data norm type
85
- # MoGe expects a normalized image but without the DINOv2 mean and std applied ("identity")
86
- data_norm_type = views[0]["data_norm_type"][0]
87
- assert data_norm_type == "identity", (
88
- "MoGe expects a normalized image but without the DINOv2 mean and std applied"
89
- )
90
-
91
- # Run MoGe inference
92
- # Output dict contains: "points", "depth", "mask", "intrinsics", "normal" (based on model config)
93
- model_outputs = self.model.infer(
94
- image=views[0]["img"], num_tokens=num_tokens, apply_mask=False
95
- )
96
-
97
- # Get the ray directions and depth along ray
98
- with torch.autocast("cuda", enabled=False):
99
- depth_along_ray = torch.norm(model_outputs["points"], dim=-1, keepdim=True)
100
- ray_directions = model_outputs["points"] / depth_along_ray
101
-
102
- # Convert the output to MapAnything format
103
- result_dict = {
104
- "pts3d": model_outputs["points"],
105
- "pts3d_cam": model_outputs["points"],
106
- "depth_z": model_outputs["depth"].unsqueeze(-1),
107
- "intrinsics": model_outputs["intrinsics"],
108
- "non_ambiguous_mask": model_outputs["mask"],
109
- "ray_directions": ray_directions,
110
- "depth_along_ray": depth_along_ray,
111
- }
112
- res = [result_dict]
113
-
114
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/moge/models/modules.py DELETED
@@ -1,467 +0,0 @@
1
- import functools
2
- import importlib
3
- import itertools
4
- from typing import List, Literal, Optional, Sequence, Tuple, Union
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
- from mapanything.models.external.dinov2.models.vision_transformer import (
11
- DinoVisionTransformer,
12
- )
13
- from mapanything.models.external.moge.models.utils import (
14
- wrap_dinov2_attention_with_sdpa,
15
- wrap_module_with_gradient_checkpointing,
16
- )
17
-
18
-
19
- class ResidualConvBlock(nn.Module):
20
- def __init__(
21
- self,
22
- in_channels: int,
23
- out_channels: int = None,
24
- hidden_channels: int = None,
25
- kernel_size: int = 3,
26
- padding_mode: str = "replicate",
27
- activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu",
28
- in_norm: Literal[
29
- "group_norm", "layer_norm", "instance_norm", "none"
30
- ] = "layer_norm",
31
- hidden_norm: Literal[
32
- "group_norm", "layer_norm", "instance_norm"
33
- ] = "group_norm",
34
- ):
35
- super(ResidualConvBlock, self).__init__()
36
- if out_channels is None:
37
- out_channels = in_channels
38
- if hidden_channels is None:
39
- hidden_channels = in_channels
40
-
41
- if activation == "relu":
42
- activation_cls = nn.ReLU
43
- elif activation == "leaky_relu":
44
- activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
45
- elif activation == "silu":
46
- activation_cls = nn.SiLU
47
- elif activation == "elu":
48
- activation_cls = nn.ELU
49
- else:
50
- raise ValueError(f"Unsupported activation function: {activation}")
51
-
52
- self.layers = nn.Sequential(
53
- nn.GroupNorm(in_channels // 32, in_channels)
54
- if in_norm == "group_norm"
55
- else nn.GroupNorm(1, in_channels)
56
- if in_norm == "layer_norm"
57
- else nn.InstanceNorm2d(in_channels)
58
- if in_norm == "instance_norm"
59
- else nn.Identity(),
60
- activation_cls(),
61
- nn.Conv2d(
62
- in_channels,
63
- hidden_channels,
64
- kernel_size=kernel_size,
65
- padding=kernel_size // 2,
66
- padding_mode=padding_mode,
67
- ),
68
- nn.GroupNorm(hidden_channels // 32, hidden_channels)
69
- if hidden_norm == "group_norm"
70
- else nn.GroupNorm(1, hidden_channels)
71
- if hidden_norm == "layer_norm"
72
- else nn.InstanceNorm2d(hidden_channels)
73
- if hidden_norm == "instance_norm"
74
- else nn.Identity(),
75
- activation_cls(),
76
- nn.Conv2d(
77
- hidden_channels,
78
- out_channels,
79
- kernel_size=kernel_size,
80
- padding=kernel_size // 2,
81
- padding_mode=padding_mode,
82
- ),
83
- )
84
-
85
- self.skip_connection = (
86
- nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
87
- if in_channels != out_channels
88
- else nn.Identity()
89
- )
90
-
91
- def forward(self, x):
92
- skip = self.skip_connection(x)
93
- x = self.layers(x)
94
- x = x + skip
95
- return x
96
-
97
-
98
- class DINOv2Encoder(nn.Module):
99
- "Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]."
100
-
101
- backbone: DinoVisionTransformer
102
- image_mean: torch.Tensor
103
- image_std: torch.Tensor
104
- dim_features: int
105
-
106
- def __init__(
107
- self,
108
- backbone: str,
109
- intermediate_layers: Union[int, List[int]],
110
- dim_out: int,
111
- **deprecated_kwargs,
112
- ):
113
- super(DINOv2Encoder, self).__init__()
114
-
115
- self.intermediate_layers = intermediate_layers
116
-
117
- # Load the backbone
118
- self.hub_loader = getattr(
119
- importlib.import_module(
120
- "mapanything.models.external.dinov2.hub.backbones", __package__
121
- ),
122
- backbone,
123
- )
124
- self.backbone_name = backbone
125
- self.backbone = self.hub_loader(pretrained=False)
126
-
127
- self.dim_features = self.backbone.blocks[0].attn.qkv.in_features
128
- self.num_features = (
129
- intermediate_layers
130
- if isinstance(intermediate_layers, int)
131
- else len(intermediate_layers)
132
- )
133
-
134
- self.output_projections = nn.ModuleList(
135
- [
136
- nn.Conv2d(
137
- in_channels=self.dim_features,
138
- out_channels=dim_out,
139
- kernel_size=1,
140
- stride=1,
141
- padding=0,
142
- )
143
- for _ in range(self.num_features)
144
- ]
145
- )
146
-
147
- self.register_buffer(
148
- "image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
149
- )
150
- self.register_buffer(
151
- "image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
152
- )
153
-
154
- @property
155
- def onnx_compatible_mode(self):
156
- return getattr(self, "_onnx_compatible_mode", False)
157
-
158
- @onnx_compatible_mode.setter
159
- def onnx_compatible_mode(self, value: bool):
160
- self._onnx_compatible_mode = value
161
- self.backbone.onnx_compatible_mode = value
162
-
163
- def init_weights(self):
164
- pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict()
165
- self.backbone.load_state_dict(pretrained_backbone_state_dict)
166
-
167
- def enable_gradient_checkpointing(self):
168
- for i in range(len(self.backbone.blocks)):
169
- wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
170
-
171
- def enable_pytorch_native_sdpa(self):
172
- for i in range(len(self.backbone.blocks)):
173
- wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
174
-
175
- def forward(
176
- self,
177
- image: torch.Tensor,
178
- token_rows: Union[int, torch.LongTensor],
179
- token_cols: Union[int, torch.LongTensor],
180
- return_class_token: bool = False,
181
- ) -> Tuple[torch.Tensor, torch.Tensor]:
182
- image_14 = F.interpolate(
183
- image,
184
- (token_rows * 14, token_cols * 14),
185
- mode="bilinear",
186
- align_corners=False,
187
- antialias=not self.onnx_compatible_mode,
188
- )
189
- image_14 = (image_14 - self.image_mean) / self.image_std
190
-
191
- # Get intermediate layers from the backbone
192
- features = self.backbone.get_intermediate_layers(
193
- image_14, n=self.intermediate_layers, return_class_token=True
194
- )
195
-
196
- # Project features to the desired dimensionality
197
- x = torch.stack(
198
- [
199
- proj(
200
- feat.permute(0, 2, 1)
201
- .unflatten(2, (token_rows, token_cols))
202
- .contiguous()
203
- )
204
- for proj, (feat, clstoken) in zip(self.output_projections, features)
205
- ],
206
- dim=1,
207
- ).sum(dim=1)
208
-
209
- if return_class_token:
210
- return x, features[-1][1]
211
- else:
212
- return x
213
-
214
-
215
- class Resampler(nn.Sequential):
216
- def __init__(
217
- self,
218
- in_channels: int,
219
- out_channels: int,
220
- type_: Literal[
221
- "pixel_shuffle",
222
- "nearest",
223
- "bilinear",
224
- "conv_transpose",
225
- "pixel_unshuffle",
226
- "avg_pool",
227
- "max_pool",
228
- ],
229
- scale_factor: int = 2,
230
- ):
231
- if type_ == "pixel_shuffle":
232
- nn.Sequential.__init__(
233
- self,
234
- nn.Conv2d(
235
- in_channels,
236
- out_channels * (scale_factor**2),
237
- kernel_size=3,
238
- stride=1,
239
- padding=1,
240
- padding_mode="replicate",
241
- ),
242
- nn.PixelShuffle(scale_factor),
243
- nn.Conv2d(
244
- out_channels,
245
- out_channels,
246
- kernel_size=3,
247
- stride=1,
248
- padding=1,
249
- padding_mode="replicate",
250
- ),
251
- )
252
- for i in range(1, scale_factor**2):
253
- self[0].weight.data[i :: scale_factor**2] = self[0].weight.data[
254
- 0 :: scale_factor**2
255
- ]
256
- self[0].bias.data[i :: scale_factor**2] = self[0].bias.data[
257
- 0 :: scale_factor**2
258
- ]
259
- elif type_ in ["nearest", "bilinear"]:
260
- nn.Sequential.__init__(
261
- self,
262
- nn.Upsample(
263
- scale_factor=scale_factor,
264
- mode=type_,
265
- align_corners=False if type_ == "bilinear" else None,
266
- ),
267
- nn.Conv2d(
268
- in_channels,
269
- out_channels,
270
- kernel_size=3,
271
- stride=1,
272
- padding=1,
273
- padding_mode="replicate",
274
- ),
275
- )
276
- elif type_ == "conv_transpose":
277
- nn.Sequential.__init__(
278
- self,
279
- nn.ConvTranspose2d(
280
- in_channels,
281
- out_channels,
282
- kernel_size=scale_factor,
283
- stride=scale_factor,
284
- ),
285
- nn.Conv2d(
286
- out_channels,
287
- out_channels,
288
- kernel_size=3,
289
- stride=1,
290
- padding=1,
291
- padding_mode="replicate",
292
- ),
293
- )
294
- self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
295
- elif type_ == "pixel_unshuffle":
296
- nn.Sequential.__init__(
297
- self,
298
- nn.PixelUnshuffle(scale_factor),
299
- nn.Conv2d(
300
- in_channels * (scale_factor**2),
301
- out_channels,
302
- kernel_size=3,
303
- stride=1,
304
- padding=1,
305
- padding_mode="replicate",
306
- ),
307
- )
308
- elif type_ == "avg_pool":
309
- nn.Sequential.__init__(
310
- self,
311
- nn.Conv2d(
312
- in_channels,
313
- out_channels,
314
- kernel_size=3,
315
- stride=1,
316
- padding=1,
317
- padding_mode="replicate",
318
- ),
319
- nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
320
- )
321
- elif type_ == "max_pool":
322
- nn.Sequential.__init__(
323
- self,
324
- nn.Conv2d(
325
- in_channels,
326
- out_channels,
327
- kernel_size=3,
328
- stride=1,
329
- padding=1,
330
- padding_mode="replicate",
331
- ),
332
- nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
333
- )
334
- else:
335
- raise ValueError(f"Unsupported resampler type: {type_}")
336
-
337
-
338
- class MLP(nn.Sequential):
339
- def __init__(self, dims: Sequence[int]):
340
- nn.Sequential.__init__(
341
- self,
342
- *itertools.chain(
343
- *[
344
- (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True))
345
- for dim_in, dim_out in zip(dims[:-2], dims[1:-1])
346
- ]
347
- ),
348
- nn.Linear(dims[-2], dims[-1]),
349
- )
350
-
351
-
352
- class ConvStack(nn.Module):
353
- def __init__(
354
- self,
355
- dim_in: List[Optional[int]],
356
- dim_res_blocks: List[int],
357
- dim_out: List[Optional[int]],
358
- resamplers: Union[
359
- Literal[
360
- "pixel_shuffle",
361
- "nearest",
362
- "bilinear",
363
- "conv_transpose",
364
- "pixel_unshuffle",
365
- "avg_pool",
366
- "max_pool",
367
- ],
368
- List,
369
- ],
370
- dim_times_res_block_hidden: int = 1,
371
- num_res_blocks: int = 1,
372
- res_block_in_norm: Literal[
373
- "layer_norm", "group_norm", "instance_norm", "none"
374
- ] = "layer_norm",
375
- res_block_hidden_norm: Literal[
376
- "layer_norm", "group_norm", "instance_norm", "none"
377
- ] = "group_norm",
378
- activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu",
379
- ):
380
- super().__init__()
381
- self.input_blocks = nn.ModuleList(
382
- [
383
- nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0)
384
- if dim_in_ is not None
385
- else nn.Identity()
386
- for dim_in_, dim_res_block_ in zip(
387
- dim_in
388
- if isinstance(dim_in, Sequence)
389
- else itertools.repeat(dim_in),
390
- dim_res_blocks,
391
- )
392
- ]
393
- )
394
- self.resamplers = nn.ModuleList(
395
- [
396
- Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
397
- for i, (dim_prev, dim_succ, resampler) in enumerate(
398
- zip(
399
- dim_res_blocks[:-1],
400
- dim_res_blocks[1:],
401
- resamplers
402
- if isinstance(resamplers, Sequence)
403
- else itertools.repeat(resamplers),
404
- )
405
- )
406
- ]
407
- )
408
- self.res_blocks = nn.ModuleList(
409
- [
410
- nn.Sequential(
411
- *(
412
- ResidualConvBlock(
413
- dim_res_block_,
414
- dim_res_block_,
415
- dim_times_res_block_hidden * dim_res_block_,
416
- activation=activation,
417
- in_norm=res_block_in_norm,
418
- hidden_norm=res_block_hidden_norm,
419
- )
420
- for _ in range(
421
- num_res_blocks[i]
422
- if isinstance(num_res_blocks, list)
423
- else num_res_blocks
424
- )
425
- )
426
- )
427
- for i, dim_res_block_ in enumerate(dim_res_blocks)
428
- ]
429
- )
430
- self.output_blocks = nn.ModuleList(
431
- [
432
- nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0)
433
- if dim_out_ is not None
434
- else nn.Identity()
435
- for dim_out_, dim_res_block_ in zip(
436
- dim_out
437
- if isinstance(dim_out, Sequence)
438
- else itertools.repeat(dim_out),
439
- dim_res_blocks,
440
- )
441
- ]
442
- )
443
-
444
- def enable_gradient_checkpointing(self):
445
- for i in range(len(self.resamplers)):
446
- self.resamplers[i] = wrap_module_with_gradient_checkpointing(
447
- self.resamplers[i]
448
- )
449
- for i in range(len(self.res_blocks)):
450
- for j in range(len(self.res_blocks[i])):
451
- self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(
452
- self.res_blocks[i][j]
453
- )
454
-
455
- def forward(self, in_features: List[torch.Tensor]):
456
- out_features = []
457
- for i in range(len(self.res_blocks)):
458
- feature = self.input_blocks[i](in_features[i])
459
- if i == 0:
460
- x = feature
461
- elif feature is not None:
462
- x = x + feature
463
- x = self.res_blocks[i](x)
464
- out_features.append(self.output_blocks[i](x))
465
- if i < len(self.res_blocks) - 1:
466
- x = self.resamplers[i](x)
467
- return out_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/moge/models/utils.py DELETED
@@ -1,477 +0,0 @@
1
- import inspect
2
- from functools import partial, wraps
3
- from numbers import Number
4
- from typing import Tuple, Union
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
-
12
- def wrap_module_with_gradient_checkpointing(module: nn.Module):
13
- from torch.utils.checkpoint import checkpoint
14
-
15
- class _CheckpointingWrapper(module.__class__):
16
- _restore_cls = module.__class__
17
-
18
- def forward(self, *args, **kwargs):
19
- return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
20
-
21
- module.__class__ = _CheckpointingWrapper
22
- return module
23
-
24
-
25
- def unwrap_module_with_gradient_checkpointing(module: nn.Module):
26
- module.__class__ = module.__class__._restore_cls
27
-
28
-
29
- def wrap_dinov2_attention_with_sdpa(module: nn.Module):
30
- assert torch.__version__ >= "2.0", "SDPA requires PyTorch 2.0 or later"
31
-
32
- class _AttentionWrapper(module.__class__):
33
- def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
34
- B, N, C = x.shape
35
- qkv = (
36
- self.qkv(x)
37
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
38
- .permute(2, 0, 3, 1, 4)
39
- ) # (3, B, H, N, C // H)
40
-
41
- q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
42
-
43
- x = F.scaled_dot_product_attention(q, k, v, attn_bias)
44
- x = x.permute(0, 2, 1, 3).reshape(B, N, C)
45
-
46
- x = self.proj(x)
47
- x = self.proj_drop(x)
48
- return x
49
-
50
- module.__class__ = _AttentionWrapper
51
- return module
52
-
53
-
54
- def sync_ddp_hook(
55
- state, bucket: torch.distributed.GradBucket
56
- ) -> torch.futures.Future[torch.Tensor]:
57
- group_to_use = torch.distributed.group.WORLD
58
- world_size = group_to_use.size()
59
- grad = bucket.buffer()
60
- grad.div_(world_size)
61
- torch.distributed.all_reduce(grad, group=group_to_use)
62
- fut = torch.futures.Future()
63
- fut.set_result(grad)
64
- return fut
65
-
66
-
67
- def normalized_view_plane_uv(
68
- width: int,
69
- height: int,
70
- aspect_ratio: float = None,
71
- dtype: torch.dtype = None,
72
- device: torch.device = None,
73
- ) -> torch.Tensor:
74
- "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
75
- if aspect_ratio is None:
76
- aspect_ratio = width / height
77
-
78
- span_x = aspect_ratio / (1 + aspect_ratio**2) ** 0.5
79
- span_y = 1 / (1 + aspect_ratio**2) ** 0.5
80
-
81
- u = torch.linspace(
82
- -span_x * (width - 1) / width,
83
- span_x * (width - 1) / width,
84
- width,
85
- dtype=dtype,
86
- device=device,
87
- )
88
- v = torch.linspace(
89
- -span_y * (height - 1) / height,
90
- span_y * (height - 1) / height,
91
- height,
92
- dtype=dtype,
93
- device=device,
94
- )
95
- u, v = torch.meshgrid(u, v, indexing="xy")
96
- uv = torch.stack([u, v], dim=-1)
97
- return uv
98
-
99
-
100
- def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
101
- "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
102
- from scipy.optimize import least_squares
103
-
104
- uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
105
-
106
- def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
107
- xy_proj = xy / (z + shift)[:, None]
108
- f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
109
- err = (f * xy_proj - uv).ravel()
110
- return err
111
-
112
- solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method="lm")
113
- optim_shift = solution["x"].squeeze().astype(np.float32)
114
-
115
- xy_proj = xy / (z + optim_shift)[:, None]
116
- optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
117
-
118
- return optim_shift, optim_focal
119
-
120
-
121
- def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
122
- "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
123
- from scipy.optimize import least_squares
124
-
125
- uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
126
-
127
- def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
128
- xy_proj = xy / (z + shift)[:, None]
129
- err = (focal * xy_proj - uv).ravel()
130
- return err
131
-
132
- solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method="lm")
133
- optim_shift = solution["x"].squeeze().astype(np.float32)
134
-
135
- return optim_shift
136
-
137
-
138
- def recover_focal_shift(
139
- points: torch.Tensor,
140
- mask: torch.Tensor = None,
141
- focal: torch.Tensor = None,
142
- downsample_size: Tuple[int, int] = (64, 64),
143
- ):
144
- """
145
- Recover the depth map and FoV from a point map with unknown z shift and focal.
146
-
147
- Note that it assumes:
148
- - the optical center is at the center of the map
149
- - the map is undistorted
150
- - the map is isometric in the x and y directions
151
-
152
- ### Parameters:
153
- - `points: torch.Tensor` of shape (..., H, W, 3)
154
- - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
155
-
156
- ### Returns:
157
- - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
158
- - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
159
- """
160
- shape = points.shape
161
- height, width = points.shape[-3], points.shape[-2]
162
-
163
- points = points.reshape(-1, *shape[-3:])
164
- mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
165
- focal = focal.reshape(-1) if focal is not None else None
166
- uv = normalized_view_plane_uv(
167
- width, height, dtype=points.dtype, device=points.device
168
- ) # (H, W, 2)
169
-
170
- points_lr = F.interpolate(
171
- points.permute(0, 3, 1, 2), downsample_size, mode="nearest"
172
- ).permute(0, 2, 3, 1)
173
- uv_lr = (
174
- F.interpolate(
175
- uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode="nearest"
176
- )
177
- .squeeze(0)
178
- .permute(1, 2, 0)
179
- )
180
- mask_lr = (
181
- None
182
- if mask is None
183
- else F.interpolate(
184
- mask.to(torch.float32).unsqueeze(1), downsample_size, mode="nearest"
185
- ).squeeze(1)
186
- > 0
187
- )
188
-
189
- uv_lr_np = uv_lr.cpu().numpy()
190
- points_lr_np = points_lr.detach().cpu().numpy()
191
- focal_np = focal.cpu().numpy() if focal is not None else None
192
- mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
193
- optim_shift, optim_focal = [], []
194
- for i in range(points.shape[0]):
195
- points_lr_i_np = (
196
- points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
197
- )
198
- uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
199
- if uv_lr_i_np.shape[0] < 2:
200
- optim_focal.append(1)
201
- optim_shift.append(0)
202
- continue
203
- if focal is None:
204
- optim_shift_i, optim_focal_i = solve_optimal_focal_shift(
205
- uv_lr_i_np, points_lr_i_np
206
- )
207
- optim_focal.append(float(optim_focal_i))
208
- else:
209
- optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
210
- optim_shift.append(float(optim_shift_i))
211
- optim_shift = torch.tensor(
212
- optim_shift, device=points.device, dtype=points.dtype
213
- ).reshape(shape[:-3])
214
-
215
- if focal is None:
216
- optim_focal = torch.tensor(
217
- optim_focal, device=points.device, dtype=points.dtype
218
- ).reshape(shape[:-3])
219
- else:
220
- optim_focal = focal.reshape(shape[:-3])
221
-
222
- return optim_focal, optim_shift
223
-
224
-
225
- def suppress_traceback(fn):
226
- @wraps(fn)
227
- def wrapper(*args, **kwargs):
228
- try:
229
- return fn(*args, **kwargs)
230
- except Exception as e:
231
- e.__traceback__ = e.__traceback__.tb_next.tb_next
232
- raise
233
-
234
- return wrapper
235
-
236
-
237
- def get_device(args, kwargs):
238
- device = None
239
- for arg in list(args) + list(kwargs.values()):
240
- if isinstance(arg, torch.Tensor):
241
- if device is None:
242
- device = arg.device
243
- elif device != arg.device:
244
- raise ValueError("All tensors must be on the same device.")
245
- return device
246
-
247
-
248
- def get_args_order(func, args, kwargs):
249
- """
250
- Get the order of the arguments of a function.
251
- """
252
- names = inspect.getfullargspec(func).args
253
- names_idx = {name: i for i, name in enumerate(names)}
254
- args_order = []
255
- kwargs_order = {}
256
- for name, arg in kwargs.items():
257
- if name in names:
258
- kwargs_order[name] = names_idx[name]
259
- names.remove(name)
260
- for i, arg in enumerate(args):
261
- if i < len(names):
262
- args_order.append(names_idx[names[i]])
263
- return args_order, kwargs_order
264
-
265
-
266
- def broadcast_args(args, kwargs, args_dim, kwargs_dim):
267
- spatial = []
268
- for arg, arg_dim in zip(
269
- args + list(kwargs.values()), args_dim + list(kwargs_dim.values())
270
- ):
271
- if isinstance(arg, torch.Tensor) and arg_dim is not None:
272
- arg_spatial = arg.shape[: arg.ndim - arg_dim]
273
- if len(arg_spatial) > len(spatial):
274
- spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial
275
- for j in range(len(arg_spatial)):
276
- if spatial[-j] < arg_spatial[-j]:
277
- if spatial[-j] == 1:
278
- spatial[-j] = arg_spatial[-j]
279
- else:
280
- raise ValueError("Cannot broadcast arguments.")
281
- for i, arg in enumerate(args):
282
- if isinstance(arg, torch.Tensor) and args_dim[i] is not None:
283
- args[i] = torch.broadcast_to(
284
- arg, [*spatial, *arg.shape[arg.ndim - args_dim[i] :]]
285
- )
286
- for key, arg in kwargs.items():
287
- if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None:
288
- kwargs[key] = torch.broadcast_to(
289
- arg, [*spatial, *arg.shape[arg.ndim - kwargs_dim[key] :]]
290
- )
291
- return args, kwargs, spatial
292
-
293
-
294
- @suppress_traceback
295
- def batched(*dims):
296
- """
297
- Decorator that allows a function to be called with batched arguments.
298
- """
299
-
300
- def decorator(func):
301
- @wraps(func)
302
- def wrapper(*args, device=torch.device("cpu"), **kwargs):
303
- args = list(args)
304
- # get arguments dimensions
305
- args_order, kwargs_order = get_args_order(func, args, kwargs)
306
- args_dim = [dims[i] for i in args_order]
307
- kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()}
308
- # convert to torch tensor
309
- device = get_device(args, kwargs) or device
310
- for i, arg in enumerate(args):
311
- if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None:
312
- args[i] = torch.tensor(arg, device=device)
313
- for key, arg in kwargs.items():
314
- if (
315
- isinstance(arg, (Number, list, tuple))
316
- and kwargs_dim[key] is not None
317
- ):
318
- kwargs[key] = torch.tensor(arg, device=device)
319
- # broadcast arguments
320
- args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim)
321
- for i, (arg, arg_dim) in enumerate(zip(args, args_dim)):
322
- if isinstance(arg, torch.Tensor) and arg_dim is not None:
323
- args[i] = arg.reshape([-1, *arg.shape[arg.ndim - arg_dim :]])
324
- for key, arg in kwargs.items():
325
- if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None:
326
- kwargs[key] = arg.reshape(
327
- [-1, *arg.shape[arg.ndim - kwargs_dim[key] :]]
328
- )
329
- # call function
330
- results = func(*args, **kwargs)
331
- type_results = type(results)
332
- results = list(results) if isinstance(results, (tuple, list)) else [results]
333
- # restore spatial dimensions
334
- for i, result in enumerate(results):
335
- results[i] = result.reshape([*spatial, *result.shape[1:]])
336
- if type_results is tuple:
337
- results = tuple(results)
338
- elif type_results is list:
339
- results = list(results)
340
- else:
341
- results = results[0]
342
- return results
343
-
344
- return wrapper
345
-
346
- return decorator
347
-
348
-
349
- def image_uv(
350
- height: int,
351
- width: int,
352
- left: int = None,
353
- top: int = None,
354
- right: int = None,
355
- bottom: int = None,
356
- device: torch.device = None,
357
- dtype: torch.dtype = None,
358
- ) -> torch.Tensor:
359
- """
360
- Get image space UV grid, ranging in [0, 1].
361
-
362
- >>> image_uv(10, 10):
363
- [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]],
364
- [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]],
365
- ... ... ...
366
- [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]]
367
-
368
- Args:
369
- width (int): image width
370
- height (int): image height
371
-
372
- Returns:
373
- torch.Tensor: shape (height, width, 2)
374
- """
375
- if left is None:
376
- left = 0
377
- if top is None:
378
- top = 0
379
- if right is None:
380
- right = width
381
- if bottom is None:
382
- bottom = height
383
- u = torch.linspace(
384
- (left + 0.5) / width,
385
- (right - 0.5) / width,
386
- right - left,
387
- device=device,
388
- dtype=dtype,
389
- )
390
- v = torch.linspace(
391
- (top + 0.5) / height,
392
- (bottom - 0.5) / height,
393
- bottom - top,
394
- device=device,
395
- dtype=dtype,
396
- )
397
- u, v = torch.meshgrid(u, v, indexing="xy")
398
- uv = torch.stack([u, v], dim=-1)
399
-
400
- return uv
401
-
402
-
403
- @batched(2, 1, 2, 2)
404
- def unproject_cv(
405
- uv_coord: torch.Tensor,
406
- depth: torch.Tensor = None,
407
- extrinsics: torch.Tensor = None,
408
- intrinsics: torch.Tensor = None,
409
- ) -> torch.Tensor:
410
- """
411
- Unproject uv coordinates to 3D view space following the OpenCV convention
412
-
413
- Args:
414
- uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
415
- The origin (0., 0.) is corresponding to the left & top
416
- depth (torch.Tensor): [..., N] depth value
417
- extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
418
- intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
419
-
420
- Returns:
421
- points (torch.Tensor): [..., N, 3] 3d points
422
- """
423
- assert intrinsics is not None, "intrinsics matrix is required"
424
- points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1)
425
- points = points @ torch.inverse(intrinsics).transpose(-2, -1)
426
- if depth is not None:
427
- points = points * depth[..., None]
428
- if extrinsics is not None:
429
- points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
430
- points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3]
431
- return points
432
-
433
-
434
- def depth_to_points(
435
- depth: torch.Tensor, intrinsics: torch.Tensor, extrinsics: torch.Tensor = None
436
- ):
437
- height, width = depth.shape[-2:]
438
- uv = image_uv(width=width, height=height, dtype=depth.dtype, device=depth.device)
439
- pts = unproject_cv(
440
- uv,
441
- depth,
442
- intrinsics=intrinsics[..., None, :, :],
443
- extrinsics=extrinsics[..., None, :, :] if extrinsics is not None else None,
444
- )
445
-
446
- return pts
447
-
448
-
449
- @batched(0, 0, 0, 0, 0, 0)
450
- def intrinsics_from_focal_center(
451
- fx: Union[float, torch.Tensor],
452
- fy: Union[float, torch.Tensor],
453
- cx: Union[float, torch.Tensor],
454
- cy: Union[float, torch.Tensor],
455
- ) -> torch.Tensor:
456
- """
457
- Get OpenCV intrinsics matrix
458
-
459
- Args:
460
- focal_x (float | torch.Tensor): focal length in x axis
461
- focal_y (float | torch.Tensor): focal length in y axis
462
- cx (float | torch.Tensor): principal point in x axis
463
- cy (float | torch.Tensor): principal point in y axis
464
-
465
- Returns:
466
- (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
467
- """
468
- N = fx.shape[0]
469
- ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device)
470
- zeros, ones = (
471
- torch.zeros(N, dtype=fx.dtype, device=fx.device),
472
- torch.ones(N, dtype=fx.dtype, device=fx.device),
473
- )
474
- ret = torch.stack(
475
- [fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1
476
- ).unflatten(-1, (3, 3))
477
- return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/moge/models/v1.py DELETED
@@ -1,595 +0,0 @@
1
- import importlib
2
- from numbers import Number
3
- from pathlib import Path
4
- from typing import Any, Dict, IO, List, Literal, Optional, Tuple, Union
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- import torch.utils
10
- import torch.utils.checkpoint
11
- import torch.version
12
- from huggingface_hub import hf_hub_download
13
-
14
- from mapanything.models.external.moge.models.utils import (
15
- depth_to_points,
16
- intrinsics_from_focal_center,
17
- normalized_view_plane_uv,
18
- recover_focal_shift,
19
- wrap_module_with_gradient_checkpointing,
20
- )
21
-
22
-
23
- class ResidualConvBlock(nn.Module):
24
- def __init__(
25
- self,
26
- in_channels: int,
27
- out_channels: int = None,
28
- hidden_channels: int = None,
29
- padding_mode: str = "replicate",
30
- activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu",
31
- norm: Literal["group_norm", "layer_norm"] = "group_norm",
32
- ):
33
- super(ResidualConvBlock, self).__init__()
34
- if out_channels is None:
35
- out_channels = in_channels
36
- if hidden_channels is None:
37
- hidden_channels = in_channels
38
-
39
- if activation == "relu":
40
- activation_cls = lambda: nn.ReLU(inplace=True) # noqa
41
- elif activation == "leaky_relu":
42
- activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) # noqa
43
- elif activation == "silu":
44
- activation_cls = lambda: nn.SiLU(inplace=True) # noqa
45
- elif activation == "elu":
46
- activation_cls = lambda: nn.ELU(inplace=True) # noqa
47
- else:
48
- raise ValueError(f"Unsupported activation function: {activation}")
49
-
50
- self.layers = nn.Sequential(
51
- nn.GroupNorm(1, in_channels),
52
- activation_cls(),
53
- nn.Conv2d(
54
- in_channels,
55
- hidden_channels,
56
- kernel_size=3,
57
- padding=1,
58
- padding_mode=padding_mode,
59
- ),
60
- nn.GroupNorm(
61
- hidden_channels // 32 if norm == "group_norm" else 1, hidden_channels
62
- ),
63
- activation_cls(),
64
- nn.Conv2d(
65
- hidden_channels,
66
- out_channels,
67
- kernel_size=3,
68
- padding=1,
69
- padding_mode=padding_mode,
70
- ),
71
- )
72
-
73
- self.skip_connection = (
74
- nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
75
- if in_channels != out_channels
76
- else nn.Identity()
77
- )
78
-
79
- def forward(self, x):
80
- skip = self.skip_connection(x)
81
- x = self.layers(x)
82
- x = x + skip
83
- return x
84
-
85
-
86
- class Head(nn.Module):
87
- def __init__(
88
- self,
89
- num_features: int,
90
- dim_in: int,
91
- dim_out: List[int],
92
- dim_proj: int = 512,
93
- dim_upsample: List[int] = [256, 128, 128],
94
- dim_times_res_block_hidden: int = 1,
95
- num_res_blocks: int = 1,
96
- res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm",
97
- last_res_blocks: int = 0,
98
- last_conv_channels: int = 32,
99
- last_conv_size: int = 1,
100
- ):
101
- super().__init__()
102
-
103
- self.projects = nn.ModuleList(
104
- [
105
- nn.Conv2d(
106
- in_channels=dim_in,
107
- out_channels=dim_proj,
108
- kernel_size=1,
109
- stride=1,
110
- padding=0,
111
- )
112
- for _ in range(num_features)
113
- ]
114
- )
115
-
116
- self.upsample_blocks = nn.ModuleList(
117
- [
118
- nn.Sequential(
119
- self._make_upsampler(in_ch + 2, out_ch),
120
- *(
121
- ResidualConvBlock(
122
- out_ch,
123
- out_ch,
124
- dim_times_res_block_hidden * out_ch,
125
- activation="relu",
126
- norm=res_block_norm,
127
- )
128
- for _ in range(num_res_blocks)
129
- ),
130
- )
131
- for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
132
- ]
133
- )
134
-
135
- self.output_block = nn.ModuleList(
136
- [
137
- self._make_output_block(
138
- dim_upsample[-1] + 2,
139
- dim_out_,
140
- dim_times_res_block_hidden,
141
- last_res_blocks,
142
- last_conv_channels,
143
- last_conv_size,
144
- res_block_norm,
145
- )
146
- for dim_out_ in dim_out
147
- ]
148
- )
149
-
150
- def _make_upsampler(self, in_channels: int, out_channels: int):
151
- upsampler = nn.Sequential(
152
- nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
153
- nn.Conv2d(
154
- out_channels,
155
- out_channels,
156
- kernel_size=3,
157
- stride=1,
158
- padding=1,
159
- padding_mode="replicate",
160
- ),
161
- )
162
- upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
163
- return upsampler
164
-
165
- def _make_output_block(
166
- self,
167
- dim_in: int,
168
- dim_out: int,
169
- dim_times_res_block_hidden: int,
170
- last_res_blocks: int,
171
- last_conv_channels: int,
172
- last_conv_size: int,
173
- res_block_norm: Literal["group_norm", "layer_norm"],
174
- ):
175
- return nn.Sequential(
176
- nn.Conv2d(
177
- dim_in,
178
- last_conv_channels,
179
- kernel_size=3,
180
- stride=1,
181
- padding=1,
182
- padding_mode="replicate",
183
- ),
184
- *(
185
- ResidualConvBlock(
186
- last_conv_channels,
187
- last_conv_channels,
188
- dim_times_res_block_hidden * last_conv_channels,
189
- activation="relu",
190
- norm=res_block_norm,
191
- )
192
- for _ in range(last_res_blocks)
193
- ),
194
- nn.ReLU(inplace=True),
195
- nn.Conv2d(
196
- last_conv_channels,
197
- dim_out,
198
- kernel_size=last_conv_size,
199
- stride=1,
200
- padding=last_conv_size // 2,
201
- padding_mode="replicate",
202
- ),
203
- )
204
-
205
- def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
206
- img_h, img_w = image.shape[-2:]
207
- patch_h, patch_w = img_h // 14, img_w // 14
208
-
209
- # Process the hidden states
210
- x = torch.stack(
211
- [
212
- proj(
213
- feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()
214
- )
215
- for proj, (feat, clstoken) in zip(self.projects, hidden_states)
216
- ],
217
- dim=1,
218
- ).sum(dim=1)
219
-
220
- # Upsample stage
221
- # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
222
- for i, block in enumerate(self.upsample_blocks):
223
- # UV coordinates is for awareness of image aspect ratio
224
- uv = normalized_view_plane_uv(
225
- width=x.shape[-1],
226
- height=x.shape[-2],
227
- aspect_ratio=img_w / img_h,
228
- dtype=x.dtype,
229
- device=x.device,
230
- )
231
- uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
232
- x = torch.cat([x, uv], dim=1)
233
- for layer in block:
234
- x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
235
-
236
- # (patch_h * 8, patch_w * 8) -> (img_h, img_w)
237
- x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
238
- uv = normalized_view_plane_uv(
239
- width=x.shape[-1],
240
- height=x.shape[-2],
241
- aspect_ratio=img_w / img_h,
242
- dtype=x.dtype,
243
- device=x.device,
244
- )
245
- uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
246
- x = torch.cat([x, uv], dim=1)
247
-
248
- if isinstance(self.output_block, nn.ModuleList):
249
- output = [
250
- torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
251
- for block in self.output_block
252
- ]
253
- else:
254
- output = torch.utils.checkpoint.checkpoint(
255
- self.output_block, x, use_reentrant=False
256
- )
257
-
258
- return output
259
-
260
-
261
- class MoGeModel(nn.Module):
262
- image_mean: torch.Tensor
263
- image_std: torch.Tensor
264
-
265
- def __init__(
266
- self,
267
- encoder: str = "dinov2_vitb14",
268
- intermediate_layers: Union[int, List[int]] = 4,
269
- dim_proj: int = 512,
270
- dim_upsample: List[int] = [256, 128, 128],
271
- dim_times_res_block_hidden: int = 1,
272
- num_res_blocks: int = 1,
273
- remap_output: Literal[
274
- False, True, "linear", "sinh", "exp", "sinh_exp"
275
- ] = "linear",
276
- res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm",
277
- num_tokens_range: Tuple[Number, Number] = [1200, 2500],
278
- last_res_blocks: int = 0,
279
- last_conv_channels: int = 32,
280
- last_conv_size: int = 1,
281
- mask_threshold: float = 0.5,
282
- **deprecated_kwargs,
283
- ):
284
- super(MoGeModel, self).__init__()
285
-
286
- if deprecated_kwargs:
287
- # Process legacy arguments
288
- if "trained_area_range" in deprecated_kwargs:
289
- num_tokens_range = [
290
- deprecated_kwargs["trained_area_range"][0] // 14**2,
291
- deprecated_kwargs["trained_area_range"][1] // 14**2,
292
- ]
293
- del deprecated_kwargs["trained_area_range"]
294
- # warnings.warn(
295
- # f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}"
296
- # )
297
-
298
- self.encoder = encoder
299
- self.remap_output = remap_output
300
- self.intermediate_layers = intermediate_layers
301
- self.num_tokens_range = num_tokens_range
302
- self.mask_threshold = mask_threshold
303
-
304
- # NOTE: We have copied the DINOv2 code in torchhub to this repository.
305
- # Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues.
306
- hub_loader = getattr(
307
- importlib.import_module(
308
- "mapanything.models.external.dinov2.hub.backbones", __package__
309
- ),
310
- encoder,
311
- )
312
- self.backbone = hub_loader(pretrained=False)
313
- dim_feature = self.backbone.blocks[0].attn.qkv.in_features
314
-
315
- self.head = Head(
316
- num_features=intermediate_layers
317
- if isinstance(intermediate_layers, int)
318
- else len(intermediate_layers),
319
- dim_in=dim_feature,
320
- dim_out=[3, 1],
321
- dim_proj=dim_proj,
322
- dim_upsample=dim_upsample,
323
- dim_times_res_block_hidden=dim_times_res_block_hidden,
324
- num_res_blocks=num_res_blocks,
325
- res_block_norm=res_block_norm,
326
- last_res_blocks=last_res_blocks,
327
- last_conv_channels=last_conv_channels,
328
- last_conv_size=last_conv_size,
329
- )
330
-
331
- image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
332
- image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
333
-
334
- self.register_buffer("image_mean", image_mean)
335
- self.register_buffer("image_std", image_std)
336
-
337
- @property
338
- def device(self) -> torch.device:
339
- return next(self.parameters()).device
340
-
341
- @property
342
- def dtype(self) -> torch.dtype:
343
- return next(self.parameters()).dtype
344
-
345
- @classmethod
346
- def from_pretrained(
347
- cls,
348
- pretrained_model_name_or_path: Union[str, Path, IO[bytes]],
349
- model_kwargs: Optional[Dict[str, Any]] = None,
350
- **hf_kwargs,
351
- ) -> "MoGeModel":
352
- """
353
- Load a model from a checkpoint file.
354
-
355
- ### Parameters:
356
- - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
357
- - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
358
- - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
359
-
360
- ### Returns:
361
- - A new instance of `MoGe` with the parameters loaded from the checkpoint.
362
- """
363
- if Path(pretrained_model_name_or_path).exists():
364
- checkpoint = torch.load(
365
- pretrained_model_name_or_path, map_location="cpu", weights_only=True
366
- )
367
- else:
368
- cached_checkpoint_path = hf_hub_download(
369
- repo_id=pretrained_model_name_or_path,
370
- repo_type="model",
371
- filename="model.pt",
372
- **hf_kwargs,
373
- )
374
- checkpoint = torch.load(
375
- cached_checkpoint_path, map_location="cpu", weights_only=True
376
- )
377
- model_config = checkpoint["model_config"]
378
- if model_kwargs is not None:
379
- model_config.update(model_kwargs)
380
- model = cls(**model_config)
381
- model.load_state_dict(checkpoint["model"])
382
- return model
383
-
384
- def init_weights(self):
385
- "Load the backbone with pretrained dinov2 weights from torch hub"
386
- state_dict = torch.hub.load(
387
- "facebookresearch/dinov2", self.encoder, pretrained=True
388
- ).state_dict()
389
- self.backbone.load_state_dict(state_dict)
390
-
391
- def enable_gradient_checkpointing(self):
392
- for i in range(len(self.backbone.blocks)):
393
- self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(
394
- self.backbone.blocks[i]
395
- )
396
-
397
- def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
398
- if self.remap_output == "linear":
399
- pass
400
- elif self.remap_output == "sinh":
401
- points = torch.sinh(points)
402
- elif self.remap_output == "exp":
403
- xy, z = points.split([2, 1], dim=-1)
404
- z = torch.exp(z)
405
- points = torch.cat([xy * z, z], dim=-1)
406
- elif self.remap_output == "sinh_exp":
407
- xy, z = points.split([2, 1], dim=-1)
408
- points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
409
- else:
410
- raise ValueError(f"Invalid remap output type: {self.remap_output}")
411
- return points
412
-
413
- def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
414
- original_height, original_width = image.shape[-2:]
415
-
416
- # Resize to expected resolution defined by num_tokens
417
- resize_factor = (
418
- (num_tokens * 14**2) / (original_height * original_width)
419
- ) ** 0.5
420
- resized_width, resized_height = (
421
- int(original_width * resize_factor),
422
- int(original_height * resize_factor),
423
- )
424
- image = F.interpolate(
425
- image,
426
- (resized_height, resized_width),
427
- mode="bicubic",
428
- align_corners=False,
429
- antialias=True,
430
- )
431
-
432
- # Apply image transformation for DINOv2
433
- image = (image - self.image_mean) / self.image_std
434
- image_14 = F.interpolate(
435
- image,
436
- (resized_height // 14 * 14, resized_width // 14 * 14),
437
- mode="bilinear",
438
- align_corners=False,
439
- antialias=True,
440
- )
441
-
442
- # Get intermediate layers from the backbone
443
- features = self.backbone.get_intermediate_layers(
444
- image_14, self.intermediate_layers, return_class_token=True
445
- )
446
-
447
- # Predict points (and mask)
448
- output = self.head(features, image)
449
- points, mask = output
450
-
451
- # Make sure fp32 precision for output
452
- with torch.autocast(device_type=image.device.type, dtype=torch.float32):
453
- # Resize to original resolution
454
- points = F.interpolate(
455
- points,
456
- (original_height, original_width),
457
- mode="bilinear",
458
- align_corners=False,
459
- antialias=False,
460
- )
461
- mask = F.interpolate(
462
- mask,
463
- (original_height, original_width),
464
- mode="bilinear",
465
- align_corners=False,
466
- antialias=False,
467
- )
468
-
469
- # Post-process points and mask
470
- points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
471
- points = self._remap_points(
472
- points
473
- ) # slightly improves the performance in case of very large output values
474
-
475
- return_dict = {"points": points, "mask": mask}
476
- return return_dict
477
-
478
- # @torch.inference_mode()
479
- def infer(
480
- self,
481
- image: torch.Tensor,
482
- fov_x: Union[Number, torch.Tensor] = None,
483
- resolution_level: int = 9,
484
- num_tokens: int = None,
485
- apply_mask: bool = True,
486
- force_projection: bool = True,
487
- use_fp16: bool = True,
488
- ) -> Dict[str, torch.Tensor]:
489
- """
490
- User-friendly inference function
491
-
492
- ### Parameters
493
- - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\
494
- - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
495
- - `resolution_level`: An integer [0-9] for the resolution level for inference.
496
- The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size.
497
- `resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.
498
- - `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`.
499
- `resolution_level` will be ignored if `num_tokens` is provided. Default: None
500
- - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
501
- - `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True
502
- - `use_fp16`: if True, use mixed precision to speed up inference. Default: True
503
-
504
- ### Returns
505
-
506
- A dictionary containing the following keys:
507
- - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
508
- - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
509
- - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
510
- """
511
- if image.dim() == 3:
512
- omit_batch_dim = True
513
- image = image.unsqueeze(0)
514
- else:
515
- omit_batch_dim = False
516
- image = image.to(dtype=self.dtype, device=self.device)
517
-
518
- original_height, original_width = image.shape[-2:]
519
- aspect_ratio = original_width / original_height
520
-
521
- if num_tokens is None:
522
- min_tokens, max_tokens = self.num_tokens_range
523
- num_tokens = int(
524
- min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)
525
- )
526
-
527
- with torch.autocast(
528
- device_type=self.device.type,
529
- dtype=torch.float16,
530
- enabled=use_fp16 and self.dtype != torch.float16,
531
- ):
532
- output = self.forward(image, num_tokens)
533
- points, mask = output["points"], output["mask"]
534
-
535
- # Always process the output in fp32 precision
536
- with torch.autocast(device_type=self.device.type, dtype=torch.float32):
537
- points, mask, fov_x = map(
538
- lambda x: x.float() if isinstance(x, torch.Tensor) else x,
539
- [points, mask, fov_x],
540
- )
541
-
542
- mask_binary = mask > self.mask_threshold
543
-
544
- # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal)
545
- if fov_x is None:
546
- focal, shift = recover_focal_shift(points, mask_binary)
547
- else:
548
- focal = (
549
- aspect_ratio
550
- / (1 + aspect_ratio**2) ** 0.5
551
- / torch.tan(
552
- torch.deg2rad(
553
- torch.as_tensor(
554
- fov_x, device=points.device, dtype=points.dtype
555
- )
556
- / 2
557
- )
558
- )
559
- )
560
- if focal.ndim == 0:
561
- focal = focal[None].expand(points.shape[0])
562
- _, shift = recover_focal_shift(points, mask_binary, focal=focal)
563
- fx = focal / 2 * (1 + aspect_ratio**2) ** 0.5 / aspect_ratio
564
- fy = focal / 2 * (1 + aspect_ratio**2) ** 0.5
565
- intrinsics = intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
566
- depth = points[..., 2] + shift[..., None, None]
567
-
568
- # If projection constraint is forced, recompute the point map using the actual depth map
569
- if force_projection:
570
- points = depth_to_points(depth, intrinsics=intrinsics)
571
- else:
572
- points = (
573
- points
574
- + torch.stack(
575
- [torch.zeros_like(shift), torch.zeros_like(shift), shift],
576
- dim=-1,
577
- )[..., None, None, :]
578
- )
579
-
580
- # Apply mask if needed
581
- if apply_mask:
582
- points = torch.where(mask_binary[..., None], points, torch.inf)
583
- depth = torch.where(mask_binary, depth, torch.inf)
584
-
585
- return_dict = {
586
- "points": points,
587
- "intrinsics": intrinsics,
588
- "depth": depth,
589
- "mask": mask_binary,
590
- }
591
-
592
- if omit_batch_dim:
593
- return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
594
-
595
- return return_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/moge/models/v2.py DELETED
@@ -1,379 +0,0 @@
1
- import warnings
2
- from numbers import Number
3
- from pathlib import Path
4
- from typing import Any, Dict, IO, List, Literal, Optional, Union
5
-
6
- import torch
7
- import torch.amp
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- import torch.utils
11
- import torch.utils.checkpoint
12
- import torch.version
13
- from huggingface_hub import hf_hub_download
14
-
15
- from mapanything.models.external.moge.models.modules import (
16
- ConvStack,
17
- DINOv2Encoder,
18
- MLP,
19
- )
20
- from mapanything.models.external.moge.models.utils import (
21
- depth_to_points,
22
- intrinsics_from_focal_center,
23
- normalized_view_plane_uv,
24
- recover_focal_shift,
25
- )
26
-
27
-
28
- class MoGeModel(nn.Module):
29
- encoder: DINOv2Encoder
30
- neck: ConvStack
31
- points_head: ConvStack
32
- mask_head: ConvStack
33
- scale_head: MLP
34
- onnx_compatible_mode: bool
35
-
36
- def __init__(
37
- self,
38
- encoder: Dict[str, Any],
39
- neck: Dict[str, Any],
40
- points_head: Dict[str, Any] = None,
41
- mask_head: Dict[str, Any] = None,
42
- normal_head: Dict[str, Any] = None,
43
- scale_head: Dict[str, Any] = None,
44
- remap_output: Literal["linear", "sinh", "exp", "sinh_exp"] = "linear",
45
- num_tokens_range: List[int] = [1200, 3600],
46
- **deprecated_kwargs,
47
- ):
48
- super(MoGeModel, self).__init__()
49
- if deprecated_kwargs:
50
- warnings.warn(
51
- f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}"
52
- )
53
-
54
- self.remap_output = remap_output
55
- self.num_tokens_range = num_tokens_range
56
-
57
- self.encoder = DINOv2Encoder(**encoder)
58
- self.neck = ConvStack(**neck)
59
- if points_head is not None:
60
- self.points_head = ConvStack(**points_head)
61
- if mask_head is not None:
62
- self.mask_head = ConvStack(**mask_head)
63
- if normal_head is not None:
64
- self.normal_head = ConvStack(**normal_head)
65
- if scale_head is not None:
66
- self.scale_head = MLP(**scale_head)
67
-
68
- @property
69
- def device(self) -> torch.device:
70
- return next(self.parameters()).device
71
-
72
- @property
73
- def dtype(self) -> torch.dtype:
74
- return next(self.parameters()).dtype
75
-
76
- @property
77
- def onnx_compatible_mode(self) -> bool:
78
- return getattr(self, "_onnx_compatible_mode", False)
79
-
80
- @onnx_compatible_mode.setter
81
- def onnx_compatible_mode(self, value: bool):
82
- self._onnx_compatible_mode = value
83
- self.encoder.onnx_compatible_mode = value
84
-
85
- @classmethod
86
- def from_pretrained(
87
- cls,
88
- pretrained_model_name_or_path: Union[str, Path, IO[bytes]],
89
- model_kwargs: Optional[Dict[str, Any]] = None,
90
- **hf_kwargs,
91
- ) -> "MoGeModel":
92
- """
93
- Load a model from a checkpoint file.
94
-
95
- ### Parameters:
96
- - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
97
- - `compiled`
98
- - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
99
- - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
100
-
101
- ### Returns:
102
- - A new instance of `MoGe` with the parameters loaded from the checkpoint.
103
- """
104
- if Path(pretrained_model_name_or_path).exists():
105
- checkpoint_path = pretrained_model_name_or_path
106
- else:
107
- checkpoint_path = hf_hub_download(
108
- repo_id=pretrained_model_name_or_path,
109
- repo_type="model",
110
- filename="model.pt",
111
- **hf_kwargs,
112
- )
113
- checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
114
-
115
- model_config = checkpoint["model_config"]
116
- if model_kwargs is not None:
117
- model_config.update(model_kwargs)
118
- model = cls(**model_config)
119
- model.load_state_dict(checkpoint["model"], strict=False)
120
-
121
- return model
122
-
123
- def init_weights(self):
124
- self.encoder.init_weights()
125
-
126
- def enable_gradient_checkpointing(self):
127
- self.encoder.enable_gradient_checkpointing()
128
- self.neck.enable_gradient_checkpointing()
129
- for head in ["points_head", "normal_head", "mask_head"]:
130
- if hasattr(self, head):
131
- getattr(self, head).enable_gradient_checkpointing()
132
-
133
- def enable_pytorch_native_sdpa(self):
134
- self.encoder.enable_pytorch_native_sdpa()
135
-
136
- def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
137
- if self.remap_output == "linear":
138
- pass
139
- elif self.remap_output == "sinh":
140
- points = torch.sinh(points)
141
- elif self.remap_output == "exp":
142
- xy, z = points.split([2, 1], dim=-1)
143
- z = torch.exp(z)
144
- points = torch.cat([xy * z, z], dim=-1)
145
- elif self.remap_output == "sinh_exp":
146
- xy, z = points.split([2, 1], dim=-1)
147
- points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
148
- else:
149
- raise ValueError(f"Invalid remap output type: {self.remap_output}")
150
- return points
151
-
152
- def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
153
- batch_size, _, img_h, img_w = image.shape
154
- device, dtype = image.device, image.dtype
155
-
156
- aspect_ratio = img_w / img_h
157
- base_h, base_w = (
158
- int((num_tokens / aspect_ratio) ** 0.5),
159
- int((num_tokens * aspect_ratio) ** 0.5),
160
- )
161
- num_tokens = base_h * base_w
162
-
163
- # Backbones encoding
164
- features, cls_token = self.encoder(
165
- image, base_h, base_w, return_class_token=True
166
- )
167
- features = [features, None, None, None, None]
168
-
169
- # Concat UVs for aspect ratio input
170
- for level in range(5):
171
- uv = normalized_view_plane_uv(
172
- width=base_w * 2**level,
173
- height=base_h * 2**level,
174
- aspect_ratio=aspect_ratio,
175
- dtype=dtype,
176
- device=device,
177
- )
178
- uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
179
- if features[level] is None:
180
- features[level] = uv
181
- else:
182
- features[level] = torch.concat([features[level], uv], dim=1)
183
-
184
- # Shared neck
185
- features = self.neck(features)
186
-
187
- # Heads decoding
188
- points, normal, mask = (
189
- getattr(self, head)(features)[-1] if hasattr(self, head) else None
190
- for head in ["points_head", "normal_head", "mask_head"]
191
- )
192
- metric_scale = (
193
- self.scale_head(cls_token) if hasattr(self, "scale_head") else None
194
- )
195
-
196
- # Resize
197
- points, normal, mask = (
198
- F.interpolate(
199
- v, (img_h, img_w), mode="bilinear", align_corners=False, antialias=False
200
- )
201
- if v is not None
202
- else None
203
- for v in [points, normal, mask]
204
- )
205
-
206
- # Remap output
207
- if points is not None:
208
- points = points.permute(0, 2, 3, 1)
209
- points = self._remap_points(
210
- points
211
- ) # slightly improves the performance in case of very large output values
212
- if normal is not None:
213
- normal = normal.permute(0, 2, 3, 1)
214
- normal = F.normalize(normal, dim=-1)
215
- if mask is not None:
216
- mask = mask.squeeze(1).sigmoid()
217
- if metric_scale is not None:
218
- metric_scale = metric_scale.squeeze(1).exp()
219
-
220
- return_dict = {
221
- "points": points,
222
- "normal": normal,
223
- "mask": mask,
224
- "metric_scale": metric_scale,
225
- }
226
- return_dict = {k: v for k, v in return_dict.items() if v is not None}
227
-
228
- return return_dict
229
-
230
- # @torch.inference_mode()
231
- def infer(
232
- self,
233
- image: torch.Tensor,
234
- num_tokens: int = None,
235
- resolution_level: int = 9,
236
- force_projection: bool = True,
237
- apply_mask: Literal[False, True, "blend"] = True,
238
- fov_x: Optional[Union[Number, torch.Tensor]] = None,
239
- use_fp16: bool = True,
240
- ) -> Dict[str, torch.Tensor]:
241
- """
242
- User-friendly inference function
243
-
244
- ### Parameters
245
- - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
246
- - `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500.
247
- More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`.
248
- - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True
249
- - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
250
- - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
251
- - `use_fp16`: if True, use mixed precision to speed up inference. Default: True
252
-
253
- ### Returns
254
-
255
- A dictionary containing the following keys:
256
- - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
257
- - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
258
- - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
259
- """
260
- if image.dim() == 3:
261
- omit_batch_dim = True
262
- image = image.unsqueeze(0)
263
- else:
264
- omit_batch_dim = False
265
- image = image.to(dtype=self.dtype, device=self.device)
266
-
267
- original_height, original_width = image.shape[-2:]
268
- aspect_ratio = original_width / original_height
269
-
270
- # Determine the number of base tokens to use
271
- if num_tokens is None:
272
- min_tokens, max_tokens = self.num_tokens_range
273
- num_tokens = int(
274
- min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)
275
- )
276
-
277
- # Forward pass
278
- with torch.autocast(
279
- device_type=self.device.type,
280
- dtype=torch.float16,
281
- enabled=use_fp16 and self.dtype != torch.float16,
282
- ):
283
- output = self.forward(image, num_tokens=num_tokens)
284
- points, normal, mask, metric_scale = (
285
- output.get(k, None) for k in ["points", "normal", "mask", "metric_scale"]
286
- )
287
-
288
- # Always process the output in fp32 precision
289
- points, normal, mask, metric_scale, fov_x = map(
290
- lambda x: x.float() if isinstance(x, torch.Tensor) else x,
291
- [points, normal, mask, metric_scale, fov_x],
292
- )
293
- with torch.autocast(device_type=self.device.type, dtype=torch.float32):
294
- if mask is not None:
295
- mask_binary = mask > 0.5
296
- else:
297
- mask_binary = None
298
-
299
- if points is not None:
300
- # Convert affine point map to camera-space. Recover depth and intrinsics from point map.
301
- # NOTE: Focal here is the focal length relative to half the image diagonal
302
- if fov_x is None:
303
- # Recover focal and shift from predicted point map
304
- focal, shift = recover_focal_shift(points, mask_binary)
305
- else:
306
- # Focal is known, recover shift only
307
- focal = (
308
- aspect_ratio
309
- / (1 + aspect_ratio**2) ** 0.5
310
- / torch.tan(
311
- torch.deg2rad(
312
- torch.as_tensor(
313
- fov_x, device=points.device, dtype=points.dtype
314
- )
315
- / 2
316
- )
317
- )
318
- )
319
- if focal.ndim == 0:
320
- focal = focal[None].expand(points.shape[0])
321
- _, shift = recover_focal_shift(points, mask_binary, focal=focal)
322
- fx, fy = (
323
- focal / 2 * (1 + aspect_ratio**2) ** 0.5 / aspect_ratio,
324
- focal / 2 * (1 + aspect_ratio**2) ** 0.5,
325
- )
326
- intrinsics = intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
327
- points[..., 2] += shift[..., None, None]
328
- if mask_binary is not None:
329
- mask_binary &= (
330
- points[..., 2] > 0
331
- ) # in case depth is contains negative values (which should never happen in practice)
332
- depth = points[..., 2].clone()
333
- else:
334
- depth, intrinsics = None, None
335
-
336
- # If projection constraint is forced, recompute the point map using the actual depth map & intrinsics
337
- if force_projection and depth is not None:
338
- points = depth_to_points(depth, intrinsics=intrinsics)
339
-
340
- # Apply metric scale
341
- if metric_scale is not None:
342
- if points is not None:
343
- points *= metric_scale[:, None, None, None]
344
- if depth is not None:
345
- depth *= metric_scale[:, None, None]
346
-
347
- # Apply mask
348
- if apply_mask and mask_binary is not None:
349
- points = (
350
- torch.where(mask_binary[..., None], points, torch.inf)
351
- if points is not None
352
- else None
353
- )
354
- depth = (
355
- torch.where(mask_binary, depth, torch.inf)
356
- if depth is not None
357
- else None
358
- )
359
- normal = (
360
- torch.where(
361
- mask_binary[..., None], normal, torch.zeros_like(normal)
362
- )
363
- if normal is not None
364
- else None
365
- )
366
-
367
- return_dict = {
368
- "points": points,
369
- "intrinsics": intrinsics,
370
- "depth": depth,
371
- "mask": mask_binary,
372
- "normal": normal,
373
- }
374
- return_dict = {k: v for k, v in return_dict.items() if v is not None}
375
-
376
- if omit_batch_dim:
377
- return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
378
-
379
- return return_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/must3r/__init__.py DELETED
@@ -1,283 +0,0 @@
1
- """
2
- Inference wrapper for MUSt3R
3
- """
4
-
5
- import datetime
6
- import os
7
-
8
- import numpy as np
9
- import torch
10
- from dust3r.viz import rgb
11
- from must3r.demo.inference import SceneState
12
- from must3r.engine.inference import inference_multi_ar, postprocess
13
- from must3r.model import get_pointmaps_activation, load_model
14
-
15
- from mapanything.models.external.vggt.utils.rotation import mat_to_quat
16
-
17
-
18
- def must3r_inference(
19
- views,
20
- filelist,
21
- model,
22
- retrieval,
23
- device,
24
- amp,
25
- num_mem_images,
26
- max_bs,
27
- init_num_images=2,
28
- batch_num_views=1,
29
- render_once=False,
30
- is_sequence=False,
31
- viser_server=None,
32
- num_refinements_iterations=2,
33
- verbose=True,
34
- ):
35
- if amp == "fp16":
36
- dtype = torch.float16
37
- elif amp == "bf16":
38
- assert torch.cuda.is_bf16_supported()
39
- dtype = torch.bfloat16
40
- else:
41
- assert not amp
42
- dtype = torch.float32
43
-
44
- max_bs = None if max_bs == 0 else max_bs
45
- encoder, decoder = model
46
- pointmaps_activation = get_pointmaps_activation(decoder, verbose=verbose)
47
-
48
- def post_process_function(x):
49
- return postprocess(
50
- x, pointmaps_activation=pointmaps_activation, compute_cam=True
51
- )
52
-
53
- if verbose:
54
- print("loading images")
55
- time_start = datetime.datetime.now()
56
- nimgs = len(views)
57
-
58
- ellapsed = datetime.datetime.now() - time_start
59
- if verbose:
60
- print(f"loaded in {ellapsed}")
61
- print("running inference")
62
- time_start = datetime.datetime.now()
63
- if viser_server is not None:
64
- viser_server.reset(nimgs)
65
-
66
- imgs = [b["img"].to("cpu") for b in views]
67
- true_shape = [torch.from_numpy(b["true_shape"]).to("cpu") for b in views]
68
- true_shape = torch.stack(true_shape, dim=0)
69
- nimgs = true_shape.shape[0]
70
-
71
- # Use all images as keyframes
72
- keyframes = np.linspace(0, len(imgs) - 1, num_mem_images, dtype=int).tolist()
73
- encoder_precomputed_features = None
74
-
75
- not_keyframes = sorted(set(range(nimgs)).difference(set(keyframes)))
76
- assert (len(keyframes) + len(not_keyframes)) == nimgs
77
- # reorder images
78
- views = [views[i] for i in keyframes] + [views[i] for i in not_keyframes]
79
- imgs = [b["img"].to(device) for b in views]
80
- true_shape = [torch.from_numpy(b["true_shape"]).to(device) for b in views]
81
- filenames = [filelist[i] for i in keyframes + not_keyframes]
82
- img_ids = [torch.tensor(v) for v in keyframes + not_keyframes]
83
-
84
- if encoder_precomputed_features is not None:
85
- x_start, pos_start = encoder_precomputed_features
86
- x = [x_start[i] for i in keyframes] + [x_start[i] for i in not_keyframes]
87
- pos = [pos_start[i] for i in keyframes] + [pos_start[i] for i in not_keyframes]
88
- encoder_precomputed_features = (x, pos)
89
-
90
- mem_batches = [init_num_images]
91
- while (sum_b := sum(mem_batches)) != max(num_mem_images, init_num_images):
92
- size_b = min(batch_num_views, num_mem_images - sum_b)
93
- mem_batches.append(size_b)
94
-
95
- if render_once:
96
- to_render = list(range(num_mem_images, nimgs))
97
- else:
98
- to_render = None
99
-
100
- with torch.autocast("cuda", dtype=dtype):
101
- x_out_0, x_out = inference_multi_ar(
102
- encoder,
103
- decoder,
104
- imgs,
105
- img_ids,
106
- true_shape,
107
- mem_batches,
108
- max_bs=max_bs,
109
- verbose=verbose,
110
- to_render=to_render,
111
- encoder_precomputed_features=encoder_precomputed_features,
112
- device=device,
113
- preserve_gpu_mem=True,
114
- post_process_function=post_process_function,
115
- viser_server=viser_server,
116
- num_refinements_iterations=num_refinements_iterations,
117
- )
118
- if to_render is not None:
119
- x_out = x_out_0 + x_out
120
-
121
- ellapsed = datetime.datetime.now() - time_start
122
- if verbose:
123
- print(f"inference in {ellapsed}")
124
- try:
125
- print(str(int(torch.cuda.max_memory_reserved(device) / (1024**2))) + " MB")
126
- except Exception:
127
- pass
128
-
129
- if viser_server is not None:
130
- viser_server.reset_cam_visility()
131
- viser_server.send_message("Finished")
132
-
133
- if verbose:
134
- print("preparing pointcloud")
135
- time_start = datetime.datetime.now()
136
- focals = []
137
- cams2world = []
138
- for i in range(nimgs):
139
- focals.append(float(x_out[i]["focal"].cpu()))
140
- cams2world.append(x_out[i]["c2w"].cpu())
141
-
142
- # x_out to cpu
143
- for i in range(len(x_out)):
144
- for k in x_out[i].keys():
145
- x_out[i][k] = x_out[i][k].cpu()
146
-
147
- rgbimg = [rgb(imgs[i], true_shape[i]) for i in range(nimgs)]
148
- scene = SceneState(x_out, rgbimg, true_shape, focals, cams2world, filenames)
149
-
150
- ellapsed = datetime.datetime.now() - time_start
151
- if verbose:
152
- print(f"pointcloud prepared in {ellapsed}")
153
-
154
- return scene
155
-
156
-
157
- class MUSt3RWrapper(torch.nn.Module):
158
- def __init__(
159
- self,
160
- name,
161
- ckpt_path,
162
- retrieval_ckpt_path,
163
- img_size=512,
164
- amp="bf16",
165
- max_bs=1,
166
- **kwargs,
167
- ):
168
- super().__init__()
169
- self.name = name
170
- self.ckpt_path = ckpt_path
171
- self.retrieval_ckpt_path = retrieval_ckpt_path
172
- self.amp = amp
173
- self.max_bs = max_bs
174
-
175
- # Init the model and load the checkpoint
176
- self.model = load_model(self.ckpt_path, img_size=512)
177
-
178
- def forward(self, views):
179
- """
180
- Forward pass wrapper for MUSt3R.
181
-
182
- Assumption:
183
- - The batch size of input views is 1.
184
-
185
- Args:
186
- views (List[dict]): List of dictionaries containing the input views' images and instance information.
187
- Each dictionary should contain the following keys, where B is the batch size and is 1:
188
- "img" (tensor): Image tensor of shape (B, C, H, W).
189
- "data_norm_type" (list): ["dust3r"]
190
- "label" (list): ["scene_name"]
191
- "instance" (list): ["image_name"]
192
-
193
- Returns:
194
- List[dict]: A list containing the final outputs for the input views.
195
- """
196
- # Check the batch size of input views
197
- batch_size_per_view, _, height, width = views[0]["img"].shape
198
- device = views[0]["img"].device
199
- num_views = len(views)
200
- assert batch_size_per_view == 1, (
201
- f"Batch size of input views should be 1, but got {batch_size_per_view}."
202
- )
203
-
204
- # Check the data norm type
205
- data_norm_type = views[0]["data_norm_type"][0]
206
- assert data_norm_type == "dust3r", (
207
- "MUSt3R expects a normalized image with the DUSt3R normalization scheme applied"
208
- )
209
-
210
- # Convert the input views to the expected input format
211
- images = []
212
- image_paths = []
213
- for view in views:
214
- images.append(
215
- dict(
216
- img=view["img"][0].cpu(),
217
- idx=len(images),
218
- instance=str(len(images)),
219
- true_shape=np.int32([view["img"].shape[-2], view["img"].shape[-1]]),
220
- )
221
- )
222
- view_name = os.path.join(view["label"][0], view["instance"][0])
223
- image_paths.append(view_name)
224
-
225
- # Run MUSt3R inference
226
- scene = must3r_inference(
227
- images,
228
- image_paths,
229
- self.model,
230
- self.retrieval_ckpt_path,
231
- device,
232
- self.amp,
233
- num_views,
234
- self.max_bs,
235
- verbose=False,
236
- )
237
-
238
- # Make sure scene is not None
239
- if scene is None:
240
- raise RuntimeError("MUSt3R failed.")
241
-
242
- # Get the predictions
243
- predictions = scene.x_out
244
-
245
- # Convert the output to the MapAnything format
246
- with torch.autocast("cuda", enabled=False):
247
- res = []
248
- for view_idx in range(num_views):
249
- # Get the current view predictions
250
- curr_view_prediction = predictions[view_idx]
251
- curr_view_conf = curr_view_prediction["conf"]
252
- curr_view_pose = curr_view_prediction["c2w"].unsqueeze(0)
253
-
254
- # Convert the pose to quaternions and translation
255
- curr_view_cam_translations = curr_view_pose[..., :3, 3]
256
- curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3])
257
-
258
- # Get the camera frame pointmaps
259
- curr_view_pts3d_cam = curr_view_prediction["pts3d_local"].unsqueeze(0)
260
-
261
- # Get the depth along ray and ray directions
262
- curr_view_depth_along_ray = torch.norm(
263
- curr_view_pts3d_cam, dim=-1, keepdim=True
264
- )
265
- curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray
266
-
267
- # Get the pointmaps
268
- curr_view_pts3d = curr_view_prediction["pts3d"].unsqueeze(0)
269
-
270
- # Append the outputs to the result list
271
- res.append(
272
- {
273
- "pts3d": curr_view_pts3d.to(device),
274
- "pts3d_cam": curr_view_pts3d_cam.to(device),
275
- "ray_directions": curr_view_ray_dirs.to(device),
276
- "depth_along_ray": curr_view_depth_along_ray.to(device),
277
- "cam_trans": curr_view_cam_translations.to(device),
278
- "cam_quats": curr_view_cam_quats.to(device),
279
- "conf": curr_view_conf.to(device),
280
- }
281
- )
282
-
283
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/pi3/__init__.py DELETED
@@ -1,119 +0,0 @@
1
- """
2
- Inference wrapper for Pi3
3
- """
4
-
5
- import torch
6
-
7
- from mapanything.models.external.pi3.models.pi3 import Pi3
8
- from mapanything.models.external.vggt.utils.rotation import mat_to_quat
9
-
10
-
11
- class Pi3Wrapper(torch.nn.Module):
12
- def __init__(
13
- self,
14
- name,
15
- torch_hub_force_reload,
16
- load_pretrained_weights=True,
17
- pos_type="rope100",
18
- decoder_size="large",
19
- ):
20
- super().__init__()
21
- self.name = name
22
- self.torch_hub_force_reload = torch_hub_force_reload
23
-
24
- if load_pretrained_weights:
25
- # Load pre-trained weights
26
- if not torch_hub_force_reload:
27
- # Initialize the Pi3 model from huggingface hub cache
28
- print("Loading Pi3 from huggingface cache ...")
29
- self.model = Pi3.from_pretrained(
30
- "yyfz233/Pi3",
31
- )
32
- else:
33
- # Initialize the Pi3 model
34
- self.model = Pi3.from_pretrained("yyfz233/Pi3", force_download=True)
35
- else:
36
- # Load the Pi3 class
37
- self.model = Pi3(
38
- pos_type=pos_type,
39
- decoder_size=decoder_size,
40
- )
41
-
42
- # Get the dtype for Pi3 inference
43
- # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
44
- self.dtype = (
45
- torch.bfloat16
46
- if torch.cuda.get_device_capability()[0] >= 8
47
- else torch.float16
48
- )
49
-
50
- def forward(self, views):
51
- """
52
- Forward pass wrapper for Pi3
53
-
54
- Assumption:
55
- - All the input views have the same image shape.
56
-
57
- Args:
58
- views (List[dict]): List of dictionaries containing the input views' images and instance information.
59
- Each dictionary should contain the following keys:
60
- "img" (tensor): Image tensor of shape (B, C, H, W).
61
- "data_norm_type" (list): ["identity"]
62
-
63
- Returns:
64
- List[dict]: A list containing the final outputs for all N views.
65
- """
66
- # Get input shape of the images, number of views, and batch size per view
67
- batch_size_per_view, _, height, width = views[0]["img"].shape
68
- num_views = len(views)
69
-
70
- # Check the data norm type
71
- # Pi3 expects a normalized image but without the DINOv2 mean and std applied ("identity")
72
- data_norm_type = views[0]["data_norm_type"][0]
73
- assert data_norm_type == "identity", (
74
- "Pi3 expects a normalized image but without the DINOv2 mean and std applied"
75
- )
76
-
77
- # Concatenate the images to create a single (B, V, C, H, W) tensor
78
- img_list = [view["img"] for view in views]
79
- images = torch.stack(img_list, dim=1)
80
-
81
- # Run the Pi3 aggregator
82
- with torch.autocast("cuda", dtype=self.dtype):
83
- results = self.model(images)
84
-
85
- # Need high precision for transformations
86
- with torch.autocast("cuda", enabled=False):
87
- # Convert the output to MapAnything format
88
- res = []
89
- for view_idx in range(num_views):
90
- # Get the extrinsics
91
- curr_view_extrinsic = results["camera_poses"][:, view_idx, ...]
92
- curr_view_cam_translations = curr_view_extrinsic[..., :3, 3]
93
- curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3])
94
-
95
- # Get the depth along ray, ray directions, local point cloud & global point cloud
96
- curr_view_pts3d_cam = results["local_points"][:, view_idx, ...]
97
- curr_view_depth_along_ray = torch.norm(
98
- curr_view_pts3d_cam, dim=-1, keepdim=True
99
- )
100
- curr_view_ray_dirs = curr_view_pts3d_cam / curr_view_depth_along_ray
101
- curr_view_pts3d = results["points"][:, view_idx, ...]
102
-
103
- # Get the confidence
104
- curr_view_confidence = results["conf"][:, view_idx, ...]
105
-
106
- # Append the outputs to the result list
107
- res.append(
108
- {
109
- "pts3d": curr_view_pts3d,
110
- "pts3d_cam": curr_view_pts3d_cam,
111
- "ray_directions": curr_view_ray_dirs,
112
- "depth_along_ray": curr_view_depth_along_ray,
113
- "cam_trans": curr_view_cam_translations,
114
- "cam_quats": curr_view_cam_quats,
115
- "conf": curr_view_confidence,
116
- }
117
- )
118
-
119
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/pi3/layers/__init__.py DELETED
File without changes
mapanything/models/external/pi3/layers/attention.py DELETED
@@ -1,429 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- # References:
7
- # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
-
10
-
11
- import os
12
-
13
- import torch
14
- from torch import nn, Tensor
15
- from torch.nn.attention import SDPBackend
16
- from torch.nn.functional import scaled_dot_product_attention
17
-
18
- XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
19
- try:
20
- if XFORMERS_ENABLED:
21
- from xformers.ops import memory_efficient_attention
22
-
23
- XFORMERS_AVAILABLE = True
24
- # warnings.warn("xFormers is available (Attention)")
25
- else:
26
- # warnings.warn("xFormers is disabled (Attention)")
27
- raise ImportError
28
- except ImportError:
29
- XFORMERS_AVAILABLE = False
30
- # warnings.warn("xFormers is not available (Attention)")
31
-
32
-
33
- class Attention(nn.Module):
34
- def __init__(
35
- self,
36
- dim: int,
37
- num_heads: int = 8,
38
- qkv_bias: bool = False,
39
- proj_bias: bool = True,
40
- attn_drop: float = 0.0,
41
- proj_drop: float = 0.0,
42
- ) -> None:
43
- super().__init__()
44
- self.num_heads = num_heads
45
- head_dim = dim // num_heads
46
- self.scale = head_dim**-0.5
47
-
48
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
49
- self.attn_drop = nn.Dropout(attn_drop)
50
- self.proj = nn.Linear(dim, dim, bias=proj_bias)
51
- self.proj_drop = nn.Dropout(proj_drop)
52
-
53
- def forward(self, x: Tensor, attn_bias=None) -> Tensor:
54
- B, N, C = x.shape
55
- qkv = (
56
- self.qkv(x)
57
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
58
- .permute(2, 0, 3, 1, 4)
59
- )
60
-
61
- q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
62
- attn = q @ k.transpose(-2, -1)
63
-
64
- attn = attn.softmax(dim=-1)
65
- attn = self.attn_drop(attn)
66
-
67
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
68
- x = self.proj(x)
69
- x = self.proj_drop(x)
70
- return x
71
-
72
-
73
- class MemEffAttention(Attention):
74
- def forward(self, x: Tensor, attn_bias=None) -> Tensor:
75
- if not XFORMERS_AVAILABLE:
76
- if attn_bias is not None:
77
- raise AssertionError("xFormers is required for using nested tensors")
78
- return super().forward(x)
79
-
80
- B, N, C = x.shape
81
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
82
-
83
- # q, k, v = unbind(qkv, 2)
84
- q, k, v = [qkv[:, :, i] for i in range(3)]
85
-
86
- x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
87
- x = x.reshape([B, N, C])
88
-
89
- x = self.proj(x)
90
- x = self.proj_drop(x)
91
- return x
92
-
93
-
94
- class FlashAttention(Attention):
95
- def forward(self, x: Tensor, attn_bias=None) -> Tensor:
96
- B, N, C = x.shape
97
- qkv = (
98
- self.qkv(x)
99
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
100
- .transpose(1, 3)
101
- )
102
-
103
- # q, k, v = unbind(qkv, 2)
104
- q, k, v = [qkv[:, :, i] for i in range(3)]
105
-
106
- if q.dtype == torch.bfloat16:
107
- with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
108
- x = scaled_dot_product_attention(q, k, v)
109
- else:
110
- with nn.attention.sdpa_kernel(
111
- [SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]
112
- ):
113
- x = scaled_dot_product_attention(q, k, v)
114
-
115
- x = x.transpose(1, 2).reshape([B, N, C])
116
-
117
- x = self.proj(x)
118
- x = self.proj_drop(x)
119
- return x
120
-
121
-
122
- """
123
- Following is written by GPT-4o
124
- """
125
-
126
-
127
- class CrossAttentionRope(nn.Module):
128
- def __init__(
129
- self,
130
- dim: int,
131
- num_heads: int = 8,
132
- qkv_bias: bool = False,
133
- proj_bias: bool = True,
134
- attn_drop: float = 0.0,
135
- proj_drop: float = 0.0,
136
- qk_norm: bool = False,
137
- norm_layer: nn.Module = nn.LayerNorm,
138
- rope=None,
139
- ) -> None:
140
- super().__init__()
141
- self.num_heads = num_heads
142
- head_dim = dim // num_heads
143
- self.scale = head_dim**-0.5
144
-
145
- # Separate projection layers for query, key, and value
146
- self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
147
- self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
148
- self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
149
-
150
- self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
151
- self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
152
-
153
- self.attn_drop = nn.Dropout(attn_drop)
154
- self.proj = nn.Linear(dim, dim, bias=proj_bias)
155
- self.proj_drop = nn.Dropout(proj_drop)
156
-
157
- self.rope = rope
158
-
159
- def forward(
160
- self,
161
- query: Tensor,
162
- key: Tensor,
163
- value: Tensor,
164
- attn_bias=None,
165
- qpos=None,
166
- kpos=None,
167
- ) -> Tensor:
168
- """
169
- Args:
170
- query: Tensor of shape (B, N, C), input query
171
- key: Tensor of shape (B, M, C), input key
172
- value: Tensor of shape (B, M, C), input value
173
- attn_bias: Optional tensor for attention bias
174
- Returns:
175
- Tensor of shape (B, N, C), output of cross-attention
176
- """
177
- B, N, C = query.shape
178
- _, M, _ = key.shape
179
-
180
- # Project query, key, and value
181
- q = (
182
- self.q_proj(query)
183
- .reshape(B, N, self.num_heads, C // self.num_heads)
184
- .permute(0, 2, 1, 3)
185
- )
186
- k = (
187
- self.k_proj(key)
188
- .reshape(B, M, self.num_heads, C // self.num_heads)
189
- .permute(0, 2, 1, 3)
190
- )
191
- v = (
192
- self.v_proj(value)
193
- .reshape(B, M, self.num_heads, C // self.num_heads)
194
- .permute(0, 2, 1, 3)
195
- )
196
- q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
197
-
198
- if self.rope is not None:
199
- q = self.rope(q, qpos)
200
- k = self.rope(k, kpos)
201
-
202
- # Scale query
203
- q = q * self.scale
204
-
205
- # Compute attention scores
206
- attn = q @ k.transpose(-2, -1) # (B, num_heads, N, M)
207
- if attn_bias is not None:
208
- attn = attn + attn_bias
209
-
210
- attn = attn.softmax(dim=-1)
211
- attn = self.attn_drop(attn)
212
-
213
- # Compute attention output
214
- x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C)
215
-
216
- # Final projection
217
- x = self.proj(x)
218
- x = self.proj_drop(x)
219
- return x
220
-
221
-
222
- class MemEffCrossAttentionRope(CrossAttentionRope):
223
- def forward(
224
- self,
225
- query: Tensor,
226
- key: Tensor,
227
- value: Tensor,
228
- attn_bias=None,
229
- qpos=None,
230
- kpos=None,
231
- ) -> Tensor:
232
- """
233
- Args:
234
- query: Tensor of shape (B, N, C), input query
235
- key: Tensor of shape (B, M, C), input key
236
- value: Tensor of shape (B, M, C), input value
237
- attn_bias: Optional tensor for attention bias
238
- Returns:
239
- Tensor of shape (B, N, C), output of cross-attention
240
- """
241
- if not XFORMERS_AVAILABLE:
242
- if attn_bias is not None:
243
- raise AssertionError("xFormers is required for using nested tensors")
244
- return super().forward(query, key, value, attn_bias)
245
-
246
- B, N, C = query.shape
247
- _, M, _ = key.shape
248
-
249
- # Project query, key, and value
250
- q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads)
251
- k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads)
252
- v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads)
253
-
254
- q = q.transpose(1, 2)
255
- k = k.transpose(1, 2)
256
- q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
257
-
258
- if self.rope is not None:
259
- q = self.rope(q, qpos)
260
- k = self.rope(k, kpos)
261
-
262
- q = q.transpose(1, 2)
263
- k = k.transpose(1, 2)
264
-
265
- # Compute memory-efficient attention
266
- x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
267
- x = x.reshape(B, N, C)
268
-
269
- # Final projection
270
- x = self.proj(x)
271
- x = self.proj_drop(x)
272
- return x
273
-
274
-
275
- class AttentionRope(nn.Module):
276
- def __init__(
277
- self,
278
- dim: int,
279
- num_heads: int = 8,
280
- qkv_bias: bool = False,
281
- proj_bias: bool = True,
282
- attn_drop: float = 0.0,
283
- proj_drop: float = 0.0,
284
- qk_norm: bool = False,
285
- norm_layer: nn.Module = nn.LayerNorm,
286
- rope=None,
287
- ) -> None:
288
- super().__init__()
289
- self.num_heads = num_heads
290
- head_dim = dim // num_heads
291
- self.scale = head_dim**-0.5
292
-
293
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
294
- self.attn_drop = nn.Dropout(attn_drop)
295
- self.proj = nn.Linear(dim, dim, bias=proj_bias)
296
- self.proj_drop = nn.Dropout(proj_drop)
297
-
298
- self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
299
- self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
300
-
301
- self.rope = rope
302
-
303
- def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
304
- B, N, C = x.shape
305
- qkv = (
306
- self.qkv(x)
307
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
308
- .permute(2, 0, 3, 1, 4)
309
- )
310
- q, k, v = qkv[0], qkv[1], qkv[2]
311
- q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
312
-
313
- if self.rope is not None:
314
- q = self.rope(q, xpos)
315
- k = self.rope(k, xpos)
316
-
317
- q = q * self.scale
318
- attn = q @ k.transpose(-2, -1)
319
-
320
- attn = attn.softmax(dim=-1)
321
- attn = self.attn_drop(attn)
322
-
323
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
324
- x = self.proj(x)
325
- x = self.proj_drop(x)
326
- return x
327
-
328
-
329
- class MemEffAttentionRope(AttentionRope):
330
- def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
331
- if not XFORMERS_AVAILABLE:
332
- if attn_bias is not None:
333
- raise AssertionError("xFormers is required for using nested tensors")
334
- return super().forward(x)
335
-
336
- B, N, C = x.shape
337
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
338
-
339
- qkv = qkv.transpose(1, 3)
340
- # q, k, v = unbind(qkv, 2)
341
- q, k, v = [qkv[:, :, i] for i in range(3)]
342
- q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
343
-
344
- if self.rope is not None:
345
- q = self.rope(q, xpos)
346
- k = self.rope(k, xpos)
347
-
348
- q = q.transpose(1, 2)
349
- k = k.transpose(1, 2)
350
- v = v.transpose(1, 2)
351
-
352
- x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
353
- x = x.reshape([B, N, C])
354
-
355
- # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1) # for frame attention matrix
356
- # global_valid_id = torch.where(score_matrix > 0)
357
- # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1)
358
-
359
- x = self.proj(x)
360
- x = self.proj_drop(x)
361
- return x
362
-
363
-
364
- class FlashAttentionRope(AttentionRope):
365
- def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
366
- B, N, C = x.shape
367
- qkv = (
368
- self.qkv(x)
369
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
370
- .transpose(1, 3)
371
- )
372
-
373
- # q, k, v = unbind(qkv, 2)
374
- q, k, v = [qkv[:, :, i] for i in range(3)]
375
- q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
376
-
377
- if self.rope is not None:
378
- q = self.rope(q, xpos)
379
- k = self.rope(k, xpos)
380
-
381
- if q.dtype == torch.bfloat16:
382
- with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
383
- x = scaled_dot_product_attention(q, k, v)
384
- else:
385
- with nn.attention.sdpa_kernel(
386
- [SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]
387
- ):
388
- x = scaled_dot_product_attention(q, k, v)
389
-
390
- x = x.transpose(1, 2).reshape([B, N, C])
391
-
392
- x = self.proj(x)
393
- x = self.proj_drop(x)
394
- return x
395
-
396
-
397
- def get_attn_score(blk_class, x, frame_num, token_length, xpos=None):
398
- x = blk_class.norm1(x)
399
-
400
- B, N, C = x.shape
401
- qkv = blk_class.attn.qkv(x).reshape(
402
- B, N, 3, blk_class.attn.num_heads, C // blk_class.attn.num_heads
403
- )
404
-
405
- qkv = qkv.transpose(1, 3)
406
- # q, k, v = unbind(qkv, 2)
407
- q, k, v = [qkv[:, :, i] for i in range(3)]
408
- q, k = blk_class.attn.q_norm(q).to(v.dtype), blk_class.attn.k_norm(k).to(v.dtype)
409
-
410
- if blk_class.attn.rope is not None:
411
- q = blk_class.attn.rope(q, xpos)
412
- k = blk_class.attn.rope(k, xpos)
413
-
414
- q = q.transpose(1, 2)
415
- k = k.transpose(1, 2)
416
-
417
- score = (
418
- (
419
- q.permute(0, 2, 1, 3)
420
- * blk_class.attn.scale
421
- @ k.permute(0, 2, 1, 3).transpose(-2, -1)
422
- )
423
- .sum(dim=1)
424
- .reshape(B, frame_num, token_length, frame_num, token_length)
425
- .mean(dim=[2, 4])
426
- .sum(-1)
427
- )
428
-
429
- return score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/pi3/layers/block.py DELETED
@@ -1,448 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- #
3
- # This source code is licensed under the Apache License, Version 2.0
4
- # found in the LICENSE file in the root directory of this source tree.
5
-
6
- # References:
7
- # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
-
10
- import os
11
- from typing import Any, Callable, Dict, List, Tuple
12
-
13
- import torch
14
- from torch import nn, Tensor
15
-
16
- from mapanything.models.external.dinov2.layers.drop_path import DropPath
17
- from mapanything.models.external.dinov2.layers.layer_scale import LayerScale
18
- from mapanything.models.external.dinov2.layers.mlp import Mlp
19
- from mapanything.models.external.pi3.layers.attention import (
20
- Attention,
21
- CrossAttentionRope,
22
- MemEffAttention,
23
- )
24
-
25
- XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
26
- try:
27
- if XFORMERS_ENABLED:
28
- from xformers.ops import fmha, index_select_cat, scaled_index_add
29
-
30
- XFORMERS_AVAILABLE = True
31
- # warnings.warn("xFormers is available (Block)")
32
- else:
33
- # warnings.warn("xFormers is disabled (Block)")
34
- raise ImportError
35
- except ImportError:
36
- XFORMERS_AVAILABLE = False
37
- # warnings.warn("xFormers is not available (Block)")
38
-
39
-
40
- class Block(nn.Module):
41
- def __init__(
42
- self,
43
- dim: int,
44
- num_heads: int,
45
- mlp_ratio: float = 4.0,
46
- qkv_bias: bool = False,
47
- proj_bias: bool = True,
48
- ffn_bias: bool = True,
49
- drop: float = 0.0,
50
- attn_drop: float = 0.0,
51
- init_values=None,
52
- drop_path: float = 0.0,
53
- act_layer: Callable[..., nn.Module] = nn.GELU,
54
- norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
55
- attn_class: Callable[..., nn.Module] = Attention,
56
- ffn_layer: Callable[..., nn.Module] = Mlp,
57
- ) -> None:
58
- super().__init__()
59
- # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
60
- self.norm1 = norm_layer(dim)
61
- self.attn = attn_class(
62
- dim,
63
- num_heads=num_heads,
64
- qkv_bias=qkv_bias,
65
- proj_bias=proj_bias,
66
- attn_drop=attn_drop,
67
- proj_drop=drop,
68
- )
69
-
70
- self.ls1 = (
71
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
72
- )
73
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
74
-
75
- self.norm2 = norm_layer(dim)
76
- mlp_hidden_dim = int(dim * mlp_ratio)
77
- self.mlp = ffn_layer(
78
- in_features=dim,
79
- hidden_features=mlp_hidden_dim,
80
- act_layer=act_layer,
81
- drop=drop,
82
- bias=ffn_bias,
83
- )
84
- self.ls2 = (
85
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
86
- )
87
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
88
-
89
- self.sample_drop_ratio = drop_path
90
-
91
- def forward(self, x: Tensor) -> Tensor:
92
- def attn_residual_func(x: Tensor) -> Tensor:
93
- return self.ls1(self.attn(self.norm1(x)))
94
-
95
- def ffn_residual_func(x: Tensor) -> Tensor:
96
- return self.ls2(self.mlp(self.norm2(x)))
97
-
98
- if self.training and self.sample_drop_ratio > 0.1:
99
- # the overhead is compensated only for a drop path rate larger than 0.1
100
- x = drop_add_residual_stochastic_depth(
101
- x,
102
- residual_func=attn_residual_func,
103
- sample_drop_ratio=self.sample_drop_ratio,
104
- )
105
- x = drop_add_residual_stochastic_depth(
106
- x,
107
- residual_func=ffn_residual_func,
108
- sample_drop_ratio=self.sample_drop_ratio,
109
- )
110
- elif self.training and self.sample_drop_ratio > 0.0:
111
- x = x + self.drop_path1(attn_residual_func(x))
112
- x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
113
- else:
114
- x = x + attn_residual_func(x)
115
- x = x + ffn_residual_func(x)
116
- return x
117
-
118
-
119
- def drop_add_residual_stochastic_depth(
120
- x: Tensor,
121
- residual_func: Callable[[Tensor], Tensor],
122
- sample_drop_ratio: float = 0.0,
123
- ) -> Tensor:
124
- # 1) extract subset using permutation
125
- b, n, d = x.shape
126
- sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
127
- brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
128
- x_subset = x[brange]
129
-
130
- # 2) apply residual_func to get residual
131
- residual = residual_func(x_subset)
132
-
133
- x_flat = x.flatten(1)
134
- residual = residual.flatten(1)
135
-
136
- residual_scale_factor = b / sample_subset_size
137
-
138
- # 3) add the residual
139
- x_plus_residual = torch.index_add(
140
- x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
141
- )
142
- return x_plus_residual.view_as(x)
143
-
144
-
145
- def get_branges_scales(x, sample_drop_ratio=0.0):
146
- b, n, d = x.shape
147
- sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
148
- brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
149
- residual_scale_factor = b / sample_subset_size
150
- return brange, residual_scale_factor
151
-
152
-
153
- def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
154
- if scaling_vector is None:
155
- x_flat = x.flatten(1)
156
- residual = residual.flatten(1)
157
- x_plus_residual = torch.index_add(
158
- x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
159
- )
160
- else:
161
- x_plus_residual = scaled_index_add(
162
- x,
163
- brange,
164
- residual.to(dtype=x.dtype),
165
- scaling=scaling_vector,
166
- alpha=residual_scale_factor,
167
- )
168
- return x_plus_residual
169
-
170
-
171
- attn_bias_cache: Dict[Tuple, Any] = {}
172
-
173
-
174
- def get_attn_bias_and_cat(x_list, branges=None):
175
- """
176
- this will perform the index select, cat the tensors, and provide the attn_bias from cache
177
- """
178
- batch_sizes = (
179
- [b.shape[0] for b in branges]
180
- if branges is not None
181
- else [x.shape[0] for x in x_list]
182
- )
183
- all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
184
- if all_shapes not in attn_bias_cache.keys():
185
- seqlens = []
186
- for b, x in zip(batch_sizes, x_list):
187
- for _ in range(b):
188
- seqlens.append(x.shape[1])
189
- attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
190
- attn_bias._batch_sizes = batch_sizes
191
- attn_bias_cache[all_shapes] = attn_bias
192
-
193
- if branges is not None:
194
- cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
195
- 1, -1, x_list[0].shape[-1]
196
- )
197
- else:
198
- tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
199
- cat_tensors = torch.cat(tensors_bs1, dim=1)
200
-
201
- return attn_bias_cache[all_shapes], cat_tensors
202
-
203
-
204
- def drop_add_residual_stochastic_depth_list(
205
- x_list: List[Tensor],
206
- residual_func: Callable[[Tensor, Any], Tensor],
207
- sample_drop_ratio: float = 0.0,
208
- scaling_vector=None,
209
- ) -> Tensor:
210
- # 1) generate random set of indices for dropping samples in the batch
211
- branges_scales = [
212
- get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
213
- ]
214
- branges = [s[0] for s in branges_scales]
215
- residual_scale_factors = [s[1] for s in branges_scales]
216
-
217
- # 2) get attention bias and index+concat the tensors
218
- attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
219
-
220
- # 3) apply residual_func to get residual, and split the result
221
- residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
222
-
223
- outputs = []
224
- for x, brange, residual, residual_scale_factor in zip(
225
- x_list, branges, residual_list, residual_scale_factors
226
- ):
227
- outputs.append(
228
- add_residual(
229
- x, brange, residual, residual_scale_factor, scaling_vector
230
- ).view_as(x)
231
- )
232
- return outputs
233
-
234
-
235
- class NestedTensorBlock(Block):
236
- def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
237
- """
238
- x_list contains a list of tensors to nest together and run
239
- """
240
- assert isinstance(self.attn, MemEffAttention)
241
-
242
- if self.training and self.sample_drop_ratio > 0.0:
243
-
244
- def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
245
- return self.attn(self.norm1(x), attn_bias=attn_bias)
246
-
247
- def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
248
- return self.mlp(self.norm2(x))
249
-
250
- x_list = drop_add_residual_stochastic_depth_list(
251
- x_list,
252
- residual_func=attn_residual_func,
253
- sample_drop_ratio=self.sample_drop_ratio,
254
- scaling_vector=self.ls1.gamma
255
- if isinstance(self.ls1, LayerScale)
256
- else None,
257
- )
258
- x_list = drop_add_residual_stochastic_depth_list(
259
- x_list,
260
- residual_func=ffn_residual_func,
261
- sample_drop_ratio=self.sample_drop_ratio,
262
- scaling_vector=self.ls2.gamma
263
- if isinstance(self.ls1, LayerScale)
264
- else None,
265
- )
266
- return x_list
267
- else:
268
-
269
- def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
270
- return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
271
-
272
- def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
273
- return self.ls2(self.mlp(self.norm2(x)))
274
-
275
- attn_bias, x = get_attn_bias_and_cat(x_list)
276
- x = x + attn_residual_func(x, attn_bias=attn_bias)
277
- x = x + ffn_residual_func(x)
278
- return attn_bias.split(x)
279
-
280
- def forward(self, x_or_x_list):
281
- if isinstance(x_or_x_list, Tensor):
282
- return super().forward(x_or_x_list)
283
- elif isinstance(x_or_x_list, list):
284
- if not XFORMERS_AVAILABLE:
285
- raise AssertionError("xFormers is required for using nested tensors")
286
- return self.forward_nested(x_or_x_list)
287
- else:
288
- raise AssertionError
289
-
290
-
291
- class BlockRope(nn.Module):
292
- def __init__(
293
- self,
294
- dim: int,
295
- num_heads: int,
296
- mlp_ratio: float = 4.0,
297
- qkv_bias: bool = False,
298
- proj_bias: bool = True,
299
- ffn_bias: bool = True,
300
- drop: float = 0.0,
301
- attn_drop: float = 0.0,
302
- init_values=None,
303
- drop_path: float = 0.0,
304
- act_layer: Callable[..., nn.Module] = nn.GELU,
305
- norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
306
- attn_class: Callable[..., nn.Module] = Attention,
307
- ffn_layer: Callable[..., nn.Module] = Mlp,
308
- qk_norm: bool = False,
309
- rope=None,
310
- ) -> None:
311
- super().__init__()
312
- # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
313
- self.norm1 = norm_layer(dim)
314
- self.attn = attn_class(
315
- dim,
316
- num_heads=num_heads,
317
- qkv_bias=qkv_bias,
318
- proj_bias=proj_bias,
319
- attn_drop=attn_drop,
320
- proj_drop=drop,
321
- qk_norm=qk_norm,
322
- rope=rope,
323
- )
324
-
325
- self.ls1 = (
326
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
327
- )
328
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
329
-
330
- self.norm2 = norm_layer(dim)
331
- mlp_hidden_dim = int(dim * mlp_ratio)
332
- self.mlp = ffn_layer(
333
- in_features=dim,
334
- hidden_features=mlp_hidden_dim,
335
- act_layer=act_layer,
336
- drop=drop,
337
- bias=ffn_bias,
338
- )
339
- self.ls2 = (
340
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
341
- )
342
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
343
-
344
- self.sample_drop_ratio = drop_path
345
-
346
- def forward(self, x: Tensor, xpos=None) -> Tensor:
347
- def attn_residual_func(x: Tensor) -> Tensor:
348
- return self.ls1(self.attn(self.norm1(x), xpos=xpos))
349
-
350
- def ffn_residual_func(x: Tensor) -> Tensor:
351
- return self.ls2(self.mlp(self.norm2(x)))
352
-
353
- if self.training and self.sample_drop_ratio > 0.1:
354
- # the overhead is compensated only for a drop path rate larger than 0.1
355
- x = drop_add_residual_stochastic_depth(
356
- x,
357
- residual_func=attn_residual_func,
358
- sample_drop_ratio=self.sample_drop_ratio,
359
- )
360
- x = drop_add_residual_stochastic_depth(
361
- x,
362
- residual_func=ffn_residual_func,
363
- sample_drop_ratio=self.sample_drop_ratio,
364
- )
365
- elif self.training and self.sample_drop_ratio > 0.0:
366
- x = x + self.drop_path1(attn_residual_func(x))
367
- x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
368
- else:
369
- x = x + attn_residual_func(x)
370
- x = x + ffn_residual_func(x)
371
- return x
372
-
373
-
374
- class CrossBlockRope(nn.Module):
375
- def __init__(
376
- self,
377
- dim: int,
378
- num_heads: int,
379
- mlp_ratio: float = 4.0,
380
- qkv_bias: bool = False,
381
- proj_bias: bool = True,
382
- ffn_bias: bool = True,
383
- act_layer: Callable[..., nn.Module] = nn.GELU,
384
- norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
385
- attn_class: Callable[..., nn.Module] = Attention,
386
- cross_attn_class: Callable[..., nn.Module] = CrossAttentionRope,
387
- ffn_layer: Callable[..., nn.Module] = Mlp,
388
- init_values=None,
389
- qk_norm: bool = False,
390
- rope=None,
391
- ) -> None:
392
- super().__init__()
393
- # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
394
- self.ls1 = (
395
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
396
- )
397
- self.norm1 = norm_layer(dim)
398
- self.attn = attn_class(
399
- dim,
400
- num_heads=num_heads,
401
- qkv_bias=qkv_bias,
402
- proj_bias=proj_bias,
403
- rope=rope,
404
- qk_norm=qk_norm,
405
- )
406
-
407
- self.ls2 = (
408
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
409
- )
410
- self.ls_y = (
411
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
412
- )
413
- self.norm2 = norm_layer(dim)
414
- self.norm_y = norm_layer(dim)
415
- self.cross_attn = cross_attn_class(
416
- dim,
417
- num_heads=num_heads,
418
- qkv_bias=qkv_bias,
419
- proj_bias=proj_bias,
420
- rope=rope,
421
- qk_norm=qk_norm,
422
- )
423
-
424
- self.norm3 = norm_layer(dim)
425
- mlp_hidden_dim = int(dim * mlp_ratio)
426
- self.mlp = ffn_layer(
427
- in_features=dim,
428
- hidden_features=mlp_hidden_dim,
429
- act_layer=act_layer,
430
- bias=ffn_bias,
431
- )
432
-
433
- def forward(self, x: Tensor, y: Tensor, xpos=None, ypos=None) -> Tensor:
434
- def attn_residual_func(x: Tensor) -> Tensor:
435
- return self.ls1(self.attn(self.norm1(x), xpos=xpos))
436
-
437
- def cross_attn_residual_func(x: Tensor, y: Tensor) -> Tensor:
438
- return self.ls_y(self.cross_attn(self.norm2(x), y, y, qpos=xpos, kpos=ypos))
439
-
440
- def ffn_residual_func(x: Tensor) -> Tensor:
441
- return self.ls2(self.mlp(self.norm3(x)))
442
-
443
- x = x + attn_residual_func(x)
444
- y_ = self.norm_y(y)
445
- x = x + cross_attn_residual_func(x, y_)
446
- x = x + ffn_residual_func(x)
447
-
448
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/pi3/layers/camera_head.py DELETED
@@ -1,106 +0,0 @@
1
- from copy import deepcopy
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
-
8
- # code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172'
9
- class ResConvBlock(nn.Module):
10
- """
11
- 1x1 convolution residual block
12
- """
13
-
14
- def __init__(self, in_channels, out_channels):
15
- super().__init__()
16
- self.in_channels = in_channels
17
- self.out_channels = out_channels
18
- self.head_skip = (
19
- nn.Identity()
20
- if self.in_channels == self.out_channels
21
- else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
22
- )
23
- # self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
24
- # self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
25
- # self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
26
-
27
- # change 1x1 convolution to linear
28
- self.res_conv1 = nn.Linear(self.in_channels, self.out_channels)
29
- self.res_conv2 = nn.Linear(self.out_channels, self.out_channels)
30
- self.res_conv3 = nn.Linear(self.out_channels, self.out_channels)
31
-
32
- def forward(self, res):
33
- x = F.relu(self.res_conv1(res))
34
- x = F.relu(self.res_conv2(x))
35
- x = F.relu(self.res_conv3(x))
36
- res = self.head_skip(res) + x
37
- return res
38
-
39
-
40
- class CameraHead(nn.Module):
41
- def __init__(self, dim=512):
42
- super().__init__()
43
- output_dim = dim
44
- self.res_conv = nn.ModuleList(
45
- [deepcopy(ResConvBlock(output_dim, output_dim)) for _ in range(2)]
46
- )
47
- self.avgpool = nn.AdaptiveAvgPool2d(1)
48
- self.more_mlps = nn.Sequential(
49
- nn.Linear(output_dim, output_dim),
50
- nn.ReLU(),
51
- nn.Linear(output_dim, output_dim),
52
- nn.ReLU(),
53
- )
54
- self.fc_t = nn.Linear(output_dim, 3)
55
- self.fc_rot = nn.Linear(output_dim, 9)
56
-
57
- def forward(self, feat, patch_h, patch_w):
58
- BN, hw, c = feat.shape
59
-
60
- for i in range(2):
61
- feat = self.res_conv[i](feat)
62
-
63
- # feat = self.avgpool(feat)
64
- feat = self.avgpool(
65
- feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous()
66
- ) ##########
67
- feat = feat.view(feat.size(0), -1)
68
-
69
- feat = self.more_mlps(feat) # [B, D_]
70
- with torch.amp.autocast(device_type="cuda", enabled=False):
71
- out_t = self.fc_t(feat.float()) # [B,3]
72
- out_r = self.fc_rot(feat.float()) # [B,9]
73
- pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device)
74
-
75
- return pose
76
-
77
- def convert_pose_to_4x4(self, B, out_r, out_t, device):
78
- out_r = self.svd_orthogonalize(out_r) # [N,3,3]
79
- pose = torch.zeros((B, 4, 4), device=device)
80
- pose[:, :3, :3] = out_r
81
- pose[:, :3, 3] = out_t
82
- pose[:, 3, 3] = 1.0
83
- return pose
84
-
85
- def svd_orthogonalize(self, m):
86
- """Convert 9D representation to SO(3) using SVD orthogonalization.
87
-
88
- Args:
89
- m: [BATCH, 3, 3] 3x3 matrices.
90
-
91
- Returns:
92
- [BATCH, 3, 3] SO(3) rotation matrices.
93
- """
94
- if m.dim() < 3:
95
- m = m.reshape((-1, 3, 3))
96
- m_transpose = torch.transpose(
97
- torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2
98
- )
99
- u, s, v = torch.svd(m_transpose)
100
- det = torch.det(torch.matmul(v, u.transpose(-2, -1)))
101
- # Check orientation reflection.
102
- r = torch.matmul(
103
- torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2),
104
- u.transpose(-2, -1),
105
- )
106
- return r
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/pi3/layers/pos_embed.py DELETED
@@ -1,190 +0,0 @@
1
- # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
-
4
-
5
- # --------------------------------------------------------
6
- # Position embedding utils
7
- # --------------------------------------------------------
8
-
9
-
10
- import numpy as np
11
- import torch
12
-
13
-
14
- # --------------------------------------------------------
15
- # 2D sine-cosine position embedding
16
- # References:
17
- # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
18
- # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
19
- # MoCo v3: https://github.com/facebookresearch/moco-v3
20
- # --------------------------------------------------------
21
- def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
22
- """
23
- grid_size: int of the grid height and width
24
- return:
25
- pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
26
- """
27
- grid_h = np.arange(grid_size, dtype=np.float32)
28
- grid_w = np.arange(grid_size, dtype=np.float32)
29
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
30
- grid = np.stack(grid, axis=0)
31
-
32
- grid = grid.reshape([2, 1, grid_size, grid_size])
33
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
34
- if n_cls_token > 0:
35
- pos_embed = np.concatenate(
36
- [np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0
37
- )
38
- return pos_embed
39
-
40
-
41
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
42
- assert embed_dim % 2 == 0
43
-
44
- # use half of dimensions to encode grid_h
45
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
46
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
47
-
48
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
49
- return emb
50
-
51
-
52
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
53
- """
54
- embed_dim: output dimension for each position
55
- pos: a list of positions to be encoded: size (M,)
56
- out: (M, D)
57
- """
58
- assert embed_dim % 2 == 0
59
- omega = np.arange(embed_dim // 2, dtype=float)
60
- omega /= embed_dim / 2.0
61
- omega = 1.0 / 10000**omega # (D/2,)
62
-
63
- pos = pos.reshape(-1) # (M,)
64
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
65
-
66
- emb_sin = np.sin(out) # (M, D/2)
67
- emb_cos = np.cos(out) # (M, D/2)
68
-
69
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
70
- return emb
71
-
72
-
73
- # --------------------------------------------------------
74
- # Interpolate position embeddings for high-resolution
75
- # References:
76
- # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
77
- # DeiT: https://github.com/facebookresearch/deit
78
- # --------------------------------------------------------
79
- def interpolate_pos_embed(model, checkpoint_model):
80
- if "pos_embed" in checkpoint_model:
81
- pos_embed_checkpoint = checkpoint_model["pos_embed"]
82
- embedding_size = pos_embed_checkpoint.shape[-1]
83
- num_patches = model.patch_embed.num_patches
84
- num_extra_tokens = model.pos_embed.shape[-2] - num_patches
85
- # height (== width) for the checkpoint position embedding
86
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
87
- # height (== width) for the new position embedding
88
- new_size = int(num_patches**0.5)
89
- # class_token and dist_token are kept unchanged
90
- if orig_size != new_size:
91
- print(
92
- "Position interpolate from %dx%d to %dx%d"
93
- % (orig_size, orig_size, new_size, new_size)
94
- )
95
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
96
- # only the position tokens are interpolated
97
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
98
- pos_tokens = pos_tokens.reshape(
99
- -1, orig_size, orig_size, embedding_size
100
- ).permute(0, 3, 1, 2)
101
- pos_tokens = torch.nn.functional.interpolate(
102
- pos_tokens,
103
- size=(new_size, new_size),
104
- mode="bicubic",
105
- align_corners=False,
106
- )
107
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
108
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
109
- checkpoint_model["pos_embed"] = new_pos_embed
110
-
111
-
112
- # ----------------------------------------------------------
113
- # RoPE2D: RoPE implementation in 2D
114
- # ----------------------------------------------------------
115
-
116
- try:
117
- from models.curope import cuRoPE2D
118
-
119
- RoPE2D = cuRoPE2D
120
- except ImportError:
121
-
122
- class RoPE2D(torch.nn.Module):
123
- def __init__(self, freq=100.0, F0=1.0):
124
- super().__init__()
125
- self.base = freq
126
- self.F0 = F0
127
- self.cache = {}
128
-
129
- def get_cos_sin(self, D, seq_len, device, dtype):
130
- if (D, seq_len, device, dtype) not in self.cache:
131
- inv_freq = 1.0 / (
132
- self.base ** (torch.arange(0, D, 2).float().to(device) / D)
133
- )
134
- t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
135
- freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
136
- freqs = torch.cat((freqs, freqs), dim=-1)
137
- cos = freqs.cos() # (Seq, Dim)
138
- sin = freqs.sin()
139
- self.cache[D, seq_len, device, dtype] = (cos, sin)
140
- return self.cache[D, seq_len, device, dtype]
141
-
142
- @staticmethod
143
- def rotate_half(x):
144
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
145
- return torch.cat((-x2, x1), dim=-1)
146
-
147
- def apply_rope1d(self, tokens, pos1d, cos, sin):
148
- assert pos1d.ndim == 2
149
- cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
150
- sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
151
- return (tokens * cos) + (self.rotate_half(tokens) * sin)
152
-
153
- def forward(self, tokens, positions):
154
- """
155
- input:
156
- * tokens: batch_size x nheads x ntokens x dim
157
- * positions: batch_size x ntokens x 2 (y and x position of each token)
158
- output:
159
- * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
160
- """
161
- assert tokens.size(3) % 2 == 0, (
162
- "number of dimensions should be a multiple of two"
163
- )
164
- D = tokens.size(3) // 2
165
- assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2
166
- cos, sin = self.get_cos_sin(
167
- D, int(positions.max()) + 1, tokens.device, tokens.dtype
168
- )
169
- # split features into two along the feature dimension, and apply rope1d on each half
170
- y, x = tokens.chunk(2, dim=-1)
171
- y = self.apply_rope1d(y, positions[:, :, 0], cos, sin)
172
- x = self.apply_rope1d(x, positions[:, :, 1], cos, sin)
173
- tokens = torch.cat((y, x), dim=-1)
174
- return tokens
175
-
176
-
177
- # patch embedding
178
- class PositionGetter(object):
179
- """return positions of patches"""
180
-
181
- def __init__(self):
182
- self.cache_positions = {}
183
-
184
- def __call__(self, b, h, w, device):
185
- if (h, w) not in self.cache_positions:
186
- x = torch.arange(w, device=device)
187
- y = torch.arange(h, device=device)
188
- self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2)
189
- pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone()
190
- return pos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/pi3/layers/transformer_head.py DELETED
@@ -1,98 +0,0 @@
1
- from functools import partial
2
-
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from torch.utils.checkpoint import checkpoint
6
-
7
- from mapanything.models.external.dinov2.layers import Mlp
8
- from mapanything.models.external.pi3.layers.attention import FlashAttentionRope
9
- from mapanything.models.external.pi3.layers.block import BlockRope
10
-
11
-
12
- class TransformerDecoder(nn.Module):
13
- def __init__(
14
- self,
15
- in_dim,
16
- out_dim,
17
- dec_embed_dim=512,
18
- depth=5,
19
- dec_num_heads=8,
20
- mlp_ratio=4,
21
- rope=None,
22
- need_project=True,
23
- use_checkpoint=False,
24
- ):
25
- super().__init__()
26
-
27
- self.projects = (
28
- nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
29
- )
30
- self.use_checkpoint = use_checkpoint
31
-
32
- self.blocks = nn.ModuleList(
33
- [
34
- BlockRope(
35
- dim=dec_embed_dim,
36
- num_heads=dec_num_heads,
37
- mlp_ratio=mlp_ratio,
38
- qkv_bias=True,
39
- proj_bias=True,
40
- ffn_bias=True,
41
- drop_path=0.0,
42
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
43
- act_layer=nn.GELU,
44
- ffn_layer=Mlp,
45
- init_values=None,
46
- qk_norm=False,
47
- # attn_class=MemEffAttentionRope,
48
- attn_class=FlashAttentionRope,
49
- rope=rope,
50
- )
51
- for _ in range(depth)
52
- ]
53
- )
54
-
55
- self.linear_out = nn.Linear(dec_embed_dim, out_dim)
56
-
57
- def forward(self, hidden, xpos=None):
58
- hidden = self.projects(hidden)
59
- for i, blk in enumerate(self.blocks):
60
- if self.use_checkpoint and self.training:
61
- hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False)
62
- else:
63
- hidden = blk(hidden, xpos=xpos)
64
- out = self.linear_out(hidden)
65
- return out
66
-
67
-
68
- class LinearPts3d(nn.Module):
69
- """
70
- Linear head for dust3r
71
- Each token outputs: - 16x16 3D points (+ confidence)
72
- """
73
-
74
- def __init__(
75
- self,
76
- patch_size,
77
- dec_embed_dim,
78
- output_dim=3,
79
- ):
80
- super().__init__()
81
- self.patch_size = patch_size
82
-
83
- self.proj = nn.Linear(dec_embed_dim, (output_dim) * self.patch_size**2)
84
-
85
- def forward(self, decout, img_shape):
86
- H, W = img_shape
87
- tokens = decout[-1]
88
- B, S, D = tokens.shape
89
-
90
- # extract 3D points
91
- feat = self.proj(tokens) # B,S,D
92
- feat = feat.transpose(-1, -2).view(
93
- B, -1, H // self.patch_size, W // self.patch_size
94
- )
95
- feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
96
-
97
- # permute + norm depth
98
- return feat.permute(0, 2, 3, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/pi3/models/__init__.py DELETED
File without changes
mapanything/models/external/pi3/models/pi3.py DELETED
@@ -1,251 +0,0 @@
1
- from copy import deepcopy
2
- from functools import partial
3
-
4
- import torch
5
- import torch.nn as nn
6
- from huggingface_hub import PyTorchModelHubMixin
7
-
8
- from mapanything.models.external.dinov2.hub.backbones import dinov2_vitl14_reg
9
- from mapanything.models.external.dinov2.layers import Mlp
10
- from mapanything.models.external.pi3.layers.attention import FlashAttentionRope
11
- from mapanything.models.external.pi3.layers.block import BlockRope
12
- from mapanything.models.external.pi3.layers.camera_head import CameraHead
13
- from mapanything.models.external.pi3.layers.pos_embed import PositionGetter, RoPE2D
14
- from mapanything.models.external.pi3.layers.transformer_head import (
15
- LinearPts3d,
16
- TransformerDecoder,
17
- )
18
-
19
-
20
- def homogenize_points(
21
- points,
22
- ):
23
- """Convert batched points (xyz) to (xyz1)."""
24
- return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
25
-
26
-
27
- class Pi3(nn.Module, PyTorchModelHubMixin):
28
- def __init__(
29
- self,
30
- pos_type="rope100",
31
- decoder_size="large",
32
- ):
33
- super().__init__()
34
-
35
- # ----------------------
36
- # Encoder
37
- # ----------------------
38
- self.encoder = dinov2_vitl14_reg(pretrained=False)
39
- self.patch_size = 14
40
- del self.encoder.mask_token
41
-
42
- # ----------------------
43
- # Positonal Encoding
44
- # ----------------------
45
- self.pos_type = pos_type if pos_type is not None else "none"
46
- self.rope = None
47
- if self.pos_type.startswith("rope"): # eg rope100
48
- if RoPE2D is None:
49
- raise ImportError(
50
- "Cannot find cuRoPE2D, please install it following the README instructions"
51
- )
52
- freq = float(self.pos_type[len("rope") :])
53
- self.rope = RoPE2D(freq=freq)
54
- self.position_getter = PositionGetter()
55
- else:
56
- raise NotImplementedError
57
-
58
- # ----------------------
59
- # Decoder
60
- # ----------------------
61
- if decoder_size == "small":
62
- dec_embed_dim = 384
63
- dec_num_heads = 6
64
- mlp_ratio = 4
65
- dec_depth = 24
66
- elif decoder_size == "base":
67
- dec_embed_dim = 768
68
- dec_num_heads = 12
69
- mlp_ratio = 4
70
- dec_depth = 24
71
- elif decoder_size == "large":
72
- dec_embed_dim = 1024
73
- dec_num_heads = 16
74
- mlp_ratio = 4
75
- dec_depth = 36
76
- else:
77
- raise NotImplementedError
78
- self.decoder = nn.ModuleList(
79
- [
80
- BlockRope(
81
- dim=dec_embed_dim,
82
- num_heads=dec_num_heads,
83
- mlp_ratio=mlp_ratio,
84
- qkv_bias=True,
85
- proj_bias=True,
86
- ffn_bias=True,
87
- drop_path=0.0,
88
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
89
- act_layer=nn.GELU,
90
- ffn_layer=Mlp,
91
- init_values=0.01,
92
- qk_norm=True,
93
- attn_class=FlashAttentionRope,
94
- rope=self.rope,
95
- )
96
- for _ in range(dec_depth)
97
- ]
98
- )
99
- self.dec_embed_dim = dec_embed_dim
100
-
101
- # ----------------------
102
- # Register_token
103
- # ----------------------
104
- num_register_tokens = 5
105
- self.patch_start_idx = num_register_tokens
106
- self.register_token = nn.Parameter(
107
- torch.randn(1, 1, num_register_tokens, self.dec_embed_dim)
108
- )
109
- nn.init.normal_(self.register_token, std=1e-6)
110
-
111
- # ----------------------
112
- # Local Points Decoder
113
- # ----------------------
114
- self.point_decoder = TransformerDecoder(
115
- in_dim=2 * self.dec_embed_dim,
116
- dec_embed_dim=1024,
117
- dec_num_heads=16,
118
- out_dim=1024,
119
- rope=self.rope,
120
- )
121
- self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3)
122
-
123
- # ----------------------
124
- # Conf Decoder
125
- # ----------------------
126
- self.conf_decoder = deepcopy(self.point_decoder)
127
- self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1)
128
-
129
- # ----------------------
130
- # Camera Pose Decoder
131
- # ----------------------
132
- self.camera_decoder = TransformerDecoder(
133
- in_dim=2 * self.dec_embed_dim,
134
- dec_embed_dim=1024,
135
- dec_num_heads=16, # 8
136
- out_dim=512,
137
- rope=self.rope,
138
- use_checkpoint=False,
139
- )
140
- self.camera_head = CameraHead(dim=512)
141
-
142
- # For ImageNet Normalize
143
- image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
144
- image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
145
-
146
- self.register_buffer("image_mean", image_mean)
147
- self.register_buffer("image_std", image_std)
148
-
149
- def decode(self, hidden, N, H, W):
150
- BN, hw, _ = hidden.shape
151
- B = BN // N
152
-
153
- final_output = []
154
-
155
- hidden = hidden.reshape(B * N, hw, -1)
156
-
157
- register_token = self.register_token.repeat(B, N, 1, 1).reshape(
158
- B * N, *self.register_token.shape[-2:]
159
- )
160
-
161
- # Concatenate special tokens with patch tokens
162
- hidden = torch.cat([register_token, hidden], dim=1)
163
- hw = hidden.shape[1]
164
-
165
- if self.pos_type.startswith("rope"):
166
- pos = self.position_getter(
167
- B * N, H // self.patch_size, W // self.patch_size, hidden.device
168
- )
169
-
170
- if self.patch_start_idx > 0:
171
- # do not use position embedding for special tokens (camera and register tokens)
172
- # so set pos to 0 for the special tokens
173
- pos = pos + 1
174
- pos_special = (
175
- torch.zeros(B * N, self.patch_start_idx, 2)
176
- .to(hidden.device)
177
- .to(pos.dtype)
178
- )
179
- pos = torch.cat([pos_special, pos], dim=1)
180
-
181
- for i in range(len(self.decoder)):
182
- blk = self.decoder[i]
183
-
184
- if i % 2 == 0:
185
- pos = pos.reshape(B * N, hw, -1)
186
- hidden = hidden.reshape(B * N, hw, -1)
187
- else:
188
- pos = pos.reshape(B, N * hw, -1)
189
- hidden = hidden.reshape(B, N * hw, -1)
190
-
191
- hidden = blk(hidden, xpos=pos)
192
-
193
- if i + 1 in [len(self.decoder) - 1, len(self.decoder)]:
194
- final_output.append(hidden.reshape(B * N, hw, -1))
195
-
196
- return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(
197
- B * N, hw, -1
198
- )
199
-
200
- def forward(self, imgs):
201
- imgs = (imgs - self.image_mean) / self.image_std
202
-
203
- B, N, _, H, W = imgs.shape
204
- patch_h, patch_w = H // 14, W // 14
205
-
206
- # encode by dinov2
207
- imgs = imgs.reshape(B * N, _, H, W)
208
- hidden = self.encoder(imgs, is_training=True)
209
-
210
- if isinstance(hidden, dict):
211
- hidden = hidden["x_norm_patchtokens"]
212
-
213
- hidden, pos = self.decode(hidden, N, H, W)
214
-
215
- point_hidden = self.point_decoder(hidden, xpos=pos)
216
- conf_hidden = self.conf_decoder(hidden, xpos=pos)
217
- camera_hidden = self.camera_decoder(hidden, xpos=pos)
218
-
219
- with torch.amp.autocast(device_type="cuda", enabled=False):
220
- # local points
221
- point_hidden = point_hidden.float()
222
- ret = self.point_head(
223
- [point_hidden[:, self.patch_start_idx :]], (H, W)
224
- ).reshape(B, N, H, W, -1)
225
- xy, z = ret.split([2, 1], dim=-1)
226
- z = torch.exp(z)
227
- local_points = torch.cat([xy * z, z], dim=-1)
228
-
229
- # confidence
230
- conf_hidden = conf_hidden.float()
231
- conf = self.conf_head(
232
- [conf_hidden[:, self.patch_start_idx :]], (H, W)
233
- ).reshape(B, N, H, W, -1)
234
-
235
- # camera
236
- camera_hidden = camera_hidden.float()
237
- camera_poses = self.camera_head(
238
- camera_hidden[:, self.patch_start_idx :], patch_h, patch_w
239
- ).reshape(B, N, 4, 4)
240
-
241
- # unproject local points using camera poses
242
- points = torch.einsum(
243
- "bnij, bnhwj -> bnhwi", camera_poses, homogenize_points(local_points)
244
- )[..., :3]
245
-
246
- return dict(
247
- points=points,
248
- local_points=local_points,
249
- conf=conf,
250
- camera_poses=camera_poses,
251
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/pow3r/__init__.py DELETED
@@ -1,860 +0,0 @@
1
- """
2
- Inference wrapper for Pow3R
3
- """
4
-
5
- import warnings
6
- from copy import deepcopy
7
-
8
- import pow3r.model.blocks # noqa
9
- import roma
10
- import torch
11
- import torch.nn as nn
12
- import tqdm
13
- from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
14
- from dust3r.image_pairs import make_pairs
15
- from dust3r.inference import check_if_same_size
16
- from dust3r.model import CroCoNet
17
- from dust3r.patch_embed import get_patch_embed as dust3r_patch_embed
18
- from dust3r.utils.device import collate_with_cat, to_cpu
19
- from dust3r.utils.misc import (
20
- fill_default_args,
21
- freeze_all_params,
22
- interleave,
23
- is_symmetrized,
24
- transpose_to_landscape,
25
- )
26
- from pow3r.model.blocks import Block, BlockInject, DecoderBlock, DecoderBlockInject, Mlp
27
- from pow3r.model.heads import head_factory
28
- from pow3r.model.inference import (
29
- add_depth,
30
- add_intrinsics,
31
- add_relpose,
32
- )
33
- from pow3r.model.patch_embed import get_patch_embed
34
-
35
- from mapanything.models.external.vggt.utils.rotation import mat_to_quat
36
- from mapanything.utils.geometry import (
37
- convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
38
- convert_z_depth_to_depth_along_ray,
39
- depthmap_to_camera_frame,
40
- get_rays_in_camera_frame,
41
- )
42
-
43
-
44
- class Pow3R(CroCoNet):
45
- """Two siamese encoders, followed by two decoders.
46
- The goal is to output 3d points directly, both images in view1's frame
47
- (hence the asymmetry).
48
- """
49
-
50
- def __init__(
51
- self,
52
- mode="embed",
53
- head_type="linear",
54
- patch_embed_cls="PatchEmbedDust3R",
55
- freeze="none",
56
- landscape_only=True,
57
- **croco_kwargs,
58
- ):
59
- # retrieve all default arguments using python magic
60
- self.croco_args = fill_default_args(croco_kwargs, super().__init__)
61
- super().__init__(**croco_kwargs)
62
- del self.mask_token # useless
63
- del self.prediction_head
64
-
65
- dec_dim, enc_dim = self.decoder_embed.weight.shape
66
- self.enc_embed_dim = enc_dim
67
- self.dec_embed_dim = dec_dim
68
-
69
- self.mode = mode
70
- # additional parameters in the encoder
71
- img_size = self.patch_embed.img_size
72
- patch_size = self.patch_embed.patch_size[0]
73
- self.patch_embed = dust3r_patch_embed(
74
- patch_embed_cls, img_size, patch_size, self.enc_embed_dim
75
- )
76
- self.patch_embed_rays = get_patch_embed(
77
- patch_embed_cls + "_Mlp",
78
- img_size,
79
- patch_size,
80
- self.enc_embed_dim,
81
- in_chans=3,
82
- )
83
- self.patch_embed_depth = get_patch_embed(
84
- patch_embed_cls + "_Mlp",
85
- img_size,
86
- patch_size,
87
- self.enc_embed_dim,
88
- in_chans=2,
89
- )
90
- self.pose_embed = Mlp(12, 4 * dec_dim, dec_dim)
91
-
92
- # additional parameters in the decoder
93
- self.dec_cls = "_cls" in self.mode
94
- self.dec_num_cls = 0
95
- if self.dec_cls:
96
- # use a CLS token in the decoder only
97
- self.mode = self.mode.replace("_cls", "")
98
- self.cls_token1 = nn.Parameter(torch.zeros((dec_dim,)))
99
- self.cls_token2 = nn.Parameter(torch.zeros((dec_dim,)))
100
- self.dec_num_cls = 1 # affects all blocks
101
-
102
- use_ln = "_ln" in self.mode # TODO remove?
103
- self.patch_ln = nn.LayerNorm(enc_dim) if use_ln else nn.Identity()
104
- self.dec1_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity()
105
- self.dec2_pre_ln = nn.LayerNorm(dec_dim) if use_ln else nn.Identity()
106
-
107
- self.dec_blocks2 = deepcopy(self.dec_blocks)
108
-
109
- # here we modify some of the blocks
110
- self.replace_some_blocks()
111
-
112
- self.set_downstream_head(head_type, landscape_only, **croco_kwargs)
113
- self.set_freeze(freeze)
114
-
115
- def replace_some_blocks(self):
116
- assert self.mode.startswith("inject") # inject[0,0.5]
117
- NewBlock = BlockInject
118
- DecoderNewBlock = DecoderBlockInject
119
-
120
- all_layers = {
121
- i / n
122
- for i in range(len(self.enc_blocks))
123
- for n in [len(self.enc_blocks), len(self.dec_blocks)]
124
- }
125
- which_layers = eval(self.mode[self.mode.find("[") :]) or all_layers
126
- assert isinstance(which_layers, (set, list))
127
-
128
- n = 0
129
- for i, block in enumerate(self.enc_blocks):
130
- if i / len(self.enc_blocks) in which_layers:
131
- block.__class__ = NewBlock
132
- block.init(self.enc_embed_dim)
133
- n += 1
134
- else:
135
- block.__class__ = Block
136
- assert n == len(which_layers), breakpoint()
137
-
138
- n = 0
139
- for i in range(len(self.dec_blocks)):
140
- for blocks in [self.dec_blocks, self.dec_blocks2]:
141
- block = blocks[i]
142
- if i / len(self.dec_blocks) in which_layers:
143
- block.__class__ = DecoderNewBlock
144
- block.init(self.dec_embed_dim)
145
- n += 1
146
- else:
147
- block.__class__ = DecoderBlock
148
- assert n == 2 * len(which_layers), breakpoint()
149
-
150
- @classmethod
151
- def from_pretrained(cls, pretrained_model_path, **kw):
152
- return _load_model(pretrained_model_path, device="cpu")
153
-
154
- def load_state_dict(self, ckpt, **kw):
155
- # duplicate all weights for the second decoder if not present
156
- new_ckpt = dict(ckpt)
157
- if not any(k.startswith("dec_blocks2") for k in ckpt):
158
- for key, value in ckpt.items():
159
- if key.startswith("dec_blocks"):
160
- new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value
161
- # remove layers that have different shapes
162
- cur_ckpt = self.state_dict()
163
- for key, val in ckpt.items():
164
- if key.startswith("downstream_head2.proj"):
165
- if key in cur_ckpt and cur_ckpt[key].shape != val.shape:
166
- print(f" (removing ckpt[{key}] because wrong shape)")
167
- del new_ckpt[key]
168
- return super().load_state_dict(new_ckpt, **kw)
169
-
170
- def set_freeze(self, freeze): # this is for use by downstream models
171
- self.freeze = freeze
172
- to_be_frozen = {
173
- "none": [],
174
- "encoder": [self.patch_embed, self.enc_blocks],
175
- }
176
- freeze_all_params(to_be_frozen[freeze])
177
-
178
- def set_prediction_head(self, *args, **kwargs):
179
- """No prediction head"""
180
- return
181
-
182
- def set_downstream_head(
183
- self,
184
- head_type,
185
- landscape_only,
186
- patch_size,
187
- img_size,
188
- mlp_ratio,
189
- dec_depth,
190
- **kw,
191
- ):
192
- assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, (
193
- f"{img_size=} must be multiple of {patch_size=}"
194
- )
195
-
196
- # split heads if different
197
- heads = head_type.split(";")
198
- assert len(heads) in (1, 2)
199
- head1_type, head2_type = (heads + heads)[:2]
200
-
201
- # allocate heads
202
- self.downstream_head1 = head_factory(head1_type, self)
203
- self.downstream_head2 = head_factory(head2_type, self)
204
-
205
- # magic wrapper
206
- self.head1 = transpose_to_landscape(
207
- self.downstream_head1, activate=landscape_only
208
- )
209
- self.head2 = transpose_to_landscape(
210
- self.downstream_head2, activate=landscape_only
211
- )
212
-
213
- def _encode_image(self, image, true_shape, rays=None, depth=None):
214
- # embed the image into patches (x has size B x Npatches x C)
215
- x, pos = self.patch_embed(image, true_shape=true_shape)
216
-
217
- if rays is not None: # B,3,H,W
218
- rays_emb, pos2 = self.patch_embed_rays(rays, true_shape=true_shape)
219
- assert (pos == pos2).all()
220
- if self.mode.startswith("embed"):
221
- x = x + rays_emb
222
- else:
223
- rays_emb = None
224
-
225
- if depth is not None: # B,2,H,W
226
- depth_emb, pos2 = self.patch_embed_depth(depth, true_shape=true_shape)
227
- assert (pos == pos2).all()
228
- if self.mode.startswith("embed"):
229
- x = x + depth_emb
230
- else:
231
- depth_emb = None
232
-
233
- x = self.patch_ln(x)
234
-
235
- # add positional embedding without cls token
236
- assert self.enc_pos_embed is None
237
-
238
- # now apply the transformer encoder and normalization
239
- for blk in self.enc_blocks:
240
- x = blk(x, pos, rays=rays_emb, depth=depth_emb)
241
-
242
- x = self.enc_norm(x)
243
- return x, pos
244
-
245
- def encode_symmetrized(self, view1, view2):
246
- img1 = view1["img"]
247
- img2 = view2["img"]
248
- B = img1.shape[0]
249
- # Recover true_shape when available, otherwise assume that the img shape is the true one
250
- shape1 = view1.get(
251
- "true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1)
252
- )
253
- shape2 = view2.get(
254
- "true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1)
255
- )
256
- # warning! maybe the images have different portrait/landscape orientations
257
-
258
- # privileged information
259
- rays1 = view1.get("known_rays", None)
260
- rays2 = view2.get("known_rays", None)
261
- depth1 = view1.get("known_depth", None)
262
- depth2 = view2.get("known_depth", None)
263
-
264
- if is_symmetrized(view1, view2):
265
- # computing half of forward pass!'
266
- def hsub(x):
267
- return None if x is None else x[::2]
268
-
269
- feat1, pos1 = self._encode_image(
270
- img1[::2], shape1[::2], rays=hsub(rays1), depth=hsub(depth1)
271
- )
272
- feat2, pos2 = self._encode_image(
273
- img2[::2], shape2[::2], rays=hsub(rays2), depth=hsub(depth2)
274
- )
275
-
276
- feat1, feat2 = interleave(feat1, feat2)
277
- pos1, pos2 = interleave(pos1, pos2)
278
- else:
279
- feat1, pos1 = self._encode_image(img1, shape1, rays=rays1, depth=depth1)
280
- feat2, pos2 = self._encode_image(img2, shape2, rays=rays2, depth=depth2)
281
-
282
- return (shape1, shape2), (feat1, feat2), (pos1, pos2)
283
-
284
- def _decoder(self, f1, pos1, f2, pos2, relpose1=None, relpose2=None):
285
- final_output = [(f1, f2)] # before projection
286
-
287
- # project to decoder dim
288
- f1 = self.decoder_embed(f1)
289
- f2 = self.decoder_embed(f2)
290
-
291
- # add CLS token for the decoder
292
- if self.dec_cls:
293
- cls1 = self.cls_token1[None, None].expand(len(f1), 1, -1).clone()
294
- cls2 = self.cls_token2[None, None].expand(len(f2), 1, -1).clone()
295
-
296
- if relpose1 is not None: # shape = (B, 4, 4)
297
- pose_emb1 = self.pose_embed(relpose1[:, :3].flatten(1)).unsqueeze(1)
298
- if self.mode.startswith("embed"):
299
- if self.dec_cls:
300
- cls1 = cls1 + pose_emb1
301
- else:
302
- f1 = f1 + pose_emb1
303
- else:
304
- pose_emb1 = None
305
-
306
- if relpose2 is not None: # shape = (B, 4, 4)
307
- pose_emb2 = self.pose_embed(relpose2[:, :3].flatten(1)).unsqueeze(1)
308
- if self.mode.startswith("embed"):
309
- if self.dec_cls:
310
- cls2 = cls2 + pose_emb2
311
- else:
312
- f2 = f2 + pose_emb2
313
- else:
314
- pose_emb2 = None
315
-
316
- if self.dec_cls:
317
- f1, pos1 = cat_cls(cls1, f1, pos1)
318
- f2, pos2 = cat_cls(cls2, f2, pos2)
319
-
320
- f1 = self.dec1_pre_ln(f1)
321
- f2 = self.dec2_pre_ln(f2)
322
-
323
- final_output.append((f1, f2)) # to be removed later
324
- for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):
325
- # img1 side
326
- f1, _ = blk1(
327
- *final_output[-1][::+1],
328
- pos1,
329
- pos2,
330
- relpose=pose_emb1,
331
- num_cls=self.dec_num_cls,
332
- )
333
- # img2 side
334
- f2, _ = blk2(
335
- *final_output[-1][::-1],
336
- pos2,
337
- pos1,
338
- relpose=pose_emb2,
339
- num_cls=self.dec_num_cls,
340
- )
341
- # store the result
342
- final_output.append((f1, f2))
343
-
344
- del final_output[1] # duplicate with final_output[0] (after decoder proj)
345
- if self.dec_cls: # remove cls token for decoder layers
346
- final_output[1:] = [(f1[:, 1:], f2[:, 1:]) for f1, f2 in final_output[1:]]
347
- # normalize last output
348
- final_output[-1] = tuple(map(self.dec_norm, final_output[-1]))
349
- return zip(*final_output)
350
-
351
- def _downstream_head(self, head_num, decout, img_shape):
352
- B, S, D = decout[-1].shape
353
- head = getattr(self, f"head{head_num}")
354
- return head(decout, img_shape)
355
-
356
- def forward(self, view1, view2):
357
- # encode the two images --> B,S,D
358
- (shape1, shape2), (feat1, feat2), (pos1, pos2) = self.encode_symmetrized(
359
- view1, view2
360
- )
361
-
362
- # combine all ref images into object-centric representation
363
- dec1, dec2 = self._decoder(
364
- feat1,
365
- pos1,
366
- feat2,
367
- pos2,
368
- relpose1=view1.get("known_pose"),
369
- relpose2=view2.get("known_pose"),
370
- )
371
- with torch.autocast("cuda", enabled=False):
372
- res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1)
373
- res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2)
374
-
375
- res2["pts3d_in_other_view"] = res2.pop(
376
- "pts3d"
377
- ) # predict view2's pts3d in view1's frame
378
- return res1, res2
379
-
380
-
381
- def convert_release_dust3r_args(args):
382
- args.model = (
383
- args.model.replace("patch_embed_cls", "patch_embed")
384
- .replace("AsymmetricMASt3R", "AsymmetricCroCo3DStereo")
385
- .replace("PatchEmbedDust3R", "convManyAR")
386
- .replace(
387
- "pos_embed='RoPE100'",
388
- "enc_pos_embed='cuRoPE100', dec_pos_embed='cuRoPE100'",
389
- )
390
- )
391
- return args
392
-
393
-
394
- def _load_model(model_path, device):
395
- print("... loading model from", model_path)
396
- ckpt = torch.load(model_path, map_location="cpu")
397
- try:
398
- net = eval(
399
- ckpt["args"].model[:-1].replace("convManyAR", "convP")
400
- + ", landscape_only=False)"
401
- )
402
- except Exception:
403
- args = convert_release_dust3r_args(ckpt["args"])
404
- net = eval(
405
- args.model[:-1].replace("convManyAR", "convP") + ", landscape_only=False)"
406
- )
407
- ckpt["model"] = {
408
- k.replace("_downstream_head", "downstream_head"): v
409
- for k, v in ckpt["model"].items()
410
- }
411
- print(net.load_state_dict(ckpt["model"], strict=False))
412
- return net.to(device)
413
-
414
-
415
- def cat_cls(cls, tokens, pos):
416
- tokens = torch.cat((cls, tokens), dim=1)
417
- pos = torch.cat((-pos.new_ones(len(cls), 1, 2), pos), dim=1)
418
- return tokens, pos
419
-
420
-
421
- class Pow3RWrapper(torch.nn.Module):
422
- def __init__(
423
- self,
424
- name,
425
- ckpt_path,
426
- geometric_input_config,
427
- **kwargs,
428
- ):
429
- super().__init__()
430
- self.name = name
431
- self.ckpt_path = ckpt_path
432
- self.geometric_input_config = geometric_input_config
433
-
434
- # Init the model and load the checkpoint
435
- print(f"Loading checkpoint from {self.ckpt_path} ...")
436
- ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False)
437
- model = ckpt["definition"]
438
- print(f"Creating model = {model}")
439
- self.model = eval(model)
440
- print(self.model.load_state_dict(ckpt["weights"]))
441
-
442
- def forward(self, views):
443
- """
444
- Forward pass wrapper for Pow3R.
445
-
446
- Assumption:
447
- - The number of input views is 2.
448
-
449
- Args:
450
- views (List[dict]): List of dictionaries containing the input views' images and instance information.
451
- Length of the list should be 2.
452
- Each dictionary should contain the following keys:
453
- "img" (tensor): Image tensor of shape (B, C, H, W).
454
- "data_norm_type" (list): ["dust3r"]
455
- Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs:
456
- "camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3).
457
- "camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation.
458
- "depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1).
459
-
460
- Returns:
461
- List[dict]: A list containing the final outputs for the two views. Length of the list will be 2.
462
- """
463
- # Check that the number of input views is 2
464
- assert len(views) == 2, "Pow3R requires 2 input views."
465
-
466
- # Check the data norm type
467
- data_norm_type = views[0]["data_norm_type"][0]
468
- assert data_norm_type == "dust3r", (
469
- "Pow3R expects a normalized image with the DUSt3R normalization scheme applied"
470
- )
471
-
472
- # Get the batch size per view, device and two views
473
- batch_size_per_view = views[0]["img"].shape[0]
474
- device = views[0]["img"].device
475
- view1, view2 = views
476
-
477
- # Decide if we need to use the geometric inputs
478
- if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]:
479
- # Decide if we need to use the camera intrinsics
480
- if (
481
- torch.rand(1, device=device)
482
- < self.geometric_input_config["ray_dirs_prob"]
483
- ):
484
- add_intrinsics(view1, view1.get("camera_intrinsics"))
485
- add_intrinsics(view2, view2.get("camera_intrinsics"))
486
-
487
- # Decide if we need to use the depth map
488
- if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]:
489
- depthmap1 = view1.get("depthmap")
490
- depthmap2 = view2.get("depthmap")
491
- if depthmap1 is not None:
492
- depthmap1 = depthmap1.squeeze(-1).to(device)
493
- if depthmap2 is not None:
494
- depthmap2 = depthmap2.squeeze(-1).to(device)
495
- add_depth(view1, depthmap1)
496
- add_depth(view2, depthmap2)
497
-
498
- # Decide if we need to use the camera pose
499
- if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]:
500
- cam1 = view1.get("camera_pose")
501
- cam2 = view2.get("camera_pose")
502
- add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1)
503
- add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1)
504
-
505
- # Get the model predictions
506
- preds = self.model(view1, view2)
507
-
508
- # Convert the output to MapAnything format
509
- with torch.autocast("cuda", enabled=False):
510
- res = []
511
- for view_idx in range(2):
512
- # Get the model predictions for the current view
513
- curr_view_pred = preds[view_idx]
514
-
515
- # For the first view
516
- if view_idx == 0:
517
- # Get the global frame and camera frame pointmaps
518
- global_pts = curr_view_pred["pts3d"]
519
- cam_pts = curr_view_pred["pts3d"]
520
- conf = curr_view_pred["conf"]
521
-
522
- # Get the ray directions and depth along ray
523
- depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True)
524
- ray_directions = cam_pts / depth_along_ray
525
-
526
- # Initalize identity camera pose
527
- cam_rot = torch.eye(3, device=device)
528
- cam_quat = mat_to_quat(cam_rot)
529
- cam_trans = torch.zeros(3, device=device)
530
- cam_quat = cam_quat.unsqueeze(0).repeat(batch_size_per_view, 1)
531
- cam_trans = cam_trans.unsqueeze(0).repeat(batch_size_per_view, 1)
532
- # For the second view
533
- elif view_idx == 1:
534
- # Get the global frame and camera frame pointmaps
535
- pred_global_pts = curr_view_pred["pts3d_in_other_view"]
536
- cam_pts = curr_view_pred["pts3d2"]
537
- conf = (curr_view_pred["conf"] * curr_view_pred["conf2"]).sqrt()
538
-
539
- # Get the ray directions and depth along ray
540
- depth_along_ray = torch.norm(cam_pts, dim=-1, keepdim=True)
541
- ray_directions = cam_pts / depth_along_ray
542
-
543
- # Compute the camera pose using the pointmaps
544
- cam_rot, cam_trans, scale = roma.rigid_points_registration(
545
- cam_pts.reshape(batch_size_per_view, -1, 3),
546
- pred_global_pts.reshape(batch_size_per_view, -1, 3),
547
- weights=conf.reshape(batch_size_per_view, -1),
548
- compute_scaling=True,
549
- )
550
- cam_quat = mat_to_quat(cam_rot)
551
-
552
- # Scale the predicted camera frame pointmap and compute the new global frame pointmap
553
- cam_pts = scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * cam_pts
554
- global_pts = cam_pts.reshape(
555
- batch_size_per_view, -1, 3
556
- ) @ cam_rot.permute(0, 2, 1) + cam_trans.unsqueeze(1)
557
- global_pts = global_pts.view(pred_global_pts.shape)
558
-
559
- # Append the result in MapAnything format
560
- res.append(
561
- {
562
- "pts3d": global_pts,
563
- "pts3d_cam": cam_pts,
564
- "ray_directions": ray_directions,
565
- "depth_along_ray": depth_along_ray,
566
- "cam_trans": cam_trans,
567
- "cam_quats": cam_quat,
568
- "conf": conf,
569
- }
570
- )
571
-
572
- return res
573
-
574
-
575
- class Pow3RBAWrapper(torch.nn.Module):
576
- def __init__(
577
- self,
578
- name,
579
- ckpt_path,
580
- geometric_input_config,
581
- scene_graph="complete",
582
- inference_batch_size=32,
583
- global_optim_schedule="cosine",
584
- global_optim_lr=0.01,
585
- global_optim_niter=300,
586
- **kwargs,
587
- ):
588
- super().__init__()
589
- self.name = name
590
- self.ckpt_path = ckpt_path
591
- self.geometric_input_config = geometric_input_config
592
- self.scene_graph = scene_graph
593
- self.inference_batch_size = inference_batch_size
594
- self.global_optim_schedule = global_optim_schedule
595
- self.global_optim_lr = global_optim_lr
596
- self.global_optim_niter = global_optim_niter
597
-
598
- # Init the model and load the checkpoint
599
- print(f"Loading checkpoint from {self.ckpt_path} ...")
600
- ckpt = torch.load(self.ckpt_path, map_location="cpu", weights_only=False)
601
- model = ckpt["definition"]
602
- print(f"Creating model = {model}")
603
- self.model = eval(model)
604
- print(self.model.load_state_dict(ckpt["weights"]))
605
-
606
- # Init the global aligner mode
607
- self.global_aligner_mode = GlobalAlignerMode.PointCloudOptimizer
608
-
609
- def infer_two_views(self, views):
610
- """
611
- Wrapper for Pow3R 2-View inference.
612
-
613
- Assumption:
614
- - The number of input views is 2.
615
-
616
- Args:
617
- views (List[dict]): List of dictionaries containing the input views' images and instance information.
618
- Length of the list should be 2.
619
- Each dictionary should contain the following keys:
620
- "img" (tensor): Image tensor of shape (B, C, H, W).
621
- "data_norm_type" (list): ["dust3r"]
622
- Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs:
623
- "camera_intrinsics" (tensor): Camera intrinsics. Tensor of shape (B, 3, 3).
624
- "camera_pose" (tensor): Camera pose. Tensor of shape (B, 4, 4). Camera pose is opencv (RDF) cam2world transformation.
625
- "depthmap" (tensor): Z Depth map. Tensor of shape (B, H, W, 1).
626
-
627
- Returns:
628
- List[dict]: A list containing the final outputs for the two views. Length of the list will be 2.
629
- """
630
- # Check that the number of input views is 2
631
- assert len(views) == 2, "Pow3R requires 2 input views."
632
-
633
- # Check the data norm type
634
- data_norm_type = views[0]["data_norm_type"][0]
635
- assert data_norm_type == "dust3r", (
636
- "Pow3R expects a normalized image with the DUSt3R normalization scheme applied"
637
- )
638
-
639
- # Get the device and two views
640
- device = views[0]["img"].device
641
- view1, view2 = views
642
-
643
- # Decide if we need to use the geometric inputs
644
- if torch.rand(1, device=device) < self.geometric_input_config["overall_prob"]:
645
- # Decide if we need to use the camera intrinsics
646
- if (
647
- torch.rand(1, device=device)
648
- < self.geometric_input_config["ray_dirs_prob"]
649
- ):
650
- add_intrinsics(view1, view1.get("camera_intrinsics"))
651
- add_intrinsics(view2, view2.get("camera_intrinsics"))
652
-
653
- # Decide if we need to use the depth map
654
- if torch.rand(1, device=device) < self.geometric_input_config["depth_prob"]:
655
- depthmap1 = view1.get("depthmap")
656
- depthmap2 = view2.get("depthmap")
657
- if depthmap1 is not None:
658
- depthmap1 = depthmap1.squeeze(-1).to(device)
659
- if depthmap2 is not None:
660
- depthmap2 = depthmap2.squeeze(-1).to(device)
661
- add_depth(view1, depthmap1)
662
- add_depth(view2, depthmap2)
663
-
664
- # Decide if we need to use the camera pose
665
- if torch.rand(1, device=device) < self.geometric_input_config["cam_prob"]:
666
- cam1 = view1.get("camera_pose")
667
- cam2 = view2.get("camera_pose")
668
- add_relpose(view1, cam2_to_world=cam2, cam1_to_world=cam1)
669
- add_relpose(view2, cam2_to_world=cam2, cam1_to_world=cam1)
670
-
671
- # Get the model predictions
672
- preds = self.model(view1, view2)
673
-
674
- return preds
675
-
676
- def loss_of_one_batch(self, batch, device):
677
- """
678
- Compute prediction for two views.
679
- """
680
- view1, view2 = batch
681
- ignore_keys = set(
682
- [
683
- "dataset",
684
- "label",
685
- "instance",
686
- "idx",
687
- "true_shape",
688
- "rng",
689
- "name",
690
- "data_norm_type",
691
- ]
692
- )
693
- for view in batch:
694
- for name in view.keys(): # pseudo_focal
695
- if name in ignore_keys:
696
- continue
697
- view[name] = view[name].to(device, non_blocking=True)
698
-
699
- pred1, pred2 = self.infer_two_views([view1, view2])
700
-
701
- result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2)
702
-
703
- return result
704
-
705
- @torch.no_grad()
706
- def inference(self, pairs, device, verbose=False):
707
- """
708
- Wrapper for multi-pair inference using Pow3R.
709
- """
710
- if verbose:
711
- print(f">> Inference with model on {len(pairs)} image pairs")
712
- result = []
713
-
714
- multiple_shapes = not (check_if_same_size(pairs))
715
- if multiple_shapes:
716
- self.inference_batch_size = 1
717
-
718
- for i in tqdm.trange(
719
- 0, len(pairs), self.inference_batch_size, disable=not verbose
720
- ):
721
- res = self.loss_of_one_batch(
722
- collate_with_cat(pairs[i : i + self.inference_batch_size]), device
723
- )
724
- result.append(to_cpu(res))
725
-
726
- result = collate_with_cat(result, lists=multiple_shapes)
727
-
728
- return result
729
-
730
- def forward(self, views):
731
- """
732
- Forward pass wrapper for Pow3R using the global aligner.
733
-
734
- Assumption:
735
- - The batch size of input views is 1.
736
-
737
- Args:
738
- views (List[dict]): List of dictionaries containing the input views' images and instance information.
739
- Each dictionary should contain the following keys, where B is the batch size and is 1:
740
- "img" (tensor): Image tensor of shape (B, C, H, W).
741
- "data_norm_type" (list): ["dust3r"]
742
-
743
- Returns:
744
- List[dict]: A list containing the final outputs for the input views.
745
- """
746
- # Check the batch size of input views
747
- batch_size_per_view, _, height, width = views[0]["img"].shape
748
- device = views[0]["img"].device
749
- num_views = len(views)
750
- assert batch_size_per_view == 1, (
751
- f"Batch size of input views should be 1, but got {batch_size_per_view}."
752
- )
753
-
754
- # Check the data norm type
755
- data_norm_type = views[0]["data_norm_type"][0]
756
- assert data_norm_type == "dust3r", (
757
- "Pow3R-BA expects a normalized image with the DUSt3R normalization scheme applied"
758
- )
759
-
760
- # Convert the input views to the expected input format
761
- images = []
762
- for view in views:
763
- images.append(
764
- dict(
765
- img=view["img"],
766
- camera_intrinsics=view["camera_intrinsics"],
767
- depthmap=view["depthmap"],
768
- camera_pose=view["camera_pose"],
769
- data_norm_type=view["data_norm_type"],
770
- true_shape=view["true_shape"],
771
- idx=len(images),
772
- instance=str(len(images)),
773
- )
774
- )
775
-
776
- # Make image pairs and run inference pair-wise
777
- pairs = make_pairs(
778
- images, scene_graph=self.scene_graph, prefilter=None, symmetrize=True
779
- )
780
- with warnings.catch_warnings():
781
- warnings.simplefilter("ignore", category=FutureWarning)
782
- output = self.inference(
783
- pairs,
784
- device,
785
- verbose=False,
786
- )
787
-
788
- # Global optimization
789
- with torch.enable_grad():
790
- scene = global_aligner(
791
- output, device=device, mode=self.global_aligner_mode, verbose=False
792
- )
793
- _ = scene.compute_global_alignment(
794
- init="mst",
795
- niter=self.global_optim_niter,
796
- schedule=self.global_optim_schedule,
797
- lr=self.global_optim_lr,
798
- )
799
-
800
- # Make sure scene is not None
801
- if scene is None:
802
- raise RuntimeError("Global optimization failed.")
803
-
804
- # Get the predictions
805
- intrinsics = scene.get_intrinsics()
806
- c2w_poses = scene.get_im_poses()
807
- depths = scene.get_depthmaps()
808
-
809
- # Convert the output to the MapAnything format
810
- with torch.autocast("cuda", enabled=False):
811
- res = []
812
- for view_idx in range(num_views):
813
- # Get the current view predictions
814
- curr_view_intrinsic = intrinsics[view_idx].unsqueeze(0)
815
- curr_view_pose = c2w_poses[view_idx].unsqueeze(0)
816
- curr_view_depth_z = depths[view_idx].unsqueeze(0)
817
-
818
- # Convert the pose to quaternions and translation
819
- curr_view_cam_translations = curr_view_pose[..., :3, 3]
820
- curr_view_cam_quats = mat_to_quat(curr_view_pose[..., :3, :3])
821
-
822
- # Get the camera frame pointmaps
823
- curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
824
- curr_view_depth_z, curr_view_intrinsic
825
- )
826
-
827
- # Convert the z depth to depth along ray
828
- curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
829
- curr_view_depth_z, curr_view_intrinsic
830
- )
831
- curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
832
-
833
- # Get the ray directions on the unit sphere in the camera frame
834
- _, curr_view_ray_dirs = get_rays_in_camera_frame(
835
- curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
836
- )
837
-
838
- # Get the pointmaps
839
- curr_view_pts3d = (
840
- convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
841
- curr_view_ray_dirs,
842
- curr_view_depth_along_ray,
843
- curr_view_cam_translations,
844
- curr_view_cam_quats,
845
- )
846
- )
847
-
848
- # Append the outputs to the result list
849
- res.append(
850
- {
851
- "pts3d": curr_view_pts3d,
852
- "pts3d_cam": curr_view_pts3d_cam,
853
- "ray_directions": curr_view_ray_dirs,
854
- "depth_along_ray": curr_view_depth_along_ray,
855
- "cam_trans": curr_view_cam_translations,
856
- "cam_quats": curr_view_cam_quats,
857
- }
858
- )
859
-
860
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/vggt/__init__.py DELETED
@@ -1,186 +0,0 @@
1
- """
2
- Inference wrapper for VGGT
3
- """
4
-
5
- import torch
6
-
7
- from mapanything.models.external.vggt.models.vggt import VGGT
8
- from mapanything.models.external.vggt.utils.geometry import closed_form_inverse_se3
9
- from mapanything.models.external.vggt.utils.pose_enc import pose_encoding_to_extri_intri
10
- from mapanything.models.external.vggt.utils.rotation import mat_to_quat
11
- from mapanything.utils.geometry import (
12
- convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
13
- convert_z_depth_to_depth_along_ray,
14
- depthmap_to_camera_frame,
15
- get_rays_in_camera_frame,
16
- )
17
-
18
-
19
- class VGGTWrapper(torch.nn.Module):
20
- def __init__(
21
- self,
22
- name,
23
- torch_hub_force_reload,
24
- load_pretrained_weights=True,
25
- depth=24,
26
- num_heads=16,
27
- intermediate_layer_idx=[4, 11, 17, 23],
28
- load_custom_ckpt=False,
29
- custom_ckpt_path=None,
30
- ):
31
- super().__init__()
32
- self.name = name
33
- self.torch_hub_force_reload = torch_hub_force_reload
34
- self.load_custom_ckpt = load_custom_ckpt
35
- self.custom_ckpt_path = custom_ckpt_path
36
-
37
- if load_pretrained_weights:
38
- # Load pre-trained weights
39
- if not torch_hub_force_reload:
40
- # Initialize the 1B VGGT model from huggingface hub cache
41
- print("Loading facebook/VGGT-1B from huggingface cache ...")
42
- self.model = VGGT.from_pretrained(
43
- "facebook/VGGT-1B",
44
- )
45
- else:
46
- # Initialize the 1B VGGT model
47
- print("Re-downloading facebook/VGGT-1B ...")
48
- self.model = VGGT.from_pretrained(
49
- "facebook/VGGT-1B", force_download=True
50
- )
51
- else:
52
- # Load the VGGT class
53
- self.model = VGGT(
54
- depth=depth,
55
- num_heads=num_heads,
56
- intermediate_layer_idx=intermediate_layer_idx,
57
- )
58
-
59
- # Get the dtype for VGGT inference
60
- # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
61
- self.dtype = (
62
- torch.bfloat16
63
- if torch.cuda.get_device_capability()[0] >= 8
64
- else torch.float16
65
- )
66
-
67
- # Load custom checkpoint if requested
68
- if self.load_custom_ckpt:
69
- print(f"Loading checkpoint from {self.custom_ckpt_path} ...")
70
- assert self.custom_ckpt_path is not None, (
71
- "custom_ckpt_path must be provided if load_custom_ckpt is set to True"
72
- )
73
- custom_ckpt = torch.load(self.custom_ckpt_path, weights_only=False)
74
- print(self.model.load_state_dict(custom_ckpt, strict=True))
75
- del custom_ckpt # in case it occupies memory
76
-
77
- def forward(self, views):
78
- """
79
- Forward pass wrapper for VGGT
80
-
81
- Assumption:
82
- - All the input views have the same image shape.
83
-
84
- Args:
85
- views (List[dict]): List of dictionaries containing the input views' images and instance information.
86
- Each dictionary should contain the following keys:
87
- "img" (tensor): Image tensor of shape (B, C, H, W).
88
- "data_norm_type" (list): ["identity"]
89
-
90
- Returns:
91
- List[dict]: A list containing the final outputs for all N views.
92
- """
93
- # Get input shape of the images, number of views, and batch size per view
94
- batch_size_per_view, _, height, width = views[0]["img"].shape
95
- num_views = len(views)
96
-
97
- # Check the data norm type
98
- # VGGT expects a normalized image but without the DINOv2 mean and std applied ("identity")
99
- data_norm_type = views[0]["data_norm_type"][0]
100
- assert data_norm_type == "identity", (
101
- "VGGT expects a normalized image but without the DINOv2 mean and std applied"
102
- )
103
-
104
- # Concatenate the images to create a single (B, V, C, H, W) tensor
105
- img_list = [view["img"] for view in views]
106
- images = torch.stack(img_list, dim=1)
107
-
108
- # Run the VGGT aggregator
109
- with torch.autocast("cuda", dtype=self.dtype):
110
- aggregated_tokens_list, ps_idx = self.model.aggregator(images)
111
-
112
- # Run the Camera + Pose Branch of VGGT
113
- with torch.autocast("cuda", enabled=False):
114
- # Predict Cameras
115
- pose_enc = self.model.camera_head(aggregated_tokens_list)[-1]
116
- # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
117
- # Extrinsics Shape: (B, V, 3, 4)
118
- # Intrinsics Shape: (B, V, 3, 3)
119
- extrinsic, intrinsic = pose_encoding_to_extri_intri(
120
- pose_enc, images.shape[-2:]
121
- )
122
-
123
- # Predict Depth Maps
124
- # Depth Shape: (B, V, H, W, 1)
125
- # Depth Confidence Shape: (B, V, H, W)
126
- depth_map, depth_conf = self.model.depth_head(
127
- aggregated_tokens_list, images, ps_idx
128
- )
129
-
130
- # Convert the output to MapAnything format
131
- res = []
132
- for view_idx in range(num_views):
133
- # Get the extrinsics, intrinsics, depth map for the current view
134
- curr_view_extrinsic = extrinsic[:, view_idx, ...]
135
- curr_view_extrinsic = closed_form_inverse_se3(
136
- curr_view_extrinsic
137
- ) # Convert to cam2world
138
- curr_view_intrinsic = intrinsic[:, view_idx, ...]
139
- curr_view_depth_z = depth_map[:, view_idx, ...]
140
- curr_view_depth_z = curr_view_depth_z.squeeze(-1)
141
- curr_view_confidence = depth_conf[:, view_idx, ...]
142
-
143
- # Get the camera frame pointmaps
144
- curr_view_pts3d_cam, _ = depthmap_to_camera_frame(
145
- curr_view_depth_z, curr_view_intrinsic
146
- )
147
-
148
- # Convert the extrinsics to quaternions and translations
149
- curr_view_cam_translations = curr_view_extrinsic[..., :3, 3]
150
- curr_view_cam_quats = mat_to_quat(curr_view_extrinsic[..., :3, :3])
151
-
152
- # Convert the z depth to depth along ray
153
- curr_view_depth_along_ray = convert_z_depth_to_depth_along_ray(
154
- curr_view_depth_z, curr_view_intrinsic
155
- )
156
- curr_view_depth_along_ray = curr_view_depth_along_ray.unsqueeze(-1)
157
-
158
- # Get the ray directions on the unit sphere in the camera frame
159
- _, curr_view_ray_dirs = get_rays_in_camera_frame(
160
- curr_view_intrinsic, height, width, normalize_to_unit_sphere=True
161
- )
162
-
163
- # Get the pointmaps
164
- curr_view_pts3d = (
165
- convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
166
- curr_view_ray_dirs,
167
- curr_view_depth_along_ray,
168
- curr_view_cam_translations,
169
- curr_view_cam_quats,
170
- )
171
- )
172
-
173
- # Append the outputs to the result list
174
- res.append(
175
- {
176
- "pts3d": curr_view_pts3d,
177
- "pts3d_cam": curr_view_pts3d_cam,
178
- "ray_directions": curr_view_ray_dirs,
179
- "depth_along_ray": curr_view_depth_along_ray,
180
- "cam_trans": curr_view_cam_translations,
181
- "cam_quats": curr_view_cam_quats,
182
- "conf": curr_view_confidence,
183
- }
184
- )
185
-
186
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/vggt/heads/__init__.py DELETED
File without changes
mapanything/models/external/vggt/heads/camera_head.py DELETED
@@ -1,167 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- import torch
9
- import torch.nn as nn
10
-
11
- from mapanything.models.external.vggt.heads.head_act import activate_pose
12
- from mapanything.models.external.vggt.layers import Mlp
13
- from mapanything.models.external.vggt.layers.block import Block
14
-
15
-
16
- class CameraHead(nn.Module):
17
- """
18
- CameraHead predicts camera parameters from token representations using iterative refinement.
19
-
20
- It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
21
- """
22
-
23
- def __init__(
24
- self,
25
- dim_in: int = 2048,
26
- trunk_depth: int = 4,
27
- pose_encoding_type: str = "absT_quaR_FoV",
28
- num_heads: int = 16,
29
- mlp_ratio: int = 4,
30
- init_values: float = 0.01,
31
- trans_act: str = "linear",
32
- quat_act: str = "linear",
33
- fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
34
- ):
35
- super().__init__()
36
-
37
- if pose_encoding_type == "absT_quaR_FoV":
38
- self.target_dim = 9
39
- else:
40
- raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
41
-
42
- self.trans_act = trans_act
43
- self.quat_act = quat_act
44
- self.fl_act = fl_act
45
- self.trunk_depth = trunk_depth
46
-
47
- # Build the trunk using a sequence of transformer blocks.
48
- self.trunk = nn.Sequential(
49
- *[
50
- Block(
51
- dim=dim_in,
52
- num_heads=num_heads,
53
- mlp_ratio=mlp_ratio,
54
- init_values=init_values,
55
- )
56
- for _ in range(trunk_depth)
57
- ]
58
- )
59
-
60
- # Normalizations for camera token and trunk output.
61
- self.token_norm = nn.LayerNorm(dim_in)
62
- self.trunk_norm = nn.LayerNorm(dim_in)
63
-
64
- # Learnable empty camera pose token.
65
- self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
66
- self.embed_pose = nn.Linear(self.target_dim, dim_in)
67
-
68
- # Module for producing modulation parameters: shift, scale, and a gate.
69
- self.poseLN_modulation = nn.Sequential(
70
- nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)
71
- )
72
-
73
- # Adaptive layer normalization without affine parameters.
74
- self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
75
- self.pose_branch = Mlp(
76
- in_features=dim_in,
77
- hidden_features=dim_in // 2,
78
- out_features=self.target_dim,
79
- drop=0,
80
- )
81
-
82
- def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
83
- """
84
- Forward pass to predict camera parameters.
85
-
86
- Args:
87
- aggregated_tokens_list (list): List of token tensors from the network;
88
- the last tensor is used for prediction.
89
- num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
90
-
91
- Returns:
92
- list: A list of predicted camera encodings (post-activation) from each iteration.
93
- """
94
- # Use tokens from the last block for camera prediction.
95
- tokens = aggregated_tokens_list[-1]
96
-
97
- # Extract the camera tokens
98
- pose_tokens = tokens[:, :, 0]
99
- pose_tokens = self.token_norm(pose_tokens)
100
-
101
- pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
102
- return pred_pose_enc_list
103
-
104
- def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
105
- """
106
- Iteratively refine camera pose predictions.
107
-
108
- Args:
109
- pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
110
- num_iterations (int): Number of refinement iterations.
111
-
112
- Returns:
113
- list: List of activated camera encodings from each iteration.
114
- """
115
- B, S, C = pose_tokens.shape # S is expected to be 1.
116
- pred_pose_enc = None
117
- pred_pose_enc_list = []
118
-
119
- for _ in range(num_iterations):
120
- # Use a learned empty pose for the first iteration.
121
- if pred_pose_enc is None:
122
- module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
123
- else:
124
- # Detach the previous prediction to avoid backprop through time.
125
- pred_pose_enc = pred_pose_enc.detach()
126
- module_input = self.embed_pose(pred_pose_enc)
127
-
128
- # Generate modulation parameters and split them into shift, scale, and gate components.
129
- shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(
130
- 3, dim=-1
131
- )
132
-
133
- # Adaptive layer normalization and modulation.
134
- pose_tokens_modulated = gate_msa * modulate(
135
- self.adaln_norm(pose_tokens), shift_msa, scale_msa
136
- )
137
- pose_tokens_modulated = pose_tokens_modulated + pose_tokens
138
-
139
- pose_tokens_modulated = self.trunk(pose_tokens_modulated)
140
- # Compute the delta update for the pose encoding.
141
- pred_pose_enc_delta = self.pose_branch(
142
- self.trunk_norm(pose_tokens_modulated)
143
- )
144
-
145
- if pred_pose_enc is None:
146
- pred_pose_enc = pred_pose_enc_delta
147
- else:
148
- pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
149
-
150
- # Apply final activation functions for translation, quaternion, and field-of-view.
151
- activated_pose = activate_pose(
152
- pred_pose_enc,
153
- trans_act=self.trans_act,
154
- quat_act=self.quat_act,
155
- fl_act=self.fl_act,
156
- )
157
- pred_pose_enc_list.append(activated_pose)
158
-
159
- return pred_pose_enc_list
160
-
161
-
162
- def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
163
- """
164
- Modulate the input tensor using scaling and shifting parameters.
165
- """
166
- # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
167
- return x * (1 + scale) + shift
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/vggt/heads/dpt_head.py DELETED
@@ -1,600 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
-
10
-
11
- from typing import List, Tuple, Union
12
-
13
- import torch
14
- import torch.nn as nn
15
-
16
- from .head_act import activate_head
17
- from .utils import create_uv_grid, position_grid_to_embed
18
-
19
-
20
- class DPTHead(nn.Module):
21
- """
22
- DPT Head for dense prediction tasks.
23
-
24
- This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
25
- (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
26
- backbone and produces dense predictions by fusing multi-scale features.
27
-
28
- Args:
29
- dim_in (int): Input dimension (channels).
30
- patch_size (int, optional): Patch size. Default is 14.
31
- output_dim (int, optional): Number of output channels. Default is 4.
32
- activation (str, optional): Activation type. Default is "inv_log".
33
- conf_activation (str, optional): Confidence activation type. Default is "expp1".
34
- features (int, optional): Feature channels for intermediate representations. Default is 256.
35
- out_channels (List[int], optional): Output channels for each intermediate layer.
36
- intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
37
- pos_embed (bool, optional): Whether to use positional embedding. Default is True.
38
- feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
39
- down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
40
- """
41
-
42
- def __init__(
43
- self,
44
- dim_in: int,
45
- patch_size: int = 14,
46
- output_dim: int = 4,
47
- activation: str = "inv_log",
48
- conf_activation: str = "expp1",
49
- features: int = 256,
50
- out_channels: List[int] = [256, 512, 1024, 1024],
51
- intermediate_layer_idx: List[int] = [4, 11, 17, 23],
52
- pos_embed: bool = True,
53
- feature_only: bool = False,
54
- down_ratio: int = 1,
55
- ) -> None:
56
- super(DPTHead, self).__init__()
57
- self.patch_size = patch_size
58
- self.activation = activation
59
- self.conf_activation = conf_activation
60
- self.pos_embed = pos_embed
61
- self.feature_only = feature_only
62
- self.down_ratio = down_ratio
63
- self.intermediate_layer_idx = intermediate_layer_idx
64
-
65
- self.norm = nn.LayerNorm(dim_in)
66
-
67
- # Projection layers for each output channel from tokens.
68
- self.projects = nn.ModuleList(
69
- [
70
- nn.Conv2d(
71
- in_channels=dim_in,
72
- out_channels=oc,
73
- kernel_size=1,
74
- stride=1,
75
- padding=0,
76
- )
77
- for oc in out_channels
78
- ]
79
- )
80
-
81
- # Resize layers for upsampling feature maps.
82
- self.resize_layers = nn.ModuleList(
83
- [
84
- nn.ConvTranspose2d(
85
- in_channels=out_channels[0],
86
- out_channels=out_channels[0],
87
- kernel_size=4,
88
- stride=4,
89
- padding=0,
90
- ),
91
- nn.ConvTranspose2d(
92
- in_channels=out_channels[1],
93
- out_channels=out_channels[1],
94
- kernel_size=2,
95
- stride=2,
96
- padding=0,
97
- ),
98
- nn.Identity(),
99
- nn.Conv2d(
100
- in_channels=out_channels[3],
101
- out_channels=out_channels[3],
102
- kernel_size=3,
103
- stride=2,
104
- padding=1,
105
- ),
106
- ]
107
- )
108
-
109
- self.scratch = _make_scratch(
110
- out_channels,
111
- features,
112
- expand=False,
113
- )
114
-
115
- # Attach additional modules to scratch.
116
- self.scratch.stem_transpose = None
117
- self.scratch.refinenet1 = _make_fusion_block(features)
118
- self.scratch.refinenet2 = _make_fusion_block(features)
119
- self.scratch.refinenet3 = _make_fusion_block(features)
120
- self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
121
-
122
- head_features_1 = features
123
- head_features_2 = 32
124
-
125
- if feature_only:
126
- self.scratch.output_conv1 = nn.Conv2d(
127
- head_features_1, head_features_1, kernel_size=3, stride=1, padding=1
128
- )
129
- else:
130
- self.scratch.output_conv1 = nn.Conv2d(
131
- head_features_1,
132
- head_features_1 // 2,
133
- kernel_size=3,
134
- stride=1,
135
- padding=1,
136
- )
137
- conv2_in_channels = head_features_1 // 2
138
-
139
- self.scratch.output_conv2 = nn.Sequential(
140
- nn.Conv2d(
141
- conv2_in_channels,
142
- head_features_2,
143
- kernel_size=3,
144
- stride=1,
145
- padding=1,
146
- ),
147
- nn.ReLU(inplace=True),
148
- nn.Conv2d(
149
- head_features_2, output_dim, kernel_size=1, stride=1, padding=0
150
- ),
151
- )
152
-
153
- def forward(
154
- self,
155
- aggregated_tokens_list: List[torch.Tensor],
156
- images: torch.Tensor,
157
- patch_start_idx: int,
158
- frames_chunk_size: int = 8,
159
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
160
- """
161
- Forward pass through the DPT head, supports processing by chunking frames.
162
- Args:
163
- aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
164
- images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
165
- patch_start_idx (int): Starting index for patch tokens in the token sequence.
166
- Used to separate patch tokens from other tokens (e.g., camera or register tokens).
167
- frames_chunk_size (int, optional): Number of frames to process in each chunk.
168
- If None or larger than S, all frames are processed at once. Default: 8.
169
-
170
- Returns:
171
- Tensor or Tuple[Tensor, Tensor]:
172
- - If feature_only=True: Feature maps with shape [B, S, C, H, W]
173
- - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
174
- """
175
- B, S, _, H, W = images.shape
176
-
177
- # If frames_chunk_size is not specified or greater than S, process all frames at once
178
- if frames_chunk_size is None or frames_chunk_size >= S:
179
- return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
180
-
181
- # Otherwise, process frames in chunks to manage memory usage
182
- assert frames_chunk_size > 0
183
-
184
- # Process frames in batches
185
- all_preds = []
186
- all_conf = []
187
-
188
- for frames_start_idx in range(0, S, frames_chunk_size):
189
- frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
190
-
191
- # Process batch of frames
192
- if self.feature_only:
193
- chunk_output = self._forward_impl(
194
- aggregated_tokens_list,
195
- images,
196
- patch_start_idx,
197
- frames_start_idx,
198
- frames_end_idx,
199
- )
200
- all_preds.append(chunk_output)
201
- else:
202
- chunk_preds, chunk_conf = self._forward_impl(
203
- aggregated_tokens_list,
204
- images,
205
- patch_start_idx,
206
- frames_start_idx,
207
- frames_end_idx,
208
- )
209
- all_preds.append(chunk_preds)
210
- all_conf.append(chunk_conf)
211
-
212
- # Concatenate results along the sequence dimension
213
- if self.feature_only:
214
- return torch.cat(all_preds, dim=1)
215
- else:
216
- return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
217
-
218
- def _forward_impl(
219
- self,
220
- aggregated_tokens_list: List[torch.Tensor],
221
- images: torch.Tensor,
222
- patch_start_idx: int,
223
- frames_start_idx: int = None,
224
- frames_end_idx: int = None,
225
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
226
- """
227
- Implementation of the forward pass through the DPT head.
228
-
229
- This method processes a specific chunk of frames from the sequence.
230
-
231
- Args:
232
- aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
233
- images (Tensor): Input images with shape [B, S, 3, H, W].
234
- patch_start_idx (int): Starting index for patch tokens.
235
- frames_start_idx (int, optional): Starting index for frames to process.
236
- frames_end_idx (int, optional): Ending index for frames to process.
237
-
238
- Returns:
239
- Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
240
- """
241
- if frames_start_idx is not None and frames_end_idx is not None:
242
- images = images[:, frames_start_idx:frames_end_idx].contiguous()
243
-
244
- B, S, _, H, W = images.shape
245
-
246
- patch_h, patch_w = H // self.patch_size, W // self.patch_size
247
-
248
- out = []
249
- dpt_idx = 0
250
-
251
- for layer_idx in self.intermediate_layer_idx:
252
- x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
253
-
254
- # Select frames if processing a chunk
255
- if frames_start_idx is not None and frames_end_idx is not None:
256
- x = x[:, frames_start_idx:frames_end_idx]
257
-
258
- x = x.reshape(B * S, -1, x.shape[-1])
259
-
260
- x = self.norm(x)
261
-
262
- x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
263
-
264
- x = self.projects[dpt_idx](x)
265
- if self.pos_embed:
266
- x = self._apply_pos_embed(x, W, H)
267
- x = self.resize_layers[dpt_idx](x)
268
-
269
- out.append(x)
270
- dpt_idx += 1
271
-
272
- # Fuse features from multiple layers.
273
- out = self.scratch_forward(out)
274
- # Interpolate fused output to match target image resolution.
275
- out = custom_interpolate(
276
- out,
277
- (
278
- int(patch_h * self.patch_size / self.down_ratio),
279
- int(patch_w * self.patch_size / self.down_ratio),
280
- ),
281
- mode="bilinear",
282
- align_corners=True,
283
- )
284
-
285
- if self.pos_embed:
286
- out = self._apply_pos_embed(out, W, H)
287
-
288
- if self.feature_only:
289
- return out.view(B, S, *out.shape[1:])
290
-
291
- out = self.scratch.output_conv2(out)
292
- preds, conf = activate_head(
293
- out, activation=self.activation, conf_activation=self.conf_activation
294
- )
295
-
296
- preds = preds.view(B, S, *preds.shape[1:])
297
- conf = conf.view(B, S, *conf.shape[1:])
298
- return preds, conf
299
-
300
- def _apply_pos_embed(
301
- self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1
302
- ) -> torch.Tensor:
303
- """
304
- Apply positional embedding to tensor x.
305
- """
306
- patch_w = x.shape[-1]
307
- patch_h = x.shape[-2]
308
- pos_embed = create_uv_grid(
309
- patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device
310
- )
311
- pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
312
- pos_embed = pos_embed * ratio
313
- pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
314
- return x + pos_embed
315
-
316
- def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
317
- """
318
- Forward pass through the fusion blocks.
319
-
320
- Args:
321
- features (List[Tensor]): List of feature maps from different layers.
322
-
323
- Returns:
324
- Tensor: Fused feature map.
325
- """
326
- layer_1, layer_2, layer_3, layer_4 = features
327
-
328
- layer_1_rn = self.scratch.layer1_rn(layer_1)
329
- layer_2_rn = self.scratch.layer2_rn(layer_2)
330
- layer_3_rn = self.scratch.layer3_rn(layer_3)
331
- layer_4_rn = self.scratch.layer4_rn(layer_4)
332
-
333
- out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
334
- del layer_4_rn, layer_4
335
-
336
- out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
337
- del layer_3_rn, layer_3
338
-
339
- out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
340
- del layer_2_rn, layer_2
341
-
342
- out = self.scratch.refinenet1(out, layer_1_rn)
343
- del layer_1_rn, layer_1
344
-
345
- out = self.scratch.output_conv1(out)
346
- return out
347
-
348
-
349
- ################################################################################
350
- # Modules
351
- ################################################################################
352
-
353
-
354
- def _make_fusion_block(
355
- features: int, size: int = None, has_residual: bool = True, groups: int = 1
356
- ) -> nn.Module:
357
- return FeatureFusionBlock(
358
- features,
359
- nn.ReLU(inplace=True),
360
- deconv=False,
361
- bn=False,
362
- expand=False,
363
- align_corners=True,
364
- size=size,
365
- has_residual=has_residual,
366
- groups=groups,
367
- )
368
-
369
-
370
- def _make_scratch(
371
- in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
372
- ) -> nn.Module:
373
- scratch = nn.Module()
374
- out_shape1 = out_shape
375
- out_shape2 = out_shape
376
- out_shape3 = out_shape
377
- if len(in_shape) >= 4:
378
- out_shape4 = out_shape
379
-
380
- if expand:
381
- out_shape1 = out_shape
382
- out_shape2 = out_shape * 2
383
- out_shape3 = out_shape * 4
384
- if len(in_shape) >= 4:
385
- out_shape4 = out_shape * 8
386
-
387
- scratch.layer1_rn = nn.Conv2d(
388
- in_shape[0],
389
- out_shape1,
390
- kernel_size=3,
391
- stride=1,
392
- padding=1,
393
- bias=False,
394
- groups=groups,
395
- )
396
- scratch.layer2_rn = nn.Conv2d(
397
- in_shape[1],
398
- out_shape2,
399
- kernel_size=3,
400
- stride=1,
401
- padding=1,
402
- bias=False,
403
- groups=groups,
404
- )
405
- scratch.layer3_rn = nn.Conv2d(
406
- in_shape[2],
407
- out_shape3,
408
- kernel_size=3,
409
- stride=1,
410
- padding=1,
411
- bias=False,
412
- groups=groups,
413
- )
414
- if len(in_shape) >= 4:
415
- scratch.layer4_rn = nn.Conv2d(
416
- in_shape[3],
417
- out_shape4,
418
- kernel_size=3,
419
- stride=1,
420
- padding=1,
421
- bias=False,
422
- groups=groups,
423
- )
424
- return scratch
425
-
426
-
427
- class ResidualConvUnit(nn.Module):
428
- """Residual convolution module."""
429
-
430
- def __init__(self, features, activation, bn, groups=1):
431
- """Init.
432
-
433
- Args:
434
- features (int): number of features
435
- """
436
- super().__init__()
437
-
438
- self.bn = bn
439
- self.groups = groups
440
- self.conv1 = nn.Conv2d(
441
- features,
442
- features,
443
- kernel_size=3,
444
- stride=1,
445
- padding=1,
446
- bias=True,
447
- groups=self.groups,
448
- )
449
- self.conv2 = nn.Conv2d(
450
- features,
451
- features,
452
- kernel_size=3,
453
- stride=1,
454
- padding=1,
455
- bias=True,
456
- groups=self.groups,
457
- )
458
-
459
- self.norm1 = None
460
- self.norm2 = None
461
-
462
- self.activation = activation
463
- self.skip_add = nn.quantized.FloatFunctional()
464
-
465
- def forward(self, x):
466
- """Forward pass.
467
-
468
- Args:
469
- x (tensor): input
470
-
471
- Returns:
472
- tensor: output
473
- """
474
-
475
- out = self.activation(x)
476
- out = self.conv1(out)
477
- if self.norm1 is not None:
478
- out = self.norm1(out)
479
-
480
- out = self.activation(out)
481
- out = self.conv2(out)
482
- if self.norm2 is not None:
483
- out = self.norm2(out)
484
-
485
- return self.skip_add.add(out, x)
486
-
487
-
488
- class FeatureFusionBlock(nn.Module):
489
- """Feature fusion block."""
490
-
491
- def __init__(
492
- self,
493
- features,
494
- activation,
495
- deconv=False,
496
- bn=False,
497
- expand=False,
498
- align_corners=True,
499
- size=None,
500
- has_residual=True,
501
- groups=1,
502
- ):
503
- """Init.
504
-
505
- Args:
506
- features (int): number of features
507
- """
508
- super(FeatureFusionBlock, self).__init__()
509
-
510
- self.deconv = deconv
511
- self.align_corners = align_corners
512
- self.groups = groups
513
- self.expand = expand
514
- out_features = features
515
- if self.expand:
516
- out_features = features // 2
517
-
518
- self.out_conv = nn.Conv2d(
519
- features,
520
- out_features,
521
- kernel_size=1,
522
- stride=1,
523
- padding=0,
524
- bias=True,
525
- groups=self.groups,
526
- )
527
-
528
- if has_residual:
529
- self.resConfUnit1 = ResidualConvUnit(
530
- features, activation, bn, groups=self.groups
531
- )
532
-
533
- self.has_residual = has_residual
534
- self.resConfUnit2 = ResidualConvUnit(
535
- features, activation, bn, groups=self.groups
536
- )
537
-
538
- self.skip_add = nn.quantized.FloatFunctional()
539
- self.size = size
540
-
541
- def forward(self, *xs, size=None):
542
- """Forward pass.
543
-
544
- Returns:
545
- tensor: output
546
- """
547
- output = xs[0]
548
-
549
- if self.has_residual:
550
- res = self.resConfUnit1(xs[1])
551
- output = self.skip_add.add(output, res)
552
-
553
- output = self.resConfUnit2(output)
554
-
555
- if (size is None) and (self.size is None):
556
- modifier = {"scale_factor": 2}
557
- elif size is None:
558
- modifier = {"size": self.size}
559
- else:
560
- modifier = {"size": size}
561
-
562
- output = custom_interpolate(
563
- output, **modifier, mode="bilinear", align_corners=self.align_corners
564
- )
565
- output = self.out_conv(output)
566
-
567
- return output
568
-
569
-
570
- def custom_interpolate(
571
- x: torch.Tensor,
572
- size: Tuple[int, int] = None,
573
- scale_factor: float = None,
574
- mode: str = "bilinear",
575
- align_corners: bool = True,
576
- ) -> torch.Tensor:
577
- """
578
- Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
579
- """
580
- if size is None:
581
- size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
582
-
583
- INT_MAX = 1610612736
584
-
585
- input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
586
-
587
- if input_elements > INT_MAX:
588
- chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
589
- interpolated_chunks = [
590
- nn.functional.interpolate(
591
- chunk, size=size, mode=mode, align_corners=align_corners
592
- )
593
- for chunk in chunks
594
- ]
595
- x = torch.cat(interpolated_chunks, dim=0)
596
- return x.contiguous()
597
- else:
598
- return nn.functional.interpolate(
599
- x, size=size, mode=mode, align_corners=align_corners
600
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/vggt/heads/head_act.py DELETED
@@ -1,127 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- import torch
9
- import torch.nn.functional as F
10
-
11
-
12
- def activate_pose(
13
- pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"
14
- ):
15
- """
16
- Activate pose parameters with specified activation functions.
17
-
18
- Args:
19
- pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
20
- trans_act: Activation type for translation component
21
- quat_act: Activation type for quaternion component
22
- fl_act: Activation type for focal length component
23
-
24
- Returns:
25
- Activated pose parameters tensor
26
- """
27
- T = pred_pose_enc[..., :3]
28
- quat = pred_pose_enc[..., 3:7]
29
- fl = pred_pose_enc[..., 7:] # or fov
30
-
31
- T = base_pose_act(T, trans_act)
32
- quat = base_pose_act(quat, quat_act)
33
- fl = base_pose_act(fl, fl_act) # or fov
34
-
35
- pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
36
-
37
- return pred_pose_enc
38
-
39
-
40
- def base_pose_act(pose_enc, act_type="linear"):
41
- """
42
- Apply basic activation function to pose parameters.
43
-
44
- Args:
45
- pose_enc: Tensor containing encoded pose parameters
46
- act_type: Activation type ("linear", "inv_log", "exp", "relu")
47
-
48
- Returns:
49
- Activated pose parameters
50
- """
51
- if act_type == "linear":
52
- return pose_enc
53
- elif act_type == "inv_log":
54
- return inverse_log_transform(pose_enc)
55
- elif act_type == "exp":
56
- return torch.exp(pose_enc)
57
- elif act_type == "relu":
58
- return F.relu(pose_enc)
59
- else:
60
- raise ValueError(f"Unknown act_type: {act_type}")
61
-
62
-
63
- def activate_head(out, activation="norm_exp", conf_activation="expp1"):
64
- """
65
- Process network output to extract 3D points and confidence values.
66
-
67
- Args:
68
- out: Network output tensor (B, C, H, W)
69
- activation: Activation type for 3D points
70
- conf_activation: Activation type for confidence values
71
-
72
- Returns:
73
- Tuple of (3D points tensor, confidence tensor)
74
- """
75
- # Move channels from last dim to the 4th dimension => (B, H, W, C)
76
- fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
77
-
78
- # Split into xyz (first C-1 channels) and confidence (last channel)
79
- xyz = fmap[:, :, :, :-1]
80
- conf = fmap[:, :, :, -1]
81
-
82
- if activation == "norm_exp":
83
- d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
84
- xyz_normed = xyz / d
85
- pts3d = xyz_normed * torch.expm1(d)
86
- elif activation == "norm":
87
- pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
88
- elif activation == "exp":
89
- pts3d = torch.exp(xyz)
90
- elif activation == "relu":
91
- pts3d = F.relu(xyz)
92
- elif activation == "inv_log":
93
- pts3d = inverse_log_transform(xyz)
94
- elif activation == "xy_inv_log":
95
- xy, z = xyz.split([2, 1], dim=-1)
96
- z = inverse_log_transform(z)
97
- pts3d = torch.cat([xy * z, z], dim=-1)
98
- elif activation == "sigmoid":
99
- pts3d = torch.sigmoid(xyz)
100
- elif activation == "linear":
101
- pts3d = xyz
102
- else:
103
- raise ValueError(f"Unknown activation: {activation}")
104
-
105
- if conf_activation == "expp1":
106
- conf_out = 1 + conf.exp()
107
- elif conf_activation == "expp0":
108
- conf_out = conf.exp()
109
- elif conf_activation == "sigmoid":
110
- conf_out = torch.sigmoid(conf)
111
- else:
112
- raise ValueError(f"Unknown conf_activation: {conf_activation}")
113
-
114
- return pts3d, conf_out
115
-
116
-
117
- def inverse_log_transform(y):
118
- """
119
- Apply inverse log transform: sign(y) * (exp(|y|) - 1)
120
-
121
- Args:
122
- y: Input tensor
123
-
124
- Returns:
125
- Transformed tensor
126
- """
127
- return torch.sign(y) * (torch.expm1(torch.abs(y)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/vggt/heads/track_head.py DELETED
@@ -1,118 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import torch.nn as nn
8
-
9
- from .dpt_head import DPTHead
10
- from .track_modules.base_track_predictor import BaseTrackerPredictor
11
-
12
-
13
- class TrackHead(nn.Module):
14
- """
15
- Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
16
- The tracking is performed iteratively, refining predictions over multiple iterations.
17
- """
18
-
19
- def __init__(
20
- self,
21
- dim_in,
22
- patch_size=14,
23
- features=128,
24
- iters=4,
25
- predict_conf=True,
26
- stride=2,
27
- corr_levels=7,
28
- corr_radius=4,
29
- hidden_size=384,
30
- ):
31
- """
32
- Initialize the TrackHead module.
33
-
34
- Args:
35
- dim_in (int): Input dimension of tokens from the backbone.
36
- patch_size (int): Size of image patches used in the vision transformer.
37
- features (int): Number of feature channels in the feature extractor output.
38
- iters (int): Number of refinement iterations for tracking predictions.
39
- predict_conf (bool): Whether to predict confidence scores for tracked points.
40
- stride (int): Stride value for the tracker predictor.
41
- corr_levels (int): Number of correlation pyramid levels
42
- corr_radius (int): Radius for correlation computation, controlling the search area.
43
- hidden_size (int): Size of hidden layers in the tracker network.
44
- """
45
- super().__init__()
46
-
47
- self.patch_size = patch_size
48
-
49
- # Feature extractor based on DPT architecture
50
- # Processes tokens into feature maps for tracking
51
- self.feature_extractor = DPTHead(
52
- dim_in=dim_in,
53
- patch_size=patch_size,
54
- features=features,
55
- feature_only=True, # Only output features, no activation
56
- down_ratio=2, # Reduces spatial dimensions by factor of 2
57
- pos_embed=False,
58
- )
59
-
60
- # Tracker module that predicts point trajectories
61
- # Takes feature maps and predicts coordinates and visibility
62
- self.tracker = BaseTrackerPredictor(
63
- latent_dim=features, # Match the output_dim of feature extractor
64
- predict_conf=predict_conf,
65
- stride=stride,
66
- corr_levels=corr_levels,
67
- corr_radius=corr_radius,
68
- hidden_size=hidden_size,
69
- )
70
-
71
- self.iters = iters
72
-
73
- def forward(
74
- self,
75
- aggregated_tokens_list,
76
- images,
77
- patch_start_idx,
78
- query_points=None,
79
- iters=None,
80
- ):
81
- """
82
- Forward pass of the TrackHead.
83
-
84
- Args:
85
- aggregated_tokens_list (list): List of aggregated tokens from the backbone.
86
- images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
87
- B = batch size, S = sequence length.
88
- patch_start_idx (int): Starting index for patch tokens.
89
- query_points (torch.Tensor, optional): Initial query points to track.
90
- If None, points are initialized by the tracker.
91
- iters (int, optional): Number of refinement iterations. If None, uses self.iters.
92
-
93
- Returns:
94
- tuple:
95
- - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
96
- - vis_scores (torch.Tensor): Visibility scores for tracked points.
97
- - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
98
- """
99
- B, S, _, H, W = images.shape
100
-
101
- # Extract features from tokens
102
- # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
103
- feature_maps = self.feature_extractor(
104
- aggregated_tokens_list, images, patch_start_idx
105
- )
106
-
107
- # Use default iterations if not specified
108
- if iters is None:
109
- iters = self.iters
110
-
111
- # Perform tracking using the extracted features
112
- coord_preds, vis_scores, conf_scores = self.tracker(
113
- query_points=query_points,
114
- fmaps=feature_maps,
115
- iters=iters,
116
- )
117
-
118
- return coord_preds, vis_scores, conf_scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mapanything/models/external/vggt/heads/track_modules/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
mapanything/models/external/vggt/heads/track_modules/base_track_predictor.py DELETED
@@ -1,242 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import torch
8
- import torch.nn as nn
9
- from einops import rearrange
10
-
11
- from .blocks import CorrBlock, EfficientUpdateFormer
12
- from .modules import Mlp
13
- from .utils import get_2d_embedding, get_2d_sincos_pos_embed, sample_features4d
14
-
15
-
16
- class BaseTrackerPredictor(nn.Module):
17
- def __init__(
18
- self,
19
- stride=1,
20
- corr_levels=5,
21
- corr_radius=4,
22
- latent_dim=128,
23
- hidden_size=384,
24
- use_spaceatt=True,
25
- depth=6,
26
- max_scale=518,
27
- predict_conf=True,
28
- ):
29
- super(BaseTrackerPredictor, self).__init__()
30
- """
31
- The base template to create a track predictor
32
-
33
- Modified from https://github.com/facebookresearch/co-tracker/
34
- and https://github.com/facebookresearch/vggsfm
35
- """
36
-
37
- self.stride = stride
38
- self.latent_dim = latent_dim
39
- self.corr_levels = corr_levels
40
- self.corr_radius = corr_radius
41
- self.hidden_size = hidden_size
42
- self.max_scale = max_scale
43
- self.predict_conf = predict_conf
44
-
45
- self.flows_emb_dim = latent_dim // 2
46
-
47
- self.corr_mlp = Mlp(
48
- in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
49
- hidden_features=self.hidden_size,
50
- out_features=self.latent_dim,
51
- )
52
-
53
- self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
54
-
55
- self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
56
-
57
- space_depth = depth if use_spaceatt else 0
58
- time_depth = depth
59
-
60
- self.updateformer = EfficientUpdateFormer(
61
- space_depth=space_depth,
62
- time_depth=time_depth,
63
- input_dim=self.transformer_dim,
64
- hidden_size=self.hidden_size,
65
- output_dim=self.latent_dim + 2,
66
- mlp_ratio=4.0,
67
- add_space_attn=use_spaceatt,
68
- )
69
-
70
- self.fmap_norm = nn.LayerNorm(self.latent_dim)
71
- self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
72
-
73
- # A linear layer to update track feats at each iteration
74
- self.ffeat_updater = nn.Sequential(
75
- nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()
76
- )
77
-
78
- self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
79
-
80
- if predict_conf:
81
- self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
82
-
83
- def forward(
84
- self,
85
- query_points,
86
- fmaps=None,
87
- iters=6,
88
- return_feat=False,
89
- down_ratio=1,
90
- apply_sigmoid=True,
91
- ):
92
- """
93
- query_points: B x N x 2, the number of batches, tracks, and xy
94
- fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
95
- note HH and WW is the size of feature maps instead of original images
96
- """
97
- B, N, D = query_points.shape
98
- B, S, C, HH, WW = fmaps.shape
99
-
100
- assert D == 2, "Input points must be 2D coordinates"
101
-
102
- # apply a layernorm to fmaps here
103
- fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
104
- fmaps = fmaps.permute(0, 1, 4, 2, 3)
105
-
106
- # Scale the input query_points because we may downsample the images
107
- # by down_ratio or self.stride
108
- # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
109
- # its query_points should be query_points/4
110
- if down_ratio > 1:
111
- query_points = query_points / float(down_ratio)
112
-
113
- query_points = query_points / float(self.stride)
114
-
115
- # Init with coords as the query points
116
- # It means the search will start from the position of query points at the reference frames
117
- coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
118
-
119
- # Sample/extract the features of the query points in the query frame
120
- query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
121
-
122
- # init track feats by query feats
123
- track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
124
- # back up the init coords
125
- coords_backup = coords.clone()
126
-
127
- fcorr_fn = CorrBlock(
128
- fmaps, num_levels=self.corr_levels, radius=self.corr_radius
129
- )
130
-
131
- coord_preds = []
132
-
133
- # Iterative Refinement
134
- for _ in range(iters):
135
- # Detach the gradients from the last iteration
136
- # (in my experience, not very important for performance)
137
- coords = coords.detach()
138
-
139
- fcorrs = fcorr_fn.corr_sample(track_feats, coords)
140
-
141
- corr_dim = fcorrs.shape[3]
142
- fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
143
- fcorrs_ = self.corr_mlp(fcorrs_)
144
-
145
- # Movement of current coords relative to query points
146
- flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
147
-
148
- flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
149
-
150
- # (In my trials, it is also okay to just add the flows_emb instead of concat)
151
- flows_emb = torch.cat(
152
- [flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1
153
- )
154
-
155
- track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(
156
- B * N, S, self.latent_dim
157
- )
158
-
159
- # Concatenate them as the input for the transformers
160
- transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
161
-
162
- # 2D positional embed
163
- # TODO: this can be much simplified
164
- pos_embed = get_2d_sincos_pos_embed(
165
- self.transformer_dim, grid_size=(HH, WW)
166
- ).to(query_points.device)
167
- sampled_pos_emb = sample_features4d(
168
- pos_embed.expand(B, -1, -1, -1), coords[:, 0]
169
- )
170
-
171
- sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(
172
- 1
173
- )
174
-
175
- x = transformer_input + sampled_pos_emb
176
-
177
- # Add the query ref token to the track feats
178
- query_ref_token = torch.cat(
179
- [
180
- self.query_ref_token[:, 0:1],
181
- self.query_ref_token[:, 1:2].expand(-1, S - 1, -1),
182
- ],
183
- dim=1,
184
- )
185
- x = x + query_ref_token.to(x.device).to(x.dtype)
186
-
187
- # B, N, S, C
188
- x = rearrange(x, "(b n) s d -> b n s d", b=B)
189
-
190
- # Compute the delta coordinates and delta track features
191
- delta, _ = self.updateformer(x)
192
-
193
- # BN, S, C
194
- delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
195
- delta_coords_ = delta[:, :, :2]
196
- delta_feats_ = delta[:, :, 2:]
197
-
198
- track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
199
- delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
200
-
201
- # Update the track features
202
- track_feats_ = (
203
- self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
204
- )
205
-
206
- track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(
207
- 0, 2, 1, 3
208
- ) # BxSxNxC
209
-
210
- # B x S x N x 2
211
- coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
212
-
213
- # Force coord0 as query
214
- # because we assume the query points should not be changed
215
- coords[:, 0] = coords_backup[:, 0]
216
-
217
- # The predicted tracks are in the original image scale
218
- if down_ratio > 1:
219
- coord_preds.append(coords * self.stride * down_ratio)
220
- else:
221
- coord_preds.append(coords * self.stride)
222
-
223
- # B, S, N
224
- vis_e = self.vis_predictor(
225
- track_feats.reshape(B * S * N, self.latent_dim)
226
- ).reshape(B, S, N)
227
- if apply_sigmoid:
228
- vis_e = torch.sigmoid(vis_e)
229
-
230
- if self.predict_conf:
231
- conf_e = self.conf_predictor(
232
- track_feats.reshape(B * S * N, self.latent_dim)
233
- ).reshape(B, S, N)
234
- if apply_sigmoid:
235
- conf_e = torch.sigmoid(conf_e)
236
- else:
237
- conf_e = None
238
-
239
- if return_feat:
240
- return coord_preds, vis_e, track_feats, query_track_feat, conf_e
241
- else:
242
- return coord_preds, vis_e, conf_e