aknapitsch user
simpler inference and refactoring
37de32d
import numpy as np
import torch
def m_dot(
transform: torch.Tensor,
points: torch.Tensor | list,
maintain_shape: bool = False,
) -> torch.Tensor | list:
"""
Apply batch matrix multiplication between transform matrices and points.
Args:
transform: Batch of transformation matrices [..., 3/4, 3/4]
points: Batch of points [..., N, 3] or a list of points
maintain_shape: If True, preserves the original shape of points
Returns:
Transformed points with shape [..., N, 3] or a list of transformed points
"""
if isinstance(points, list):
return [m_dot(t, p, maintain_shape) for t, p in zip(transform, points)]
# Store original shape and flatten batch dimensions
orig_shape = points.shape
batch_dims = points.shape[:-3]
# Reshape to standard batch format
transform_flat = transform.reshape(-1, transform.shape[-2], transform.shape[-1])
points_flat = points.reshape(transform_flat.shape[0], -1, points.shape[-1])
# Apply transformation
pts = torch.bmm(
transform_flat[:, :3, :3],
points_flat[..., :3].permute(0, 2, 1).to(transform_flat.dtype),
).permute(0, 2, 1)
if transform.shape[-1] == 4:
pts = pts + transform_flat[:, :3, 3].unsqueeze(1)
# Restore original shape
if maintain_shape:
return pts.reshape(orig_shape)
else:
return pts.reshape(*batch_dims, -1, 3)
def m_unproject(
depth: torch.Tensor,
intrinsic: torch.Tensor,
cam2world: torch.Tensor = None,
img_grid: torch.Tensor = None,
valid: torch.Tensor = None,
H: int | None = None,
W: int | None = None,
img_feats: torch.Tensor = None,
maintain_shape: bool = False,
) -> torch.Tensor:
"""
Unproject 2D image points with depth values to 3D points in camera or world space.
Args:
depth: Depth values, either a tensor of shape ...xHxW or a float value
intrinsic: Camera intrinsic matrix of shape ...x3x3
cam2world: Optional camera-to-world transformation matrix of shape ...x4x4
img_grid: Optional pre-computed image grid. If None, will be created
valid: Optional mask for valid depth values or minimum depth threshold
H: Image height (required if depth is a scalar)
W: Image width (required if depth is a scalar)
img_feats: Optional image features to append to 3D points
maintain_shape: If True, preserves the original shape of points
Returns:
3D points in camera or world space, with optional features appended
"""
# Get device and shape information from intrinsic matrix
device = intrinsic.device
pre_shape = intrinsic.shape[:-2] # Batch dimensions
# Validate inputs
if isinstance(depth, (int, float)) and H is None:
raise ValueError("H must be provided if depth is a scalar")
# Determine image dimensions from depth if not provided
if isinstance(depth, torch.Tensor) and H is None:
H, W = depth.shape[-2:]
# Create image grid if not provided
if img_grid is None:
# Create coordinate grid with shape HxWx3 (last dimension is homogeneous)
img_grid = _create_image_grid(H, W, device)
# Add homogeneous coordinate
img_grid = torch.cat([img_grid, torch.ones_like(img_grid[..., :1])], -1)
# Expand img_grid to match batch dimensions of intrinsic
if img_grid.dim() <= intrinsic.dim():
img_grid = img_grid.unsqueeze(0)
img_grid = img_grid.expand(*pre_shape, *img_grid.shape[-3:])
# Handle valid mask or minimum depth threshold
depth_mask = None
if valid is not None:
if isinstance(valid, float):
# Create mask for minimum depth value
depth_mask = depth > valid
elif isinstance(valid, torch.Tensor):
depth_mask = valid
# Apply mask to image grid and other inputs
img_grid = masking(img_grid, depth_mask, dim=intrinsic.dim())
if not isinstance(depth, (int, float)):
depth = masking(depth, depth_mask, dim=intrinsic.dim() - 1)
if img_feats is not None:
img_feats = masking(img_feats, depth_mask, dim=intrinsic.dim() - 1)
# Unproject 2D points to 3D camera space
cam_pts: torch.Tensor = m_dot(
m_inverse_intrinsics(intrinsic),
img_grid[..., [1, 0, 2]],
maintain_shape=True,
)
# Scale by depth values
cam_pts = mult(cam_pts, depth.unsqueeze(-1))
# Transform to world space if cam2world is provided
if cam2world is not None:
cam_pts = m_dot(cam2world, cam_pts, maintain_shape=True)
# Append image features if provided
if img_feats is not None:
if isinstance(cam_pts, list):
if isinstance(cam_pts[0], list):
# Handle nested list case
result = []
for batch_idx, batch in enumerate(cam_pts):
batch_result = []
for view_idx, view in enumerate(batch):
batch_result.append(
torch.cat([view, img_feats[batch_idx][view_idx]], -1)
)
result.append(batch_result)
cam_pts = result
else:
# Handle single list case
cam_pts = [
torch.cat([pts, feats], -1)
for pts, feats in zip(cam_pts, img_feats)
]
else:
# Handle tensor case
cam_pts = torch.cat([cam_pts, img_feats], -1)
if maintain_shape:
return cam_pts
# Flatten last dimension
return cam_pts.reshape(*pre_shape, -1, 3)
def m_project(
world_pts: torch.Tensor,
intrinsic: torch.Tensor,
world2cam: torch.Tensor | None = None,
maintain_shape: bool = False,
) -> torch.Tensor:
"""
Project 3D world points to 2D image coordinates.
Args:
world_pts: 3D points in world coordinates
intrinsic: Camera intrinsic matrix
world2cam: Optional transformation from world to camera coordinates
maintain_shape: If True, preserves the original shape of points
Returns:
Image points with coordinates in img_y,img_x,z order
"""
# Transform points from world to camera space if world2cam is provided
cam_pts: torch.Tensor = world_pts
if world2cam is not None:
cam_pts = m_dot(world2cam, world_pts, maintain_shape=maintain_shape)
# Get shapes to properly expand intrinsics
shared_dims = intrinsic.shape[:-2]
extra_dims = cam_pts.shape[len(shared_dims) : -1]
# Expand intrinsics to match cam_pts shape
expanded_intrinsic = intrinsic.view(*shared_dims, *([1] * len(extra_dims)), 3, 3)
expanded_intrinsic = expanded_intrinsic.expand(*shared_dims, *extra_dims, 3, 3)
# Project points from camera space to image space
depth_abs = cam_pts[..., 2].abs().clamp(min=1e-5)
return torch.stack(
[
expanded_intrinsic[..., 1, 1] * cam_pts[..., 1] / depth_abs
+ expanded_intrinsic[..., 1, 2],
expanded_intrinsic[..., 0, 0] * cam_pts[..., 0] / depth_abs
+ expanded_intrinsic[..., 0, 2],
cam_pts[..., 2],
],
-1,
)
def in_image(
image_pts: torch.Tensor | list,
H: int,
W: int,
min_depth: float = 0.0,
) -> torch.Tensor | list:
"""
Check if image points are within the image boundaries.
Args:
image_pts: Image points in pixel coordinates
H: Image height
W: Image width
min_depth: Minimum valid depth
Returns:
Boolean mask indicating which points are within the image
"""
is_list = isinstance(image_pts, list)
if is_list:
return [in_image(pts, H, W, min_depth=min_depth) for pts in image_pts]
in_image_mask = (
torch.all(image_pts >= 0, -1)
& (image_pts[..., 0] < H)
& (image_pts[..., 1] < W)
)
if (min_depth is not None) and image_pts.shape[-1] == 3:
in_image_mask &= image_pts[..., 2] > min_depth
return in_image_mask
def _create_image_grid(H: int, W: int, device: torch.device) -> torch.Tensor:
"""
Create a coordinate grid for image pixels.
Args:
H: Image height
W: Image width
device: Computation device
Returns:
Image grid with shape HxWx3 (last dimension is homogeneous)
"""
y_coords = torch.arange(H, device=device)
x_coords = torch.arange(W, device=device)
# Use meshgrid with indexing="ij" for correct orientation
y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij")
# Stack coordinates and add homogeneous coordinate
img_grid = torch.stack([y_grid, x_grid, torch.ones_like(y_grid)], dim=-1)
return img_grid
def masking(
X: torch.Tensor | list,
mask: torch.Tensor | list,
dim: int = 3,
) -> torch.Tensor | list:
"""
Apply a Boolean mask to tensor or list elements.
Handles nested structures by recursively applying the mask.
Args:
X: Input tensor or list to be masked
mask: Boolean mask to apply
dim: Dimension threshold for recursive processing
Returns:
Masked tensor or list with the same structure as input
"""
if isinstance(X, list) or (isinstance(X, torch.Tensor) and X.dim() >= dim):
return [masking(x, m, dim) for x, m in zip(X, mask)]
return X[mask]
def m_inverse_intrinsics(intrinsics: torch.Tensor) -> torch.Tensor:
"""
Compute the inverse of camera intrinsics matrices analytically.
This is much faster than using torch.inverse() for intrinsics matrices.
The intrinsics matrix has the form:
K = [fx s cx]
[0 fy cy]
[0 0 1]
And its inverse is:
K^-1 = [1/fx -s/(fx*fy) (s*cy-cx*fy)/(fx*fy)]
[0 1/fy -cy/fy ]
[0 0 1 ]
Args:
intrinsics: Camera intrinsics matrices of shape [..., 3, 3]
Returns:
Inverse intrinsics matrices of shape [..., 3, 3]
"""
# Extract the components of the intrinsics matrix
fx = intrinsics[..., 0, 0]
s = intrinsics[..., 0, 1] # skew, usually 0
cx = intrinsics[..., 0, 2]
fy = intrinsics[..., 1, 1]
cy = intrinsics[..., 1, 2]
# Create output tensor with same shape and device
inv_intrinsics = torch.zeros_like(intrinsics)
# Compute the inverse analytically
inv_intrinsics[..., 0, 0] = 1.0 / fx
inv_intrinsics[..., 0, 1] = -s / (fx * fy)
inv_intrinsics[..., 0, 2] = (s * cy - cx * fy) / (fx * fy)
inv_intrinsics[..., 1, 1] = 1.0 / fy
inv_intrinsics[..., 1, 2] = -cy / fy
inv_intrinsics[..., 2, 2] = 1.0
return inv_intrinsics
def mult(
A: torch.Tensor | np.ndarray | list | float | int,
B: torch.Tensor | np.ndarray | list | float | int,
) -> torch.Tensor | np.ndarray | list | float | int:
"""
Multiply two objects with support for lists, tensors, arrays, and scalars.
Handles nested structures by recursively applying multiplication.
Args:
A: First operand (tensor, array, list, or scalar)
B: Second operand (tensor, array, list, or scalar)
Returns:
Result of multiplication with the same structure as inputs
"""
if isinstance(A, list) and isinstance(B, (int, float)):
return [mult(a, B) for a in A]
if isinstance(B, list) and isinstance(A, (int, float)):
return [mult(A, b) for b in B]
if isinstance(A, list) and isinstance(B, list):
return [mult(a, b) for a, b in zip(A, B)]
return A * B