Spaces:
Build error
Build error
| from __future__ import annotations | |
| from typing import Optional | |
| import matplotlib | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from plyfile import PlyData, PlyElement | |
| def signed_log1p_inverse(x): | |
| """ | |
| Computes the inverse of signed_log1p: x = sign(x) * (exp(abs(x)) - 1). | |
| Args: | |
| y (torch.Tensor): Input tensor (output of signed_log1p). | |
| Returns: | |
| torch.Tensor: Original tensor x. | |
| """ | |
| if isinstance(x, torch.Tensor): | |
| return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) | |
| elif isinstance(x, np.ndarray): | |
| return np.sign(x) * (np.exp(np.abs(x)) - 1) | |
| else: | |
| raise TypeError("Input must be a torch.Tensor or numpy.ndarray") | |
| def colorize_depth(depth, cmap="Spectral"): | |
| min_d, max_d = (depth[depth > 0]).min(), (depth[depth > 0]).max() | |
| depth = (max_d - depth) / (max_d - min_d) | |
| cm = matplotlib.colormaps[cmap] | |
| depth = depth.clip(0, 1) | |
| depth = cm(depth, bytes=False)[..., 0:3] | |
| return depth | |
| def save_ply(pointmap, image, output_file, downsample=20, mask=None): | |
| _, h, w, _ = pointmap.shape | |
| image = image[:, :h, :w] | |
| pointmap = pointmap[:, :h, :w] | |
| points = pointmap.reshape(-1, 3) # (H*W, 3) | |
| colors = image.reshape(-1, 3) # (H*W, 3) | |
| if mask is not None: | |
| points = points[mask.reshape(-1)] | |
| colors = colors[mask.reshape(-1)] | |
| indices = np.random.choice( | |
| colors.shape[0], int(colors.shape[0] / downsample), replace=False | |
| ) | |
| points = points[indices] | |
| colors = colors[indices] | |
| vertices = [] | |
| for p, c in zip(points, colors): | |
| vertex = (p[0], p[1], p[2], int(c[0]), int(c[1]), int(c[2])) | |
| vertices.append(vertex) | |
| vertex_dtype = np.dtype( | |
| [ | |
| ("x", "f4"), | |
| ("y", "f4"), | |
| ("z", "f4"), | |
| ("red", "u1"), | |
| ("green", "u1"), | |
| ("blue", "u1"), | |
| ] | |
| ) | |
| vertex_array = np.array(vertices, dtype=vertex_dtype) | |
| ply_element = PlyElement.describe(vertex_array, "vertex") | |
| PlyData([ply_element], text=True).write(output_file) | |
| def fov_to_focal(fovx, fovy, h, w): | |
| focal_x = w * 0.5 / np.tan(fovx) | |
| focal_y = h * 0.5 / np.tan(fovy) | |
| focal = (focal_x + focal_y) / 2 | |
| return focal | |
| def get_rays(pose, h, w, focal=None, fovx=None, fovy=None): | |
| import torch.nn.functional as F | |
| pose = torch.from_numpy(pose).float() | |
| x, y = torch.meshgrid( | |
| torch.arange(w), | |
| torch.arange(h), | |
| indexing="xy", | |
| ) | |
| x = x.flatten().unsqueeze(0).repeat(pose.shape[0], 1) | |
| y = y.flatten().unsqueeze(0).repeat(pose.shape[0], 1) | |
| cx = w * 0.5 | |
| cy = h * 0.5 | |
| intrinsics, focal = get_intrinsics(pose.shape[0], h, w, fovx, fovy, focal) | |
| focal = torch.from_numpy(focal).float() | |
| camera_dirs = F.pad( | |
| torch.stack( | |
| [ | |
| (x - cx + 0.5) / focal.unsqueeze(-1), | |
| (y - cy + 0.5) / focal.unsqueeze(-1), | |
| ], | |
| dim=-1, | |
| ), | |
| (0, 1), | |
| value=1.0, | |
| ) # [t, hw, 3] | |
| pose = pose.to(dtype=camera_dirs.dtype) | |
| rays_d = camera_dirs @ pose[:, :3, :3].transpose(1, 2) # [t, hw, 3] | |
| rays_o = pose[:, :3, 3].unsqueeze(1).expand_as(rays_d) # [hw, 3] | |
| rays_o = rays_o.view(pose.shape[0], h, w, 3) | |
| rays_d = rays_d.view(pose.shape[0], h, w, 3) | |
| return rays_o.float().numpy(), rays_d.float().numpy(), intrinsics | |
| def get_intrinsics(batch_size, h, w, fovx=None, fovy=None, focal=None): | |
| if focal is None: | |
| focal_x = w * 0.5 / np.tan(fovx) | |
| focal_y = h * 0.5 / np.tan(fovy) | |
| focal = (focal_x + focal_y) / 2 | |
| cx = w * 0.5 | |
| cy = h * 0.5 | |
| intrinsics = np.zeros((batch_size, 3, 3)) | |
| intrinsics[:, 0, 0] = focal | |
| intrinsics[:, 1, 1] = focal | |
| intrinsics[:, 0, 2] = cx | |
| intrinsics[:, 1, 2] = cy | |
| intrinsics[:, 2, 2] = 1.0 | |
| return intrinsics, focal | |
| def save_pointmap( | |
| rgb, | |
| disparity, | |
| raymap, | |
| save_file, | |
| vae_downsample_scale=8, | |
| camera_pose=None, | |
| ray_o_scale_inv=1.0, | |
| max_depth=1e2, | |
| save_full_pcd_videos=False, | |
| smooth_camera=False, | |
| smooth_method="kalman", # or simple | |
| **kwargs, | |
| ): | |
| """ | |
| Args: | |
| rgb (numpy.ndarray): Shape of (t, h, w, 3), range [0, 1] | |
| disparity (numpy.ndarray): Shape of (t, h, w), range [0, 1] | |
| raymap (numpy.ndarray): Shape of (t, 6, h // 8, w // 8) | |
| ray_o_scale_inv (float, optional): A `ray_o` scale constant. Defaults to 10. | |
| """ | |
| rgb = np.clip(rgb, 0, 1) * 255 | |
| pointmap_dict = postprocess_pointmap( | |
| disparity, | |
| raymap, | |
| vae_downsample_scale, | |
| camera_pose, | |
| ray_o_scale_inv=ray_o_scale_inv, | |
| smooth_camera=smooth_camera, | |
| smooth_method=smooth_method, | |
| **kwargs, | |
| ) | |
| save_ply( | |
| pointmap_dict["pointmap"], | |
| rgb, | |
| save_file, | |
| mask=(pointmap_dict["depth"] < max_depth), | |
| ) | |
| if save_full_pcd_videos: | |
| pcd_dict = { | |
| "points": pointmap_dict["pointmap"], | |
| "colors": rgb, | |
| "intrinsics": pointmap_dict["intrinsics"], | |
| "poses": pointmap_dict["camera_pose"], | |
| "depths": pointmap_dict["depth"], | |
| } | |
| np.save(save_file.replace(".ply", "_pcd.npy"), pcd_dict) | |
| return pointmap_dict | |
| def raymap_to_poses( | |
| raymap, camera_pose=None, ray_o_scale_inv=1.0, return_intrinsics=True | |
| ): | |
| ts = raymap.shape[0] | |
| if (not return_intrinsics) and (camera_pose is not None): | |
| return camera_pose, None, None | |
| raymap[:, 3:] = signed_log1p_inverse(raymap[:, 3:]) | |
| # Extract ray origins and directions | |
| ray_o = ( | |
| rearrange(raymap[:, 3:], "t c h w -> t h w c") * ray_o_scale_inv | |
| ) # [T, H, W, C] | |
| ray_d = rearrange(raymap[:, :3], "t c h w -> t h w c") # [T, H, W, C] | |
| # Compute orientation and directions | |
| orient = ray_o.reshape(ts, -1, 3).mean(axis=1) # T, 3 | |
| image_orient = (ray_o + ray_d).reshape(ts, -1, 3).mean(axis=1) # T, 3 | |
| Focal = np.linalg.norm(image_orient - orient, axis=-1) # T, | |
| Z_Dir = image_orient - orient # T, 3 | |
| # Compute the width (W) and field of view (FoV_x) | |
| W_Left = ray_d[:, :, :1, :].reshape(ts, -1, 3).mean(axis=1) | |
| W_Right = ray_d[:, :, -1:, :].reshape(ts, -1, 3).mean(axis=1) | |
| W = W_Right - W_Left | |
| W_real = ( | |
| np.linalg.norm(np.cross(W, Z_Dir), axis=-1) | |
| / (raymap.shape[-1] - 1) | |
| * raymap.shape[-1] | |
| ) | |
| Fov_x = np.arctan(W_real / (2 * Focal)) | |
| # Compute the height (H) and field of view (FoV_y) | |
| H_Up = ray_d[:, :1, :, :].reshape(ts, -1, 3).mean(axis=1) | |
| H_Down = ray_d[:, -1:, :, :].reshape(ts, -1, 3).mean(axis=1) | |
| H = H_Up - H_Down | |
| H_real = ( | |
| np.linalg.norm(np.cross(H, Z_Dir), axis=-1) | |
| / (raymap.shape[-2] - 1) | |
| * raymap.shape[-2] | |
| ) | |
| Fov_y = np.arctan(H_real / (2 * Focal)) | |
| # Compute X, Y, and Z directions for the camera | |
| X_Dir = W_Right - W_Left | |
| Y_Dir = np.cross(Z_Dir, X_Dir) | |
| X_Dir = np.cross(Y_Dir, Z_Dir) | |
| X_Dir /= np.linalg.norm(X_Dir, axis=-1, keepdims=True) | |
| Y_Dir /= np.linalg.norm(Y_Dir, axis=-1, keepdims=True) | |
| Z_Dir /= np.linalg.norm(Z_Dir, axis=-1, keepdims=True) | |
| # Create the camera-to-world (camera_pose) transformation matrix | |
| if camera_pose is None: | |
| camera_pose = np.zeros((ts, 4, 4)) | |
| camera_pose[:, :3, 0] = X_Dir | |
| camera_pose[:, :3, 1] = Y_Dir | |
| camera_pose[:, :3, 2] = Z_Dir | |
| camera_pose[:, :3, 3] = orient | |
| camera_pose[:, 3, 3] = 1.0 | |
| return camera_pose, Fov_x, Fov_y | |
| def postprocess_pointmap( | |
| disparity, | |
| raymap, | |
| vae_downsample_scale=8, | |
| camera_pose=None, | |
| focal=None, | |
| ray_o_scale_inv=1.0, | |
| smooth_camera=False, | |
| smooth_method="simple", | |
| **kwargs, | |
| ): | |
| """ | |
| Args: | |
| disparity (numpy.ndarray): Shape of (t, h, w), range [0, 1] | |
| raymap (numpy.ndarray): Shape of (t, 6, h // 8, w // 8) | |
| ray_o_scale_inv (float, optional): A `ray_o` scale constant. Defaults to 10. | |
| """ | |
| depth = np.clip(1.0 / np.clip(disparity, 1e-3, 1), 0, 1e8) | |
| camera_pose, fov_x, fov_y = raymap_to_poses( | |
| raymap, | |
| camera_pose=camera_pose, | |
| ray_o_scale_inv=ray_o_scale_inv, | |
| return_intrinsics=(focal is not None), | |
| ) | |
| if focal is None: | |
| focal = fov_to_focal( | |
| fov_x, | |
| fov_y, | |
| int(raymap.shape[2] * vae_downsample_scale), | |
| int(raymap.shape[3] * vae_downsample_scale), | |
| ) | |
| if smooth_camera: | |
| # Check if sequence is static | |
| is_static, trans_diff, rot_diff = detect_static_sequence(camera_pose) | |
| if is_static: | |
| print( | |
| f"Detected static/near-static sequence (trans_diff={trans_diff:.6f}, rot_diff={rot_diff:.6f})" | |
| ) | |
| # Apply stronger smoothing for static sequences | |
| camera_pose = adaptive_pose_smoothing(camera_pose, trans_diff, rot_diff) | |
| else: | |
| if smooth_method == "simple": | |
| camera_pose = smooth_poses( | |
| camera_pose, window_size=5, method="gaussian" | |
| ) | |
| elif smooth_method == "kalman": | |
| camera_pose = smooth_trajectory(camera_pose, window_size=5) | |
| ray_o, ray_d, intrinsics = get_rays( | |
| camera_pose, | |
| int(raymap.shape[2] * vae_downsample_scale), | |
| int(raymap.shape[3] * vae_downsample_scale), | |
| focal, | |
| ) | |
| pointmap = depth[..., None] * ray_d + ray_o | |
| return { | |
| "pointmap": pointmap, | |
| "camera_pose": camera_pose, | |
| "intrinsics": intrinsics, | |
| "ray_o": ray_o, | |
| "ray_d": ray_d, | |
| "depth": depth, | |
| } | |
| def detect_static_sequence(poses, threshold=0.01): | |
| """Detect if the camera sequence is static based on pose differences.""" | |
| translations = poses[:, :3, 3] | |
| rotations = poses[:, :3, :3] | |
| # Compute translation differences | |
| trans_diff = np.linalg.norm(translations[1:] - translations[:-1], axis=1).mean() | |
| # Compute rotation differences (using matrix frobenius norm) | |
| rot_diff = np.linalg.norm(rotations[1:] - rotations[:-1], axis=(1, 2)).mean() | |
| return trans_diff < threshold and rot_diff < threshold, trans_diff, rot_diff | |
| def adaptive_pose_smoothing(poses, trans_diff, rot_diff, base_window=5): | |
| """Apply adaptive smoothing based on motion magnitude.""" | |
| # Increase window size for low motion sequences | |
| motion_magnitude = trans_diff + rot_diff | |
| adaptive_window = min( | |
| 41, max(base_window, int(base_window * (0.1 / max(motion_magnitude, 1e-6)))) | |
| ) | |
| # Apply stronger smoothing for low motion | |
| poses_smooth = smooth_poses(poses, window_size=adaptive_window, method="gaussian") | |
| return poses_smooth | |
| def get_pixel(H, W): | |
| # get 2D pixels (u, v) for image_a in cam_a pixel space | |
| u_a, v_a = np.meshgrid(np.arange(W), np.arange(H)) | |
| # u_a = np.flip(u_a, axis=1) | |
| # v_a = np.flip(v_a, axis=0) | |
| pixels_a = np.stack( | |
| [u_a.flatten() + 0.5, v_a.flatten() + 0.5, np.ones_like(u_a.flatten())], axis=0 | |
| ) | |
| return pixels_a | |
| def project(depth, intrinsic, pose): | |
| H, W = depth.shape | |
| pixel = get_pixel(H, W).astype(np.float32) | |
| points = (np.linalg.inv(intrinsic) @ pixel) * depth.reshape(-1) | |
| points = pose[:3, :4] @ np.concatenate( | |
| [points, np.ones((1, points.shape[1]))], axis=0 | |
| ) | |
| points = points.T.reshape(H, W, 3) | |
| return points | |
| def depth_edge( | |
| depth: torch.Tensor, | |
| atol: float = None, | |
| rtol: float = None, | |
| kernel_size: int = 3, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> torch.BoolTensor: | |
| """ | |
| Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth. | |
| Args: | |
| depth (torch.Tensor): shape (..., height, width), linear depth map | |
| atol (float): absolute tolerance | |
| rtol (float): relative tolerance | |
| Returns: | |
| edge (torch.Tensor): shape (..., height, width) of dtype torch.bool | |
| """ | |
| is_numpy = isinstance(depth, np.ndarray) | |
| if is_numpy: | |
| depth = torch.from_numpy(depth) | |
| if isinstance(mask, np.ndarray): | |
| mask = torch.from_numpy(mask) | |
| shape = depth.shape | |
| depth = depth.reshape(-1, 1, *shape[-2:]) | |
| if mask is not None: | |
| mask = mask.reshape(-1, 1, *shape[-2:]) | |
| if mask is None: | |
| diff = F.max_pool2d( | |
| depth, kernel_size, stride=1, padding=kernel_size // 2 | |
| ) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) | |
| else: | |
| diff = F.max_pool2d( | |
| torch.where(mask, depth, -torch.inf), | |
| kernel_size, | |
| stride=1, | |
| padding=kernel_size // 2, | |
| ) + F.max_pool2d( | |
| torch.where(mask, -depth, -torch.inf), | |
| kernel_size, | |
| stride=1, | |
| padding=kernel_size // 2, | |
| ) | |
| edge = torch.zeros_like(depth, dtype=torch.bool) | |
| if atol is not None: | |
| edge |= diff > atol | |
| if rtol is not None: | |
| edge |= (diff / depth).nan_to_num_() > rtol | |
| edge = edge.reshape(*shape) | |
| if is_numpy: | |
| return edge.numpy() | |
| return edge | |
| def align_rigid( | |
| p, | |
| q, | |
| weights, | |
| ): | |
| """Compute a rigid transformation that, when applied to p, minimizes the weighted | |
| squared distance between transformed points in p and points in q. See "Least-Squares | |
| Rigid Motion Using SVD" by Olga Sorkine-Hornung and Michael Rabinovich for more | |
| details (https://igl.ethz.ch/projects/ARAP/svd_rot.pdf). | |
| """ | |
| device = p.device | |
| dtype = p.dtype | |
| batch, _, _ = p.shape | |
| # 1. Compute the centroids of both point sets. | |
| weights_normalized = weights / (weights.sum(dim=-1, keepdim=True) + 1e-8) | |
| p_centroid = (weights_normalized[..., None] * p).sum(dim=-2) | |
| q_centroid = (weights_normalized[..., None] * q).sum(dim=-2) | |
| # 2. Compute the centered vectors. | |
| p_centered = p - p_centroid[..., None, :] | |
| q_centered = q - q_centroid[..., None, :] | |
| # 3. Compute the 3x3 covariance matrix. | |
| covariance = (q_centered * weights[..., None]).transpose(-1, -2) @ p_centered | |
| # 4. Compute the singular value decomposition and then the rotation. | |
| u, _, vt = torch.linalg.svd(covariance) | |
| s = torch.eye(3, dtype=dtype, device=device) | |
| s = s.expand((batch, 3, 3)).contiguous() | |
| s[..., 2, 2] = (u.det() * vt.det()).sign() | |
| rotation = u @ s @ vt | |
| # 5. Compute the optimal scale | |
| scale = ( | |
| (torch.einsum("b i j, b k j -> b k i", rotation, p_centered) * q_centered).sum( | |
| -1 | |
| ) | |
| * weights | |
| ).sum(-1) / ((p_centered**2).sum(-1) * weights).sum(-1) | |
| # scale = (torch.einsum("b i j, b k j -> b k i", rotation, p_centered) * q_centered).sum([-1, -2]) / (p_centered**2).sum([-1, -2]) | |
| # 6. Compute the optimal translation. | |
| translation = q_centroid - torch.einsum( | |
| "b i j, b j -> b i", rotation, p_centroid * scale[:, None] | |
| ) | |
| return rotation, translation, scale | |
| def align_camera_extrinsics( | |
| cameras_src: torch.Tensor, # Bx3x4 tensor representing [R | t] | |
| cameras_tgt: torch.Tensor, # Bx3x4 tensor representing [R | t] | |
| estimate_scale: bool = True, | |
| eps: float = 1e-9, | |
| ): | |
| """ | |
| Align the source camera extrinsics to the target camera extrinsics. | |
| NOTE Assume OPENCV convention | |
| Args: | |
| cameras_src (torch.Tensor): Bx3x4 tensor representing [R | t] for source cameras. | |
| cameras_tgt (torch.Tensor): Bx3x4 tensor representing [R | t] for target cameras. | |
| estimate_scale (bool, optional): Whether to estimate the scale factor. Default is True. | |
| eps (float, optional): Small value to avoid division by zero. Default is 1e-9. | |
| Returns: | |
| align_t_R (torch.Tensor): 1x3x3 rotation matrix for alignment. | |
| align_t_T (torch.Tensor): 1x3 translation vector for alignment. | |
| align_t_s (float): Scaling factor for alignment. | |
| """ | |
| R_src = cameras_src[:, :, :3] # Extracting the rotation matrices from [R | t] | |
| R_tgt = cameras_tgt[:, :, :3] # Extracting the rotation matrices from [R | t] | |
| RRcov = torch.bmm(R_tgt.transpose(2, 1), R_src).mean(0) | |
| U, _, V = torch.svd(RRcov) | |
| align_t_R = V @ U.t() | |
| T_src = cameras_src[:, :, 3] # Extracting the translation vectors from [R | t] | |
| T_tgt = cameras_tgt[:, :, 3] # Extracting the translation vectors from [R | t] | |
| A = torch.bmm(T_src[:, None], R_src)[:, 0] | |
| B = torch.bmm(T_tgt[:, None], R_src)[:, 0] | |
| Amu = A.mean(0, keepdim=True) | |
| Bmu = B.mean(0, keepdim=True) | |
| if estimate_scale and A.shape[0] > 1: | |
| # get the scaling component by matching covariances | |
| # of centered A and centered B | |
| Ac = A - Amu | |
| Bc = B - Bmu | |
| align_t_s = (Ac * Bc).mean() / (Ac**2).mean().clamp(eps) | |
| else: | |
| # set the scale to identity | |
| align_t_s = 1.0 | |
| # get the translation as the difference between the means of A and B | |
| align_t_T = Bmu - align_t_s * Amu | |
| align_t_R = align_t_R[None] | |
| return align_t_R, align_t_T, align_t_s | |
| def apply_transformation( | |
| cameras_src: torch.Tensor, # Bx3x4 tensor representing [R | t] | |
| align_t_R: torch.Tensor, # 1x3x3 rotation matrix | |
| align_t_T: torch.Tensor, # 1x3 translation vector | |
| align_t_s: float, # Scaling factor | |
| return_extri: bool = True, | |
| ) -> torch.Tensor: | |
| """ | |
| Align and transform the source cameras using the provided rotation, translation, and scaling factors. | |
| NOTE Assume OPENCV convention | |
| Args: | |
| cameras_src (torch.Tensor): Bx3x4 tensor representing [R | t] for source cameras. | |
| align_t_R (torch.Tensor): 1x3x3 rotation matrix for alignment. | |
| align_t_T (torch.Tensor): 1x3 translation vector for alignment. | |
| align_t_s (float): Scaling factor for alignment. | |
| Returns: | |
| aligned_R (torch.Tensor): Bx3x3 tensor representing the aligned rotation matrices. | |
| aligned_T (torch.Tensor): Bx3 tensor representing the aligned translation vectors. | |
| """ | |
| R_src = cameras_src[:, :, :3] | |
| T_src = cameras_src[:, :, 3] | |
| aligned_R = torch.bmm(R_src, align_t_R.expand(R_src.shape[0], 3, 3)) | |
| # Apply the translation alignment to the source translations | |
| align_t_T_expanded = align_t_T[..., None].repeat(R_src.shape[0], 1, 1) | |
| transformed_T = torch.bmm(R_src, align_t_T_expanded)[..., 0] | |
| aligned_T = transformed_T + T_src * align_t_s | |
| if return_extri: | |
| extri = torch.cat([aligned_R, aligned_T.unsqueeze(-1)], dim=-1) | |
| return extri | |
| return aligned_R, aligned_T | |
| def slerp(q1, q2, t): | |
| """Spherical Linear Interpolation between quaternions. | |
| Args: | |
| q1: (4,) first quaternion | |
| q2: (4,) second quaternion | |
| t: float between 0 and 1 | |
| Returns: | |
| (4,) interpolated quaternion | |
| """ | |
| # Compute the cosine of the angle between the two vectors | |
| dot = np.sum(q1 * q2) | |
| # If the dot product is negative, slerp won't take the shorter path | |
| # Fix by negating one of the input quaternions | |
| if dot < 0.0: | |
| q2 = -q2 | |
| dot = -dot | |
| # Threshold for using linear interpolation instead of spherical | |
| DOT_THRESHOLD = 0.9995 | |
| if dot > DOT_THRESHOLD: | |
| # If the inputs are too close for comfort, linearly interpolate | |
| # and normalize the result | |
| result = q1 + t * (q2 - q1) | |
| return result / np.linalg.norm(result) | |
| # Compute the angle between the quaternions | |
| theta_0 = np.arccos(dot) | |
| sin_theta_0 = np.sin(theta_0) | |
| # Compute interpolation factors | |
| theta = theta_0 * t | |
| sin_theta = np.sin(theta) | |
| s0 = np.cos(theta) - dot * sin_theta / sin_theta_0 | |
| s1 = sin_theta / sin_theta_0 | |
| return (s0 * q1) + (s1 * q2) | |
| def interpolate_poses(pose1, pose2, weight): | |
| """Interpolate between two camera poses with weight. | |
| Args: | |
| pose1: (4, 4) first camera pose | |
| pose2: (4, 4) second camera pose | |
| weight: float between 0 and 1, weight for pose1 (1-weight for pose2) | |
| Returns: | |
| (4, 4) interpolated pose | |
| """ | |
| from scipy.spatial.transform import Rotation as R | |
| # Extract rotations and translations | |
| R1 = R.from_matrix(pose1[:3, :3]) | |
| R2 = R.from_matrix(pose2[:3, :3]) | |
| t1 = pose1[:3, 3] | |
| t2 = pose2[:3, 3] | |
| # Get quaternions | |
| q1 = R1.as_quat() | |
| q2 = R2.as_quat() | |
| # Interpolate rotation using our slerp implementation | |
| q_interp = slerp(q1, q2, 1 - weight) # 1-weight because weight is for pose1 | |
| R_interp = R.from_quat(q_interp) | |
| # Linear interpolation for translation | |
| t_interp = weight * t1 + (1 - weight) * t2 | |
| # Construct interpolated pose | |
| pose_interp = np.eye(4) | |
| pose_interp[:3, :3] = R_interp.as_matrix() | |
| pose_interp[:3, 3] = t_interp | |
| return pose_interp | |
| def smooth_poses(poses, window_size=5, method="gaussian"): | |
| """Smooth camera poses temporally. | |
| Args: | |
| poses: (N, 4, 4) camera poses | |
| window_size: int, must be odd number | |
| method: str, 'gaussian' or 'savgol' or 'ma' | |
| Returns: | |
| (N, 4, 4) smoothed poses | |
| """ | |
| from scipy.ndimage import gaussian_filter1d | |
| from scipy.signal import savgol_filter | |
| from scipy.spatial.transform import Rotation as R | |
| assert window_size % 2 == 1, "window_size must be odd" | |
| N = poses.shape[0] | |
| smoothed = np.zeros_like(poses) | |
| # Extract translations and quaternions | |
| translations = poses[:, :3, 3] | |
| rotations = R.from_matrix(poses[:, :3, :3]) | |
| quats = rotations.as_quat() # (N, 4) | |
| # Ensure consistent quaternion signs to prevent interpolation artifacts | |
| for i in range(1, N): | |
| if np.dot(quats[i], quats[i - 1]) < 0: | |
| quats[i] = -quats[i] | |
| # Smooth translations | |
| if method == "gaussian": | |
| sigma = window_size / 6.0 # approximately 99.7% of the weight within the window | |
| smoothed_trans = gaussian_filter1d(translations, sigma, axis=0, mode="nearest") | |
| smoothed_quats = gaussian_filter1d(quats, sigma, axis=0, mode="nearest") | |
| elif method == "savgol": | |
| # Savitzky-Golay filter: polynomial fitting | |
| poly_order = min(window_size - 1, 3) | |
| smoothed_trans = savgol_filter( | |
| translations, window_size, poly_order, axis=0, mode="nearest" | |
| ) | |
| smoothed_quats = savgol_filter( | |
| quats, window_size, poly_order, axis=0, mode="nearest" | |
| ) | |
| elif method == "ma": | |
| # Simple moving average | |
| kernel = np.ones(window_size) / window_size | |
| smoothed_trans = np.array( | |
| [np.convolve(translations[:, i], kernel, mode="same") for i in range(3)] | |
| ).T | |
| smoothed_quats = np.array( | |
| [np.convolve(quats[:, i], kernel, mode="same") for i in range(4)] | |
| ).T | |
| # Normalize quaternions | |
| smoothed_quats /= np.linalg.norm(smoothed_quats, axis=1, keepdims=True) | |
| # Reconstruct poses | |
| smoothed_rots = R.from_quat(smoothed_quats).as_matrix() | |
| for i in range(N): | |
| smoothed[i] = np.eye(4) | |
| smoothed[i, :3, :3] = smoothed_rots[i] | |
| smoothed[i, :3, 3] = smoothed_trans[i] | |
| return smoothed | |
| def smooth_trajectory(poses, window_size=5): | |
| """Smooth camera trajectory using Kalman filter. | |
| Args: | |
| poses: (N, 4, 4) camera poses | |
| window_size: int, window size for initial smoothing | |
| Returns: | |
| (N, 4, 4) smoothed poses | |
| """ | |
| from filterpy.kalman import KalmanFilter | |
| from scipy.spatial.transform import Rotation as R | |
| N = poses.shape[0] | |
| # Initialize Kalman filter for position and velocity | |
| kf = KalmanFilter(dim_x=6, dim_z=3) # 3D position and velocity | |
| dt = 1.0 # assume uniform time steps | |
| # State transition matrix | |
| kf.F = np.array( | |
| [ | |
| [1, 0, 0, dt, 0, 0], | |
| [0, 1, 0, 0, dt, 0], | |
| [0, 0, 1, 0, 0, dt], | |
| [0, 0, 0, 1, 0, 0], | |
| [0, 0, 0, 0, 1, 0], | |
| [0, 0, 0, 0, 0, 1], | |
| ] | |
| ) | |
| # Measurement matrix | |
| kf.H = np.array([[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0]]) | |
| # Measurement noise | |
| kf.R *= 0.1 | |
| # Process noise | |
| kf.Q *= 0.1 | |
| # Initial state uncertainty | |
| kf.P *= 1.0 | |
| # Extract translations and rotations | |
| translations = poses[:, :3, 3] | |
| rotations = R.from_matrix(poses[:, :3, :3]) | |
| quats = rotations.as_quat() | |
| # First pass: simple smoothing for initial estimates | |
| smoothed = smooth_poses(poses, window_size, method="gaussian") | |
| smooth_trans = smoothed[:, :3, 3] | |
| # Second pass: Kalman filter for trajectory | |
| filtered_trans = np.zeros_like(translations) | |
| kf.x = np.zeros(6) | |
| kf.x[:3] = smooth_trans[0] | |
| filtered_trans[0] = smooth_trans[0] | |
| # Forward pass | |
| for i in range(1, N): | |
| kf.predict() | |
| kf.update(smooth_trans[i]) | |
| filtered_trans[i] = kf.x[:3] | |
| # Backward smoothing for rotations using SLERP | |
| window_half = window_size // 2 | |
| smoothed_quats = np.zeros_like(quats) | |
| for i in range(N): | |
| start_idx = max(0, i - window_half) | |
| end_idx = min(N, i + window_half + 1) | |
| weights = np.exp( | |
| -0.5 * ((np.arange(start_idx, end_idx) - i) / (window_half / 2)) ** 2 | |
| ) | |
| weights /= weights.sum() | |
| # Weighted average of nearby quaternions | |
| avg_quat = np.zeros(4) | |
| for j, w in zip(range(start_idx, end_idx), weights): | |
| if np.dot(quats[j], quats[i]) < 0: | |
| avg_quat += w * -quats[j] | |
| else: | |
| avg_quat += w * quats[j] | |
| smoothed_quats[i] = avg_quat / np.linalg.norm(avg_quat) | |
| # Reconstruct final smoothed poses | |
| final_smoothed = np.zeros_like(poses) | |
| smoothed_rots = R.from_quat(smoothed_quats).as_matrix() | |
| for i in range(N): | |
| final_smoothed[i] = np.eye(4) | |
| final_smoothed[i, :3, :3] = smoothed_rots[i] | |
| final_smoothed[i, :3, 3] = filtered_trans[i] | |
| return final_smoothed | |
| def compute_scale(prediction, target, mask): | |
| if isinstance(prediction, np.ndarray): | |
| prediction = torch.from_numpy(prediction).float() | |
| if isinstance(target, np.ndarray): | |
| target = torch.from_numpy(target).float() | |
| if isinstance(mask, np.ndarray): | |
| mask = torch.from_numpy(mask).bool() | |
| numerator = torch.sum(mask * prediction * target, (1, 2)) | |
| denominator = torch.sum(mask * prediction * prediction, (1, 2)) | |
| scale = torch.zeros_like(numerator) | |
| valid = (denominator != 0).nonzero() | |
| scale[valid] = numerator[valid] / denominator[valid] | |
| return scale.item() | |