File size: 11,649 Bytes
37de32d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import numpy as np
import torch


def m_dot(
    transform: torch.Tensor,
    points: torch.Tensor | list,
    maintain_shape: bool = False,
) -> torch.Tensor | list:
    """
    Apply batch matrix multiplication between transform matrices and points.

    Args:
        transform: Batch of transformation matrices [..., 3/4, 3/4]
        points: Batch of points [..., N, 3] or a list of points
        maintain_shape: If True, preserves the original shape of points

    Returns:
        Transformed points with shape [..., N, 3] or a list of transformed points
    """
    if isinstance(points, list):
        return [m_dot(t, p, maintain_shape) for t, p in zip(transform, points)]

    # Store original shape and flatten batch dimensions
    orig_shape = points.shape
    batch_dims = points.shape[:-3]

    # Reshape to standard batch format
    transform_flat = transform.reshape(-1, transform.shape[-2], transform.shape[-1])
    points_flat = points.reshape(transform_flat.shape[0], -1, points.shape[-1])

    # Apply transformation
    pts = torch.bmm(
        transform_flat[:, :3, :3],
        points_flat[..., :3].permute(0, 2, 1).to(transform_flat.dtype),
    ).permute(0, 2, 1)

    if transform.shape[-1] == 4:
        pts = pts + transform_flat[:, :3, 3].unsqueeze(1)

    # Restore original shape
    if maintain_shape:
        return pts.reshape(orig_shape)
    else:
        return pts.reshape(*batch_dims, -1, 3)


def m_unproject(
    depth: torch.Tensor,
    intrinsic: torch.Tensor,
    cam2world: torch.Tensor = None,
    img_grid: torch.Tensor = None,
    valid: torch.Tensor = None,
    H: int | None = None,
    W: int | None = None,
    img_feats: torch.Tensor = None,
    maintain_shape: bool = False,
) -> torch.Tensor:
    """
    Unproject 2D image points with depth values to 3D points in camera or world space.

    Args:
        depth: Depth values, either a tensor of shape ...xHxW or a float value
        intrinsic: Camera intrinsic matrix of shape ...x3x3
        cam2world: Optional camera-to-world transformation matrix of shape ...x4x4
        img_grid: Optional pre-computed image grid. If None, will be created
        valid: Optional mask for valid depth values or minimum depth threshold
        H: Image height (required if depth is a scalar)
        W: Image width (required if depth is a scalar)
        img_feats: Optional image features to append to 3D points
        maintain_shape: If True, preserves the original shape of points

    Returns:
        3D points in camera or world space, with optional features appended
    """
    # Get device and shape information from intrinsic matrix
    device = intrinsic.device
    pre_shape = intrinsic.shape[:-2]  # Batch dimensions

    # Validate inputs
    if isinstance(depth, (int, float)) and H is None:
        raise ValueError("H must be provided if depth is a scalar")

    # Determine image dimensions from depth if not provided
    if isinstance(depth, torch.Tensor) and H is None:
        H, W = depth.shape[-2:]

    # Create image grid if not provided
    if img_grid is None:
        # Create coordinate grid with shape HxWx3 (last dimension is homogeneous)
        img_grid = _create_image_grid(H, W, device)
        # Add homogeneous coordinate
        img_grid = torch.cat([img_grid, torch.ones_like(img_grid[..., :1])], -1)

    # Expand img_grid to match batch dimensions of intrinsic
    if img_grid.dim() <= intrinsic.dim():
        img_grid = img_grid.unsqueeze(0)
        img_grid = img_grid.expand(*pre_shape, *img_grid.shape[-3:])

    # Handle valid mask or minimum depth threshold
    depth_mask = None
    if valid is not None:
        if isinstance(valid, float):
            # Create mask for minimum depth value
            depth_mask = depth > valid
        elif isinstance(valid, torch.Tensor):
            depth_mask = valid

        # Apply mask to image grid and other inputs
        img_grid = masking(img_grid, depth_mask, dim=intrinsic.dim())
        if not isinstance(depth, (int, float)):
            depth = masking(depth, depth_mask, dim=intrinsic.dim() - 1)
        if img_feats is not None:
            img_feats = masking(img_feats, depth_mask, dim=intrinsic.dim() - 1)

    # Unproject 2D points to 3D camera space
    cam_pts: torch.Tensor = m_dot(
        m_inverse_intrinsics(intrinsic),
        img_grid[..., [1, 0, 2]],
        maintain_shape=True,
    )
    # Scale by depth values
    cam_pts = mult(cam_pts, depth.unsqueeze(-1))

    # Transform to world space if cam2world is provided
    if cam2world is not None:
        cam_pts = m_dot(cam2world, cam_pts, maintain_shape=True)

    # Append image features if provided
    if img_feats is not None:
        if isinstance(cam_pts, list):
            if isinstance(cam_pts[0], list):
                # Handle nested list case
                result = []
                for batch_idx, batch in enumerate(cam_pts):
                    batch_result = []
                    for view_idx, view in enumerate(batch):
                        batch_result.append(
                            torch.cat([view, img_feats[batch_idx][view_idx]], -1)
                        )
                    result.append(batch_result)
                cam_pts = result
            else:
                # Handle single list case
                cam_pts = [
                    torch.cat([pts, feats], -1)
                    for pts, feats in zip(cam_pts, img_feats)
                ]
        else:
            # Handle tensor case
            cam_pts = torch.cat([cam_pts, img_feats], -1)

    if maintain_shape:
        return cam_pts

    # Flatten last dimension
    return cam_pts.reshape(*pre_shape, -1, 3)


def m_project(
    world_pts: torch.Tensor,
    intrinsic: torch.Tensor,
    world2cam: torch.Tensor | None = None,
    maintain_shape: bool = False,
) -> torch.Tensor:
    """
    Project 3D world points to 2D image coordinates.

    Args:
        world_pts: 3D points in world coordinates
        intrinsic: Camera intrinsic matrix
        world2cam: Optional transformation from world to camera coordinates
        maintain_shape: If True, preserves the original shape of points

    Returns:
        Image points with coordinates in img_y,img_x,z order
    """
    # Transform points from world to camera space if world2cam is provided
    cam_pts: torch.Tensor = world_pts
    if world2cam is not None:
        cam_pts = m_dot(world2cam, world_pts, maintain_shape=maintain_shape)

    # Get shapes to properly expand intrinsics
    shared_dims = intrinsic.shape[:-2]
    extra_dims = cam_pts.shape[len(shared_dims) : -1]

    # Expand intrinsics to match cam_pts shape
    expanded_intrinsic = intrinsic.view(*shared_dims, *([1] * len(extra_dims)), 3, 3)
    expanded_intrinsic = expanded_intrinsic.expand(*shared_dims, *extra_dims, 3, 3)

    # Project points from camera space to image space
    depth_abs = cam_pts[..., 2].abs().clamp(min=1e-5)
    return torch.stack(
        [
            expanded_intrinsic[..., 1, 1] * cam_pts[..., 1] / depth_abs
            + expanded_intrinsic[..., 1, 2],
            expanded_intrinsic[..., 0, 0] * cam_pts[..., 0] / depth_abs
            + expanded_intrinsic[..., 0, 2],
            cam_pts[..., 2],
        ],
        -1,
    )


def in_image(
    image_pts: torch.Tensor | list,
    H: int,
    W: int,
    min_depth: float = 0.0,
) -> torch.Tensor | list:
    """
    Check if image points are within the image boundaries.

    Args:
        image_pts: Image points in pixel coordinates
        H: Image height
        W: Image width
        min_depth: Minimum valid depth

    Returns:
        Boolean mask indicating which points are within the image
    """
    is_list = isinstance(image_pts, list)
    if is_list:
        return [in_image(pts, H, W, min_depth=min_depth) for pts in image_pts]

    in_image_mask = (
        torch.all(image_pts >= 0, -1)
        & (image_pts[..., 0] < H)
        & (image_pts[..., 1] < W)
    )
    if (min_depth is not None) and image_pts.shape[-1] == 3:
        in_image_mask &= image_pts[..., 2] > min_depth
    return in_image_mask


def _create_image_grid(H: int, W: int, device: torch.device) -> torch.Tensor:
    """
    Create a coordinate grid for image pixels.

    Args:
        H: Image height
        W: Image width
        device: Computation device

    Returns:
        Image grid with shape HxWx3 (last dimension is homogeneous)
    """
    y_coords = torch.arange(H, device=device)
    x_coords = torch.arange(W, device=device)

    # Use meshgrid with indexing="ij" for correct orientation
    y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij")

    # Stack coordinates and add homogeneous coordinate
    img_grid = torch.stack([y_grid, x_grid, torch.ones_like(y_grid)], dim=-1)

    return img_grid


def masking(
    X: torch.Tensor | list,
    mask: torch.Tensor | list,
    dim: int = 3,
) -> torch.Tensor | list:
    """
    Apply a Boolean mask to tensor or list elements.
    Handles nested structures by recursively applying the mask.

    Args:
        X: Input tensor or list to be masked
        mask: Boolean mask to apply
        dim: Dimension threshold for recursive processing

    Returns:
        Masked tensor or list with the same structure as input
    """
    if isinstance(X, list) or (isinstance(X, torch.Tensor) and X.dim() >= dim):
        return [masking(x, m, dim) for x, m in zip(X, mask)]
    return X[mask]


def m_inverse_intrinsics(intrinsics: torch.Tensor) -> torch.Tensor:
    """
    Compute the inverse of camera intrinsics matrices analytically.
    This is much faster than using torch.inverse() for intrinsics matrices.

    The intrinsics matrix has the form:
    K = [fx  s  cx]
        [0   fy cy]
        [0   0   1]

    And its inverse is:
    K^-1 = [1/fx  -s/(fx*fy)  (s*cy-cx*fy)/(fx*fy)]
           [0     1/fy        -cy/fy            ]
           [0     0           1                 ]

    Args:
        intrinsics: Camera intrinsics matrices of shape [..., 3, 3]

    Returns:
        Inverse intrinsics matrices of shape [..., 3, 3]
    """
    # Extract the components of the intrinsics matrix
    fx = intrinsics[..., 0, 0]
    s = intrinsics[..., 0, 1]  # skew, usually 0
    cx = intrinsics[..., 0, 2]
    fy = intrinsics[..., 1, 1]
    cy = intrinsics[..., 1, 2]

    # Create output tensor with same shape and device
    inv_intrinsics = torch.zeros_like(intrinsics)

    # Compute the inverse analytically
    inv_intrinsics[..., 0, 0] = 1.0 / fx
    inv_intrinsics[..., 0, 1] = -s / (fx * fy)
    inv_intrinsics[..., 0, 2] = (s * cy - cx * fy) / (fx * fy)
    inv_intrinsics[..., 1, 1] = 1.0 / fy
    inv_intrinsics[..., 1, 2] = -cy / fy
    inv_intrinsics[..., 2, 2] = 1.0

    return inv_intrinsics


def mult(
    A: torch.Tensor | np.ndarray | list | float | int,
    B: torch.Tensor | np.ndarray | list | float | int,
) -> torch.Tensor | np.ndarray | list | float | int:
    """
    Multiply two objects with support for lists, tensors, arrays, and scalars.
    Handles nested structures by recursively applying multiplication.

    Args:
        A: First operand (tensor, array, list, or scalar)
        B: Second operand (tensor, array, list, or scalar)

    Returns:
        Result of multiplication with the same structure as inputs
    """
    if isinstance(A, list) and isinstance(B, (int, float)):
        return [mult(a, B) for a in A]
    if isinstance(B, list) and isinstance(A, (int, float)):
        return [mult(A, b) for b in B]
    if isinstance(A, list) and isinstance(B, list):
        return [mult(a, b) for a, b in zip(A, B)]
    return A * B