Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import torch | |
| import os | |
| import shutil | |
| import tempfile | |
| import argparse | |
| import gradio as gr | |
| import sys | |
| import io | |
| import subprocess | |
| from PIL import Image | |
| import numpy as np | |
| from hydra import initialize, compose | |
| import hydra | |
| from omegaconf import OmegaConf | |
| import time | |
| def install_submodules(): | |
| subprocess.check_call(['pip', 'install', './submodules/RoMa']) | |
| STATIC_FILE_SERVING_FOLDER = "./served_files" | |
| MODEL_PATH = None | |
| os.makedirs(STATIC_FILE_SERVING_FOLDER, exist_ok=True) | |
| trainer = None | |
| class Tee(io.TextIOBase): | |
| def __init__(self, *streams): | |
| self.streams = streams | |
| def write(self, data): | |
| for stream in self.streams: | |
| stream.write(data) | |
| return len(data) | |
| def flush(self): | |
| for stream in self.streams: | |
| stream.flush() | |
| def capture_logs(func, *args, **kwargs): | |
| log_capture_string = io.StringIO() | |
| tee = Tee(sys.__stdout__, log_capture_string) | |
| with contextlib.redirect_stdout(tee): | |
| result = func(*args, **kwargs) | |
| return result, log_capture_string.getvalue() | |
| # Training Pipeline | |
| def run_training_pipeline(scene_dir, | |
| num_ref_views=16, | |
| num_corrs_per_view=20000, | |
| num_steps=1_000, | |
| mode_toggle="Ours (EDGS)"): | |
| with initialize(config_path="./configs", version_base="1.1"): | |
| cfg = compose(config_name="train") | |
| scene_name = os.path.basename(scene_dir) | |
| model_output_dir = f"./outputs/{scene_name}_trained" | |
| cfg.wandb.mode = "disabled" | |
| cfg.gs.dataset.model_path = model_output_dir | |
| cfg.gs.dataset.source_path = scene_dir | |
| cfg.gs.dataset.images = "images" | |
| cfg.gs.opt.TEST_CAM_IDX_TO_LOG = 12 | |
| cfg.train.gs_epochs = 30000 | |
| if mode_toggle=="Ours (EDGS)": | |
| cfg.gs.opt.opacity_reset_interval = 1_000_000 | |
| cfg.train.reduce_opacity = True | |
| cfg.train.no_densify = True | |
| cfg.train.max_lr = True | |
| cfg.init_wC.use = True | |
| cfg.init_wC.matches_per_ref = num_corrs_per_view | |
| cfg.init_wC.nns_per_ref = 1 | |
| cfg.init_wC.num_refs = num_ref_views | |
| cfg.init_wC.add_SfM_init = False | |
| cfg.init_wC.scaling_factor = 0.00077 * 2. | |
| set_seed(cfg.seed) | |
| os.makedirs(cfg.gs.dataset.model_path, exist_ok=True) | |
| global trainer | |
| global MODEL_PATH | |
| generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False) | |
| trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=cfg.device, log_wandb=cfg.wandb.mode != 'disabled') | |
| # Disable evaluation and saving | |
| trainer.saving_iterations = [] | |
| trainer.evaluate_iterations = [] | |
| # Initialize | |
| trainer.timer.start() | |
| start_time = time.time() | |
| trainer.init_with_corr(cfg.init_wC, roma_model=roma_model) | |
| time_for_init = time.time()-start_time | |
| viewpoint_cams = trainer.GS.scene.getTrainCameras() | |
| path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams, | |
| n_selected=6, | |
| n_points_per_segment=30, | |
| closed=False) | |
| path_cameras = path_cameras + path_cameras[::-1] | |
| path_renderings = [] | |
| idx = 0 | |
| # Visualize after init | |
| for _ in range(120): | |
| with torch.no_grad(): | |
| viewpoint_cam = path_cameras[idx] | |
| idx = (idx + 1) % len(path_cameras) | |
| render_pkg = trainer.GS(viewpoint_cam) | |
| image = render_pkg["render"] | |
| image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) | |
| image_np = (image_np * 255).astype(np.uint8) | |
| path_renderings.append(put_text_on_image(img=image_np, | |
| text=f"Init stage.\nTime:{time_for_init:.3f}s. ")) | |
| path_renderings = path_renderings + [put_text_on_image(img=image_np, text=f"Start fitting.\nTime:{time_for_init:.3f}s. ")]*30 | |
| # Train and save visualizations during training. | |
| start_time = time.time() | |
| for _ in range(int(num_steps//10)): | |
| with torch.no_grad(): | |
| viewpoint_cam = path_cameras[idx] | |
| idx = (idx + 1) % len(path_cameras) | |
| render_pkg = trainer.GS(viewpoint_cam) | |
| image = render_pkg["render"] | |
| image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) | |
| image_np = (image_np * 255).astype(np.uint8) | |
| path_renderings.append(put_text_on_image( | |
| img=image_np, | |
| text=f"Fitting stage.\nTime:{time_for_init + time.time()-start_time:.3f}s. ")) | |
| cfg.train.gs_epochs = 10 | |
| trainer.train(cfg.train) | |
| print(f"Time elapsed: {(time_for_init + time.time()-start_time):.2f}s.") | |
| # if (cfg.init_wC.use == False) and (time_for_init + time.time()-start_time) > 60: | |
| # break | |
| final_time = time.time() | |
| # Add static frame. To highlight we're done | |
| path_renderings += [put_text_on_image( | |
| img=image_np, text=f"Done.\nTime:{time_for_init + final_time -start_time:.3f}s. ")]*30 | |
| # Final rendering at the end. | |
| for _ in range(len(path_cameras)): | |
| with torch.no_grad(): | |
| viewpoint_cam = path_cameras[idx] | |
| idx = (idx + 1) % len(path_cameras) | |
| render_pkg = trainer.GS(viewpoint_cam) | |
| image = render_pkg["render"] | |
| image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) | |
| image_np = (image_np * 255).astype(np.uint8) | |
| path_renderings.append(put_text_on_image(img=image_np, | |
| text=f"Final result.\nTime:{time_for_init + final_time -start_time:.3f}s. ")) | |
| trainer.save_model() | |
| final_video_path = os.path.join(STATIC_FILE_SERVING_FOLDER, f"{scene_name}_final.mp4") | |
| save_numpy_frames_as_mp4(frames=path_renderings, output_path=final_video_path, fps=30, center_crop=0.85) | |
| MODEL_PATH = cfg.gs.dataset.model_path | |
| ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply") | |
| shutil.copy(ply_path, os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply")) | |
| return final_video_path, ply_path | |
| # Gradio Interface | |
| def gradio_interface(input_path, num_ref_views, num_corrs, num_steps): | |
| images, scene_dir = run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024) | |
| shutil.copytree(scene_dir, STATIC_FILE_SERVING_FOLDER+'/scene_colmaped', dirs_exist_ok=True) | |
| (final_video_path, ply_path), log_output = capture_logs(run_training_pipeline, | |
| scene_dir, | |
| num_ref_views, | |
| num_corrs, | |
| num_steps | |
| ) | |
| images_rgb = [img[:, :, ::-1] for img in images] | |
| return images_rgb, final_video_path, scene_dir, ply_path, log_output | |
| # Dummy Render Functions | |
| def render_all_views(scene_dir): | |
| viewpoint_cams = trainer.GS.scene.getTrainCameras() | |
| path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams, | |
| n_selected=8, | |
| n_points_per_segment=60, | |
| closed=False) | |
| path_cameras = path_cameras + path_cameras[::-1] | |
| path_renderings = [] | |
| with torch.no_grad(): | |
| for viewpoint_cam in path_cameras: | |
| render_pkg = trainer.GS(viewpoint_cam) | |
| image = render_pkg["render"] | |
| image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) | |
| image_np = (image_np * 255).astype(np.uint8) | |
| path_renderings.append(image_np) | |
| save_numpy_frames_as_mp4(frames=path_renderings, | |
| output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4"), | |
| fps=30, | |
| center_crop=0.85) | |
| return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4") | |
| def render_circular_path(scene_dir): | |
| viewpoint_cams = trainer.GS.scene.getTrainCameras() | |
| path_cameras = generate_circular_camera_path(existing_cameras=viewpoint_cams, | |
| N=240, | |
| radius_scale=0.65, | |
| d=0) | |
| path_renderings = [] | |
| with torch.no_grad(): | |
| for viewpoint_cam in path_cameras: | |
| render_pkg = trainer.GS(viewpoint_cam) | |
| image = render_pkg["render"] | |
| image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) | |
| image_np = (image_np * 255).astype(np.uint8) | |
| path_renderings.append(image_np) | |
| save_numpy_frames_as_mp4(frames=path_renderings, | |
| output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4"), | |
| fps=30, | |
| center_crop=0.85) | |
| return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4") | |
| # Download Functions | |
| def download_cameras(): | |
| path = os.path.join(MODEL_PATH, "cameras.json") | |
| return f"[π₯ Download Cameras.json](file={path})" | |
| def download_model(): | |
| path = os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply") | |
| return f"[π₯ Download Pretrained Model (.ply)](file={path})" | |
| # Full pipeline helpers | |
| def run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024): | |
| tmpdirname = tempfile.mkdtemp() | |
| scene_dir = os.path.join(tmpdirname, "scene") | |
| os.makedirs(scene_dir, exist_ok=True) | |
| selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size) | |
| run_colmap_on_scene(scene_dir) | |
| return selected_frames, scene_dir | |
| # Preprocess Input | |
| def process_input(input_path, num_ref_views, output_dir, max_size=1024): | |
| if isinstance(input_path, (str, os.PathLike)): | |
| if os.path.isdir(input_path): | |
| frames = [] | |
| for img_file in sorted(os.listdir(input_path)): | |
| if img_file.lower().endswith(('jpg', 'jpeg', 'png')): | |
| img = Image.open(os.path.join(output_dir, img_file)).convert('RGB') | |
| img.thumbnail((1024, 1024)) | |
| frames.append(np.array(img)) | |
| else: | |
| frames = read_video_frames(video_input=input_path, max_size=max_size) | |
| else: | |
| frames = read_video_frames(video_input=input_path, max_size=max_size) | |
| frames_scores = preprocess_frames(frames) | |
| selected_frames_indices = select_optimal_frames(scores=frames_scores, k=min(num_ref_views, len(frames))) | |
| selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices] | |
| save_frames_to_scene_dir(frames=selected_frames, scene_dir=output_dir) | |
| return selected_frames | |
| def preprocess_input(input_path, num_ref_views, max_size=1024): | |
| tmpdirname = tempfile.mkdtemp() | |
| scene_dir = os.path.join(tmpdirname, "scene") | |
| os.makedirs(scene_dir, exist_ok=True) | |
| selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size) | |
| run_colmap_on_scene(scene_dir) | |
| return selected_frames, scene_dir | |
| def start_training(scene_dir, num_ref_views, num_corrs, num_steps): | |
| return capture_logs(run_training_pipeline, scene_dir, num_ref_views, num_corrs, num_steps) | |
| # Gradio App | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| gr.Markdown(""" | |
| ## <span style='font-size: 20px;'>π EDGS: Eliminating Densification for Efficient Convergence of 3DGS</span> | |
| π <a href='https://compvis.github.io/EDGS' target='_blank'>Project Page</a> | |
| """, elem_id="header") | |
| gr.Markdown(""" | |
| ### <span style='font-size: 22px;'>π οΈ How to Use This Demo</span> | |
| 1. Upload a **front-facing video** or **a folder of images** of a **static** scene. | |
| 2. Use the sliders to configure the number of reference views, correspondences, and optimization steps. | |
| 3. First press on preprocess Input to extract frames from video(for videos) and COLMAP frames. | |
| 4.Then click **π Start Reconstruction** to actually launch the reconstruction pipeline. | |
| 5. Watch the training visualization and explore the 3D model. | |
| βΌοΈ **If you see nothing in the 3D model viewer**, try rotating or zooming β sometimes the initial camera orientation is off. | |
| β Best for scenes with small camera motion. | |
| β For full 360Β° or large-scale scenes, we recommend the Colab version (see project page). | |
| """, elem_id="quickstart") | |
| scene_dir_state = gr.State() | |
| ply_model_state = gr.State() | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_file = gr.File(label="Upload Video or Images", | |
| file_types=[".mp4", ".avi", ".mov", ".png", ".jpg", ".jpeg"], | |
| file_count="multiple") | |
| gr.Examples( | |
| examples = [ | |
| [["assets/examples/video_bakery.mp4"]], | |
| [["assets/examples/video_flowers.mp4"]], | |
| [["assets/examples/video_fruits.mp4"]], | |
| [["assets/examples/video_plant.mp4"]], | |
| [["assets/examples/video_salad.mp4"]], | |
| [["assets/examples/video_tram.mp4"]], | |
| [["assets/examples/video_tulips.mp4"]] | |
| ], | |
| inputs=[input_file], | |
| label="ποΈ ALternatively, try an Example Video", | |
| examples_per_page=4 | |
| ) | |
| ref_slider = gr.Slider(4, 32, value=16, step=1, label="Number of Reference Views") | |
| corr_slider = gr.Slider(5000, 30000, value=20000, step=1000, label="Correspondences per Reference View") | |
| fit_steps_slider = gr.Slider(100, 5000, value=400, step=100, label="Number of optimization steps") | |
| preprocess_button = gr.Button("πΈ Preprocess Input") | |
| start_button = gr.Button("π Start Reconstruction", interactive=False) | |
| gallery = gr.Gallery(label="Selected Reference Views", columns=4, height=300) | |
| with gr.Column(scale=3): | |
| gr.Markdown("### ποΈ Training Visualization") | |
| video_output = gr.Video(label="Training Video", autoplay=True) | |
| render_all_views_button = gr.Button("π₯ Render All-Views Path") | |
| render_circular_path_button = gr.Button("π₯ Render Circular Path") | |
| rendered_video_output = gr.Video(label="Rendered Video", autoplay=True) | |
| with gr.Column(scale=5): | |
| gr.Markdown("### π Final 3D Model") | |
| model3d_viewer = gr.Model3D(label="3D Model Viewer") | |
| gr.Markdown("### π¦ Output Files") | |
| with gr.Row(height=50): | |
| with gr.Column(): | |
| #gr.Markdown(value=f"[π₯ Download .ply](file/point_cloud_final.ply)") | |
| download_cameras_button = gr.Button("π₯ Download Cameras.json") | |
| download_cameras_file = gr.File(label="π Cameras.json") | |
| with gr.Column(): | |
| download_model_button = gr.Button("π₯ Download Pretrained Model (.ply)") | |
| download_model_file = gr.File(label="π Pretrained Model (.ply)") | |
| log_output_box = gr.Textbox(label="π₯οΈ Log", lines=10, interactive=False) | |
| def on_preprocess_click(input_file, num_ref_views): | |
| images, scene_dir = preprocess_input(input_file, num_ref_views) | |
| return gr.update(value=[x[...,::-1] for x in images]), scene_dir, gr.update(interactive=True) | |
| def on_start_click(scene_dir, num_ref_views, num_corrs, num_steps): | |
| (video_path, ply_path), logs = start_training(scene_dir, num_ref_views, num_corrs, num_steps) | |
| return video_path, ply_path, logs | |
| preprocess_button.click( | |
| fn=on_preprocess_click, | |
| inputs=[input_file, ref_slider], | |
| outputs=[gallery, scene_dir_state, start_button] | |
| ) | |
| start_button.click( | |
| fn=on_start_click, | |
| inputs=[scene_dir_state, ref_slider, corr_slider, fit_steps_slider], | |
| outputs=[video_output, model3d_viewer, log_output_box] | |
| ) | |
| render_all_views_button.click(fn=render_all_views, inputs=[scene_dir_state], outputs=[rendered_video_output]) | |
| render_circular_path_button.click(fn=render_circular_path, inputs=[scene_dir_state], outputs=[rendered_video_output]) | |
| download_cameras_button.click(fn=lambda: os.path.join(MODEL_PATH, "cameras.json"), inputs=[], outputs=[download_cameras_file]) | |
| download_model_button.click(fn=lambda: os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"), inputs=[], outputs=[download_model_file]) | |
| gr.Markdown(""" | |
| --- | |
| ### <span style='font-size: 20px;'>π Detailed Overview</span> | |
| If you uploaded a video, it will be automatically cut into a smaller number of frames (default: 16). | |
| The model pipeline: | |
| 1. π§ Runs PyCOLMAP to estimate camera intrinsics & poses (~3β7 seconds for <16 images). | |
| 2. π Computes 2D-2D correspondences between views. More correspondences generally improve quality. | |
| 3. π§ Optimizes a 3D Gaussian Splatting model for several steps. | |
| ### π₯ Training Visualization | |
| You will see a visualization of the entire training process in the "Training Video" pane. | |
| ### π Rendering & 3D Model | |
| - Render the scene from a circular path of novel views. | |
| - Or from camera views close to the original input. | |
| The 3D model is shown in the right viewer. You can explore it interactively: | |
| - On PC: WASD keys, arrow keys, and mouse clicks | |
| - On mobile: pan and pinch to zoom | |
| π Note: the 3D viewer takes a few extra seconds (~5s) to display after training ends. | |
| --- | |
| Preloaded models coming soon. (TODO) | |
| """, elem_id="details") | |
| if __name__ == "__main__": | |
| install_submodules() | |
| from source.utils_aux import set_seed | |
| from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene | |
| from source.trainer import EDGSTrainer | |
| from source.visualization import generate_circular_camera_path, save_numpy_frames_as_mp4, generate_fully_smooth_cameras_with_tsp, put_text_on_image | |
| # Init RoMA model: | |
| sys.path.append('../submodules/RoMa') | |
| from romatch import roma_outdoor, roma_indoor | |
| roma_model = roma_indoor(device="cpu") | |
| roma_model = roma_model.to("cuda") | |
| roma_model.upsample_preds = False | |
| roma_model.symmetric = False | |
| demo.launch(share=True) | |