Spaces:
Running
Running
| """ | |
| Utilities for geometry operations. | |
| References: DUSt3R, MoGe | |
| """ | |
| from numbers import Number | |
| from typing import Tuple, Union | |
| import einops as ein | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from mapanything.utils.misc import invalid_to_zeros | |
| from mapanything.utils.warnings import no_warnings | |
| def depthmap_to_camera_frame(depthmap, intrinsics): | |
| """ | |
| Convert depth image to a pointcloud in camera frame. | |
| Args: | |
| - depthmap: HxW or BxHxW torch tensor | |
| - intrinsics: 3x3 or Bx3x3 torch tensor | |
| Returns: | |
| pointmap in camera frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels. | |
| """ | |
| # Add batch dimension if not present | |
| if depthmap.dim() == 2: | |
| depthmap = depthmap.unsqueeze(0) | |
| intrinsics = intrinsics.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| batch_size, height, width = depthmap.shape | |
| device = depthmap.device | |
| # Compute 3D point in camera frame associated with each pixel | |
| x_grid, y_grid = torch.meshgrid( | |
| torch.arange(width, device=device).float(), | |
| torch.arange(height, device=device).float(), | |
| indexing="xy", | |
| ) | |
| x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) | |
| y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) | |
| fx = intrinsics[:, 0, 0].view(-1, 1, 1) | |
| fy = intrinsics[:, 1, 1].view(-1, 1, 1) | |
| cx = intrinsics[:, 0, 2].view(-1, 1, 1) | |
| cy = intrinsics[:, 1, 2].view(-1, 1, 1) | |
| depth_z = depthmap | |
| xx = (x_grid - cx) * depth_z / fx | |
| yy = (y_grid - cy) * depth_z / fy | |
| pts3d_cam = torch.stack((xx, yy, depth_z), dim=-1) | |
| # Compute mask of valid non-zero depth pixels | |
| valid_mask = depthmap > 0.0 | |
| # Remove batch dimension if it was added | |
| if squeeze_batch_dim: | |
| pts3d_cam = pts3d_cam.squeeze(0) | |
| valid_mask = valid_mask.squeeze(0) | |
| return pts3d_cam, valid_mask | |
| def depthmap_to_world_frame(depthmap, intrinsics, camera_pose=None): | |
| """ | |
| Convert depth image to a pointcloud in world frame. | |
| Args: | |
| - depthmap: HxW or BxHxW torch tensor | |
| - intrinsics: 3x3 or Bx3x3 torch tensor | |
| - camera_pose: 4x4 or Bx4x4 torch tensor | |
| Returns: | |
| pointmap in world frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels. | |
| """ | |
| pts3d_cam, valid_mask = depthmap_to_camera_frame(depthmap, intrinsics) | |
| if camera_pose is not None: | |
| # Add batch dimension if not present | |
| if camera_pose.dim() == 2: | |
| camera_pose = camera_pose.unsqueeze(0) | |
| pts3d_cam = pts3d_cam.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| # Convert points from camera frame to world frame | |
| pts3d_cam_homo = torch.cat( | |
| [pts3d_cam, torch.ones_like(pts3d_cam[..., :1])], dim=-1 | |
| ) | |
| pts3d_world = ein.einsum( | |
| camera_pose, pts3d_cam_homo, "b i k, b h w k -> b h w i" | |
| ) | |
| pts3d_world = pts3d_world[..., :3] | |
| # Remove batch dimension if it was added | |
| if squeeze_batch_dim: | |
| pts3d_world = pts3d_world.squeeze(0) | |
| else: | |
| pts3d_world = pts3d_cam | |
| return pts3d_world, valid_mask | |
| def transform_pts3d(pts3d, transformation): | |
| """ | |
| Transform 3D points using a 4x4 transformation matrix. | |
| Args: | |
| - pts3d: HxWx3 or BxHxWx3 torch tensor | |
| - transformation: 4x4 or Bx4x4 torch tensor | |
| Returns: | |
| transformed points (HxWx3 or BxHxWx3 tensor) | |
| """ | |
| # Add batch dimension if not present | |
| if pts3d.dim() == 3: | |
| pts3d = pts3d.unsqueeze(0) | |
| transformation = transformation.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| # Convert points to homogeneous coordinates | |
| pts3d_homo = torch.cat([pts3d, torch.ones_like(pts3d[..., :1])], dim=-1) | |
| # Transform points | |
| transformed_pts3d = ein.einsum( | |
| transformation, pts3d_homo, "b i k, b h w k -> b h w i" | |
| ) | |
| transformed_pts3d = transformed_pts3d[..., :3] | |
| # Remove batch dimension if it was added | |
| if squeeze_batch_dim: | |
| transformed_pts3d = transformed_pts3d.squeeze(0) | |
| return transformed_pts3d | |
| def project_pts3d_to_image(pts3d, intrinsics, return_z_dim): | |
| """ | |
| Project 3D points to image plane (assumes pinhole camera model with no distortion). | |
| Args: | |
| - pts3d: HxWx3 or BxHxWx3 torch tensor | |
| - intrinsics: 3x3 or Bx3x3 torch tensor | |
| - return_z_dim: bool, whether to return the third dimension of the projected points | |
| Returns: | |
| projected points (HxWx2) | |
| """ | |
| if pts3d.dim() == 3: | |
| pts3d = pts3d.unsqueeze(0) | |
| intrinsics = intrinsics.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| # Project points to image plane | |
| projected_pts2d = ein.einsum(intrinsics, pts3d, "b i k, b h w k -> b h w i") | |
| projected_pts2d[..., :2] /= projected_pts2d[..., 2].unsqueeze(-1).clamp(min=1e-6) | |
| # Remove the z dimension if not required | |
| if not return_z_dim: | |
| projected_pts2d = projected_pts2d[..., :2] | |
| # Remove batch dimension if it was added | |
| if squeeze_batch_dim: | |
| projected_pts2d = projected_pts2d.squeeze(0) | |
| return projected_pts2d | |
| def get_rays_in_camera_frame(intrinsics, height, width, normalize_to_unit_sphere): | |
| """ | |
| Convert camera intrinsics to a raymap (ray origins + directions) in camera frame. | |
| Note: Currently only supports pinhole camera model. | |
| Args: | |
| - intrinsics: 3x3 or Bx3x3 torch tensor | |
| - height: int | |
| - width: int | |
| - normalize_to_unit_sphere: bool | |
| Returns: | |
| - ray_origins: (HxWx3 or BxHxWx3) tensor | |
| - ray_directions: (HxWx3 or BxHxWx3) tensor | |
| """ | |
| # Add batch dimension if not present | |
| if intrinsics.dim() == 2: | |
| intrinsics = intrinsics.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| batch_size = intrinsics.shape[0] | |
| device = intrinsics.device | |
| # Compute rays in camera frame associated with each pixel | |
| x_grid, y_grid = torch.meshgrid( | |
| torch.arange(width, device=device).float(), | |
| torch.arange(height, device=device).float(), | |
| indexing="xy", | |
| ) | |
| x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) | |
| y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) | |
| fx = intrinsics[:, 0, 0].view(-1, 1, 1) | |
| fy = intrinsics[:, 1, 1].view(-1, 1, 1) | |
| cx = intrinsics[:, 0, 2].view(-1, 1, 1) | |
| cy = intrinsics[:, 1, 2].view(-1, 1, 1) | |
| ray_origins = torch.zeros((batch_size, height, width, 3), device=device) | |
| xx = (x_grid - cx) / fx | |
| yy = (y_grid - cy) / fy | |
| ray_directions = torch.stack((xx, yy, torch.ones_like(xx)), dim=-1) | |
| # Normalize ray directions to unit sphere if required (else rays will lie on unit plane) | |
| if normalize_to_unit_sphere: | |
| ray_directions = ray_directions / torch.norm( | |
| ray_directions, dim=-1, keepdim=True | |
| ) | |
| # Remove batch dimension if it was added | |
| if squeeze_batch_dim: | |
| ray_origins = ray_origins.squeeze(0) | |
| ray_directions = ray_directions.squeeze(0) | |
| return ray_origins, ray_directions | |
| def get_rays_in_world_frame( | |
| intrinsics, height, width, normalize_to_unit_sphere, camera_pose=None | |
| ): | |
| """ | |
| Convert camera intrinsics & camera_pose (if provided) to a raymap (ray origins + directions) in camera or world frame (if camera_pose is provided). | |
| Note: Currently only supports pinhole camera model. | |
| Args: | |
| - intrinsics: 3x3 or Bx3x3 torch tensor | |
| - height: int | |
| - width: int | |
| - normalize_to_unit_sphere: bool | |
| - camera_pose: 4x4 or Bx4x4 torch tensor | |
| Returns: | |
| - ray_origins: (HxWx3 or BxHxWx3) tensor | |
| - ray_directions: (HxWx3 or BxHxWx3) tensor | |
| """ | |
| # Get rays in camera frame | |
| ray_origins, ray_directions = get_rays_in_camera_frame( | |
| intrinsics, height, width, normalize_to_unit_sphere | |
| ) | |
| if camera_pose is not None: | |
| # Add batch dimension if not present | |
| if camera_pose.dim() == 2: | |
| camera_pose = camera_pose.unsqueeze(0) | |
| ray_origins = ray_origins.unsqueeze(0) | |
| ray_directions = ray_directions.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| # Convert rays from camera frame to world frame | |
| ray_origins_homo = torch.cat( | |
| [ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1 | |
| ) | |
| ray_directions_homo = torch.cat( | |
| [ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1 | |
| ) | |
| ray_origins_world = ein.einsum( | |
| camera_pose, ray_origins_homo, "b i k, b h w k -> b h w i" | |
| ) | |
| ray_directions_world = ein.einsum( | |
| camera_pose, ray_directions_homo, "b i k, b h w k -> b h w i" | |
| ) | |
| ray_origins_world = ray_origins_world[..., :3] | |
| ray_directions_world = ray_directions_world[..., :3] | |
| # Remove batch dimension if it was added | |
| if squeeze_batch_dim: | |
| ray_origins_world = ray_origins_world.squeeze(0) | |
| ray_directions_world = ray_directions_world.squeeze(0) | |
| else: | |
| ray_origins_world = ray_origins | |
| ray_directions_world = ray_directions | |
| return ray_origins_world, ray_directions_world | |
| def recover_pinhole_intrinsics_from_ray_directions( | |
| ray_directions, use_geometric_calculation=False | |
| ): | |
| """ | |
| Recover pinhole camera intrinsics from ray directions, supporting both batched and non-batched inputs. | |
| Args: | |
| ray_directions: Tensor of shape [H, W, 3] or [B, H, W, 3] containing unit normalized ray directions | |
| Returns: | |
| Dictionary containing camera intrinsics (fx, fy, cx, cy) as tensors | |
| """ | |
| # Add batch dimension if not present | |
| if ray_directions.dim() == 3: # [H, W, 3] | |
| ray_directions = ray_directions.unsqueeze(0) # [1, H, W, 3] | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| batch_size, height, width, _ = ray_directions.shape | |
| device = ray_directions.device | |
| # Create pixel coordinate grid | |
| x_grid, y_grid = torch.meshgrid( | |
| torch.arange(width, device=device).float(), | |
| torch.arange(height, device=device).float(), | |
| indexing="xy", | |
| ) | |
| # Expand grid for all batches | |
| x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) # [B, H, W] | |
| y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) # [B, H, W] | |
| # Determine if high resolution or not | |
| is_high_res = height * width > 1000000 | |
| if is_high_res or use_geometric_calculation: | |
| # For high-resolution cases, use direct geometric calculation | |
| # Define key points | |
| center_h, center_w = height // 2, width // 2 | |
| quarter_w, three_quarter_w = width // 4, 3 * width // 4 | |
| quarter_h, three_quarter_h = height // 4, 3 * height // 4 | |
| # Get rays at key points | |
| center_rays = ray_directions[:, center_h, center_w, :].clone() # [B, 3] | |
| left_rays = ray_directions[:, center_h, quarter_w, :].clone() # [B, 3] | |
| right_rays = ray_directions[:, center_h, three_quarter_w, :].clone() # [B, 3] | |
| top_rays = ray_directions[:, quarter_h, center_w, :].clone() # [B, 3] | |
| bottom_rays = ray_directions[:, three_quarter_h, center_w, :].clone() # [B, 3] | |
| # Normalize rays to have dz = 1 | |
| center_rays = center_rays / center_rays[:, 2].unsqueeze(1) # [B, 3] | |
| left_rays = left_rays / left_rays[:, 2].unsqueeze(1) # [B, 3] | |
| right_rays = right_rays / right_rays[:, 2].unsqueeze(1) # [B, 3] | |
| top_rays = top_rays / top_rays[:, 2].unsqueeze(1) # [B, 3] | |
| bottom_rays = bottom_rays / bottom_rays[:, 2].unsqueeze(1) # [B, 3] | |
| # Calculate fx directly (vectorized across batch) | |
| fx_left = (quarter_w - center_w) / (left_rays[:, 0] - center_rays[:, 0]) | |
| fx_right = (three_quarter_w - center_w) / (right_rays[:, 0] - center_rays[:, 0]) | |
| fx = (fx_left + fx_right) / 2 # Average for robustness | |
| # Calculate cx | |
| cx = center_w - fx * center_rays[:, 0] | |
| # Calculate fy and cy | |
| fy_top = (quarter_h - center_h) / (top_rays[:, 1] - center_rays[:, 1]) | |
| fy_bottom = (three_quarter_h - center_h) / ( | |
| bottom_rays[:, 1] - center_rays[:, 1] | |
| ) | |
| fy = (fy_top + fy_bottom) / 2 | |
| cy = center_h - fy * center_rays[:, 1] | |
| else: | |
| # For standard resolution, use regression with sampling for efficiency | |
| # Sample a grid of points (but more dense than for high-res) | |
| step_h = max(1, height // 50) | |
| step_w = max(1, width // 50) | |
| h_indices = torch.arange(0, height, step_h, device=device) | |
| w_indices = torch.arange(0, width, step_w, device=device) | |
| # Extract subset of coordinates | |
| x_sampled = x_grid[:, h_indices[:, None], w_indices[None, :]] # [B, H', W'] | |
| y_sampled = y_grid[:, h_indices[:, None], w_indices[None, :]] # [B, H', W'] | |
| rays_sampled = ray_directions[ | |
| :, h_indices[:, None], w_indices[None, :], : | |
| ] # [B, H', W', 3] | |
| # Reshape for linear regression | |
| x_flat = x_sampled.reshape(batch_size, -1) # [B, N] | |
| y_flat = y_sampled.reshape(batch_size, -1) # [B, N] | |
| # Extract ray direction components | |
| dx = rays_sampled[..., 0].reshape(batch_size, -1) # [B, N] | |
| dy = rays_sampled[..., 1].reshape(batch_size, -1) # [B, N] | |
| dz = rays_sampled[..., 2].reshape(batch_size, -1) # [B, N] | |
| # Compute ratios for linear regression | |
| ratio_x = dx / dz # [B, N] | |
| ratio_y = dy / dz # [B, N] | |
| # Since torch.linalg.lstsq doesn't support batched input, we'll use a different approach | |
| # For x-direction: x = cx + fx * (dx/dz) | |
| # We can solve this using normal equations: A^T A x = A^T b | |
| # Create design matrices | |
| ones = torch.ones_like(x_flat) # [B, N] | |
| A_x = torch.stack([ones, ratio_x], dim=2) # [B, N, 2] | |
| b_x = x_flat.unsqueeze(2) # [B, N, 1] | |
| # Compute A^T A and A^T b for each batch | |
| ATA_x = torch.bmm(A_x.transpose(1, 2), A_x) # [B, 2, 2] | |
| ATb_x = torch.bmm(A_x.transpose(1, 2), b_x) # [B, 2, 1] | |
| # Solve the system for each batch | |
| solution_x = torch.linalg.solve(ATA_x, ATb_x).squeeze(2) # [B, 2] | |
| cx, fx = solution_x[:, 0], solution_x[:, 1] | |
| # Repeat for y-direction | |
| A_y = torch.stack([ones, ratio_y], dim=2) # [B, N, 2] | |
| b_y = y_flat.unsqueeze(2) # [B, N, 1] | |
| ATA_y = torch.bmm(A_y.transpose(1, 2), A_y) # [B, 2, 2] | |
| ATb_y = torch.bmm(A_y.transpose(1, 2), b_y) # [B, 2, 1] | |
| solution_y = torch.linalg.solve(ATA_y, ATb_y).squeeze(2) # [B, 2] | |
| cy, fy = solution_y[:, 0], solution_y[:, 1] | |
| # Create intrinsics matrices | |
| batch_size = fx.shape[0] | |
| intrinsics = torch.zeros(batch_size, 3, 3, device=ray_directions.device) | |
| # Fill in the intrinsics matrices | |
| intrinsics[:, 0, 0] = fx # focal length x | |
| intrinsics[:, 1, 1] = fy # focal length y | |
| intrinsics[:, 0, 2] = cx # principal point x | |
| intrinsics[:, 1, 2] = cy # principal point y | |
| intrinsics[:, 2, 2] = 1.0 # bottom-right element is always 1 | |
| # Remove batch dimension if it was added | |
| if squeeze_batch_dim: | |
| intrinsics = intrinsics.squeeze(0) | |
| return intrinsics | |
| def transform_rays(ray_origins, ray_directions, transformation): | |
| """ | |
| Transform 6D rays (ray origins and ray directions) using a 4x4 transformation matrix. | |
| Args: | |
| - ray_origins: HxWx3 or BxHxWx3 torch tensor | |
| - ray_directions: HxWx3 or BxHxWx3 torch tensor | |
| - transformation: 4x4 or Bx4x4 torch tensor | |
| - normalize_to_unit_sphere: bool, whether to normalize the transformed ray directions to unit length | |
| Returns: | |
| transformed ray_origins (HxWx3 or BxHxWx3 tensor) and ray_directions (HxWx3 or BxHxWx3 tensor) | |
| """ | |
| # Add batch dimension if not present | |
| if ray_origins.dim() == 3: | |
| ray_origins = ray_origins.unsqueeze(0) | |
| ray_directions = ray_directions.unsqueeze(0) | |
| transformation = transformation.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| # Transform ray origins and directions | |
| ray_origins_homo = torch.cat( | |
| [ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1 | |
| ) | |
| ray_directions_homo = torch.cat( | |
| [ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1 | |
| ) | |
| transformed_ray_origins = ein.einsum( | |
| transformation, ray_origins_homo, "b i k, b h w k -> b h w i" | |
| ) | |
| transformed_ray_directions = ein.einsum( | |
| transformation, ray_directions_homo, "b i k, b h w k -> b h w i" | |
| ) | |
| transformed_ray_origins = transformed_ray_origins[..., :3] | |
| transformed_ray_directions = transformed_ray_directions[..., :3] | |
| # Remove batch dimension if it was added | |
| if squeeze_batch_dim: | |
| transformed_ray_origins = transformed_ray_origins.squeeze(0) | |
| transformed_ray_directions = transformed_ray_directions.squeeze(0) | |
| return transformed_ray_origins, transformed_ray_directions | |
| def convert_z_depth_to_depth_along_ray(z_depth, intrinsics): | |
| """ | |
| Convert z-depth image to depth along camera rays. | |
| Args: | |
| - z_depth: HxW or BxHxW torch tensor | |
| - intrinsics: 3x3 or Bx3x3 torch tensor | |
| Returns: | |
| - depth_along_ray: HxW or BxHxW torch tensor | |
| """ | |
| # Add batch dimension if not present | |
| if z_depth.dim() == 2: | |
| z_depth = z_depth.unsqueeze(0) | |
| intrinsics = intrinsics.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| # Get rays in camera frame | |
| batch_size, height, width = z_depth.shape | |
| _, ray_directions = get_rays_in_camera_frame( | |
| intrinsics, height, width, normalize_to_unit_sphere=False | |
| ) | |
| # Compute depth along ray | |
| pts3d_cam = z_depth[..., None] * ray_directions | |
| depth_along_ray = torch.norm(pts3d_cam, dim=-1) | |
| # Remove batch dimension if it was added | |
| if squeeze_batch_dim: | |
| depth_along_ray = depth_along_ray.squeeze(0) | |
| return depth_along_ray | |
| def convert_raymap_z_depth_quats_to_pointmap(ray_origins, ray_directions, depth, quats): | |
| """ | |
| Convert raymap (ray origins + directions on unit plane), z-depth and | |
| unit quaternions (representing rotation) to a pointmap in world frame. | |
| Args: | |
| - ray_origins: (HxWx3 or BxHxWx3) torch tensor | |
| - ray_directions: (HxWx3 or BxHxWx3) torch tensor | |
| - depth: (HxWx1 or BxHxWx1) torch tensor | |
| - quats: (HxWx4 or BxHxWx4) torch tensor (unit quaternions and notation is (x, y, z, w)) | |
| Returns: | |
| - pointmap: (HxWx3 or BxHxWx3) torch tensor | |
| """ | |
| # Add batch dimension if not present | |
| if ray_origins.dim() == 3: | |
| ray_origins = ray_origins.unsqueeze(0) | |
| ray_directions = ray_directions.unsqueeze(0) | |
| depth = depth.unsqueeze(0) | |
| quats = quats.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| batch_size, height, width, _ = depth.shape | |
| device = depth.device | |
| # Normalize the quaternions to ensure they are unit quaternions | |
| quats = quats / torch.norm(quats, dim=-1, keepdim=True) | |
| # Convert quaternions to pixel-wise rotation matrices | |
| qx, qy, qz, qw = quats[..., 0], quats[..., 1], quats[..., 2], quats[..., 3] | |
| rot_mat = ( | |
| torch.stack( | |
| [ | |
| qw**2 + qx**2 - qy**2 - qz**2, | |
| 2 * (qx * qy - qw * qz), | |
| 2 * (qw * qy + qx * qz), | |
| 2 * (qw * qz + qx * qy), | |
| qw**2 - qx**2 + qy**2 - qz**2, | |
| 2 * (qy * qz - qw * qx), | |
| 2 * (qx * qz - qw * qy), | |
| 2 * (qw * qx + qy * qz), | |
| qw**2 - qx**2 - qy**2 + qz**2, | |
| ], | |
| dim=-1, | |
| ) | |
| .reshape(batch_size, height, width, 3, 3) | |
| .to(device) | |
| ) | |
| # Compute 3D points in local camera frame | |
| pts3d_local = depth * ray_directions | |
| # Rotate the local points using the quaternions | |
| rotated_pts3d_local = ein.einsum( | |
| rot_mat, pts3d_local, "b h w i k, b h w k -> b h w i" | |
| ) | |
| # Compute 3D point in world frame associated with each pixel | |
| pts3d = ray_origins + rotated_pts3d_local | |
| # Remove batch dimension if it was added | |
| if squeeze_batch_dim: | |
| pts3d = pts3d.squeeze(0) | |
| return pts3d | |
| def quaternion_to_rotation_matrix(quat): | |
| """ | |
| Convert a quaternion into a 3x3 rotation matrix. | |
| Args: | |
| - quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) | |
| Returns: | |
| - rot_matrix: 3x3 or Bx3x3 torch tensor | |
| """ | |
| if quat.dim() == 1: | |
| quat = quat.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| # Ensure the quaternion is normalized | |
| quat = quat / quat.norm(dim=1, keepdim=True) | |
| x, y, z, w = quat.unbind(dim=1) | |
| # Compute the rotation matrix elements | |
| xx = x * x | |
| yy = y * y | |
| zz = z * z | |
| xy = x * y | |
| xz = x * z | |
| yz = y * z | |
| wx = w * x | |
| wy = w * y | |
| wz = w * z | |
| # Construct the rotation matrix | |
| rot_matrix = torch.stack( | |
| [ | |
| 1 - 2 * (yy + zz), | |
| 2 * (xy - wz), | |
| 2 * (xz + wy), | |
| 2 * (xy + wz), | |
| 1 - 2 * (xx + zz), | |
| 2 * (yz - wx), | |
| 2 * (xz - wy), | |
| 2 * (yz + wx), | |
| 1 - 2 * (xx + yy), | |
| ], | |
| dim=1, | |
| ).view(-1, 3, 3) | |
| # Squeeze batch dimension if it was unsqueezed | |
| if squeeze_batch_dim: | |
| rot_matrix = rot_matrix.squeeze(0) | |
| return rot_matrix | |
| def rotation_matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Convert rotations given as rotation matrices to quaternions. | |
| Args: | |
| matrix: Rotation matrices as tensor of shape (..., 3, 3). | |
| Returns: | |
| quaternions with real part last, as tensor of shape (..., 4). | |
| Quaternion Order: XYZW or say ijkr, scalar-last | |
| """ | |
| if matrix.size(-1) != 3 or matrix.size(-2) != 3: | |
| raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") | |
| batch_dim = matrix.shape[:-2] | |
| m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( | |
| matrix.reshape(batch_dim + (9,)), dim=-1 | |
| ) | |
| q_abs = _sqrt_positive_part( | |
| torch.stack( | |
| [ | |
| 1.0 + m00 + m11 + m22, | |
| 1.0 + m00 - m11 - m22, | |
| 1.0 - m00 + m11 - m22, | |
| 1.0 - m00 - m11 + m22, | |
| ], | |
| dim=-1, | |
| ) | |
| ) | |
| # we produce the desired quaternion multiplied by each of r, i, j, k | |
| quat_by_rijk = torch.stack( | |
| [ | |
| torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), | |
| torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), | |
| torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), | |
| torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), | |
| ], | |
| dim=-2, | |
| ) | |
| # We floor here at 0.1 but the exact level is not important; if q_abs is small, | |
| # the candidate won't be picked. | |
| flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) | |
| quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) | |
| # if not for numerical problems, quat_candidates[i] should be same (up to a sign), | |
| # forall i; we pick the best-conditioned one (with the largest denominator) | |
| out = quat_candidates[ | |
| F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : | |
| ].reshape(batch_dim + (4,)) | |
| # Convert from rijk to ijkr | |
| out = out[..., [1, 2, 3, 0]] | |
| out = standardize_quaternion(out) | |
| return out | |
| def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Returns torch.sqrt(torch.max(0, x)) | |
| but with a zero subgradient where x is 0. | |
| """ | |
| ret = torch.zeros_like(x) | |
| positive_mask = x > 0 | |
| if torch.is_grad_enabled(): | |
| ret[positive_mask] = torch.sqrt(x[positive_mask]) | |
| else: | |
| ret = torch.where(positive_mask, torch.sqrt(x), ret) | |
| return ret | |
| def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Convert a unit quaternion to a standard form: one in which the real | |
| part is non negative. | |
| Args: | |
| quaternions: Quaternions with real part last, | |
| as tensor of shape (..., 4). | |
| Returns: | |
| Standardized quaternions as tensor of shape (..., 4). | |
| """ | |
| return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) | |
| def quaternion_inverse(quat): | |
| """ | |
| Compute the inverse of a quaternion. | |
| Args: | |
| - quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) | |
| Returns: | |
| - inv_quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) | |
| """ | |
| # Unsqueeze batch dimension if not present | |
| if quat.dim() == 1: | |
| quat = quat.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| # Compute the inverse | |
| quat_conj = quat.clone() | |
| quat_conj[:, :3] = -quat_conj[:, :3] | |
| quat_norm = torch.sum(quat * quat, dim=1, keepdim=True) | |
| inv_quat = quat_conj / quat_norm | |
| # Squeeze batch dimension if it was unsqueezed | |
| if squeeze_batch_dim: | |
| inv_quat = inv_quat.squeeze(0) | |
| return inv_quat | |
| def quaternion_multiply(q1, q2): | |
| """ | |
| Multiply two quaternions. | |
| Args: | |
| - q1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) | |
| - q2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) | |
| Returns: | |
| - qm: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) | |
| """ | |
| # Unsqueeze batch dimension if not present | |
| if q1.dim() == 1: | |
| q1 = q1.unsqueeze(0) | |
| q2 = q2.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| # Unbind the quaternions | |
| x1, y1, z1, w1 = q1.unbind(dim=1) | |
| x2, y2, z2, w2 = q2.unbind(dim=1) | |
| # Compute the product | |
| x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 | |
| y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 | |
| z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 | |
| w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 | |
| # Stack the components | |
| qm = torch.stack([x, y, z, w], dim=1) | |
| # Squeeze batch dimension if it was unsqueezed | |
| if squeeze_batch_dim: | |
| qm = qm.squeeze(0) | |
| return qm | |
| def transform_pose_using_quats_and_trans_2_to_1(quats1, trans1, quats2, trans2): | |
| """ | |
| Transform quats and translation of pose2 from absolute frame (pose2 to world) to relative frame (pose2 to pose1). | |
| Args: | |
| - quats1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) | |
| - trans1: 3 or Bx3 torch tensor | |
| - quats2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) | |
| - trans2: 3 or Bx3 torch tensor | |
| Returns: | |
| - quats: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w)) | |
| - trans: 3 or Bx3 torch tensor | |
| """ | |
| # Unsqueeze batch dimension if not present | |
| if quats1.dim() == 1: | |
| quats1 = quats1.unsqueeze(0) | |
| trans1 = trans1.unsqueeze(0) | |
| quats2 = quats2.unsqueeze(0) | |
| trans2 = trans2.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| # Compute the inverse of view1's pose | |
| inv_quats1 = quaternion_inverse(quats1) | |
| R1_inv = quaternion_to_rotation_matrix(inv_quats1) | |
| t1_inv = -1 * ein.einsum(R1_inv, trans1, "b i j, b j -> b i") | |
| # Transform view2's pose to view1's frame | |
| quats = quaternion_multiply(inv_quats1, quats2) | |
| trans = ein.einsum(R1_inv, trans2, "b i j, b j -> b i") + t1_inv | |
| # Squeeze batch dimension if it was unsqueezed | |
| if squeeze_batch_dim: | |
| quats = quats.squeeze(0) | |
| trans = trans.squeeze(0) | |
| return quats, trans | |
| def convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap( | |
| ray_directions, depth_along_ray, pose_trans, pose_quats | |
| ): | |
| """ | |
| Convert ray directions, depth along ray, pose translation, and | |
| unit quaternions (representing pose rotation) to a pointmap in world frame. | |
| Args: | |
| - ray_directions: (HxWx3 or BxHxWx3) torch tensor | |
| - depth_along_ray: (HxWx1 or BxHxWx1) torch tensor | |
| - pose_trans: (3 or Bx3) torch tensor | |
| - pose_quats: (4 or Bx4) torch tensor (unit quaternions and notation is (x, y, z, w)) | |
| Returns: | |
| - pointmap: (HxWx3 or BxHxWx3) torch tensor | |
| """ | |
| # Add batch dimension if not present | |
| if ray_directions.dim() == 3: | |
| ray_directions = ray_directions.unsqueeze(0) | |
| depth_along_ray = depth_along_ray.unsqueeze(0) | |
| pose_trans = pose_trans.unsqueeze(0) | |
| pose_quats = pose_quats.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| else: | |
| squeeze_batch_dim = False | |
| batch_size, height, width, _ = depth_along_ray.shape | |
| device = depth_along_ray.device | |
| # Normalize the quaternions to ensure they are unit quaternions | |
| pose_quats = pose_quats / torch.norm(pose_quats, dim=-1, keepdim=True) | |
| # Convert quaternions to rotation matrices (B x 3 x 3) | |
| rot_mat = quaternion_to_rotation_matrix(pose_quats) | |
| # Get pose matrix (B x 4 x 4) | |
| pose_mat = torch.eye(4, device=device).unsqueeze(0).repeat(batch_size, 1, 1) | |
| pose_mat[:, :3, :3] = rot_mat | |
| pose_mat[:, :3, 3] = pose_trans | |
| # Compute 3D points in local camera frame | |
| pts3d_local = depth_along_ray * ray_directions | |
| # Compute 3D points in world frame | |
| pts3d_homo = torch.cat([pts3d_local, torch.ones_like(pts3d_local[..., :1])], dim=-1) | |
| pts3d_world = ein.einsum(pose_mat, pts3d_homo, "b i k, b h w k -> b h w i") | |
| pts3d_world = pts3d_world[..., :3] | |
| # Remove batch dimension if it was added | |
| if squeeze_batch_dim: | |
| pts3d_world = pts3d_world.squeeze(0) | |
| return pts3d_world | |
| def xy_grid( | |
| W, | |
| H, | |
| device=None, | |
| origin=(0, 0), | |
| unsqueeze=None, | |
| cat_dim=-1, | |
| homogeneous=False, | |
| **arange_kw, | |
| ): | |
| """ | |
| Generate a coordinate grid of shape (H,W,2) or (H,W,3) if homogeneous=True. | |
| Args: | |
| W (int): Width of the grid | |
| H (int): Height of the grid | |
| device (torch.device, optional): Device to place the grid on. If None, uses numpy arrays | |
| origin (tuple, optional): Origin coordinates (x,y) for the grid. Default is (0,0) | |
| unsqueeze (int, optional): Dimension to unsqueeze in the output tensors | |
| cat_dim (int, optional): Dimension to concatenate the x,y coordinates. If None, returns tuple | |
| homogeneous (bool, optional): If True, adds a third dimension of ones to make homogeneous coordinates | |
| **arange_kw: Additional keyword arguments passed to np.arange or torch.arange | |
| Returns: | |
| numpy.ndarray or torch.Tensor: Coordinate grid where: | |
| - output[j,i,0] = i + origin[0] (x-coordinate) | |
| - output[j,i,1] = j + origin[1] (y-coordinate) | |
| - output[j,i,2] = 1 (if homogeneous=True) | |
| """ | |
| if device is None: | |
| # numpy | |
| arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones | |
| else: | |
| # torch | |
| def arange(*a, **kw): | |
| return torch.arange(*a, device=device, **kw) | |
| meshgrid, stack = torch.meshgrid, torch.stack | |
| def ones(*a): | |
| return torch.ones(*a, device=device) | |
| tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] | |
| grid = meshgrid(tw, th, indexing="xy") | |
| if homogeneous: | |
| grid = grid + (ones((H, W)),) | |
| if unsqueeze is not None: | |
| grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) | |
| if cat_dim is not None: | |
| grid = stack(grid, cat_dim) | |
| return grid | |
| def geotrf(Trf, pts, ncol=None, norm=False): | |
| """ | |
| Apply a geometric transformation to a set of 3-D points. | |
| Args: | |
| Trf: 3x3 or 4x4 projection matrix (typically a Homography) or batch of matrices | |
| with shape (B, 3, 3) or (B, 4, 4) | |
| pts: numpy/torch/tuple of coordinates with shape (..., 2) or (..., 3) | |
| ncol: int, number of columns of the result (2 or 3) | |
| norm: float, if not 0, the result is projected on the z=norm plane | |
| (homogeneous normalization) | |
| Returns: | |
| Array or tensor of projected points with the same type as input and shape (..., ncol) | |
| """ | |
| assert Trf.ndim >= 2 | |
| if isinstance(Trf, np.ndarray): | |
| pts = np.asarray(pts) | |
| elif isinstance(Trf, torch.Tensor): | |
| pts = torch.as_tensor(pts, dtype=Trf.dtype) | |
| # Adapt shape if necessary | |
| output_reshape = pts.shape[:-1] | |
| ncol = ncol or pts.shape[-1] | |
| # Optimized code | |
| if ( | |
| isinstance(Trf, torch.Tensor) | |
| and isinstance(pts, torch.Tensor) | |
| and Trf.ndim == 3 | |
| and pts.ndim == 4 | |
| ): | |
| d = pts.shape[3] | |
| if Trf.shape[-1] == d: | |
| pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) | |
| elif Trf.shape[-1] == d + 1: | |
| pts = ( | |
| torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) | |
| + Trf[:, None, None, :d, d] | |
| ) | |
| else: | |
| raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") | |
| else: | |
| if Trf.ndim >= 3: | |
| n = Trf.ndim - 2 | |
| assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" | |
| Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) | |
| if pts.ndim > Trf.ndim: | |
| # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) | |
| pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) | |
| elif pts.ndim == 2: | |
| # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) | |
| pts = pts[:, None, :] | |
| if pts.shape[-1] + 1 == Trf.shape[-1]: | |
| Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
| pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] | |
| elif pts.shape[-1] == Trf.shape[-1]: | |
| Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
| pts = pts @ Trf | |
| else: | |
| pts = Trf @ pts.T | |
| if pts.ndim >= 2: | |
| pts = pts.swapaxes(-1, -2) | |
| if norm: | |
| pts = pts / pts[..., -1:] # DONT DO /=, it will lead to a bug | |
| if norm != 1: | |
| pts *= norm | |
| res = pts[..., :ncol].reshape(*output_reshape, ncol) | |
| return res | |
| def inv(mat): | |
| """ | |
| Invert a torch or numpy matrix | |
| """ | |
| if isinstance(mat, torch.Tensor): | |
| return torch.linalg.inv(mat) | |
| if isinstance(mat, np.ndarray): | |
| return np.linalg.inv(mat) | |
| raise ValueError(f"bad matrix type = {type(mat)}") | |
| def closed_form_pose_inverse( | |
| pose_matrices, rotation_matrices=None, translation_vectors=None | |
| ): | |
| """ | |
| Compute the inverse of each 4x4 (or 3x4) SE3 pose matrices in a batch. | |
| If `rotation_matrices` and `translation_vectors` are provided, they must correspond to the rotation and translation | |
| components of `pose_matrices`. Otherwise, they will be extracted from `pose_matrices`. | |
| Args: | |
| pose_matrices: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. | |
| rotation_matrices (optional): Nx3x3 array or tensor of rotation matrices. | |
| translation_vectors (optional): Nx3x1 array or tensor of translation vectors. | |
| Returns: | |
| Inverted SE3 matrices with the same type and device as input `pose_matrices`. | |
| Shapes: | |
| pose_matrices: (N, 4, 4) | |
| rotation_matrices: (N, 3, 3) | |
| translation_vectors: (N, 3, 1) | |
| """ | |
| # Check if pose_matrices is a numpy array or a torch tensor | |
| is_numpy = isinstance(pose_matrices, np.ndarray) | |
| # Validate shapes | |
| if pose_matrices.shape[-2:] != (4, 4) and pose_matrices.shape[-2:] != (3, 4): | |
| raise ValueError( | |
| f"pose_matrices must be of shape (N,4,4), got {pose_matrices.shape}." | |
| ) | |
| # Extract rotation_matrices and translation_vectors if not provided | |
| if rotation_matrices is None: | |
| rotation_matrices = pose_matrices[:, :3, :3] | |
| if translation_vectors is None: | |
| translation_vectors = pose_matrices[:, :3, 3:] | |
| # Compute the inverse of input SE3 matrices | |
| if is_numpy: | |
| rotation_transposed = np.transpose(rotation_matrices, (0, 2, 1)) | |
| new_translation = -np.matmul(rotation_transposed, translation_vectors) | |
| inverted_matrix = np.tile(np.eye(4), (len(rotation_matrices), 1, 1)) | |
| else: | |
| rotation_transposed = rotation_matrices.transpose(1, 2) | |
| new_translation = -torch.bmm(rotation_transposed, translation_vectors) | |
| inverted_matrix = torch.eye(4, 4)[None].repeat(len(rotation_matrices), 1, 1) | |
| inverted_matrix = inverted_matrix.to(rotation_matrices.dtype).to( | |
| rotation_matrices.device | |
| ) | |
| inverted_matrix[:, :3, :3] = rotation_transposed | |
| inverted_matrix[:, :3, 3:] = new_translation | |
| return inverted_matrix | |
| def relative_pose_transformation(trans_01, trans_02): | |
| r""" | |
| Function that computes the relative homogenous transformation from a | |
| reference transformation :math:`T_1^{0} = \begin{bmatrix} R_1 & t_1 \\ | |
| \mathbf{0} & 1 \end{bmatrix}` to destination :math:`T_2^{0} = | |
| \begin{bmatrix} R_2 & t_2 \\ \mathbf{0} & 1 \end{bmatrix}`. | |
| The relative transformation is computed as follows: | |
| .. math:: | |
| T_1^{2} = (T_0^{1})^{-1} \cdot T_0^{2} | |
| Arguments: | |
| trans_01 (torch.Tensor): reference transformation tensor of shape | |
| :math:`(N, 4, 4)` or :math:`(4, 4)`. | |
| trans_02 (torch.Tensor): destination transformation tensor of shape | |
| :math:`(N, 4, 4)` or :math:`(4, 4)`. | |
| Shape: | |
| - Output: :math:`(N, 4, 4)` or :math:`(4, 4)`. | |
| Returns: | |
| torch.Tensor: the relative transformation between the transformations. | |
| Example:: | |
| >>> trans_01 = torch.eye(4) # 4x4 | |
| >>> trans_02 = torch.eye(4) # 4x4 | |
| >>> trans_12 = relative_transformation(trans_01, trans_02) # 4x4 | |
| """ | |
| if not torch.is_tensor(trans_01): | |
| raise TypeError( | |
| "Input trans_01 type is not a torch.Tensor. Got {}".format(type(trans_01)) | |
| ) | |
| if not torch.is_tensor(trans_02): | |
| raise TypeError( | |
| "Input trans_02 type is not a torch.Tensor. Got {}".format(type(trans_02)) | |
| ) | |
| if trans_01.dim() not in (2, 3) and trans_01.shape[-2:] == (4, 4): | |
| raise ValueError( | |
| "Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_01.shape) | |
| ) | |
| if trans_02.dim() not in (2, 3) and trans_02.shape[-2:] == (4, 4): | |
| raise ValueError( | |
| "Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_02.shape) | |
| ) | |
| if not trans_01.dim() == trans_02.dim(): | |
| raise ValueError( | |
| "Input number of dims must match. Got {} and {}".format( | |
| trans_01.dim(), trans_02.dim() | |
| ) | |
| ) | |
| # Convert to Nx4x4 if inputs are 4x4 | |
| squeeze_batch_dim = False | |
| if trans_01.dim() == 2: | |
| trans_01 = trans_01.unsqueeze(0) | |
| trans_02 = trans_02.unsqueeze(0) | |
| squeeze_batch_dim = True | |
| # Compute inverse of trans_01 using closed form | |
| trans_10 = closed_form_pose_inverse(trans_01) | |
| # Compose transformations using matrix multiplication | |
| trans_12 = torch.matmul(trans_10, trans_02) | |
| # Remove batch dimension if it was added | |
| if squeeze_batch_dim: | |
| trans_12 = trans_12.squeeze(0) | |
| return trans_12 | |
| def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): | |
| """ | |
| Args: | |
| - depthmap (BxHxW array): | |
| - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] | |
| Returns: | |
| pointmap of absolute coordinates (BxHxWx3 array) | |
| """ | |
| if len(depth.shape) == 4: | |
| B, H, W, n = depth.shape | |
| else: | |
| B, H, W = depth.shape | |
| n = None | |
| if len(pseudo_focal.shape) == 3: # [B,H,W] | |
| pseudo_focalx = pseudo_focaly = pseudo_focal | |
| elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] | |
| pseudo_focalx = pseudo_focal[:, 0] | |
| if pseudo_focal.shape[1] == 2: | |
| pseudo_focaly = pseudo_focal[:, 1] | |
| else: | |
| pseudo_focaly = pseudo_focalx | |
| else: | |
| raise NotImplementedError("Error, unknown input focal shape format.") | |
| assert pseudo_focalx.shape == depth.shape[:3] | |
| assert pseudo_focaly.shape == depth.shape[:3] | |
| grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] | |
| # set principal point | |
| if pp is None: | |
| grid_x = grid_x - (W - 1) / 2 | |
| grid_y = grid_y - (H - 1) / 2 | |
| else: | |
| grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] | |
| grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] | |
| if n is None: | |
| pts3d = torch.empty((B, H, W, 3), device=depth.device) | |
| pts3d[..., 0] = depth * grid_x / pseudo_focalx | |
| pts3d[..., 1] = depth * grid_y / pseudo_focaly | |
| pts3d[..., 2] = depth | |
| else: | |
| pts3d = torch.empty((B, H, W, 3, n), device=depth.device) | |
| pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] | |
| pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] | |
| pts3d[..., 2, :] = depth | |
| return pts3d | |
| def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): | |
| """ | |
| Args: | |
| - depthmap (HxW array): | |
| - camera_intrinsics: a 3x3 matrix | |
| Returns: | |
| pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. | |
| """ | |
| camera_intrinsics = np.float32(camera_intrinsics) | |
| H, W = depthmap.shape | |
| # Compute 3D ray associated with each pixel | |
| # Strong assumption: there are no skew terms | |
| assert camera_intrinsics[0, 1] == 0.0 | |
| assert camera_intrinsics[1, 0] == 0.0 | |
| if pseudo_focal is None: | |
| fu = camera_intrinsics[0, 0] | |
| fv = camera_intrinsics[1, 1] | |
| else: | |
| assert pseudo_focal.shape == (H, W) | |
| fu = fv = pseudo_focal | |
| cu = camera_intrinsics[0, 2] | |
| cv = camera_intrinsics[1, 2] | |
| u, v = np.meshgrid(np.arange(W), np.arange(H)) | |
| z_cam = depthmap | |
| x_cam = (u - cu) * z_cam / fu | |
| y_cam = (v - cv) * z_cam / fv | |
| X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) | |
| # Mask for valid coordinates | |
| valid_mask = depthmap > 0.0 | |
| return X_cam, valid_mask | |
| def depthmap_to_absolute_camera_coordinates( | |
| depthmap, camera_intrinsics, camera_pose, **kw | |
| ): | |
| """ | |
| Args: | |
| - depthmap (HxW array): | |
| - camera_intrinsics: a 3x3 matrix | |
| - camera_pose: a 4x3 or 4x4 cam2world matrix | |
| Returns: | |
| pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. | |
| """ | |
| X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) | |
| X_world = X_cam # default | |
| if camera_pose is not None: | |
| # R_cam2world = np.float32(camera_params["R_cam2world"]) | |
| # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() | |
| R_cam2world = camera_pose[:3, :3] | |
| t_cam2world = camera_pose[:3, 3] | |
| # Express in absolute coordinates (invalid depth values) | |
| X_world = ( | |
| np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] | |
| ) | |
| return X_world, valid_mask | |
| def get_absolute_pointmaps_and_rays_info( | |
| depthmap, camera_intrinsics, camera_pose, **kw | |
| ): | |
| """ | |
| Args: | |
| - depthmap (HxW array): | |
| - camera_intrinsics: a 3x3 matrix | |
| - camera_pose: a 4x3 or 4x4 cam2world matrix | |
| Returns: | |
| pointmap of absolute coordinates (HxWx3 array), | |
| a mask specifying valid pixels, | |
| ray origins of absolute coordinates (HxWx3 array), | |
| ray directions of absolute coordinates (HxWx3 array), | |
| depth along ray (HxWx1 array), | |
| ray directions of camera/local coordinates (HxWx3 array), | |
| pointmap of camera/local coordinates (HxWx3 array). | |
| """ | |
| camera_intrinsics = np.float32(camera_intrinsics) | |
| H, W = depthmap.shape | |
| # Compute 3D ray associated with each pixel | |
| # Strong assumption: pinhole & there are no skew terms | |
| assert camera_intrinsics[0, 1] == 0.0 | |
| assert camera_intrinsics[1, 0] == 0.0 | |
| fu = camera_intrinsics[0, 0] | |
| fv = camera_intrinsics[1, 1] | |
| cu = camera_intrinsics[0, 2] | |
| cv = camera_intrinsics[1, 2] | |
| # Get the rays on the unit plane | |
| u, v = np.meshgrid(np.arange(W), np.arange(H)) | |
| x_cam = (u - cu) / fu | |
| y_cam = (v - cv) / fv | |
| z_cam = np.ones_like(x_cam) | |
| ray_dirs_cam_on_unit_plane = np.stack((x_cam, y_cam, z_cam), axis=-1).astype( | |
| np.float32 | |
| ) | |
| # Compute the 3d points in the local camera coordinate system | |
| pts_cam = depthmap[..., None] * ray_dirs_cam_on_unit_plane | |
| # Get the depth along the ray and compute the ray directions on the unit sphere | |
| depth_along_ray = np.linalg.norm(pts_cam, axis=-1, keepdims=True) | |
| ray_directions_cam = ray_dirs_cam_on_unit_plane / np.linalg.norm( | |
| ray_dirs_cam_on_unit_plane, axis=-1, keepdims=True | |
| ) | |
| # Mask for valid coordinates | |
| valid_mask = depthmap > 0.0 | |
| # Get the ray origins in absolute coordinates and the ray directions in absolute coordinates | |
| ray_origins_world = np.zeros_like(ray_directions_cam) | |
| ray_directions_world = ray_directions_cam | |
| pts_world = pts_cam | |
| if camera_pose is not None: | |
| R_cam2world = camera_pose[:3, :3] | |
| t_cam2world = camera_pose[:3, 3] | |
| # Express in absolute coordinates | |
| ray_origins_world = ray_origins_world + t_cam2world[None, None, :] | |
| ray_directions_world = np.einsum( | |
| "ik, vuk -> vui", R_cam2world, ray_directions_cam | |
| ) | |
| pts_world = ray_origins_world + ray_directions_world * depth_along_ray | |
| return ( | |
| pts_world, | |
| valid_mask, | |
| ray_origins_world, | |
| ray_directions_world, | |
| depth_along_ray, | |
| ray_directions_cam, | |
| pts_cam, | |
| ) | |
| def adjust_camera_params_for_rotation(camera_params, original_size, k): | |
| """ | |
| Adjust camera parameters for rotation. | |
| Args: | |
| camera_params: Camera parameters [fx, fy, cx, cy, ...] | |
| original_size: Original image size as (width, height) | |
| k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise) | |
| Returns: | |
| Adjusted camera parameters | |
| """ | |
| fx, fy, cx, cy = camera_params[:4] | |
| width, height = original_size | |
| if k % 4 == 1: # 90 degrees counter-clockwise | |
| new_fx, new_fy = fy, fx | |
| new_cx, new_cy = height - cy, cx | |
| elif k % 4 == 2: # 180 degrees | |
| new_fx, new_fy = fx, fy | |
| new_cx, new_cy = width - cx, height - cy | |
| elif k % 4 == 3: # 90 degrees clockwise (270 counter-clockwise) | |
| new_fx, new_fy = fy, fx | |
| new_cx, new_cy = cy, width - cx | |
| else: # No rotation | |
| return camera_params | |
| adjusted_params = [new_fx, new_fy, new_cx, new_cy] | |
| if len(camera_params) > 4: | |
| adjusted_params.extend(camera_params[4:]) | |
| return adjusted_params | |
| def adjust_pose_for_rotation(pose, k): | |
| """ | |
| Adjust camera pose for rotation. | |
| Args: | |
| pose: 4x4 camera pose matrix (camera-to-world, OpenCV convention - X right, Y down, Z forward) | |
| k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise) | |
| Returns: | |
| Adjusted 4x4 camera pose matrix | |
| """ | |
| # Create rotation matrices for different rotations | |
| if k % 4 == 1: # 90 degrees counter-clockwise | |
| rot_transform = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]]) | |
| elif k % 4 == 2: # 180 degrees | |
| rot_transform = np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]]) | |
| elif k % 4 == 3: # 90 degrees clockwise (270 counter-clockwise) | |
| rot_transform = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) | |
| else: # No rotation | |
| return pose | |
| # Apply the transformation to the pose | |
| adjusted_pose = pose | |
| adjusted_pose[:3, :3] = adjusted_pose[:3, :3] @ rot_transform.T | |
| return adjusted_pose | |
| def crop_to_aspect_ratio(image, depth, camera_params, target_ratio=1.5): | |
| """ | |
| Crop image and depth to the largest possible target aspect ratio while | |
| keeping the left side if aspect ratio is wider and the bottom of image if the aspect ratio is taller. | |
| Args: | |
| image: PIL image | |
| depth: Depth map as numpy array | |
| camera_params: Camera parameters [fx, fy, cx, cy, ...] | |
| target_ratio: Target width/height ratio | |
| Returns: | |
| Cropped image, cropped depth, adjusted camera parameters | |
| """ | |
| width, height = image.size | |
| fx, fy, cx, cy = camera_params[:4] | |
| current_ratio = width / height | |
| if abs(current_ratio - target_ratio) < 1e-6: | |
| # Already at target ratio | |
| return image, depth, camera_params | |
| if current_ratio > target_ratio: | |
| # Image is wider than target ratio, crop width | |
| new_width = int(height * target_ratio) | |
| left = 0 | |
| right = new_width | |
| # Crop image | |
| cropped_image = image.crop((left, 0, right, height)) | |
| # Crop depth | |
| if len(depth.shape) == 3: | |
| cropped_depth = depth[:, left:right, :] | |
| else: | |
| cropped_depth = depth[:, left:right] | |
| # Adjust camera parameters | |
| new_cx = cx - left | |
| adjusted_params = [fx, fy, new_cx, cy] + list(camera_params[4:]) | |
| else: | |
| # Image is taller than target ratio, crop height | |
| new_height = int(width / target_ratio) | |
| top = max(0, height - new_height) | |
| bottom = height | |
| # Crop image | |
| cropped_image = image.crop((0, top, width, bottom)) | |
| # Crop depth | |
| if len(depth.shape) == 3: | |
| cropped_depth = depth[top:bottom, :, :] | |
| else: | |
| cropped_depth = depth[top:bottom, :] | |
| # Adjust camera parameters | |
| new_cy = cy - top | |
| adjusted_params = [fx, fy, cx, new_cy] + list(camera_params[4:]) | |
| return cropped_image, cropped_depth, adjusted_params | |
| def colmap_to_opencv_intrinsics(K): | |
| """ | |
| Modify camera intrinsics to follow a different convention. | |
| Coordinates of the center of the top-left pixels are by default: | |
| - (0.5, 0.5) in Colmap | |
| - (0,0) in OpenCV | |
| """ | |
| K = K.copy() | |
| K[0, 2] -= 0.5 | |
| K[1, 2] -= 0.5 | |
| return K | |
| def opencv_to_colmap_intrinsics(K): | |
| """ | |
| Modify camera intrinsics to follow a different convention. | |
| Coordinates of the center of the top-left pixels are by default: | |
| - (0.5, 0.5) in Colmap | |
| - (0,0) in OpenCV | |
| """ | |
| K = K.copy() | |
| K[0, 2] += 0.5 | |
| K[1, 2] += 0.5 | |
| return K | |
| def normalize_depth_using_non_zero_pixels(depth, return_norm_factor=False): | |
| """ | |
| Normalize the depth by the average depth of non-zero depth pixels. | |
| Args: | |
| depth (torch.Tensor): Depth tensor of size [B, H, W, 1]. | |
| Returns: | |
| normalized_depth (torch.Tensor): Normalized depth tensor. | |
| norm_factor (torch.Tensor): Norm factor tensor of size B. | |
| """ | |
| assert depth.ndim == 4 and depth.shape[3] == 1 | |
| # Calculate the sum and count of non-zero depth pixels for each batch | |
| valid_depth_mask = depth > 0 | |
| valid_sum = torch.sum(depth * valid_depth_mask, dim=(1, 2, 3)) | |
| valid_count = torch.sum(valid_depth_mask, dim=(1, 2, 3)) | |
| # Calculate the norm factor | |
| norm_factor = valid_sum / (valid_count + 1e-8) | |
| while norm_factor.ndim < depth.ndim: | |
| norm_factor.unsqueeze_(-1) | |
| # Normalize the depth by the norm factor | |
| norm_factor = norm_factor.clip(min=1e-8) | |
| normalized_depth = depth / norm_factor | |
| # Create the output tuple | |
| output = ( | |
| (normalized_depth, norm_factor.squeeze(-1).squeeze(-1).squeeze(-1)) | |
| if return_norm_factor | |
| else normalized_depth | |
| ) | |
| return output | |
| def normalize_pose_translations(pose_translations, return_norm_factor=False): | |
| """ | |
| Normalize the pose translations by the average norm of the non-zero pose translations. | |
| Args: | |
| pose_translations (torch.Tensor): Pose translations tensor of size [B, V, 3]. B is the batch size, V is the number of views. | |
| Returns: | |
| normalized_pose_translations (torch.Tensor): Normalized pose translations tensor of size [B, V, 3]. | |
| norm_factor (torch.Tensor): Norm factor tensor of size B. | |
| """ | |
| assert pose_translations.ndim == 3 and pose_translations.shape[2] == 3 | |
| # Compute distance of all pose translations to origin | |
| pose_translations_dis = pose_translations.norm(dim=-1) # [B, V] | |
| non_zero_pose_translations_dis = pose_translations_dis > 0 # [B, V] | |
| # Calculate the average norm of the translations across all views (considering only views with non-zero translations) | |
| sum_of_all_views_pose_translations = pose_translations_dis.sum(dim=1) # [B] | |
| count_of_all_views_with_non_zero_pose_translations = ( | |
| non_zero_pose_translations_dis.sum(dim=1) | |
| ) # [B] | |
| norm_factor = sum_of_all_views_pose_translations / ( | |
| count_of_all_views_with_non_zero_pose_translations + 1e-8 | |
| ) # [B] | |
| # Normalize the pose translations by the norm factor | |
| norm_factor = norm_factor.clip(min=1e-8) | |
| normalized_pose_translations = pose_translations / norm_factor.unsqueeze( | |
| -1 | |
| ).unsqueeze(-1) | |
| # Create the output tuple | |
| output = ( | |
| (normalized_pose_translations, norm_factor) | |
| if return_norm_factor | |
| else normalized_pose_translations | |
| ) | |
| return output | |
| def normalize_multiple_pointclouds( | |
| pts_list, valid_masks=None, norm_mode="avg_dis", ret_factor=False | |
| ): | |
| """ | |
| Normalize multiple point clouds using a joint normalization strategy. | |
| Args: | |
| pts_list: List of point clouds, each with shape (..., H, W, 3) or (B, H, W, 3) | |
| valid_masks: Optional list of masks indicating valid points in each point cloud | |
| norm_mode: String in format "{norm}_{dis}" where: | |
| - norm: Normalization strategy (currently only "avg" is supported) | |
| - dis: Distance transformation ("dis" for raw distance, "log1p" for log(1+distance), | |
| "warp-log1p" to warp points using log distance) | |
| ret_factor: If True, return the normalization factor as the last element in the result list | |
| Returns: | |
| List of normalized point clouds with the same shapes as inputs. | |
| If ret_factor is True, the last element is the normalization factor. | |
| """ | |
| assert all(pts.ndim >= 3 and pts.shape[-1] == 3 for pts in pts_list) | |
| if valid_masks is not None: | |
| assert len(pts_list) == len(valid_masks) | |
| norm_mode, dis_mode = norm_mode.split("_") | |
| # Gather all points together (joint normalization) | |
| nan_pts_list = [ | |
| invalid_to_zeros(pts, valid_masks[i], ndim=3) | |
| if valid_masks | |
| else invalid_to_zeros(pts, None, ndim=3) | |
| for i, pts in enumerate(pts_list) | |
| ] | |
| all_pts = torch.cat([nan_pts for nan_pts, _ in nan_pts_list], dim=1) | |
| nnz_list = [nnz for _, nnz in nan_pts_list] | |
| # Compute distance to origin | |
| all_dis = all_pts.norm(dim=-1) | |
| if dis_mode == "dis": | |
| pass # do nothing | |
| elif dis_mode == "log1p": | |
| all_dis = torch.log1p(all_dis) | |
| elif dis_mode == "warp-log1p": | |
| # Warp input points before normalizing them | |
| log_dis = torch.log1p(all_dis) | |
| warp_factor = log_dis / all_dis.clip(min=1e-8) | |
| for i, pts in enumerate(pts_list): | |
| H, W = pts.shape[1:-1] | |
| pts_list[i] = pts * warp_factor[:, i * (H * W) : (i + 1) * (H * W)].view( | |
| -1, H, W, 1 | |
| ) | |
| all_dis = log_dis | |
| else: | |
| raise ValueError(f"bad {dis_mode=}") | |
| # Compute normalization factor | |
| norm_factor = all_dis.sum(dim=1) / (sum(nnz_list) + 1e-8) | |
| norm_factor = norm_factor.clip(min=1e-8) | |
| while norm_factor.ndim < pts_list[0].ndim: | |
| norm_factor.unsqueeze_(-1) | |
| # Normalize points | |
| res = [pts / norm_factor for pts in pts_list] | |
| if ret_factor: | |
| res.append(norm_factor) | |
| return res | |
| def apply_log_to_norm(input_data): | |
| """ | |
| Normalize the input data and apply a logarithmic transformation based on the normalization factor. | |
| Args: | |
| input_data (torch.Tensor): The input tensor to be normalized and transformed. | |
| Returns: | |
| torch.Tensor: The transformed tensor after normalization and logarithmic scaling. | |
| """ | |
| org_d = input_data.norm(dim=-1, keepdim=True) | |
| input_data = input_data / org_d.clip(min=1e-8) | |
| input_data = input_data * torch.log1p(org_d) | |
| return input_data | |
| def angle_diff_vec3(v1, v2, eps=1e-12): | |
| """ | |
| Compute angle difference between 3D vectors. | |
| Args: | |
| v1: torch.Tensor of shape (..., 3) | |
| v2: torch.Tensor of shape (..., 3) | |
| eps: Small epsilon value for numerical stability | |
| Returns: | |
| torch.Tensor: Angle differences in radians | |
| """ | |
| cross_norm = torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps | |
| dot_prod = (v1 * v2).sum(dim=-1) | |
| return torch.atan2(cross_norm, dot_prod) | |
| def angle_diff_vec3_numpy(v1: np.ndarray, v2: np.ndarray, eps: float = 1e-12): | |
| """ | |
| Compute angle difference between 3D vectors using NumPy. | |
| Args: | |
| v1 (np.ndarray): First vector of shape (..., 3) | |
| v2 (np.ndarray): Second vector of shape (..., 3) | |
| eps (float, optional): Small epsilon value for numerical stability. Defaults to 1e-12. | |
| Returns: | |
| np.ndarray: Angle differences in radians | |
| """ | |
| return np.arctan2( | |
| np.linalg.norm(np.cross(v1, v2, axis=-1), axis=-1) + eps, (v1 * v2).sum(axis=-1) | |
| ) | |
| def points_to_normals( | |
| point: np.ndarray, mask: np.ndarray = None, edge_threshold: float = None | |
| ) -> np.ndarray: | |
| """ | |
| Calculate normal map from point map. Value range is [-1, 1]. | |
| Args: | |
| point (np.ndarray): shape (height, width, 3), point map | |
| mask (optional, np.ndarray): shape (height, width), dtype=bool. Mask of valid depth pixels. Defaults to None. | |
| edge_threshold (optional, float): threshold for the angle (in degrees) between the normal and the view direction. Defaults to None. | |
| Returns: | |
| normal (np.ndarray): shape (height, width, 3), normal map. | |
| """ | |
| height, width = point.shape[-3:-1] | |
| has_mask = mask is not None | |
| if mask is None: | |
| mask = np.ones_like(point[..., 0], dtype=bool) | |
| mask_pad = np.zeros((height + 2, width + 2), dtype=bool) | |
| mask_pad[1:-1, 1:-1] = mask | |
| mask = mask_pad | |
| pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype) | |
| pts[1:-1, 1:-1, :] = point | |
| up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :] | |
| left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :] | |
| down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :] | |
| right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :] | |
| normal = np.stack( | |
| [ | |
| np.cross(up, left, axis=-1), | |
| np.cross(left, down, axis=-1), | |
| np.cross(down, right, axis=-1), | |
| np.cross(right, up, axis=-1), | |
| ] | |
| ) | |
| normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) | |
| valid = ( | |
| np.stack( | |
| [ | |
| mask[:-2, 1:-1] & mask[1:-1, :-2], | |
| mask[1:-1, :-2] & mask[2:, 1:-1], | |
| mask[2:, 1:-1] & mask[1:-1, 2:], | |
| mask[1:-1, 2:] & mask[:-2, 1:-1], | |
| ] | |
| ) | |
| & mask[None, 1:-1, 1:-1] | |
| ) | |
| if edge_threshold is not None: | |
| view_angle = angle_diff_vec3_numpy(pts[None, 1:-1, 1:-1, :], normal) | |
| view_angle = np.minimum(view_angle, np.pi - view_angle) | |
| valid = valid & (view_angle < np.deg2rad(edge_threshold)) | |
| normal = (normal * valid[..., None]).sum(axis=0) | |
| normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) | |
| if has_mask: | |
| normal_mask = valid.any(axis=0) | |
| normal = np.where(normal_mask[..., None], normal, 0) | |
| return normal, normal_mask | |
| else: | |
| return normal | |
| def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1): | |
| """ | |
| Create a sliding window view of the input array along a specified axis. | |
| This function creates a memory-efficient view of the input array with sliding windows | |
| of the specified size and stride. The window dimension is appended to the end of the | |
| output array's shape. This is useful for operations like convolution, pooling, or | |
| any analysis that requires examining local neighborhoods in the data. | |
| Args: | |
| x (np.ndarray): Input array with shape (..., axis_size, ...) | |
| window_size (int): Size of the sliding window | |
| stride (int): Stride of the sliding window (step size between consecutive windows) | |
| axis (int, optional): Axis to perform sliding window over. Defaults to -1 (last axis) | |
| Returns: | |
| np.ndarray: View of the input array with shape (..., n_windows, ..., window_size), | |
| where n_windows = (axis_size - window_size + 1) // stride | |
| Raises: | |
| AssertionError: If window_size is larger than the size of the specified axis | |
| Example: | |
| >>> x = np.array([1, 2, 3, 4, 5, 6]) | |
| >>> sliding_window_1d(x, window_size=3, stride=2) | |
| array([[1, 2, 3], | |
| [3, 4, 5]]) | |
| """ | |
| assert x.shape[axis] >= window_size, ( | |
| f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})" | |
| ) | |
| axis = axis % x.ndim | |
| shape = ( | |
| *x.shape[:axis], | |
| (x.shape[axis] - window_size + 1) // stride, | |
| *x.shape[axis + 1 :], | |
| window_size, | |
| ) | |
| strides = ( | |
| *x.strides[:axis], | |
| stride * x.strides[axis], | |
| *x.strides[axis + 1 :], | |
| x.strides[axis], | |
| ) | |
| x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) | |
| return x_sliding | |
| def sliding_window_nd( | |
| x: np.ndarray, | |
| window_size: Tuple[int, ...], | |
| stride: Tuple[int, ...], | |
| axis: Tuple[int, ...], | |
| ) -> np.ndarray: | |
| """ | |
| Create sliding windows along multiple dimensions of the input array. | |
| This function applies sliding_window_1d sequentially along multiple axes to create | |
| N-dimensional sliding windows. This is useful for operations that need to examine | |
| local neighborhoods in multiple dimensions simultaneously. | |
| Args: | |
| x (np.ndarray): Input array | |
| window_size (Tuple[int, ...]): Size of the sliding window for each axis | |
| stride (Tuple[int, ...]): Stride of the sliding window for each axis | |
| axis (Tuple[int, ...]): Axes to perform sliding window over | |
| Returns: | |
| np.ndarray: Array with sliding windows along the specified dimensions. | |
| The window dimensions are appended to the end of the shape. | |
| Note: | |
| The length of window_size, stride, and axis tuples must be equal. | |
| Example: | |
| >>> x = np.random.rand(10, 10) | |
| >>> windows = sliding_window_nd(x, window_size=(3, 3), stride=(2, 2), axis=(-2, -1)) | |
| >>> # Creates 3x3 sliding windows with stride 2 in both dimensions | |
| """ | |
| axis = [axis[i] % x.ndim for i in range(len(axis))] | |
| for i in range(len(axis)): | |
| x = sliding_window_1d(x, window_size[i], stride[i], axis[i]) | |
| return x | |
| def sliding_window_2d( | |
| x: np.ndarray, | |
| window_size: Union[int, Tuple[int, int]], | |
| stride: Union[int, Tuple[int, int]], | |
| axis: Tuple[int, int] = (-2, -1), | |
| ) -> np.ndarray: | |
| """ | |
| Create 2D sliding windows over the input array. | |
| Convenience function for creating 2D sliding windows, commonly used for image | |
| processing operations like convolution, pooling, or patch extraction. | |
| Args: | |
| x (np.ndarray): Input array | |
| window_size (Union[int, Tuple[int, int]]): Size of the 2D sliding window. | |
| If int, same size is used for both dimensions. | |
| stride (Union[int, Tuple[int, int]]): Stride of the 2D sliding window. | |
| If int, same stride is used for both dimensions. | |
| axis (Tuple[int, int], optional): Two axes to perform sliding window over. | |
| Defaults to (-2, -1) (last two dimensions). | |
| Returns: | |
| np.ndarray: Array with 2D sliding windows. The window dimensions (height, width) | |
| are appended to the end of the shape. | |
| Example: | |
| >>> image = np.random.rand(100, 100) | |
| >>> patches = sliding_window_2d(image, window_size=8, stride=4) | |
| >>> # Creates 8x8 patches with stride 4 from the image | |
| """ | |
| if isinstance(window_size, int): | |
| window_size = (window_size, window_size) | |
| if isinstance(stride, int): | |
| stride = (stride, stride) | |
| return sliding_window_nd(x, window_size, stride, axis) | |
| def max_pool_1d( | |
| x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1 | |
| ): | |
| """ | |
| Perform 1D max pooling on the input array. | |
| Max pooling reduces the dimensionality of the input by taking the maximum value | |
| within each sliding window. This is commonly used in neural networks and signal | |
| processing for downsampling and feature extraction. | |
| Args: | |
| x (np.ndarray): Input array | |
| kernel_size (int): Size of the pooling kernel | |
| stride (int): Stride of the pooling operation | |
| padding (int, optional): Amount of padding to add on both sides. Defaults to 0. | |
| axis (int, optional): Axis to perform max pooling over. Defaults to -1. | |
| Returns: | |
| np.ndarray: Max pooled array with reduced size along the specified axis | |
| Note: | |
| - For floating point arrays, padding is done with np.nan values | |
| - For integer arrays, padding is done with the minimum value of the dtype | |
| - np.nanmax is used to handle NaN values in the computation | |
| Example: | |
| >>> x = np.array([1, 3, 2, 4, 5, 1, 2]) | |
| >>> max_pool_1d(x, kernel_size=3, stride=2) | |
| array([3, 5, 2]) | |
| """ | |
| axis = axis % x.ndim | |
| if padding > 0: | |
| fill_value = np.nan if x.dtype.kind == "f" else np.iinfo(x.dtype).min | |
| padding_arr = np.full( | |
| (*x.shape[:axis], padding, *x.shape[axis + 1 :]), | |
| fill_value=fill_value, | |
| dtype=x.dtype, | |
| ) | |
| x = np.concatenate([padding_arr, x, padding_arr], axis=axis) | |
| a_sliding = sliding_window_1d(x, kernel_size, stride, axis) | |
| max_pool = np.nanmax(a_sliding, axis=-1) | |
| return max_pool | |
| def max_pool_nd( | |
| x: np.ndarray, | |
| kernel_size: Tuple[int, ...], | |
| stride: Tuple[int, ...], | |
| padding: Tuple[int, ...], | |
| axis: Tuple[int, ...], | |
| ) -> np.ndarray: | |
| """ | |
| Perform N-dimensional max pooling on the input array. | |
| This function applies max_pool_1d sequentially along multiple axes to perform | |
| multi-dimensional max pooling. This is useful for downsampling multi-dimensional | |
| data while preserving the most important features. | |
| Args: | |
| x (np.ndarray): Input array | |
| kernel_size (Tuple[int, ...]): Size of the pooling kernel for each axis | |
| stride (Tuple[int, ...]): Stride of the pooling operation for each axis | |
| padding (Tuple[int, ...]): Amount of padding for each axis | |
| axis (Tuple[int, ...]): Axes to perform max pooling over | |
| Returns: | |
| np.ndarray: Max pooled array with reduced size along the specified axes | |
| Note: | |
| The length of kernel_size, stride, padding, and axis tuples must be equal. | |
| Max pooling is applied sequentially along each axis in the order specified. | |
| Example: | |
| >>> x = np.random.rand(10, 10, 10) | |
| >>> pooled = max_pool_nd(x, kernel_size=(2, 2, 2), stride=(2, 2, 2), | |
| ... padding=(0, 0, 0), axis=(-3, -2, -1)) | |
| >>> # Reduces each dimension by half with 2x2x2 max pooling | |
| """ | |
| for i in range(len(axis)): | |
| x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i]) | |
| return x | |
| def max_pool_2d( | |
| x: np.ndarray, | |
| kernel_size: Union[int, Tuple[int, int]], | |
| stride: Union[int, Tuple[int, int]], | |
| padding: Union[int, Tuple[int, int]], | |
| axis: Tuple[int, int] = (-2, -1), | |
| ): | |
| """ | |
| Perform 2D max pooling on the input array. | |
| Convenience function for 2D max pooling, commonly used in computer vision | |
| and image processing for downsampling images while preserving important features. | |
| Args: | |
| x (np.ndarray): Input array | |
| kernel_size (Union[int, Tuple[int, int]]): Size of the 2D pooling kernel. | |
| If int, same size is used for both dimensions. | |
| stride (Union[int, Tuple[int, int]]): Stride of the 2D pooling operation. | |
| If int, same stride is used for both dimensions. | |
| padding (Union[int, Tuple[int, int]]): Amount of padding for both dimensions. | |
| If int, same padding is used for both dimensions. | |
| axis (Tuple[int, int], optional): Two axes to perform max pooling over. | |
| Defaults to (-2, -1) (last two dimensions). | |
| Returns: | |
| np.ndarray: 2D max pooled array with reduced size along the specified axes | |
| Example: | |
| >>> image = np.random.rand(64, 64) | |
| >>> pooled = max_pool_2d(image, kernel_size=2, stride=2, padding=0) | |
| >>> # Reduces image size from 64x64 to 32x32 with 2x2 max pooling | |
| """ | |
| if isinstance(kernel_size, Number): | |
| kernel_size = (kernel_size, kernel_size) | |
| if isinstance(stride, Number): | |
| stride = (stride, stride) | |
| if isinstance(padding, Number): | |
| padding = (padding, padding) | |
| axis = tuple(axis) | |
| return max_pool_nd(x, kernel_size, stride, padding, axis) | |
| def depth_edge( | |
| depth: np.ndarray, | |
| atol: float = None, | |
| rtol: float = None, | |
| kernel_size: int = 3, | |
| mask: np.ndarray = None, | |
| ) -> np.ndarray: | |
| """ | |
| Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth. | |
| Args: | |
| depth (np.ndarray): shape (..., height, width), linear depth map | |
| atol (float): absolute tolerance | |
| rtol (float): relative tolerance | |
| Returns: | |
| edge (np.ndarray): shape (..., height, width) of dtype torch.bool | |
| """ | |
| if mask is None: | |
| diff = max_pool_2d( | |
| depth, kernel_size, stride=1, padding=kernel_size // 2 | |
| ) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) | |
| else: | |
| diff = max_pool_2d( | |
| np.where(mask, depth, -np.inf), | |
| kernel_size, | |
| stride=1, | |
| padding=kernel_size // 2, | |
| ) + max_pool_2d( | |
| np.where(mask, -depth, -np.inf), | |
| kernel_size, | |
| stride=1, | |
| padding=kernel_size // 2, | |
| ) | |
| edge = np.zeros_like(depth, dtype=bool) | |
| if atol is not None: | |
| edge |= diff > atol | |
| if rtol is not None: | |
| edge |= diff / depth > rtol | |
| return edge | |
| def depth_aliasing( | |
| depth: np.ndarray, | |
| atol: float = None, | |
| rtol: float = None, | |
| kernel_size: int = 3, | |
| mask: np.ndarray = None, | |
| ) -> np.ndarray: | |
| """ | |
| Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. | |
| Args: | |
| depth (np.ndarray): shape (..., height, width), linear depth map | |
| atol (float): absolute tolerance | |
| rtol (float): relative tolerance | |
| Returns: | |
| edge (np.ndarray): shape (..., height, width) of dtype torch.bool | |
| """ | |
| if mask is None: | |
| diff_max = ( | |
| max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth | |
| ) | |
| diff_min = ( | |
| max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth | |
| ) | |
| else: | |
| diff_max = ( | |
| max_pool_2d( | |
| np.where(mask, depth, -np.inf), | |
| kernel_size, | |
| stride=1, | |
| padding=kernel_size // 2, | |
| ) | |
| - depth | |
| ) | |
| diff_min = ( | |
| max_pool_2d( | |
| np.where(mask, -depth, -np.inf), | |
| kernel_size, | |
| stride=1, | |
| padding=kernel_size // 2, | |
| ) | |
| + depth | |
| ) | |
| diff = np.minimum(diff_max, diff_min) | |
| edge = np.zeros_like(depth, dtype=bool) | |
| if atol is not None: | |
| edge |= diff > atol | |
| if rtol is not None: | |
| edge |= diff / depth > rtol | |
| return edge | |
| def normals_edge( | |
| normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None | |
| ) -> np.ndarray: | |
| """ | |
| Compute the edge mask from normal map. | |
| Args: | |
| normal (np.ndarray): shape (..., height, width, 3), normal map | |
| tol (float): tolerance in degrees | |
| Returns: | |
| edge (np.ndarray): shape (..., height, width) of dtype torch.bool | |
| """ | |
| assert normals.ndim >= 3 and normals.shape[-1] == 3, ( | |
| "normal should be of shape (..., height, width, 3)" | |
| ) | |
| normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12) | |
| padding = kernel_size // 2 | |
| normals_window = sliding_window_2d( | |
| np.pad( | |
| normals, | |
| ( | |
| *([(0, 0)] * (normals.ndim - 3)), | |
| (padding, padding), | |
| (padding, padding), | |
| (0, 0), | |
| ), | |
| mode="edge", | |
| ), | |
| window_size=kernel_size, | |
| stride=1, | |
| axis=(-3, -2), | |
| ) | |
| if mask is None: | |
| angle_diff = np.arccos( | |
| (normals[..., None, None] * normals_window).sum(axis=-3) | |
| ).max(axis=(-2, -1)) | |
| else: | |
| mask_window = sliding_window_2d( | |
| np.pad( | |
| mask, | |
| (*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)), | |
| mode="edge", | |
| ), | |
| window_size=kernel_size, | |
| stride=1, | |
| axis=(-3, -2), | |
| ) | |
| angle_diff = np.where( | |
| mask_window, | |
| np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)), | |
| 0, | |
| ).max(axis=(-2, -1)) | |
| angle_diff = max_pool_2d( | |
| angle_diff, kernel_size, stride=1, padding=kernel_size // 2 | |
| ) | |
| edge = angle_diff > np.deg2rad(tol) | |
| return edge | |