Spaces:
Running
on
Zero
Running
on
Zero
| import sys | |
| sys.path.append('../') | |
| sys.path.append("../submodules") | |
| sys.path.append('../submodules/RoMa') | |
| from matplotlib import pyplot as plt | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| #from tqdm import tqdm_notebook as tqdm | |
| from tqdm import tqdm | |
| from scipy.cluster.vq import kmeans, vq | |
| from scipy.spatial.distance import cdist | |
| import torch.nn.functional as F | |
| from romatch import roma_outdoor, roma_indoor | |
| from utils.sh_utils import RGB2SH | |
| from romatch.utils import get_tuple_transform_ops | |
| def pairwise_distances(matrix): | |
| """ | |
| Computes the pairwise Euclidean distances between all vectors in the input matrix. | |
| Args: | |
| matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality. | |
| Returns: | |
| torch.Tensor: Pairwise distance matrix of shape [N, N]. | |
| """ | |
| # Compute squared pairwise distances | |
| squared_diff = torch.cdist(matrix, matrix, p=2) | |
| return squared_diff | |
| def k_closest_vectors(matrix, k): | |
| """ | |
| Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance. | |
| Args: | |
| matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality. | |
| k (int): Number of closest vectors to return for each vector. | |
| Returns: | |
| torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself. | |
| """ | |
| # Compute pairwise distances | |
| distances = pairwise_distances(matrix) | |
| # For each vector, sort distances and get the indices of the k-closest vectors (excluding itself) | |
| # Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors | |
| distances.fill_diagonal_(float('inf')) | |
| # Get the indices of the k smallest distances (k-closest vectors) | |
| _, indices = torch.topk(distances, k, largest=False, dim=1) | |
| return indices | |
| def select_cameras_kmeans(cameras, K): | |
| """ | |
| Selects K cameras from a set using K-means clustering. | |
| Args: | |
| cameras: NumPy array of shape (N, 16), representing N cameras with their 4x4 homogeneous matrices flattened. | |
| K: Number of clusters (cameras to select). | |
| Returns: | |
| selected_indices: List of indices of the cameras closest to the cluster centers. | |
| """ | |
| # Ensure input is a NumPy array | |
| if not isinstance(cameras, np.ndarray): | |
| cameras = np.asarray(cameras) | |
| if cameras.shape[1] != 16: | |
| raise ValueError("Each camera must have 16 values corresponding to a flattened 4x4 matrix.") | |
| # Perform K-means clustering | |
| cluster_centers, _ = kmeans(cameras, K) | |
| # Assign each camera to a cluster and find distances to cluster centers | |
| cluster_assignments, _ = vq(cameras, cluster_centers) | |
| # Find the camera nearest to each cluster center | |
| selected_indices = [] | |
| for k in range(K): | |
| cluster_members = cameras[cluster_assignments == k] | |
| distances = cdist([cluster_centers[k]], cluster_members)[0] | |
| nearest_camera_idx = np.where(cluster_assignments == k)[0][np.argmin(distances)] | |
| selected_indices.append(nearest_camera_idx) | |
| return selected_indices | |
| def compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, device="cuda", verbose=False, output_dict={}): | |
| """ | |
| Computes the warp and confidence between two viewpoint cameras using the roma_model. | |
| Args: | |
| viewpoint_cam1: Source viewpoint camera. | |
| viewpoint_cam2: Target viewpoint camera. | |
| roma_model: Pre-trained Roma model for correspondence matching. | |
| device: Device to run the computation on. | |
| verbose: If True, displays the images. | |
| Returns: | |
| certainty: Confidence tensor. | |
| warp: Warp tensor. | |
| imB: Processed image B as numpy array. | |
| """ | |
| # Prepare images | |
| imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0) | |
| imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0) | |
| imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8)) | |
| imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8)) | |
| if verbose: | |
| fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8)) | |
| cax1 = ax[0].imshow(imA) | |
| ax[0].set_title("Image 1") | |
| cax2 = ax[1].imshow(imB) | |
| ax[1].set_title("Image 2") | |
| fig.colorbar(cax1, ax=ax[0]) | |
| fig.colorbar(cax2, ax=ax[1]) | |
| for axis in ax: | |
| axis.axis('off') | |
| # Save the figure into the dictionary | |
| output_dict[f'image_pair'] = fig | |
| # Transform images | |
| ws, hs = roma_model.w_resized, roma_model.h_resized | |
| test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True) | |
| im_A, im_B = test_transform((imA, imB)) | |
| batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)} | |
| # Forward pass through Roma model | |
| corresps = roma_model.forward(batch) if not roma_model.symmetric else roma_model.forward_symmetric(batch) | |
| finest_scale = 1 | |
| hs, ws = roma_model.upsample_res if roma_model.upsample_preds else (hs, ws) | |
| # Process certainty and warp | |
| certainty = corresps[finest_scale]["certainty"] | |
| im_A_to_im_B = corresps[finest_scale]["flow"] | |
| if roma_model.attenuate_cert: | |
| low_res_certainty = F.interpolate( | |
| corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear" | |
| ) | |
| certainty -= 0.5 * low_res_certainty * (low_res_certainty < 0) | |
| # Upsample predictions if needed | |
| if roma_model.upsample_preds: | |
| im_A_to_im_B = F.interpolate( | |
| im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear" | |
| ) | |
| certainty = F.interpolate( | |
| certainty, size=(hs, ws), align_corners=False, mode="bilinear" | |
| ) | |
| # Convert predictions to final format | |
| im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1) | |
| im_A_coords = torch.stack(torch.meshgrid( | |
| torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), | |
| torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), | |
| indexing='ij' | |
| ), dim=0).permute(1, 2, 0).unsqueeze(0).expand(im_A_to_im_B.size(0), -1, -1, -1) | |
| warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1) | |
| certainty = certainty.sigmoid() | |
| return certainty[0, 0], warp[0], np.array(imB) | |
| def resize_batch(tensors_3d, tensors_4d, target_shape): | |
| """ | |
| Resizes a batch of tensors with shapes [B, H, W] and [B, H, W, 4] to the target spatial dimensions. | |
| Args: | |
| tensors_3d: Tensor of shape [B, H, W]. | |
| tensors_4d: Tensor of shape [B, H, W, 4]. | |
| target_shape: Tuple (target_H, target_W) specifying the target spatial dimensions. | |
| Returns: | |
| resized_tensors_3d: Tensor of shape [B, target_H, target_W]. | |
| resized_tensors_4d: Tensor of shape [B, target_H, target_W, 4]. | |
| """ | |
| target_H, target_W = target_shape | |
| # Resize [B, H, W] tensor | |
| resized_tensors_3d = F.interpolate( | |
| tensors_3d.unsqueeze(1), size=(target_H, target_W), mode="bilinear", align_corners=False | |
| ).squeeze(1) | |
| # Resize [B, H, W, 4] tensor | |
| B, _, _, C = tensors_4d.shape | |
| resized_tensors_4d = F.interpolate( | |
| tensors_4d.permute(0, 3, 1, 2), size=(target_H, target_W), mode="bilinear", align_corners=False | |
| ).permute(0, 2, 3, 1) | |
| return resized_tensors_3d, resized_tensors_4d | |
| def aggregate_confidences_and_warps(viewpoint_stack, closest_indices, roma_model, source_idx, verbose=False, output_dict={}): | |
| """ | |
| Aggregates confidences and warps by iterating over the nearest neighbors of the source viewpoint. | |
| Args: | |
| viewpoint_stack: Stack of viewpoint cameras. | |
| closest_indices: Indices of the nearest neighbors for each viewpoint. | |
| roma_model: Pre-trained Roma model. | |
| source_idx: Index of the source viewpoint. | |
| verbose: If True, displays intermediate results. | |
| Returns: | |
| certainties_max: Aggregated maximum confidences. | |
| warps_max: Aggregated warps corresponding to maximum confidences. | |
| certainties_max_idcs: Pixel-wise index of the image from which we taken the best matching. | |
| imB_compound: List of the neighboring images. | |
| """ | |
| certainties_all, warps_all, imB_compound = [], [], [] | |
| for nn in tqdm(closest_indices[source_idx]): | |
| viewpoint_cam1 = viewpoint_stack[source_idx] | |
| viewpoint_cam2 = viewpoint_stack[nn] | |
| certainty, warp, imB = compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, verbose=verbose, output_dict=output_dict) | |
| certainties_all.append(certainty) | |
| warps_all.append(warp) | |
| imB_compound.append(imB) | |
| certainties_all = torch.stack(certainties_all, dim=0) | |
| target_shape = imB_compound[0].shape[:2] | |
| if verbose: | |
| print("certainties_all.shape:", certainties_all.shape) | |
| print("torch.stack(warps_all, dim=0).shape:", torch.stack(warps_all, dim=0).shape) | |
| print("target_shape:", target_shape) | |
| certainties_all_resized, warps_all_resized = resize_batch(certainties_all, | |
| torch.stack(warps_all, dim=0), | |
| target_shape | |
| ) | |
| if verbose: | |
| print("warps_all_resized.shape:", warps_all_resized.shape) | |
| for n, cert in enumerate(certainties_all): | |
| fig, ax = plt.subplots() | |
| cax = ax.imshow(cert.cpu().numpy(), cmap='viridis') | |
| fig.colorbar(cax, ax=ax) | |
| ax.set_title("Pixel-wise Confidence") | |
| output_dict[f'certainty_{n}'] = fig | |
| for n, warp in enumerate(warps_all): | |
| fig, ax = plt.subplots() | |
| cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis') | |
| fig.colorbar(cax, ax=ax) | |
| ax.set_title("Pixel-wise warp") | |
| output_dict[f'warp_resized_{n}'] = fig | |
| for n, cert in enumerate(certainties_all_resized): | |
| fig, ax = plt.subplots() | |
| cax = ax.imshow(cert.cpu().numpy(), cmap='viridis') | |
| fig.colorbar(cax, ax=ax) | |
| ax.set_title("Pixel-wise Confidence resized") | |
| output_dict[f'certainty_resized_{n}'] = fig | |
| for n, warp in enumerate(warps_all_resized): | |
| fig, ax = plt.subplots() | |
| cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis') | |
| fig.colorbar(cax, ax=ax) | |
| ax.set_title("Pixel-wise warp resized") | |
| output_dict[f'warp_resized_{n}'] = fig | |
| certainties_max, certainties_max_idcs = torch.max(certainties_all_resized, dim=0) | |
| H, W = certainties_max.shape | |
| warps_max = warps_all_resized[certainties_max_idcs, torch.arange(H).unsqueeze(1), torch.arange(W)] | |
| return certainties_max, warps_max, certainties_max_idcs, imB_compound, certainties_all_resized, warps_all_resized | |
| def extract_keypoints_and_colors(imA, imB_compound, certainties_max, certainties_max_idcs, matches, roma_model, | |
| verbose=False, output_dict={}): | |
| """ | |
| Extracts keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound). | |
| Args: | |
| imA: Source image as a NumPy array (H_A, W_A, C). | |
| imB_compound: List of target images as NumPy arrays [(H_B, W_B, C), ...]. | |
| certainties_max: Tensor of pixel-wise maximum confidences. | |
| certainties_max_idcs: Tensor of pixel-wise indices for the best matches. | |
| matches: Matches in normalized coordinates. | |
| roma_model: Roma model instance for keypoint operations. | |
| verbose: if to show intermediate outputs and visualize results | |
| Returns: | |
| kptsA_np: Keypoints in imA in normalized coordinates. | |
| kptsB_np: Keypoints in imB in normalized coordinates. | |
| kptsA_color: Colors of keypoints in imA. | |
| kptsB_color: Colors of keypoints in imB based on certainties_max_idcs. | |
| """ | |
| H_A, W_A, _ = imA.shape | |
| H, W = certainties_max.shape | |
| # Convert matches to pixel coordinates | |
| kptsA, kptsB = roma_model.to_pixel_coordinates( | |
| matches, W_A, H_A, H, W # W, H | |
| ) | |
| kptsA_np = kptsA.detach().cpu().numpy() | |
| kptsB_np = kptsB.detach().cpu().numpy() | |
| kptsA_np = kptsA_np[:, [1, 0]] | |
| if verbose: | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| cax = ax.imshow(imA) | |
| ax.set_title("Reference image, imA") | |
| output_dict[f'reference_image'] = fig | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| cax = ax.imshow(imB_compound[0]) | |
| ax.set_title("Image to compare to image, imB_compound") | |
| output_dict[f'imB_compound'] = fig | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| cax = ax.imshow(np.flipud(imA)) | |
| cax = ax.scatter(kptsA_np[:, 0], H_A - kptsA_np[:, 1], s=.03) | |
| ax.set_title("Keypoints in imA") | |
| ax.set_xlim(0, W_A) | |
| ax.set_ylim(0, H_A) | |
| output_dict[f'kptsA'] = fig | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| cax = ax.imshow(np.flipud(imB_compound[0])) | |
| cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03) | |
| ax.set_title("Keypoints in imB") | |
| ax.set_xlim(0, W_A) | |
| ax.set_ylim(0, H_A) | |
| output_dict[f'kptsB'] = fig | |
| # Keypoints are in format (row, column) so the first value is alwain in range [0;height] and second is in range[0;width] | |
| kptsA_np = kptsA.detach().cpu().numpy() | |
| kptsB_np = kptsB.detach().cpu().numpy() | |
| # Extract colors for keypoints in imA (vectorized) | |
| # New experimental version | |
| kptsA_x = np.round(kptsA_np[:, 0] / 1.).astype(int) | |
| kptsA_y = np.round(kptsA_np[:, 1] / 1.).astype(int) | |
| kptsA_color = imA[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)] | |
| # Create a composite image from imB_compound | |
| imB_compound_np = np.stack(imB_compound, axis=0) | |
| H_B, W_B, _ = imB_compound[0].shape | |
| # Extract colors for keypoints in imB using certainties_max_idcs | |
| imB_np = imB_compound_np[ | |
| certainties_max_idcs.detach().cpu().numpy(), | |
| np.arange(H).reshape(-1, 1), | |
| np.arange(W) | |
| ] | |
| if verbose: | |
| print("imB_np.shape:", imB_np.shape) | |
| print("imB_np:", imB_np) | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| cax = ax.imshow(np.flipud(imB_np)) | |
| cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03) | |
| ax.set_title("np.flipud(imB_np[0]") | |
| ax.set_xlim(0, W_A) | |
| ax.set_ylim(0, H_A) | |
| output_dict[f'np.flipud(imB_np[0]'] = fig | |
| kptsB_x = np.round(kptsB_np[:, 0]).astype(int) | |
| kptsB_y = np.round(kptsB_np[:, 1]).astype(int) | |
| certainties_max_idcs_np = certainties_max_idcs.detach().cpu().numpy() | |
| kptsB_proj_matrices_idx = certainties_max_idcs_np[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)] | |
| kptsB_color = imB_compound_np[kptsB_proj_matrices_idx, np.clip(kptsB_y, 0, H - 1), np.clip(kptsB_x, 0, W - 1)] | |
| # Normalize keypoints in both images | |
| kptsA_np[:, 0] = kptsA_np[:, 0] / H * 2.0 - 1.0 | |
| kptsA_np[:, 1] = kptsA_np[:, 1] / W * 2.0 - 1.0 | |
| kptsB_np[:, 0] = kptsB_np[:, 0] / W_B * 2.0 - 1.0 | |
| kptsB_np[:, 1] = kptsB_np[:, 1] / H_B * 2.0 - 1.0 | |
| return kptsA_np[:, [1, 0]], kptsB_np, kptsB_proj_matrices_idx, kptsA_color, kptsB_color | |
| def prepare_tensor(input_array, device): | |
| """ | |
| Converts an input array to a torch tensor, clones it, and detaches it for safe computation. | |
| Args: | |
| input_array (array-like): The input array to convert. | |
| device (str or torch.device): The device to move the tensor to. | |
| Returns: | |
| torch.Tensor: A detached tensor clone of the input array on the specified device. | |
| """ | |
| if not isinstance(input_array, torch.Tensor): | |
| return torch.tensor(input_array, dtype=torch.float32).to(device).clone().detach() | |
| return input_array.clone().detach().to(device).to(torch.float32) | |
| def triangulate_points(P1, P2, k1_x, k1_y, k2_x, k2_y, device="cuda"): | |
| """ | |
| Solves for a batch of 3D points given batches of projection matrices and corresponding image points. | |
| Parameters: | |
| - P1, P2: Tensors of projection matrices of size (batch_size, 4, 4) or (4, 4) | |
| - k1_x, k1_y: Tensors of shape (batch_size,) | |
| - k2_x, k2_y: Tensors of shape (batch_size,) | |
| Returns: | |
| - X: A tensor containing the 3D homogeneous coordinates, shape (batch_size, 4) | |
| """ | |
| EPS = 1e-4 | |
| # Ensure inputs are tensors | |
| P1 = prepare_tensor(P1, device) | |
| P2 = prepare_tensor(P2, device) | |
| k1_x = prepare_tensor(k1_x, device) | |
| k1_y = prepare_tensor(k1_y, device) | |
| k2_x = prepare_tensor(k2_x, device) | |
| k2_y = prepare_tensor(k2_y, device) | |
| batch_size = k1_x.shape[0] | |
| # Expand P1 and P2 if they are not batched | |
| if P1.ndim == 2: | |
| P1 = P1.unsqueeze(0).expand(batch_size, -1, -1) | |
| if P2.ndim == 2: | |
| P2 = P2.unsqueeze(0).expand(batch_size, -1, -1) | |
| # Extract columns from P1 and P2 | |
| P1_0 = P1[:, :, 0] # Shape: (batch_size, 4) | |
| P1_1 = P1[:, :, 1] | |
| P1_2 = P1[:, :, 2] | |
| P2_0 = P2[:, :, 0] | |
| P2_1 = P2[:, :, 1] | |
| P2_2 = P2[:, :, 2] | |
| # Reshape kx and ky to (batch_size, 1) | |
| k1_x = k1_x.view(-1, 1) | |
| k1_y = k1_y.view(-1, 1) | |
| k2_x = k2_x.view(-1, 1) | |
| k2_y = k2_y.view(-1, 1) | |
| # Construct the equations for each batch | |
| # For camera 1 | |
| A1 = P1_0 - k1_x * P1_2 # Shape: (batch_size, 4) | |
| A2 = P1_1 - k1_y * P1_2 | |
| # For camera 2 | |
| A3 = P2_0 - k2_x * P2_2 | |
| A4 = P2_1 - k2_y * P2_2 | |
| # Stack the equations | |
| A = torch.stack([A1, A2, A3, A4], dim=1) # Shape: (batch_size, 4, 4) | |
| # Right-hand side (constants) | |
| b = -A[:, :, 3] # Shape: (batch_size, 4) | |
| A_reduced = A[:, :, :3] # Coefficients of x, y, z | |
| # Solve using torch.linalg.lstsq (supports batching) | |
| X_xyz = torch.linalg.lstsq(A_reduced, b.unsqueeze(2)).solution.squeeze(2) # Shape: (batch_size, 3) | |
| # Append 1 to get homogeneous coordinates | |
| ones = torch.ones((batch_size, 1), dtype=torch.float32, device=X_xyz.device) | |
| X = torch.cat([X_xyz, ones], dim=1) # Shape: (batch_size, 4) | |
| # Now compute the errors of projections. | |
| seeked_splats_proj1 = (X.unsqueeze(1) @ P1).squeeze(1) | |
| seeked_splats_proj1 = seeked_splats_proj1 / (EPS + seeked_splats_proj1[:, [3]]) | |
| seeked_splats_proj2 = (X.unsqueeze(1) @ P2).squeeze(1) | |
| seeked_splats_proj2 = seeked_splats_proj2 / (EPS + seeked_splats_proj2[:, [3]]) | |
| proj1_target = torch.concat([k1_x, k1_y], dim=1) | |
| proj2_target = torch.concat([k2_x, k2_y], dim=1) | |
| errors_proj1 = torch.abs(seeked_splats_proj1[:, :2] - proj1_target).sum(1).detach().cpu().numpy() | |
| errors_proj2 = torch.abs(seeked_splats_proj2[:, :2] - proj2_target).sum(1).detach().cpu().numpy() | |
| return X, errors_proj1, errors_proj2 | |
| def select_best_keypoints( | |
| NNs_triangulated_points, NNs_errors_proj1, NNs_errors_proj2, device="cuda"): | |
| """ | |
| From all the points fitted to keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound). | |
| Args: | |
| NNs_triangulated_points: torch tensor with keypoints coordinates (num_nns, num_points, dim). dim can be arbitrary, | |
| usually 3 or 4(for homogeneous representation). | |
| NNs_errors_proj1: numpy array with projection error of the estimated keypoint on the reference frame (num_nns, num_points). | |
| NNs_errors_proj2: numpy array with projection error of the estimated keypoint on the neighbor frame (num_nns, num_points). | |
| Returns: | |
| selected_keypoints: keypoints with the best score. | |
| """ | |
| NNs_errors_proj = np.maximum(NNs_errors_proj1, NNs_errors_proj2) | |
| # Convert indices to PyTorch tensor | |
| indices = torch.from_numpy(np.argmin(NNs_errors_proj, axis=0)).long().to(device) | |
| # Create index tensor for the second dimension | |
| n_indices = torch.arange(NNs_triangulated_points.shape[1]).long().to(device) | |
| # Use advanced indexing to select elements | |
| NNs_triangulated_points_selected = NNs_triangulated_points[indices, n_indices, :] # Shape: [N, k] | |
| return NNs_triangulated_points_selected, np.min(NNs_errors_proj, axis=0) | |
| import time | |
| from collections import defaultdict | |
| from tqdm import tqdm | |
| # def init_gaussians_with_corr_profiled(gaussians, scene, cfg, device, verbose=False, roma_model=None): | |
| # timings = defaultdict(list) # To accumulate timings | |
| # if roma_model is None: | |
| # if cfg.roma_model == "indoors": | |
| # roma_model = roma_indoor(device=device) | |
| # else: | |
| # roma_model = roma_outdoor(device=device) | |
| # roma_model.upsample_preds = False | |
| # roma_model.symmetric = False | |
| # M = cfg.matches_per_ref | |
| # upper_thresh = roma_model.sample_thresh | |
| # scaling_factor = cfg.scaling_factor | |
| # expansion_factor = 1 | |
| # keypoint_fit_error_tolerance = cfg.proj_err_tolerance | |
| # visualizations = {} | |
| # viewpoint_stack = scene.getTrainCameras().copy() | |
| # NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack)) | |
| # NUM_NNS_PER_REFERENCE = min(cfg.nns_per_ref, len(viewpoint_stack)) | |
| # viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0) | |
| # selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES) | |
| # selected_indices = sorted(selected_indices) | |
| # viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0) | |
| # closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE) | |
| # closest_indices_selected = closest_indices[:, :].detach().cpu().numpy() | |
| # all_new_xyz = [] | |
| # all_new_features_dc = [] | |
| # all_new_features_rest = [] | |
| # all_new_opacities = [] | |
| # all_new_scaling = [] | |
| # all_new_rotation = [] | |
| # # Dummy first pass to initialize model | |
| # with torch.no_grad(): | |
| # viewpoint_cam1 = viewpoint_stack[0] | |
| # viewpoint_cam2 = viewpoint_stack[1] | |
| # imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0) | |
| # imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0) | |
| # imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8)) | |
| # imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8)) | |
| # warp, certainty_warp = roma_model.match(imA, imB, device=device) | |
| # del warp, certainty_warp | |
| # torch.cuda.empty_cache() | |
| # # Main Loop over source_idx | |
| # for source_idx in tqdm(sorted(selected_indices), desc="Profiling source frames"): | |
| # # =================== Step 1: Aggregate Confidences and Warps =================== | |
| # start = time.time() | |
| # viewpoint_cam1 = viewpoint_stack[source_idx] | |
| # viewpoint_cam2 = viewpoint_stack[closest_indices_selected[source_idx,0]] | |
| # imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0) | |
| # imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0) | |
| # imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8)) | |
| # imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8)) | |
| # warp, certainty_warp = roma_model.match(imA, imB, device=device) | |
| # certainties_max, warps_max, certainties_max_idcs, imB_compound, certainties_all, warps_all = aggregate_confidences_and_warps( | |
| # viewpoint_stack=viewpoint_stack, | |
| # closest_indices=closest_indices_selected, | |
| # roma_model=roma_model, | |
| # source_idx=source_idx, | |
| # verbose=verbose, | |
| # output_dict=visualizations | |
| # ) | |
| # certainties_max = certainty_warp | |
| # with torch.no_grad(): | |
| # warps_all = warps.unsqueeze(0) | |
| # timings['aggregation_warp_certainty'].append(time.time() - start) | |
| # # =================== Step 2: Good Samples Selection =================== | |
| # start = time.time() | |
| # certainty = certainties_max.reshape(-1).clone() | |
| # certainty[certainty > upper_thresh] = 1 | |
| # good_samples = torch.multinomial(certainty, num_samples=min(expansion_factor * M, len(certainty)), replacement=False) | |
| # timings['good_samples_selection'].append(time.time() - start) | |
| # # =================== Step 3: Triangulate Keypoints for Each NN =================== | |
| # reference_image_dict = { | |
| # "triangulated_points": [], | |
| # "triangulated_points_errors_proj1": [], | |
| # "triangulated_points_errors_proj2": [] | |
| # } | |
| # start = time.time() | |
| # for NN_idx in range(len(warps_all)): | |
| # matches_NN = warps_all[NN_idx].reshape(-1, 4)[good_samples] | |
| # # Extract keypoints and colors | |
| # kptsA_np, kptsB_np, kptsB_proj_matrices_idcs, kptsA_color, kptsB_color = extract_keypoints_and_colors( | |
| # imA, imB_compound, certainties_max, certainties_max_idcs, matches_NN, roma_model | |
| # ) | |
| # proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform | |
| # proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, NN_idx]].full_proj_transform | |
| # triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points( | |
| # P1=torch.stack([proj_matrices_A] * M, axis=0), | |
| # P2=torch.stack([proj_matrices_B] * M, axis=0), | |
| # k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1], | |
| # k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1]) | |
| # reference_image_dict["triangulated_points"].append(triangulated_points) | |
| # reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1) | |
| # reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2) | |
| # timings['triangulation_per_NN'].append(time.time() - start) | |
| # # =================== Step 4: Select Best Triangulated Points =================== | |
| # start = time.time() | |
| # NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints( | |
| # NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0), | |
| # NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0), | |
| # NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0)) | |
| # timings['select_best_keypoints'].append(time.time() - start) | |
| # # =================== Step 5: Create New Gaussians =================== | |
| # start = time.time() | |
| # viewpoint_cam1 = viewpoint_stack[source_idx] | |
| # N = len(NNs_triangulated_points_selected) | |
| # new_xyz = NNs_triangulated_points_selected[:, :-1] | |
| # all_new_xyz.append(new_xyz) | |
| # all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1)) | |
| # all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0)) | |
| # mask_bad_points = torch.tensor( | |
| # NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance, | |
| # dtype=torch.float32).unsqueeze(1).to(device) | |
| # all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1)) | |
| # dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz, dim=1, ord=2) | |
| # all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3))) | |
| # all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0)) | |
| # timings['save_gaussians'].append(time.time() - start) | |
| # # =================== Final Densification Postfix =================== | |
| # start = time.time() | |
| # all_new_xyz = torch.cat(all_new_xyz, dim=0) | |
| # all_new_features_dc = torch.cat(all_new_features_dc, dim=0) | |
| # new_tmp_radii = torch.zeros(all_new_xyz.shape[0]) | |
| # prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool) | |
| # gaussians.densification_postfix( | |
| # all_new_xyz[prune_mask].to(device), | |
| # all_new_features_dc[prune_mask].to(device), | |
| # torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device), | |
| # torch.cat(all_new_opacities, dim=0)[prune_mask].to(device), | |
| # torch.cat(all_new_scaling, dim=0)[prune_mask].to(device), | |
| # torch.cat(all_new_rotation, dim=0)[prune_mask].to(device), | |
| # new_tmp_radii[prune_mask].to(device) | |
| # ) | |
| # timings['final_densification_postfix'].append(time.time() - start) | |
| # # =================== Print Profiling Results =================== | |
| # print("\n=== Profiling Summary (average per frame) ===") | |
| # for key, times in timings.items(): | |
| # print(f"{key:35s}: {sum(times) / len(times):.4f} sec (total {sum(times):.2f} sec)") | |
| # return viewpoint_stack, closest_indices_selected, visualizations | |
| def extract_keypoints_and_colors_single(imA, imB, matches, roma_model, verbose=False, output_dict={}): | |
| """ | |
| Extracts keypoints and corresponding colors from a source image (imA) and a single target image (imB). | |
| Args: | |
| imA: Source image as a NumPy array (H_A, W_A, C). | |
| imB: Target image as a NumPy array (H_B, W_B, C). | |
| matches: Matches in normalized coordinates (torch.Tensor). | |
| roma_model: Roma model instance for keypoint operations. | |
| verbose: If True, outputs intermediate visualizations. | |
| Returns: | |
| kptsA_np: Keypoints in imA (normalized). | |
| kptsB_np: Keypoints in imB (normalized). | |
| kptsA_color: Colors of keypoints in imA. | |
| kptsB_color: Colors of keypoints in imB. | |
| """ | |
| H_A, W_A, _ = imA.shape | |
| H_B, W_B, _ = imB.shape | |
| # Convert matches to pixel coordinates | |
| # Matches format: (B, 4) = (x1_norm, y1_norm, x2_norm, y2_norm) | |
| kptsA = matches[:, :2] # [N, 2] | |
| kptsB = matches[:, 2:] # [N, 2] | |
| # Scale normalized coordinates [-1,1] to pixel coordinates | |
| kptsA_pix = torch.zeros_like(kptsA) | |
| kptsB_pix = torch.zeros_like(kptsB) | |
| # Important! [Normalized to pixel space] | |
| kptsA_pix[:, 0] = (kptsA[:, 0] + 1) * (W_A - 1) / 2 | |
| kptsA_pix[:, 1] = (kptsA[:, 1] + 1) * (H_A - 1) / 2 | |
| kptsB_pix[:, 0] = (kptsB[:, 0] + 1) * (W_B - 1) / 2 | |
| kptsB_pix[:, 1] = (kptsB[:, 1] + 1) * (H_B - 1) / 2 | |
| kptsA_np = kptsA_pix.detach().cpu().numpy() | |
| kptsB_np = kptsB_pix.detach().cpu().numpy() | |
| # Extract colors | |
| kptsA_x = np.round(kptsA_np[:, 0]).astype(int) | |
| kptsA_y = np.round(kptsA_np[:, 1]).astype(int) | |
| kptsB_x = np.round(kptsB_np[:, 0]).astype(int) | |
| kptsB_y = np.round(kptsB_np[:, 1]).astype(int) | |
| kptsA_color = imA[np.clip(kptsA_y, 0, H_A-1), np.clip(kptsA_x, 0, W_A-1)] | |
| kptsB_color = imB[np.clip(kptsB_y, 0, H_B-1), np.clip(kptsB_x, 0, W_B-1)] | |
| # Normalize keypoints into [-1, 1] for downstream triangulation | |
| kptsA_np_norm = np.zeros_like(kptsA_np) | |
| kptsB_np_norm = np.zeros_like(kptsB_np) | |
| kptsA_np_norm[:, 0] = kptsA_np[:, 0] / (W_A - 1) * 2.0 - 1.0 | |
| kptsA_np_norm[:, 1] = kptsA_np[:, 1] / (H_A - 1) * 2.0 - 1.0 | |
| kptsB_np_norm[:, 0] = kptsB_np[:, 0] / (W_B - 1) * 2.0 - 1.0 | |
| kptsB_np_norm[:, 1] = kptsB_np[:, 1] / (H_B - 1) * 2.0 - 1.0 | |
| return kptsA_np_norm, kptsB_np_norm, kptsA_color, kptsB_color | |
| def init_gaussians_with_corr_profiled(gaussians, scene, cfg, device, verbose=False, roma_model=None): | |
| timings = defaultdict(list) | |
| if roma_model is None: | |
| if cfg.roma_model == "indoors": | |
| roma_model = roma_indoor(device=device) | |
| else: | |
| roma_model = roma_outdoor(device=device) | |
| roma_model.upsample_preds = False | |
| roma_model.symmetric = False | |
| M = cfg.matches_per_ref | |
| upper_thresh = roma_model.sample_thresh | |
| scaling_factor = cfg.scaling_factor | |
| expansion_factor = 1 | |
| keypoint_fit_error_tolerance = cfg.proj_err_tolerance | |
| visualizations = {} | |
| viewpoint_stack = scene.getTrainCameras().copy() | |
| NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack)) | |
| NUM_NNS_PER_REFERENCE = 1 # Only ONE neighbor now! | |
| viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0) | |
| selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES) | |
| selected_indices = sorted(selected_indices) | |
| viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0) | |
| closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE) | |
| closest_indices_selected = closest_indices[:, :].detach().cpu().numpy() | |
| all_new_xyz = [] | |
| all_new_features_dc = [] | |
| all_new_features_rest = [] | |
| all_new_opacities = [] | |
| all_new_scaling = [] | |
| all_new_rotation = [] | |
| # Dummy first pass to initialize model | |
| with torch.no_grad(): | |
| viewpoint_cam1 = viewpoint_stack[0] | |
| viewpoint_cam2 = viewpoint_stack[1] | |
| imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0) | |
| imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0) | |
| imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8)) | |
| imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8)) | |
| warp, certainty_warp = roma_model.match(imA, imB, device=device) | |
| del warp, certainty_warp | |
| torch.cuda.empty_cache() | |
| # Main Loop over source_idx | |
| for source_idx in tqdm(sorted(selected_indices), desc="Profiling source frames"): | |
| # =================== Step 1: Compute Warp and Certainty =================== | |
| start = time.time() | |
| viewpoint_cam1 = viewpoint_stack[source_idx] | |
| NNs=closest_indices_selected.shape[1] | |
| viewpoint_cam2 = viewpoint_stack[closest_indices_selected[source_idx, np.random.randint(NNs)]] | |
| imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0) | |
| imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0) | |
| imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8)) | |
| imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8)) | |
| warp, certainty_warp = roma_model.match(imA, imB, device=device) | |
| certainties_max = certainty_warp # New manual sampling | |
| timings['aggregation_warp_certainty'].append(time.time() - start) | |
| # =================== Step 2: Good Samples Selection =================== | |
| start = time.time() | |
| certainty = certainties_max.reshape(-1).clone() | |
| certainty[certainty > upper_thresh] = 1 | |
| good_samples = torch.multinomial(certainty, num_samples=min(expansion_factor * M, len(certainty)), replacement=False) | |
| timings['good_samples_selection'].append(time.time() - start) | |
| # =================== Step 3: Triangulate Keypoints =================== | |
| reference_image_dict = { | |
| "triangulated_points": [], | |
| "triangulated_points_errors_proj1": [], | |
| "triangulated_points_errors_proj2": [] | |
| } | |
| start = time.time() | |
| matches_NN = warp.reshape(-1, 4)[good_samples] | |
| # Convert matches to pixel coordinates | |
| kptsA_np, kptsB_np, kptsA_color, kptsB_color = extract_keypoints_and_colors_single( | |
| np.array(imA).astype(np.uint8), | |
| np.array(imB).astype(np.uint8), | |
| matches_NN, | |
| roma_model | |
| ) | |
| proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform | |
| proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, 0]].full_proj_transform | |
| triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points( | |
| P1=torch.stack([proj_matrices_A] * M, axis=0), | |
| P2=torch.stack([proj_matrices_B] * M, axis=0), | |
| k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1], | |
| k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1]) | |
| reference_image_dict["triangulated_points"].append(triangulated_points) | |
| reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1) | |
| reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2) | |
| timings['triangulation_per_NN'].append(time.time() - start) | |
| # =================== Step 4: Select Best Triangulated Points =================== | |
| start = time.time() | |
| NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints( | |
| NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0), | |
| NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0), | |
| NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0)) | |
| timings['select_best_keypoints'].append(time.time() - start) | |
| # =================== Step 5: Create New Gaussians =================== | |
| start = time.time() | |
| viewpoint_cam1 = viewpoint_stack[source_idx] | |
| N = len(NNs_triangulated_points_selected) | |
| new_xyz = NNs_triangulated_points_selected[:, :-1] | |
| all_new_xyz.append(new_xyz) | |
| all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1)) | |
| all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0)) | |
| mask_bad_points = torch.tensor( | |
| NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance, | |
| dtype=torch.float32).unsqueeze(1).to(device) | |
| all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1)) | |
| dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz, dim=1, ord=2) | |
| all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3))) | |
| all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0)) | |
| timings['save_gaussians'].append(time.time() - start) | |
| # =================== Final Densification Postfix =================== | |
| start = time.time() | |
| all_new_xyz = torch.cat(all_new_xyz, dim=0) | |
| all_new_features_dc = torch.cat(all_new_features_dc, dim=0) | |
| new_tmp_radii = torch.zeros(all_new_xyz.shape[0]) | |
| prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool) | |
| gaussians.densification_postfix( | |
| all_new_xyz[prune_mask].to(device), | |
| all_new_features_dc[prune_mask].to(device), | |
| torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device), | |
| torch.cat(all_new_opacities, dim=0)[prune_mask].to(device), | |
| torch.cat(all_new_scaling, dim=0)[prune_mask].to(device), | |
| torch.cat(all_new_rotation, dim=0)[prune_mask].to(device), | |
| new_tmp_radii[prune_mask].to(device) | |
| ) | |
| timings['final_densification_postfix'].append(time.time() - start) | |
| # =================== Print Profiling Results =================== | |
| print("\n=== Profiling Summary (average per frame) ===") | |
| for key, times in timings.items(): | |
| print(f"{key:35s}: {sum(times) / len(times):.4f} sec (total {sum(times):.2f} sec)") | |
| return viewpoint_stack, closest_indices_selected, visualizations |