File size: 4,779 Bytes
0ca05b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import numpy as np


def depth_to_camera_coords(depthmap, camera_intrinsics):
    """
    Convert depth map to 3D camera coordinates.
    
    Args:
        depthmap (BxHxW tensor): Batch of depth maps
        camera_intrinsics (Bx3x3 tensor): Camera intrinsics matrix for each camera
        
    Returns:
        X_cam (BxHxWx3 tensor): 3D points in camera coordinates
        valid_mask (BxHxW tensor): Mask indicating valid depth pixels
    """
    B, H, W = depthmap.shape
    device = depthmap.device
    dtype = depthmap.dtype
    
    # Ensure intrinsics are float
    camera_intrinsics = camera_intrinsics.float()
    
    # Extract focal lengths and principal points
    fx = camera_intrinsics[:, 0, 0]  # (B,)
    fy = camera_intrinsics[:, 1, 1]  # (B,)
    cx = camera_intrinsics[:, 0, 2]  # (B,)
    cy = camera_intrinsics[:, 1, 2]  # (B,)
    
    # Generate pixel grid
    v_grid, u_grid = torch.meshgrid(
        torch.arange(H, dtype=dtype, device=device),
        torch.arange(W, dtype=dtype, device=device),
        indexing='ij'
    )
    
    # Reshape for broadcasting: (1, H, W)
    u_grid = u_grid.unsqueeze(0)
    v_grid = v_grid.unsqueeze(0)
    
    # Compute 3D camera coordinates
    # X = (u - cx) * Z / fx
    # Y = (v - cy) * Z / fy
    # Z = depth
    z_cam = depthmap  # (B, H, W)
    x_cam = (u_grid - cx.view(B, 1, 1)) * z_cam / fx.view(B, 1, 1)
    y_cam = (v_grid - cy.view(B, 1, 1)) * z_cam / fy.view(B, 1, 1)
    
    # Stack to form (B, H, W, 3)
    X_cam = torch.stack([x_cam, y_cam, z_cam], dim=-1)
    
    # Valid depth mask
    valid_mask = depthmap > 0.0
    
    return X_cam, valid_mask

def depth_to_world_coords_points(
    depth_map: torch.Tensor, extrinsic: torch.Tensor, intrinsic: torch.Tensor, eps=1e-8
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Convert a batch of depth maps to world coordinates.

    Args:
        depth_map (torch.Tensor): (B, H, W) Depth map
        extrinsic (torch.Tensor): (B, 4, 4) Camera extrinsic matrix (camera-to-world transformation)
        intrinsic (torch.Tensor): (B, 3, 3) Camera intrinsic matrix

    Returns:
        world_coords_points (torch.Tensor): (B, H, W, 3) World coordinates
        camera_points (torch.Tensor): (B, H, W, 3) Camera coordinates
        point_mask (torch.Tensor): (B, H, W) Valid depth mask
    """
    if depth_map is None:
        return None, None, None

    # Valid depth mask (B, H, W)
    point_mask = depth_map > eps

    # Convert depth map to camera coordinates (B, H, W, 3)
    camera_points, _ = depth_to_camera_coords(depth_map, intrinsic)

    # Apply extrinsic matrix (camera -> world)
    R_cam_to_world = extrinsic[:, :3, :3]   # (B, 3, 3)
    t_cam_to_world = extrinsic[:, :3, 3]    # (B, 3)

    # Transform (B, H, W, 3) x (B, 3, 3)^T + (B, 3) -> (B, H, W, 3)
    world_coords_points = torch.einsum('bhwi,bji->bhwj', camera_points, R_cam_to_world) + t_cam_to_world[:, None, None, :]

    return world_coords_points, camera_points, point_mask


def closed_form_inverse_se3(se3: torch.Tensor) -> torch.Tensor:
    """
    Efficiently invert batched SE(3) matrices of shape (B, 4, 4).

    Args:
        se3 (torch.Tensor): (B, 4, 4) Transformation matrices

    Returns:
        out (torch.Tensor): (B, 4, 4) Inverse transformation matrices
    """
    assert se3.ndim == 3 and se3.shape[1:] == (4, 4), f"se3 must be (B, 4, 4), got {se3.shape}"
    R = se3[:, :3, :3]        # (B, 3, 3)
    t = se3[:, :3, 3]         # (B, 3)
    Rt = R.transpose(1, 2)    # (B, 3, 3)
    t_inv = -torch.bmm(Rt, t.unsqueeze(-1)).squeeze(-1)  # (B, 3)
    out = se3.new_zeros(se3.shape)
    out[:, :3, :3] = Rt
    out[:, :3, 3] = t_inv
    out[:, 3, 3] = 1.0
    return out


def create_pixel_coordinate_grid(num_frames, height, width):
    """
    Creates a grid of pixel coordinates and frame indices for all frames.
    Returns:
        tuple: A tuple containing:
            - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3)
                                            with x, y coordinates and frame indices
    """
    # Create coordinate grids for a single frame
    y_grid, x_grid = np.indices((height, width), dtype=np.float32)
    x_grid = x_grid[np.newaxis, :, :]
    y_grid = y_grid[np.newaxis, :, :]

    # Broadcast to all frames
    x_coords = np.broadcast_to(x_grid, (num_frames, height, width))
    y_coords = np.broadcast_to(y_grid, (num_frames, height, width))

    # Create frame indices and broadcast
    f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis]
    f_coords = np.broadcast_to(f_idx, (num_frames, height, width))

    # Stack coordinates and frame indices
    points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1)

    return points_xyf