Spaces:
Running
Running
File size: 11,649 Bytes
37de32d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 |
import numpy as np
import torch
def m_dot(
transform: torch.Tensor,
points: torch.Tensor | list,
maintain_shape: bool = False,
) -> torch.Tensor | list:
"""
Apply batch matrix multiplication between transform matrices and points.
Args:
transform: Batch of transformation matrices [..., 3/4, 3/4]
points: Batch of points [..., N, 3] or a list of points
maintain_shape: If True, preserves the original shape of points
Returns:
Transformed points with shape [..., N, 3] or a list of transformed points
"""
if isinstance(points, list):
return [m_dot(t, p, maintain_shape) for t, p in zip(transform, points)]
# Store original shape and flatten batch dimensions
orig_shape = points.shape
batch_dims = points.shape[:-3]
# Reshape to standard batch format
transform_flat = transform.reshape(-1, transform.shape[-2], transform.shape[-1])
points_flat = points.reshape(transform_flat.shape[0], -1, points.shape[-1])
# Apply transformation
pts = torch.bmm(
transform_flat[:, :3, :3],
points_flat[..., :3].permute(0, 2, 1).to(transform_flat.dtype),
).permute(0, 2, 1)
if transform.shape[-1] == 4:
pts = pts + transform_flat[:, :3, 3].unsqueeze(1)
# Restore original shape
if maintain_shape:
return pts.reshape(orig_shape)
else:
return pts.reshape(*batch_dims, -1, 3)
def m_unproject(
depth: torch.Tensor,
intrinsic: torch.Tensor,
cam2world: torch.Tensor = None,
img_grid: torch.Tensor = None,
valid: torch.Tensor = None,
H: int | None = None,
W: int | None = None,
img_feats: torch.Tensor = None,
maintain_shape: bool = False,
) -> torch.Tensor:
"""
Unproject 2D image points with depth values to 3D points in camera or world space.
Args:
depth: Depth values, either a tensor of shape ...xHxW or a float value
intrinsic: Camera intrinsic matrix of shape ...x3x3
cam2world: Optional camera-to-world transformation matrix of shape ...x4x4
img_grid: Optional pre-computed image grid. If None, will be created
valid: Optional mask for valid depth values or minimum depth threshold
H: Image height (required if depth is a scalar)
W: Image width (required if depth is a scalar)
img_feats: Optional image features to append to 3D points
maintain_shape: If True, preserves the original shape of points
Returns:
3D points in camera or world space, with optional features appended
"""
# Get device and shape information from intrinsic matrix
device = intrinsic.device
pre_shape = intrinsic.shape[:-2] # Batch dimensions
# Validate inputs
if isinstance(depth, (int, float)) and H is None:
raise ValueError("H must be provided if depth is a scalar")
# Determine image dimensions from depth if not provided
if isinstance(depth, torch.Tensor) and H is None:
H, W = depth.shape[-2:]
# Create image grid if not provided
if img_grid is None:
# Create coordinate grid with shape HxWx3 (last dimension is homogeneous)
img_grid = _create_image_grid(H, W, device)
# Add homogeneous coordinate
img_grid = torch.cat([img_grid, torch.ones_like(img_grid[..., :1])], -1)
# Expand img_grid to match batch dimensions of intrinsic
if img_grid.dim() <= intrinsic.dim():
img_grid = img_grid.unsqueeze(0)
img_grid = img_grid.expand(*pre_shape, *img_grid.shape[-3:])
# Handle valid mask or minimum depth threshold
depth_mask = None
if valid is not None:
if isinstance(valid, float):
# Create mask for minimum depth value
depth_mask = depth > valid
elif isinstance(valid, torch.Tensor):
depth_mask = valid
# Apply mask to image grid and other inputs
img_grid = masking(img_grid, depth_mask, dim=intrinsic.dim())
if not isinstance(depth, (int, float)):
depth = masking(depth, depth_mask, dim=intrinsic.dim() - 1)
if img_feats is not None:
img_feats = masking(img_feats, depth_mask, dim=intrinsic.dim() - 1)
# Unproject 2D points to 3D camera space
cam_pts: torch.Tensor = m_dot(
m_inverse_intrinsics(intrinsic),
img_grid[..., [1, 0, 2]],
maintain_shape=True,
)
# Scale by depth values
cam_pts = mult(cam_pts, depth.unsqueeze(-1))
# Transform to world space if cam2world is provided
if cam2world is not None:
cam_pts = m_dot(cam2world, cam_pts, maintain_shape=True)
# Append image features if provided
if img_feats is not None:
if isinstance(cam_pts, list):
if isinstance(cam_pts[0], list):
# Handle nested list case
result = []
for batch_idx, batch in enumerate(cam_pts):
batch_result = []
for view_idx, view in enumerate(batch):
batch_result.append(
torch.cat([view, img_feats[batch_idx][view_idx]], -1)
)
result.append(batch_result)
cam_pts = result
else:
# Handle single list case
cam_pts = [
torch.cat([pts, feats], -1)
for pts, feats in zip(cam_pts, img_feats)
]
else:
# Handle tensor case
cam_pts = torch.cat([cam_pts, img_feats], -1)
if maintain_shape:
return cam_pts
# Flatten last dimension
return cam_pts.reshape(*pre_shape, -1, 3)
def m_project(
world_pts: torch.Tensor,
intrinsic: torch.Tensor,
world2cam: torch.Tensor | None = None,
maintain_shape: bool = False,
) -> torch.Tensor:
"""
Project 3D world points to 2D image coordinates.
Args:
world_pts: 3D points in world coordinates
intrinsic: Camera intrinsic matrix
world2cam: Optional transformation from world to camera coordinates
maintain_shape: If True, preserves the original shape of points
Returns:
Image points with coordinates in img_y,img_x,z order
"""
# Transform points from world to camera space if world2cam is provided
cam_pts: torch.Tensor = world_pts
if world2cam is not None:
cam_pts = m_dot(world2cam, world_pts, maintain_shape=maintain_shape)
# Get shapes to properly expand intrinsics
shared_dims = intrinsic.shape[:-2]
extra_dims = cam_pts.shape[len(shared_dims) : -1]
# Expand intrinsics to match cam_pts shape
expanded_intrinsic = intrinsic.view(*shared_dims, *([1] * len(extra_dims)), 3, 3)
expanded_intrinsic = expanded_intrinsic.expand(*shared_dims, *extra_dims, 3, 3)
# Project points from camera space to image space
depth_abs = cam_pts[..., 2].abs().clamp(min=1e-5)
return torch.stack(
[
expanded_intrinsic[..., 1, 1] * cam_pts[..., 1] / depth_abs
+ expanded_intrinsic[..., 1, 2],
expanded_intrinsic[..., 0, 0] * cam_pts[..., 0] / depth_abs
+ expanded_intrinsic[..., 0, 2],
cam_pts[..., 2],
],
-1,
)
def in_image(
image_pts: torch.Tensor | list,
H: int,
W: int,
min_depth: float = 0.0,
) -> torch.Tensor | list:
"""
Check if image points are within the image boundaries.
Args:
image_pts: Image points in pixel coordinates
H: Image height
W: Image width
min_depth: Minimum valid depth
Returns:
Boolean mask indicating which points are within the image
"""
is_list = isinstance(image_pts, list)
if is_list:
return [in_image(pts, H, W, min_depth=min_depth) for pts in image_pts]
in_image_mask = (
torch.all(image_pts >= 0, -1)
& (image_pts[..., 0] < H)
& (image_pts[..., 1] < W)
)
if (min_depth is not None) and image_pts.shape[-1] == 3:
in_image_mask &= image_pts[..., 2] > min_depth
return in_image_mask
def _create_image_grid(H: int, W: int, device: torch.device) -> torch.Tensor:
"""
Create a coordinate grid for image pixels.
Args:
H: Image height
W: Image width
device: Computation device
Returns:
Image grid with shape HxWx3 (last dimension is homogeneous)
"""
y_coords = torch.arange(H, device=device)
x_coords = torch.arange(W, device=device)
# Use meshgrid with indexing="ij" for correct orientation
y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij")
# Stack coordinates and add homogeneous coordinate
img_grid = torch.stack([y_grid, x_grid, torch.ones_like(y_grid)], dim=-1)
return img_grid
def masking(
X: torch.Tensor | list,
mask: torch.Tensor | list,
dim: int = 3,
) -> torch.Tensor | list:
"""
Apply a Boolean mask to tensor or list elements.
Handles nested structures by recursively applying the mask.
Args:
X: Input tensor or list to be masked
mask: Boolean mask to apply
dim: Dimension threshold for recursive processing
Returns:
Masked tensor or list with the same structure as input
"""
if isinstance(X, list) or (isinstance(X, torch.Tensor) and X.dim() >= dim):
return [masking(x, m, dim) for x, m in zip(X, mask)]
return X[mask]
def m_inverse_intrinsics(intrinsics: torch.Tensor) -> torch.Tensor:
"""
Compute the inverse of camera intrinsics matrices analytically.
This is much faster than using torch.inverse() for intrinsics matrices.
The intrinsics matrix has the form:
K = [fx s cx]
[0 fy cy]
[0 0 1]
And its inverse is:
K^-1 = [1/fx -s/(fx*fy) (s*cy-cx*fy)/(fx*fy)]
[0 1/fy -cy/fy ]
[0 0 1 ]
Args:
intrinsics: Camera intrinsics matrices of shape [..., 3, 3]
Returns:
Inverse intrinsics matrices of shape [..., 3, 3]
"""
# Extract the components of the intrinsics matrix
fx = intrinsics[..., 0, 0]
s = intrinsics[..., 0, 1] # skew, usually 0
cx = intrinsics[..., 0, 2]
fy = intrinsics[..., 1, 1]
cy = intrinsics[..., 1, 2]
# Create output tensor with same shape and device
inv_intrinsics = torch.zeros_like(intrinsics)
# Compute the inverse analytically
inv_intrinsics[..., 0, 0] = 1.0 / fx
inv_intrinsics[..., 0, 1] = -s / (fx * fy)
inv_intrinsics[..., 0, 2] = (s * cy - cx * fy) / (fx * fy)
inv_intrinsics[..., 1, 1] = 1.0 / fy
inv_intrinsics[..., 1, 2] = -cy / fy
inv_intrinsics[..., 2, 2] = 1.0
return inv_intrinsics
def mult(
A: torch.Tensor | np.ndarray | list | float | int,
B: torch.Tensor | np.ndarray | list | float | int,
) -> torch.Tensor | np.ndarray | list | float | int:
"""
Multiply two objects with support for lists, tensors, arrays, and scalars.
Handles nested structures by recursively applying multiplication.
Args:
A: First operand (tensor, array, list, or scalar)
B: Second operand (tensor, array, list, or scalar)
Returns:
Result of multiplication with the same structure as inputs
"""
if isinstance(A, list) and isinstance(B, (int, float)):
return [mult(a, B) for a in A]
if isinstance(B, list) and isinstance(A, (int, float)):
return [mult(A, b) for b in B]
if isinstance(A, list) and isinstance(B, list):
return [mult(a, b) for a, b in zip(A, B)]
return A * B
|