Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,779 Bytes
0ca05b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import numpy as np
def depth_to_camera_coords(depthmap, camera_intrinsics):
"""
Convert depth map to 3D camera coordinates.
Args:
depthmap (BxHxW tensor): Batch of depth maps
camera_intrinsics (Bx3x3 tensor): Camera intrinsics matrix for each camera
Returns:
X_cam (BxHxWx3 tensor): 3D points in camera coordinates
valid_mask (BxHxW tensor): Mask indicating valid depth pixels
"""
B, H, W = depthmap.shape
device = depthmap.device
dtype = depthmap.dtype
# Ensure intrinsics are float
camera_intrinsics = camera_intrinsics.float()
# Extract focal lengths and principal points
fx = camera_intrinsics[:, 0, 0] # (B,)
fy = camera_intrinsics[:, 1, 1] # (B,)
cx = camera_intrinsics[:, 0, 2] # (B,)
cy = camera_intrinsics[:, 1, 2] # (B,)
# Generate pixel grid
v_grid, u_grid = torch.meshgrid(
torch.arange(H, dtype=dtype, device=device),
torch.arange(W, dtype=dtype, device=device),
indexing='ij'
)
# Reshape for broadcasting: (1, H, W)
u_grid = u_grid.unsqueeze(0)
v_grid = v_grid.unsqueeze(0)
# Compute 3D camera coordinates
# X = (u - cx) * Z / fx
# Y = (v - cy) * Z / fy
# Z = depth
z_cam = depthmap # (B, H, W)
x_cam = (u_grid - cx.view(B, 1, 1)) * z_cam / fx.view(B, 1, 1)
y_cam = (v_grid - cy.view(B, 1, 1)) * z_cam / fy.view(B, 1, 1)
# Stack to form (B, H, W, 3)
X_cam = torch.stack([x_cam, y_cam, z_cam], dim=-1)
# Valid depth mask
valid_mask = depthmap > 0.0
return X_cam, valid_mask
def depth_to_world_coords_points(
depth_map: torch.Tensor, extrinsic: torch.Tensor, intrinsic: torch.Tensor, eps=1e-8
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Convert a batch of depth maps to world coordinates.
Args:
depth_map (torch.Tensor): (B, H, W) Depth map
extrinsic (torch.Tensor): (B, 4, 4) Camera extrinsic matrix (camera-to-world transformation)
intrinsic (torch.Tensor): (B, 3, 3) Camera intrinsic matrix
Returns:
world_coords_points (torch.Tensor): (B, H, W, 3) World coordinates
camera_points (torch.Tensor): (B, H, W, 3) Camera coordinates
point_mask (torch.Tensor): (B, H, W) Valid depth mask
"""
if depth_map is None:
return None, None, None
# Valid depth mask (B, H, W)
point_mask = depth_map > eps
# Convert depth map to camera coordinates (B, H, W, 3)
camera_points, _ = depth_to_camera_coords(depth_map, intrinsic)
# Apply extrinsic matrix (camera -> world)
R_cam_to_world = extrinsic[:, :3, :3] # (B, 3, 3)
t_cam_to_world = extrinsic[:, :3, 3] # (B, 3)
# Transform (B, H, W, 3) x (B, 3, 3)^T + (B, 3) -> (B, H, W, 3)
world_coords_points = torch.einsum('bhwi,bji->bhwj', camera_points, R_cam_to_world) + t_cam_to_world[:, None, None, :]
return world_coords_points, camera_points, point_mask
def closed_form_inverse_se3(se3: torch.Tensor) -> torch.Tensor:
"""
Efficiently invert batched SE(3) matrices of shape (B, 4, 4).
Args:
se3 (torch.Tensor): (B, 4, 4) Transformation matrices
Returns:
out (torch.Tensor): (B, 4, 4) Inverse transformation matrices
"""
assert se3.ndim == 3 and se3.shape[1:] == (4, 4), f"se3 must be (B, 4, 4), got {se3.shape}"
R = se3[:, :3, :3] # (B, 3, 3)
t = se3[:, :3, 3] # (B, 3)
Rt = R.transpose(1, 2) # (B, 3, 3)
t_inv = -torch.bmm(Rt, t.unsqueeze(-1)).squeeze(-1) # (B, 3)
out = se3.new_zeros(se3.shape)
out[:, :3, :3] = Rt
out[:, :3, 3] = t_inv
out[:, 3, 3] = 1.0
return out
def create_pixel_coordinate_grid(num_frames, height, width):
"""
Creates a grid of pixel coordinates and frame indices for all frames.
Returns:
tuple: A tuple containing:
- points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3)
with x, y coordinates and frame indices
"""
# Create coordinate grids for a single frame
y_grid, x_grid = np.indices((height, width), dtype=np.float32)
x_grid = x_grid[np.newaxis, :, :]
y_grid = y_grid[np.newaxis, :, :]
# Broadcast to all frames
x_coords = np.broadcast_to(x_grid, (num_frames, height, width))
y_coords = np.broadcast_to(y_grid, (num_frames, height, width))
# Create frame indices and broadcast
f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis]
f_coords = np.broadcast_to(f_idx, (num_frames, height, width))
# Stack coordinates and frame indices
points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1)
return points_xyf |