ZhenweiWang's picture
Upload folder using huggingface_hub
0ca05b5 verified
import torch
import numpy as np
def depth_to_camera_coords(depthmap, camera_intrinsics):
"""
Convert depth map to 3D camera coordinates.
Args:
depthmap (BxHxW tensor): Batch of depth maps
camera_intrinsics (Bx3x3 tensor): Camera intrinsics matrix for each camera
Returns:
X_cam (BxHxWx3 tensor): 3D points in camera coordinates
valid_mask (BxHxW tensor): Mask indicating valid depth pixels
"""
B, H, W = depthmap.shape
device = depthmap.device
dtype = depthmap.dtype
# Ensure intrinsics are float
camera_intrinsics = camera_intrinsics.float()
# Extract focal lengths and principal points
fx = camera_intrinsics[:, 0, 0] # (B,)
fy = camera_intrinsics[:, 1, 1] # (B,)
cx = camera_intrinsics[:, 0, 2] # (B,)
cy = camera_intrinsics[:, 1, 2] # (B,)
# Generate pixel grid
v_grid, u_grid = torch.meshgrid(
torch.arange(H, dtype=dtype, device=device),
torch.arange(W, dtype=dtype, device=device),
indexing='ij'
)
# Reshape for broadcasting: (1, H, W)
u_grid = u_grid.unsqueeze(0)
v_grid = v_grid.unsqueeze(0)
# Compute 3D camera coordinates
# X = (u - cx) * Z / fx
# Y = (v - cy) * Z / fy
# Z = depth
z_cam = depthmap # (B, H, W)
x_cam = (u_grid - cx.view(B, 1, 1)) * z_cam / fx.view(B, 1, 1)
y_cam = (v_grid - cy.view(B, 1, 1)) * z_cam / fy.view(B, 1, 1)
# Stack to form (B, H, W, 3)
X_cam = torch.stack([x_cam, y_cam, z_cam], dim=-1)
# Valid depth mask
valid_mask = depthmap > 0.0
return X_cam, valid_mask
def depth_to_world_coords_points(
depth_map: torch.Tensor, extrinsic: torch.Tensor, intrinsic: torch.Tensor, eps=1e-8
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Convert a batch of depth maps to world coordinates.
Args:
depth_map (torch.Tensor): (B, H, W) Depth map
extrinsic (torch.Tensor): (B, 4, 4) Camera extrinsic matrix (camera-to-world transformation)
intrinsic (torch.Tensor): (B, 3, 3) Camera intrinsic matrix
Returns:
world_coords_points (torch.Tensor): (B, H, W, 3) World coordinates
camera_points (torch.Tensor): (B, H, W, 3) Camera coordinates
point_mask (torch.Tensor): (B, H, W) Valid depth mask
"""
if depth_map is None:
return None, None, None
# Valid depth mask (B, H, W)
point_mask = depth_map > eps
# Convert depth map to camera coordinates (B, H, W, 3)
camera_points, _ = depth_to_camera_coords(depth_map, intrinsic)
# Apply extrinsic matrix (camera -> world)
R_cam_to_world = extrinsic[:, :3, :3] # (B, 3, 3)
t_cam_to_world = extrinsic[:, :3, 3] # (B, 3)
# Transform (B, H, W, 3) x (B, 3, 3)^T + (B, 3) -> (B, H, W, 3)
world_coords_points = torch.einsum('bhwi,bji->bhwj', camera_points, R_cam_to_world) + t_cam_to_world[:, None, None, :]
return world_coords_points, camera_points, point_mask
def closed_form_inverse_se3(se3: torch.Tensor) -> torch.Tensor:
"""
Efficiently invert batched SE(3) matrices of shape (B, 4, 4).
Args:
se3 (torch.Tensor): (B, 4, 4) Transformation matrices
Returns:
out (torch.Tensor): (B, 4, 4) Inverse transformation matrices
"""
assert se3.ndim == 3 and se3.shape[1:] == (4, 4), f"se3 must be (B, 4, 4), got {se3.shape}"
R = se3[:, :3, :3] # (B, 3, 3)
t = se3[:, :3, 3] # (B, 3)
Rt = R.transpose(1, 2) # (B, 3, 3)
t_inv = -torch.bmm(Rt, t.unsqueeze(-1)).squeeze(-1) # (B, 3)
out = se3.new_zeros(se3.shape)
out[:, :3, :3] = Rt
out[:, :3, 3] = t_inv
out[:, 3, 3] = 1.0
return out
def create_pixel_coordinate_grid(num_frames, height, width):
"""
Creates a grid of pixel coordinates and frame indices for all frames.
Returns:
tuple: A tuple containing:
- points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3)
with x, y coordinates and frame indices
"""
# Create coordinate grids for a single frame
y_grid, x_grid = np.indices((height, width), dtype=np.float32)
x_grid = x_grid[np.newaxis, :, :]
y_grid = y_grid[np.newaxis, :, :]
# Broadcast to all frames
x_coords = np.broadcast_to(x_grid, (num_frames, height, width))
y_coords = np.broadcast_to(y_grid, (num_frames, height, width))
# Create frame indices and broadcast
f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis]
f_coords = np.broadcast_to(f_idx, (num_frames, height, width))
# Stack coordinates and frame indices
points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1)
return points_xyf