Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| def m_dot( | |
| transform: torch.Tensor, | |
| points: torch.Tensor | list, | |
| maintain_shape: bool = False, | |
| ) -> torch.Tensor | list: | |
| """ | |
| Apply batch matrix multiplication between transform matrices and points. | |
| Args: | |
| transform: Batch of transformation matrices [..., 3/4, 3/4] | |
| points: Batch of points [..., N, 3] or a list of points | |
| maintain_shape: If True, preserves the original shape of points | |
| Returns: | |
| Transformed points with shape [..., N, 3] or a list of transformed points | |
| """ | |
| if isinstance(points, list): | |
| return [m_dot(t, p, maintain_shape) for t, p in zip(transform, points)] | |
| # Store original shape and flatten batch dimensions | |
| orig_shape = points.shape | |
| batch_dims = points.shape[:-3] | |
| # Reshape to standard batch format | |
| transform_flat = transform.reshape(-1, transform.shape[-2], transform.shape[-1]) | |
| points_flat = points.reshape(transform_flat.shape[0], -1, points.shape[-1]) | |
| # Apply transformation | |
| pts = torch.bmm( | |
| transform_flat[:, :3, :3], | |
| points_flat[..., :3].permute(0, 2, 1).to(transform_flat.dtype), | |
| ).permute(0, 2, 1) | |
| if transform.shape[-1] == 4: | |
| pts = pts + transform_flat[:, :3, 3].unsqueeze(1) | |
| # Restore original shape | |
| if maintain_shape: | |
| return pts.reshape(orig_shape) | |
| else: | |
| return pts.reshape(*batch_dims, -1, 3) | |
| def m_unproject( | |
| depth: torch.Tensor, | |
| intrinsic: torch.Tensor, | |
| cam2world: torch.Tensor = None, | |
| img_grid: torch.Tensor = None, | |
| valid: torch.Tensor = None, | |
| H: int | None = None, | |
| W: int | None = None, | |
| img_feats: torch.Tensor = None, | |
| maintain_shape: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Unproject 2D image points with depth values to 3D points in camera or world space. | |
| Args: | |
| depth: Depth values, either a tensor of shape ...xHxW or a float value | |
| intrinsic: Camera intrinsic matrix of shape ...x3x3 | |
| cam2world: Optional camera-to-world transformation matrix of shape ...x4x4 | |
| img_grid: Optional pre-computed image grid. If None, will be created | |
| valid: Optional mask for valid depth values or minimum depth threshold | |
| H: Image height (required if depth is a scalar) | |
| W: Image width (required if depth is a scalar) | |
| img_feats: Optional image features to append to 3D points | |
| maintain_shape: If True, preserves the original shape of points | |
| Returns: | |
| 3D points in camera or world space, with optional features appended | |
| """ | |
| # Get device and shape information from intrinsic matrix | |
| device = intrinsic.device | |
| pre_shape = intrinsic.shape[:-2] # Batch dimensions | |
| # Validate inputs | |
| if isinstance(depth, (int, float)) and H is None: | |
| raise ValueError("H must be provided if depth is a scalar") | |
| # Determine image dimensions from depth if not provided | |
| if isinstance(depth, torch.Tensor) and H is None: | |
| H, W = depth.shape[-2:] | |
| # Create image grid if not provided | |
| if img_grid is None: | |
| # Create coordinate grid with shape HxWx3 (last dimension is homogeneous) | |
| img_grid = _create_image_grid(H, W, device) | |
| # Add homogeneous coordinate | |
| img_grid = torch.cat([img_grid, torch.ones_like(img_grid[..., :1])], -1) | |
| # Expand img_grid to match batch dimensions of intrinsic | |
| if img_grid.dim() <= intrinsic.dim(): | |
| img_grid = img_grid.unsqueeze(0) | |
| img_grid = img_grid.expand(*pre_shape, *img_grid.shape[-3:]) | |
| # Handle valid mask or minimum depth threshold | |
| depth_mask = None | |
| if valid is not None: | |
| if isinstance(valid, float): | |
| # Create mask for minimum depth value | |
| depth_mask = depth > valid | |
| elif isinstance(valid, torch.Tensor): | |
| depth_mask = valid | |
| # Apply mask to image grid and other inputs | |
| img_grid = masking(img_grid, depth_mask, dim=intrinsic.dim()) | |
| if not isinstance(depth, (int, float)): | |
| depth = masking(depth, depth_mask, dim=intrinsic.dim() - 1) | |
| if img_feats is not None: | |
| img_feats = masking(img_feats, depth_mask, dim=intrinsic.dim() - 1) | |
| # Unproject 2D points to 3D camera space | |
| cam_pts: torch.Tensor = m_dot( | |
| m_inverse_intrinsics(intrinsic), | |
| img_grid[..., [1, 0, 2]], | |
| maintain_shape=True, | |
| ) | |
| # Scale by depth values | |
| cam_pts = mult(cam_pts, depth.unsqueeze(-1)) | |
| # Transform to world space if cam2world is provided | |
| if cam2world is not None: | |
| cam_pts = m_dot(cam2world, cam_pts, maintain_shape=True) | |
| # Append image features if provided | |
| if img_feats is not None: | |
| if isinstance(cam_pts, list): | |
| if isinstance(cam_pts[0], list): | |
| # Handle nested list case | |
| result = [] | |
| for batch_idx, batch in enumerate(cam_pts): | |
| batch_result = [] | |
| for view_idx, view in enumerate(batch): | |
| batch_result.append( | |
| torch.cat([view, img_feats[batch_idx][view_idx]], -1) | |
| ) | |
| result.append(batch_result) | |
| cam_pts = result | |
| else: | |
| # Handle single list case | |
| cam_pts = [ | |
| torch.cat([pts, feats], -1) | |
| for pts, feats in zip(cam_pts, img_feats) | |
| ] | |
| else: | |
| # Handle tensor case | |
| cam_pts = torch.cat([cam_pts, img_feats], -1) | |
| if maintain_shape: | |
| return cam_pts | |
| # Flatten last dimension | |
| return cam_pts.reshape(*pre_shape, -1, 3) | |
| def m_project( | |
| world_pts: torch.Tensor, | |
| intrinsic: torch.Tensor, | |
| world2cam: torch.Tensor | None = None, | |
| maintain_shape: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Project 3D world points to 2D image coordinates. | |
| Args: | |
| world_pts: 3D points in world coordinates | |
| intrinsic: Camera intrinsic matrix | |
| world2cam: Optional transformation from world to camera coordinates | |
| maintain_shape: If True, preserves the original shape of points | |
| Returns: | |
| Image points with coordinates in img_y,img_x,z order | |
| """ | |
| # Transform points from world to camera space if world2cam is provided | |
| cam_pts: torch.Tensor = world_pts | |
| if world2cam is not None: | |
| cam_pts = m_dot(world2cam, world_pts, maintain_shape=maintain_shape) | |
| # Get shapes to properly expand intrinsics | |
| shared_dims = intrinsic.shape[:-2] | |
| extra_dims = cam_pts.shape[len(shared_dims) : -1] | |
| # Expand intrinsics to match cam_pts shape | |
| expanded_intrinsic = intrinsic.view(*shared_dims, *([1] * len(extra_dims)), 3, 3) | |
| expanded_intrinsic = expanded_intrinsic.expand(*shared_dims, *extra_dims, 3, 3) | |
| # Project points from camera space to image space | |
| depth_abs = cam_pts[..., 2].abs().clamp(min=1e-5) | |
| return torch.stack( | |
| [ | |
| expanded_intrinsic[..., 1, 1] * cam_pts[..., 1] / depth_abs | |
| + expanded_intrinsic[..., 1, 2], | |
| expanded_intrinsic[..., 0, 0] * cam_pts[..., 0] / depth_abs | |
| + expanded_intrinsic[..., 0, 2], | |
| cam_pts[..., 2], | |
| ], | |
| -1, | |
| ) | |
| def in_image( | |
| image_pts: torch.Tensor | list, | |
| H: int, | |
| W: int, | |
| min_depth: float = 0.0, | |
| ) -> torch.Tensor | list: | |
| """ | |
| Check if image points are within the image boundaries. | |
| Args: | |
| image_pts: Image points in pixel coordinates | |
| H: Image height | |
| W: Image width | |
| min_depth: Minimum valid depth | |
| Returns: | |
| Boolean mask indicating which points are within the image | |
| """ | |
| is_list = isinstance(image_pts, list) | |
| if is_list: | |
| return [in_image(pts, H, W, min_depth=min_depth) for pts in image_pts] | |
| in_image_mask = ( | |
| torch.all(image_pts >= 0, -1) | |
| & (image_pts[..., 0] < H) | |
| & (image_pts[..., 1] < W) | |
| ) | |
| if (min_depth is not None) and image_pts.shape[-1] == 3: | |
| in_image_mask &= image_pts[..., 2] > min_depth | |
| return in_image_mask | |
| def _create_image_grid(H: int, W: int, device: torch.device) -> torch.Tensor: | |
| """ | |
| Create a coordinate grid for image pixels. | |
| Args: | |
| H: Image height | |
| W: Image width | |
| device: Computation device | |
| Returns: | |
| Image grid with shape HxWx3 (last dimension is homogeneous) | |
| """ | |
| y_coords = torch.arange(H, device=device) | |
| x_coords = torch.arange(W, device=device) | |
| # Use meshgrid with indexing="ij" for correct orientation | |
| y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij") | |
| # Stack coordinates and add homogeneous coordinate | |
| img_grid = torch.stack([y_grid, x_grid, torch.ones_like(y_grid)], dim=-1) | |
| return img_grid | |
| def masking( | |
| X: torch.Tensor | list, | |
| mask: torch.Tensor | list, | |
| dim: int = 3, | |
| ) -> torch.Tensor | list: | |
| """ | |
| Apply a Boolean mask to tensor or list elements. | |
| Handles nested structures by recursively applying the mask. | |
| Args: | |
| X: Input tensor or list to be masked | |
| mask: Boolean mask to apply | |
| dim: Dimension threshold for recursive processing | |
| Returns: | |
| Masked tensor or list with the same structure as input | |
| """ | |
| if isinstance(X, list) or (isinstance(X, torch.Tensor) and X.dim() >= dim): | |
| return [masking(x, m, dim) for x, m in zip(X, mask)] | |
| return X[mask] | |
| def m_inverse_intrinsics(intrinsics: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute the inverse of camera intrinsics matrices analytically. | |
| This is much faster than using torch.inverse() for intrinsics matrices. | |
| The intrinsics matrix has the form: | |
| K = [fx s cx] | |
| [0 fy cy] | |
| [0 0 1] | |
| And its inverse is: | |
| K^-1 = [1/fx -s/(fx*fy) (s*cy-cx*fy)/(fx*fy)] | |
| [0 1/fy -cy/fy ] | |
| [0 0 1 ] | |
| Args: | |
| intrinsics: Camera intrinsics matrices of shape [..., 3, 3] | |
| Returns: | |
| Inverse intrinsics matrices of shape [..., 3, 3] | |
| """ | |
| # Extract the components of the intrinsics matrix | |
| fx = intrinsics[..., 0, 0] | |
| s = intrinsics[..., 0, 1] # skew, usually 0 | |
| cx = intrinsics[..., 0, 2] | |
| fy = intrinsics[..., 1, 1] | |
| cy = intrinsics[..., 1, 2] | |
| # Create output tensor with same shape and device | |
| inv_intrinsics = torch.zeros_like(intrinsics) | |
| # Compute the inverse analytically | |
| inv_intrinsics[..., 0, 0] = 1.0 / fx | |
| inv_intrinsics[..., 0, 1] = -s / (fx * fy) | |
| inv_intrinsics[..., 0, 2] = (s * cy - cx * fy) / (fx * fy) | |
| inv_intrinsics[..., 1, 1] = 1.0 / fy | |
| inv_intrinsics[..., 1, 2] = -cy / fy | |
| inv_intrinsics[..., 2, 2] = 1.0 | |
| return inv_intrinsics | |
| def mult( | |
| A: torch.Tensor | np.ndarray | list | float | int, | |
| B: torch.Tensor | np.ndarray | list | float | int, | |
| ) -> torch.Tensor | np.ndarray | list | float | int: | |
| """ | |
| Multiply two objects with support for lists, tensors, arrays, and scalars. | |
| Handles nested structures by recursively applying multiplication. | |
| Args: | |
| A: First operand (tensor, array, list, or scalar) | |
| B: Second operand (tensor, array, list, or scalar) | |
| Returns: | |
| Result of multiplication with the same structure as inputs | |
| """ | |
| if isinstance(A, list) and isinstance(B, (int, float)): | |
| return [mult(a, B) for a in A] | |
| if isinstance(B, list) and isinstance(A, (int, float)): | |
| return [mult(A, b) for b in B] | |
| if isinstance(A, list) and isinstance(B, list): | |
| return [mult(a, b) for a, b in zip(A, B)] | |
| return A * B | |