Spaces:
Running
on
Zero
Running
on
Zero
| # wzw | |
| """ | |
| Utilities for saving images, depths, normals, point clouds, and Gaussian splat data. | |
| tencent | |
| """ | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from plyfile import PlyData, PlyElement | |
| from io import BytesIO | |
| import json | |
| import os | |
| def save_camera_params(extrinsics, intrinsics, target_dir): | |
| """ | |
| Save camera parameters (extrinsics and intrinsics) in JSON format | |
| Args: | |
| extrinsics: numpy array, shape [N, 4, 4] - extrinsic matrices for N cameras | |
| intrinsics: numpy array, shape [N, 3, 3] - intrinsic matrices for N cameras | |
| target_dir: str - directory to save the parameters | |
| Returns: | |
| str: path to the saved file | |
| """ | |
| camera_data = { | |
| "num_cameras": int(extrinsics.shape[0]), | |
| "extrinsics": [], | |
| "intrinsics": [] | |
| } | |
| # Convert each camera's parameters to list format | |
| for i in range(extrinsics.shape[0]): | |
| camera_data["extrinsics"].append({ | |
| "camera_id": i, | |
| "matrix": extrinsics[i].tolist() # [4, 4] -> list | |
| }) | |
| camera_data["intrinsics"].append({ | |
| "camera_id": i, | |
| "matrix": intrinsics[i].tolist() # [3, 3] -> list | |
| }) | |
| # Save as JSON file | |
| camera_params_path = os.path.join(target_dir, "camera_params.json") | |
| with open(camera_params_path, 'w') as f: | |
| json.dump(camera_data, f, indent=2) | |
| return camera_params_path | |
| def save_image_png(path: Path, image_tensor: torch.Tensor) -> None: | |
| # image_tensor: [H, W, 3] | |
| img = (image_tensor.detach().cpu() * 255.0).to(torch.uint8).numpy() | |
| Image.fromarray(img).save(str(path)) | |
| def save_depth_png(path: Path, depth_tensor: torch.Tensor) -> None: | |
| # depth_tensor: [H, W] | |
| d = depth_tensor.detach() | |
| d = d - d.min() | |
| d = d / (d.max() + 1e-9) | |
| img = (d.clamp(0, 1) * 255.0).to(torch.uint8).cpu().numpy() | |
| Image.fromarray(img, mode="L").save(str(path)) | |
| def save_depth_npy(path: Path, depth_tensor: torch.Tensor) -> None: | |
| # depth_tensor: [H, W] | |
| # Save actual depth values in numpy format | |
| d = depth_tensor.detach().cpu().numpy() | |
| np.save(str(path), d) | |
| def save_normal_png(path: Path, normal_hwc: torch.Tensor) -> None: | |
| # normal_hwc: [H, W, 3], in [-1, 1] | |
| n = (normal_hwc.detach().cpu() + 1.0) * 0.5 | |
| img = (n.clamp(0, 1) * 255.0).to(torch.uint8).numpy() | |
| Image.fromarray(img).save(str(path)) | |
| def save_scene_ply(path: Path, | |
| points_xyz: torch.Tensor, | |
| point_colors: torch.Tensor, | |
| valid_mask: torch.Tensor = None) -> None: | |
| """Save point cloud to PLY format""" | |
| pts = points_xyz.detach().cpu().to(torch.float32).numpy().reshape(-1, 3) | |
| colors = point_colors.detach().cpu().to(torch.uint8).numpy().reshape(-1, 3) | |
| # Filter out invalid points (NaN, Inf) | |
| if valid_mask is None: | |
| valid_mask = np.isfinite(pts).all(axis=1) | |
| else: | |
| valid_mask = valid_mask.detach().cpu().numpy().reshape(-1) | |
| pts = pts[valid_mask] | |
| colors = colors[valid_mask] | |
| # Handle empty point cloud | |
| if len(pts) == 0: | |
| pts = np.array([[0, 0, 0]], dtype=np.float32) | |
| colors = np.array([[255, 255, 255]], dtype=np.uint8) | |
| # Create PLY data | |
| vertex_dtype = [("x", "f4"), ("y", "f4"), ("z", "f4"), | |
| ("red", "u1"), ("green", "u1"), ("blue", "u1")] | |
| vertex_elements = np.empty(len(pts), dtype=vertex_dtype) | |
| vertex_elements["x"] = pts[:, 0] | |
| vertex_elements["y"] = pts[:, 1] | |
| vertex_elements["z"] = pts[:, 2] | |
| vertex_elements["red"] = colors[:, 0] | |
| vertex_elements["green"] = colors[:, 1] | |
| vertex_elements["blue"] = colors[:, 2] | |
| # Write PLY file | |
| PlyData([PlyElement.describe(vertex_elements, "vertex")]).write(str(path)) | |
| def save_points_ply(path: Path, pts_np: np.ndarray, cols_np: np.ndarray) -> None: | |
| """Save point cloud to PLY format from numpy arrays""" | |
| vertex_dtype = [("x", "f4"), ("y", "f4"), ("z", "f4"), | |
| ("red", "u1"), ("green", "u1"), ("blue", "u1")] | |
| vertex_elements = np.empty(len(pts_np), dtype=vertex_dtype) | |
| vertex_elements["x"] = pts_np[:, 0] | |
| vertex_elements["y"] = pts_np[:, 1] | |
| vertex_elements["z"] = pts_np[:, 2] | |
| vertex_elements["red"] = cols_np[:, 0] | |
| vertex_elements["green"] = cols_np[:, 1] | |
| vertex_elements["blue"] = cols_np[:, 2] | |
| # Write PLY file | |
| PlyData([PlyElement.describe(vertex_elements, "vertex")]).write(str(path)) | |
| def save_gs_ply(path: Path, | |
| means: torch.Tensor, | |
| scales: torch.Tensor, | |
| rotations: torch.Tensor, | |
| rgbs: torch.Tensor, | |
| opacities: torch.Tensor) -> None: | |
| """ | |
| Export Gaussian splat data to PLY format. | |
| Args: | |
| path: Output PLY file path | |
| means: Gaussian centers [N, 3] | |
| scales: Gaussian scales [N, 3] | |
| rotations: Gaussian rotations as quaternions [N, 4] | |
| rgbs: RGB colors [N, 3] | |
| opacities: Opacity values [N] | |
| """ | |
| # Filter out points with scales greater than the 95th percentile | |
| scale_threshold = torch.quantile(scales.max(dim=-1)[0], 0.95, dim=0) | |
| filter_mask = scales.max(dim=-1)[0] <= scale_threshold | |
| # Apply the filter to all tensors | |
| means = means[filter_mask].reshape(-1, 3) | |
| scales = scales[filter_mask].reshape(-1, 3) | |
| rotations = rotations[filter_mask].reshape(-1, 4) | |
| rgbs = rgbs[filter_mask].reshape(-1, 3) | |
| opacities = opacities[filter_mask].reshape(-1) | |
| # Construct attribute names | |
| attributes = ["x", "y", "z", "nx", "ny", "nz"] | |
| for i in range(3): | |
| attributes.append(f"f_dc_{i}") | |
| attributes.append("opacity") | |
| for i in range(3): | |
| attributes.append(f"scale_{i}") | |
| for i in range(4): | |
| attributes.append(f"rot_{i}") | |
| # Prepare PLY data structure | |
| dtype_full = [(attribute, "f4") for attribute in attributes] | |
| elements = np.empty(means.shape[0], dtype=dtype_full) | |
| # Concatenate all attributes | |
| attributes_data = ( | |
| means.float().detach().cpu().numpy(), | |
| torch.zeros_like(means).float().detach().cpu().numpy(), | |
| rgbs.detach().cpu().contiguous().numpy(), | |
| opacities[..., None].detach().cpu().numpy(), | |
| scales.log().detach().cpu().numpy(), | |
| rotations.detach().cpu().numpy(), | |
| ) | |
| attributes_data = np.concatenate(attributes_data, axis=1) | |
| elements[:] = list(map(tuple, attributes_data)) | |
| # Write to PLY file | |
| PlyData([PlyElement.describe(elements, "vertex")]).write(str(path)) | |
| def convert_gs_to_ply(means, scales, rotations, rgbs, opacities): | |
| """ | |
| Export Gaussian splat data to PLY format. | |
| Args: | |
| path: Output PLY file path | |
| means: Gaussian centers [N, 3] | |
| scales: Gaussian scales [N, 3] | |
| rotations: Gaussian rotations as quaternions [N, 4] | |
| rgbs: RGB colors [N, 3] | |
| opacities: Opacity values [N] | |
| """ | |
| # Filter out points with scales greater than the 90th percentile | |
| scale_threshold = torch.quantile(scales.max(dim=-1)[0], 0.90, dim=0) | |
| filter_mask = scales.max(dim=-1)[0] <= scale_threshold | |
| # Apply the filter to all tensors | |
| means = means[filter_mask].reshape(-1, 3) | |
| scales = scales[filter_mask].reshape(-1, 3) | |
| rotations = rotations[filter_mask].reshape(-1, 4) | |
| rgbs = rgbs[filter_mask].reshape(-1, 3) | |
| opacities = opacities[filter_mask].reshape(-1) | |
| # Construct attribute names | |
| attributes = ["x", "y", "z", "nx", "ny", "nz"] | |
| for i in range(3): | |
| attributes.append(f"f_dc_{i}") | |
| attributes.append("opacity") | |
| for i in range(3): | |
| attributes.append(f"scale_{i}") | |
| for i in range(4): | |
| attributes.append(f"rot_{i}") | |
| # Prepare PLY data structure | |
| dtype_full = [(attribute, "f4") for attribute in attributes] | |
| elements = np.empty(means.shape[0], dtype=dtype_full) | |
| # Concatenate all attributes | |
| attributes_data = ( | |
| means.float().detach().cpu().numpy(), | |
| torch.zeros_like(means).float().detach().cpu().numpy(), | |
| rgbs.detach().cpu().contiguous().numpy(), | |
| opacities[..., None].detach().cpu().numpy(), | |
| scales.log().detach().cpu().numpy(), | |
| rotations.detach().cpu().numpy(), | |
| ) | |
| attributes_data = np.concatenate(attributes_data, axis=1) | |
| elements[:] = list(map(tuple, attributes_data)) | |
| plydata = PlyData([PlyElement.describe(elements, "vertex")]) | |
| return plydata | |
| def process_ply_to_splat(plydata, output_path): | |
| vert = plydata["vertex"] | |
| sorted_indices = np.argsort( | |
| -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"]) | |
| / (1 + np.exp(-vert["opacity"])) | |
| ) | |
| buffer = BytesIO() | |
| for idx in sorted_indices: | |
| v = plydata["vertex"][idx] | |
| position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32) | |
| scales = np.exp( | |
| np.array( | |
| [v["scale_0"], v["scale_1"], v["scale_2"]], | |
| dtype=np.float32, | |
| ) | |
| ) | |
| rot = np.array( | |
| [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], | |
| dtype=np.float32, | |
| ) | |
| SH_C0 = 0.28209479177387814 | |
| color = np.array( | |
| [ | |
| 0.5 + SH_C0 * v["f_dc_0"], | |
| 0.5 + SH_C0 * v["f_dc_1"], | |
| 0.5 + SH_C0 * v["f_dc_2"], | |
| 1 / (1 + np.exp(-v["opacity"])), | |
| ] | |
| ) | |
| buffer.write(position.tobytes()) | |
| buffer.write(scales.tobytes()) | |
| buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) | |
| buffer.write( | |
| ((rot / np.linalg.norm(rot)) * 128 + 128) | |
| .clip(0, 255) | |
| .astype(np.uint8) | |
| .tobytes() | |
| ) | |
| value = buffer.getvalue() | |
| with open(output_path, "wb") as f: | |
| f.write(value) | |
| return output_path | |