Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import glob | |
| import time | |
| import threading | |
| import argparse | |
| from typing import List, Optional | |
| import numpy as np | |
| import torch | |
| from tqdm.auto import tqdm | |
| import viser | |
| import viser.transforms as viser_tf | |
| import cv2 | |
| try: | |
| import onnxruntime | |
| except ImportError: | |
| print("onnxruntime not found. Sky segmentation may not work.") | |
| from visual_util import segment_sky, download_file_from_url | |
| from vggt.models.vggt import VGGT | |
| from vggt.utils.load_fn import load_and_preprocess_images | |
| from vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map | |
| from vggt.utils.pose_enc import pose_encoding_to_extri_intri | |
| def viser_wrapper( | |
| pred_dict: dict, | |
| port: int = 8080, | |
| init_conf_threshold: float = 50.0, # represents percentage (e.g., 50 means filter lowest 50%) | |
| use_point_map: bool = False, | |
| background_mode: bool = False, | |
| mask_sky: bool = False, | |
| image_folder: str = None, | |
| ): | |
| """ | |
| Visualize predicted 3D points and camera poses with viser. | |
| Args: | |
| pred_dict (dict): | |
| { | |
| "images": (S, 3, H, W) - Input images, | |
| "world_points": (S, H, W, 3), | |
| "world_points_conf": (S, H, W), | |
| "depth": (S, H, W, 1), | |
| "depth_conf": (S, H, W), | |
| "extrinsic": (S, 3, 4), | |
| "intrinsic": (S, 3, 3), | |
| } | |
| port (int): Port number for the viser server. | |
| init_conf_threshold (float): Initial percentage of low-confidence points to filter out. | |
| use_point_map (bool): Whether to visualize world_points or use depth-based points. | |
| background_mode (bool): Whether to run the server in background thread. | |
| mask_sky (bool): Whether to apply sky segmentation to filter out sky points. | |
| image_folder (str): Path to the folder containing input images. | |
| """ | |
| print(f"Starting viser server on port {port}") | |
| server = viser.ViserServer(host="0.0.0.0", port=port) | |
| server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") | |
| # Unpack prediction dict | |
| images = pred_dict["images"] # (S, 3, H, W) | |
| world_points_map = pred_dict["world_points"] # (S, H, W, 3) | |
| conf_map = pred_dict["world_points_conf"] # (S, H, W) | |
| depth_map = pred_dict["depth"] # (S, H, W, 1) | |
| depth_conf = pred_dict["depth_conf"] # (S, H, W) | |
| extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4) | |
| intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3) | |
| # Compute world points from depth if not using the precomputed point map | |
| if not use_point_map: | |
| world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam) | |
| conf = depth_conf | |
| else: | |
| world_points = world_points_map | |
| conf = conf_map | |
| # Apply sky segmentation if enabled | |
| if mask_sky and image_folder is not None: | |
| conf = apply_sky_segmentation(conf, image_folder) | |
| # Convert images from (S, 3, H, W) to (S, H, W, 3) | |
| # Then flatten everything for the point cloud | |
| colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3) | |
| S, H, W, _ = world_points.shape | |
| # Flatten | |
| points = world_points.reshape(-1, 3) | |
| colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8) | |
| conf_flat = conf.reshape(-1) | |
| cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) # shape (S, 4, 4) typically | |
| # For convenience, we store only (3,4) portion | |
| cam_to_world = cam_to_world_mat[:, :3, :] | |
| # Compute scene center and recenter | |
| scene_center = np.mean(points, axis=0) | |
| points_centered = points - scene_center | |
| cam_to_world[..., -1] -= scene_center | |
| # Store frame indices so we can filter by frame | |
| frame_indices = np.repeat(np.arange(S), H * W) | |
| # Build the viser GUI | |
| gui_show_frames = server.gui.add_checkbox( | |
| "Show Cameras", | |
| initial_value=True, | |
| ) | |
| # Now the slider represents percentage of points to filter out | |
| gui_points_conf = server.gui.add_slider( | |
| "Confidence Percent", | |
| min=0, | |
| max=100, | |
| step=0.1, | |
| initial_value=init_conf_threshold, | |
| ) | |
| gui_frame_selector = server.gui.add_dropdown( | |
| "Show Points from Frames", | |
| options=["All"] + [str(i) for i in range(S)], | |
| initial_value="All", | |
| ) | |
| # Create the main point cloud handle | |
| # Compute the threshold value as the given percentile | |
| init_threshold_val = np.percentile(conf_flat, init_conf_threshold) | |
| init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1) | |
| point_cloud = server.scene.add_point_cloud( | |
| name="viser_pcd", | |
| points=points_centered[init_conf_mask], | |
| colors=colors_flat[init_conf_mask], | |
| point_size=0.001, | |
| point_shape="circle", | |
| ) | |
| # We will store references to frames & frustums so we can toggle visibility | |
| frames: List[viser.FrameHandle] = [] | |
| frustums: List[viser.CameraFrustumHandle] = [] | |
| def visualize_frames(extrinsics: np.ndarray, images_: np.ndarray) -> None: | |
| """ | |
| Add camera frames and frustums to the scene. | |
| extrinsics: (S, 3, 4) | |
| images_: (S, 3, H, W) | |
| """ | |
| # Clear any existing frames or frustums | |
| for f in frames: | |
| f.remove() | |
| frames.clear() | |
| for fr in frustums: | |
| fr.remove() | |
| frustums.clear() | |
| # Optionally attach a callback that sets the viewpoint to the chosen camera | |
| def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None: | |
| def _(_) -> None: | |
| for client in server.get_clients().values(): | |
| client.camera.wxyz = frame.wxyz | |
| client.camera.position = frame.position | |
| img_ids = range(S) | |
| for img_id in tqdm(img_ids): | |
| cam2world_3x4 = extrinsics[img_id] | |
| T_world_camera = viser_tf.SE3.from_matrix(cam2world_3x4) | |
| # Add a small frame axis | |
| frame_axis = server.scene.add_frame( | |
| f"frame_{img_id}", | |
| wxyz=T_world_camera.rotation().wxyz, | |
| position=T_world_camera.translation(), | |
| axes_length=0.05, | |
| axes_radius=0.002, | |
| origin_radius=0.002, | |
| ) | |
| frames.append(frame_axis) | |
| # Convert the image for the frustum | |
| img = images_[img_id] # shape (3, H, W) | |
| img = (img.transpose(1, 2, 0) * 255).astype(np.uint8) | |
| h, w = img.shape[:2] | |
| # If you want correct FOV from intrinsics, do something like: | |
| # fx = intrinsics_cam[img_id, 0, 0] | |
| # fov = 2 * np.arctan2(h/2, fx) | |
| # For demonstration, we pick a simple approximate FOV: | |
| fy = 1.1 * h | |
| fov = 2 * np.arctan2(h / 2, fy) | |
| # Add the frustum | |
| frustum_cam = server.scene.add_camera_frustum( | |
| f"frame_{img_id}/frustum", | |
| fov=fov, | |
| aspect=w / h, | |
| scale=0.05, | |
| image=img, | |
| line_width=1.0, | |
| ) | |
| frustums.append(frustum_cam) | |
| attach_callback(frustum_cam, frame_axis) | |
| def update_point_cloud() -> None: | |
| """Update the point cloud based on current GUI selections.""" | |
| # Here we compute the threshold value based on the current percentage | |
| current_percentage = gui_points_conf.value | |
| threshold_val = np.percentile(conf_flat, current_percentage) | |
| print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%") | |
| conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5) | |
| if gui_frame_selector.value == "All": | |
| frame_mask = np.ones_like(conf_mask, dtype=bool) | |
| else: | |
| selected_idx = int(gui_frame_selector.value) | |
| frame_mask = frame_indices == selected_idx | |
| combined_mask = conf_mask & frame_mask | |
| point_cloud.points = points_centered[combined_mask] | |
| point_cloud.colors = colors_flat[combined_mask] | |
| def _(_) -> None: | |
| update_point_cloud() | |
| def _(_) -> None: | |
| update_point_cloud() | |
| def _(_) -> None: | |
| """Toggle visibility of camera frames and frustums.""" | |
| for f in frames: | |
| f.visible = gui_show_frames.value | |
| for fr in frustums: | |
| fr.visible = gui_show_frames.value | |
| # Add the camera frames to the scene | |
| visualize_frames(cam_to_world, images) | |
| print("Starting viser server...") | |
| # If background_mode is True, spawn a daemon thread so the main thread can continue. | |
| if background_mode: | |
| def server_loop(): | |
| while True: | |
| time.sleep(0.001) | |
| thread = threading.Thread(target=server_loop, daemon=True) | |
| thread.start() | |
| else: | |
| while True: | |
| time.sleep(0.01) | |
| return server | |
| # Helper functions for sky segmentation | |
| def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray: | |
| """ | |
| Apply sky segmentation to confidence scores. | |
| Args: | |
| conf (np.ndarray): Confidence scores with shape (S, H, W) | |
| image_folder (str): Path to the folder containing input images | |
| Returns: | |
| np.ndarray: Updated confidence scores with sky regions masked out | |
| """ | |
| S, H, W = conf.shape | |
| sky_masks_dir = image_folder.rstrip('/') + "_sky_masks" | |
| os.makedirs(sky_masks_dir, exist_ok=True) | |
| # Download skyseg.onnx if it doesn't exist | |
| if not os.path.exists("skyseg.onnx"): | |
| print("Downloading skyseg.onnx...") | |
| download_file_from_url( | |
| "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx" | |
| ) | |
| skyseg_session = onnxruntime.InferenceSession("skyseg.onnx") | |
| image_files = sorted(glob.glob(os.path.join(image_folder, "*"))) | |
| sky_mask_list = [] | |
| print("Generating sky masks...") | |
| for i, image_path in enumerate(tqdm(image_files[:S])): # Limit to the number of images in the batch | |
| image_name = os.path.basename(image_path) | |
| mask_filepath = os.path.join(sky_masks_dir, image_name) | |
| if os.path.exists(mask_filepath): | |
| sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) | |
| else: | |
| sky_mask = segment_sky(image_path, skyseg_session, mask_filepath) | |
| # Resize mask to match H×W if needed | |
| if sky_mask.shape[0] != H or sky_mask.shape[1] != W: | |
| sky_mask = cv2.resize(sky_mask, (W, H)) | |
| sky_mask_list.append(sky_mask) | |
| # Convert list to numpy array with shape S×H×W | |
| sky_mask_array = np.array(sky_mask_list) | |
| # Apply sky mask to confidence scores | |
| sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32) | |
| conf = conf * sky_mask_binary | |
| print("Sky segmentation applied successfully") | |
| return conf | |
| parser = argparse.ArgumentParser(description="VGGT demo with viser for 3D visualization") | |
| parser.add_argument( | |
| "--image_folder", type=str, default="examples/kitchen/images/", help="Path to folder containing images" | |
| ) | |
| parser.add_argument("--use_point_map", action="store_true", help="Use point map instead of depth-based points") | |
| parser.add_argument("--background_mode", action="store_true", help="Run the viser server in background mode") | |
| parser.add_argument("--port", type=int, default=8080, help="Port number for the viser server") | |
| parser.add_argument( | |
| "--conf_threshold", type=float, default=25.0, help="Initial percentage of low-confidence points to filter out" | |
| ) | |
| parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points") | |
| def main(): | |
| """ | |
| Main function for the VGGT demo with viser for 3D visualization. | |
| This function: | |
| 1. Loads the VGGT model | |
| 2. Processes input images from the specified folder | |
| 3. Runs inference to generate 3D points and camera poses | |
| 4. Optionally applies sky segmentation to filter out sky points | |
| 5. Visualizes the results using viser | |
| Command-line arguments: | |
| --image_folder: Path to folder containing input images | |
| --use_point_map: Use point map instead of depth-based points | |
| --background_mode: Run the viser server in background mode | |
| --port: Port number for the viser server | |
| --conf_threshold: Initial percentage of low-confidence points to filter out | |
| --mask_sky: Apply sky segmentation to filter out sky points | |
| """ | |
| args = parser.parse_args() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| print("Initializing and loading VGGT model...") | |
| # model = VGGT.from_pretrained("facebook/VGGT-1B") | |
| model = VGGT() | |
| _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" | |
| model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) | |
| model.eval() | |
| model = model.to(device) | |
| # Use the provided image folder path | |
| print(f"Loading images from {args.image_folder}...") | |
| image_names = glob.glob(os.path.join(args.image_folder, "*")) | |
| print(f"Found {len(image_names)} images") | |
| images = load_and_preprocess_images(image_names).to(device) | |
| print(f"Preprocessed images shape: {images.shape}") | |
| print("Running inference...") | |
| with torch.no_grad(): | |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
| predictions = model(images) | |
| print("Converting pose encoding to extrinsic and intrinsic matrices...") | |
| extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) | |
| predictions["extrinsic"] = extrinsic | |
| predictions["intrinsic"] = intrinsic | |
| print("Processing model outputs...") | |
| for key in predictions.keys(): | |
| if isinstance(predictions[key], torch.Tensor): | |
| predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension and convert to numpy | |
| if args.use_point_map: | |
| print("Visualizing 3D points from point map") | |
| else: | |
| print("Visualizing 3D points by unprojecting depth map by cameras") | |
| if args.mask_sky: | |
| print("Sky segmentation enabled - will filter out sky points") | |
| print("Starting viser visualization...") | |
| viser_server = viser_wrapper( | |
| predictions, | |
| port=args.port, | |
| init_conf_threshold=args.conf_threshold, | |
| use_point_map=args.use_point_map, | |
| background_mode=args.background_mode, | |
| mask_sky=args.mask_sky, | |
| image_folder=args.image_folder, | |
| ) | |
| print("Visualization complete") | |
| if __name__ == "__main__": | |
| main() |