Spaces:
Runtime error
Runtime error
| ''' | |
| python cli_all_app.py --input_img_path 战场原.webp --preset_traj "orbit" "spiral" "lemniscate" "zoom-in" "zoom-out" "dolly zoom-in" "dolly zoom-out" "move-forward" "move-backward" "move-up" "move-down" "move-left" "move-right" --output_dir 战场原 | |
| ''' | |
| import copy | |
| import json | |
| import os | |
| import os.path as osp | |
| import queue | |
| import secrets | |
| import threading | |
| import time | |
| from datetime import datetime | |
| from glob import glob | |
| from pathlib import Path | |
| from typing import Literal, List | |
| import imageio.v3 as iio | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import tyro | |
| import viser | |
| import viser.transforms as vt | |
| from einops import rearrange | |
| from seva.eval import ( | |
| IS_TORCH_NIGHTLY, | |
| chunk_input_and_test, | |
| create_transforms_simple, | |
| infer_prior_stats, | |
| run_one_scene, | |
| transform_img_and_K, | |
| ) | |
| from seva.geometry import ( | |
| DEFAULT_FOV_RAD, | |
| get_default_intrinsics, | |
| get_preset_pose_fov, | |
| normalize_scene, | |
| ) | |
| from seva.model import SGMWrapper | |
| from seva.modules.autoencoder import AutoEncoder | |
| from seva.modules.conditioner import CLIPConditioner | |
| from seva.modules.preprocessor import Dust3rPipeline | |
| from seva.sampling import DDPMDiscretization, DiscreteDenoiser | |
| from seva.utils import load_model | |
| device = "cuda:0" | |
| # Constants. | |
| WORK_DIR = "work_dirs/demo_gr" | |
| MAX_SESSIONS = 1 | |
| if IS_TORCH_NIGHTLY: | |
| COMPILE = True | |
| os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1" | |
| os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" | |
| else: | |
| COMPILE = False | |
| # Shared global variables across sessions. | |
| DUST3R = Dust3rPipeline(device=device) # type: ignore | |
| MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device) | |
| AE = AutoEncoder(chunk_size=1).to(device) | |
| CONDITIONER = CLIPConditioner().to(device) | |
| DISCRETIZATION = DDPMDiscretization() | |
| DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device) | |
| VERSION_DICT = { | |
| "H": 576, | |
| "W": 576, | |
| "T": 21, | |
| "C": 4, | |
| "f": 8, | |
| "options": {}, | |
| } | |
| SERVERS = {} | |
| ABORT_EVENTS = {} | |
| if COMPILE: | |
| MODEL = torch.compile(MODEL) | |
| CONDITIONER = torch.compile(CONDITIONER) | |
| AE = torch.compile(AE) | |
| class SevaRenderer(object): | |
| def __init__(self): | |
| self.gui_state = None | |
| def preprocess(self, input_img_path: str) -> dict: | |
| # Simply hardcode these such that aspect ratio is always kept and | |
| # shorter side is resized to 576. This is only to make GUI option fewer | |
| # though, changing it still works. | |
| shorter: int = 576 | |
| # Has to be 64 multiple for the network. | |
| shorter = round(shorter / 64) * 64 | |
| # Assume `Basic` demo mode: just hardcode the camera parameters and ignore points. | |
| input_imgs = torch.as_tensor( | |
| iio.imread(input_img_path) / 255.0, dtype=torch.float32 | |
| )[None, ..., :3] | |
| input_imgs = transform_img_and_K( | |
| input_imgs.permute(0, 3, 1, 2), | |
| shorter, | |
| K=None, | |
| size_stride=64, | |
| )[0].permute(0, 2, 3, 1) | |
| input_Ks = get_default_intrinsics( | |
| aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1] | |
| ) | |
| input_c2ws = torch.eye(4)[None] | |
| # Simulate a small time interval such that gradio can update | |
| # propgress properly. | |
| time.sleep(0.1) | |
| return { | |
| "input_imgs": input_imgs, | |
| "input_Ks": input_Ks, | |
| "input_c2ws": input_c2ws, | |
| "input_wh": (input_imgs.shape[2], input_imgs.shape[1]), | |
| "points": [np.zeros((0, 3))], | |
| "point_colors": [np.zeros((0, 3))], | |
| "scene_scale": 1.0, | |
| } | |
| def render( | |
| self, | |
| preprocessed: dict, | |
| seed: int, | |
| chunk_strategy: str, | |
| cfg: float, | |
| preset_traj: Literal[ | |
| "orbit", | |
| "spiral", | |
| "lemniscate", | |
| "zoom-in", | |
| "zoom-out", | |
| "dolly zoom-in", | |
| "dolly zoom-out", | |
| "move-forward", | |
| "move-backward", | |
| "move-up", | |
| "move-down", | |
| "move-left", | |
| "move-right", | |
| ], | |
| num_frames: int, | |
| zoom_factor: float | None, | |
| camera_scale: float, | |
| output_dir: str, | |
| ) -> str: | |
| # Generate a unique render name based on the input image filename and preset_traj | |
| input_img_name = osp.splitext(osp.basename(preprocessed["input_img_path"]))[0] | |
| render_name = f"{input_img_name}_{preset_traj}" | |
| render_dir = osp.join(output_dir, render_name) | |
| input_imgs, input_Ks, input_c2ws, (W, H) = ( | |
| preprocessed["input_imgs"], | |
| preprocessed["input_Ks"], | |
| preprocessed["input_c2ws"], | |
| preprocessed["input_wh"], | |
| ) | |
| num_inputs = len(input_imgs) | |
| assert num_inputs == 1 | |
| input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype) | |
| target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset( | |
| preprocessed, preset_traj, num_frames, zoom_factor | |
| ) | |
| all_c2ws = torch.cat([input_c2ws, target_c2ws], 0) | |
| all_Ks = ( | |
| torch.cat([input_Ks, target_Ks], 0) | |
| * input_Ks.new_tensor([W, H, 1])[:, None] | |
| ) | |
| num_targets = len(target_c2ws) | |
| input_indices = list(range(num_inputs)) | |
| target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist() | |
| # Get anchor cameras. | |
| T = VERSION_DICT["T"] | |
| version_dict = copy.deepcopy(VERSION_DICT) | |
| num_anchors = infer_prior_stats( | |
| T, | |
| num_inputs, | |
| num_total_frames=num_targets, | |
| version_dict=version_dict, | |
| ) | |
| # infer_prior_stats modifies T in-place. | |
| T = version_dict["T"] | |
| assert isinstance(num_anchors, int) | |
| anchor_indices = np.linspace( | |
| num_inputs, | |
| num_inputs + num_targets - 1, | |
| num_anchors, | |
| ).tolist() | |
| anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]] | |
| anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]] | |
| # Create image conditioning. | |
| all_imgs_np = ( | |
| F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy() | |
| * 255.0 | |
| ).astype(np.uint8) | |
| image_cond = { | |
| "img": all_imgs_np, | |
| "input_indices": input_indices, | |
| "prior_indices": anchor_indices, | |
| } | |
| # Create camera conditioning (K is unnormalized). | |
| camera_cond = { | |
| "c2w": all_c2ws, | |
| "K": all_Ks, | |
| "input_indices": list(range(num_inputs + num_targets)), | |
| } | |
| # Run rendering. | |
| num_steps = 50 | |
| options_ori = VERSION_DICT["options"] | |
| options = copy.deepcopy(options_ori) | |
| options["chunk_strategy"] = chunk_strategy | |
| options["video_save_fps"] = 30.0 | |
| options["beta_linear_start"] = 5e-6 | |
| options["log_snr_shift"] = 2.4 | |
| options["guider_types"] = [1, 2] | |
| options["cfg"] = [ | |
| float(cfg), | |
| 3.0 if num_inputs >= 9 else 2.0, | |
| ] # We define semi-dense-view regime to have 9 input views. | |
| options["camera_scale"] = camera_scale | |
| options["num_steps"] = num_steps | |
| options["cfg_min"] = 1.2 | |
| options["encoding_t"] = 1 | |
| options["decoding_t"] = 1 | |
| task = "img2trajvid" | |
| # Get number of first pass chunks. | |
| T_first_pass = T[0] if isinstance(T, (list, tuple)) else T | |
| chunk_strategy_first_pass = options.get( | |
| "chunk_strategy_first_pass", "gt-nearest" | |
| ) | |
| num_chunks_0 = len( | |
| chunk_input_and_test( | |
| T_first_pass, | |
| input_c2ws, | |
| anchor_c2ws, | |
| input_indices, | |
| image_cond["prior_indices"], | |
| options={**options, "sampler_verbose": False}, | |
| task=task, | |
| chunk_strategy=chunk_strategy_first_pass, | |
| gt_input_inds=list(range(input_c2ws.shape[0])), | |
| )[1] | |
| ) | |
| # Get number of second pass chunks. | |
| anchor_argsort = np.argsort(input_indices + anchor_indices).tolist() | |
| anchor_indices = np.array(input_indices + anchor_indices)[ | |
| anchor_argsort | |
| ].tolist() | |
| gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])] | |
| anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[ | |
| anchor_argsort | |
| ] | |
| T_second_pass = T[1] if isinstance(T, (list, tuple)) else T | |
| chunk_strategy = options.get("chunk_strategy", "nearest") | |
| num_chunks_1 = len( | |
| chunk_input_and_test( | |
| T_second_pass, | |
| anchor_c2ws_second_pass, | |
| target_c2ws, | |
| anchor_indices, | |
| target_indices, | |
| options={**options, "sampler_verbose": False}, | |
| task=task, | |
| chunk_strategy=chunk_strategy, | |
| gt_input_inds=gt_input_inds, | |
| )[1] | |
| ) | |
| video_path_generator = run_one_scene( | |
| task=task, | |
| version_dict={ | |
| "H": H, | |
| "W": W, | |
| "T": T, | |
| "C": VERSION_DICT["C"], | |
| "f": VERSION_DICT["f"], | |
| "options": options, | |
| }, | |
| model=MODEL, | |
| ae=AE, | |
| conditioner=CONDITIONER, | |
| denoiser=DENOISER, | |
| image_cond=image_cond, | |
| camera_cond=camera_cond, | |
| save_path=render_dir, | |
| use_traj_prior=True, | |
| traj_prior_c2ws=anchor_c2ws, | |
| traj_prior_Ks=anchor_Ks, | |
| seed=seed, | |
| gradio=True, | |
| ) | |
| for video_path in video_path_generator: | |
| return video_path | |
| return "" | |
| def get_target_c2ws_and_Ks_from_preset( | |
| self, | |
| preprocessed: dict, | |
| preset_traj: Literal[ | |
| "orbit", | |
| "spiral", | |
| "lemniscate", | |
| "zoom-in", | |
| "zoom-out", | |
| "dolly zoom-in", | |
| "dolly zoom-out", | |
| "move-forward", | |
| "move-backward", | |
| "move-up", | |
| "move-down", | |
| "move-left", | |
| "move-right", | |
| ], | |
| num_frames: int, | |
| zoom_factor: float | None, | |
| ): | |
| img_wh = preprocessed["input_wh"] | |
| start_c2w = preprocessed["input_c2ws"][0] | |
| start_w2c = torch.linalg.inv(start_c2w) | |
| look_at = torch.tensor([0, 0, 10]) | |
| start_fov = DEFAULT_FOV_RAD | |
| target_c2ws, target_fovs = get_preset_pose_fov( | |
| preset_traj, | |
| num_frames, | |
| start_w2c, | |
| look_at, | |
| -start_c2w[:3, 1], | |
| start_fov, | |
| spiral_radii=[1.0, 1.0, 0.5], | |
| zoom_factor=zoom_factor, | |
| ) | |
| target_c2ws = torch.as_tensor(target_c2ws) | |
| target_fovs = torch.as_tensor(target_fovs) | |
| target_Ks = get_default_intrinsics( | |
| target_fovs, # type: ignore | |
| aspect_ratio=img_wh[0] / img_wh[1], | |
| ) | |
| return target_c2ws, target_Ks | |
| def main( | |
| input_img_path: str, | |
| preset_traj: List[Literal[ | |
| "orbit", | |
| "spiral", | |
| "lemniscate", | |
| "zoom-in", | |
| "zoom-out", | |
| "dolly zoom-in", | |
| "dolly zoom-out", | |
| "move-forward", | |
| "move-backward", | |
| "move-up", | |
| "move-down", | |
| "move-left", | |
| "move-right", | |
| ]], | |
| num_frames: int = 80, | |
| zoom_factor: float | None = None, | |
| seed: int = 23, | |
| chunk_strategy: str = "interp", | |
| cfg: float = 4.0, | |
| camera_scale: float = 2.0, | |
| output_dir: str = WORK_DIR, | |
| ): | |
| renderer = SevaRenderer() | |
| preprocessed = renderer.preprocess(input_img_path) | |
| preprocessed["input_img_path"] = input_img_path # Add input_img_path to preprocessed dict | |
| for traj in preset_traj: | |
| video_path = renderer.render( | |
| preprocessed, | |
| seed, | |
| chunk_strategy, | |
| cfg, | |
| traj, | |
| num_frames, | |
| zoom_factor, | |
| camera_scale, | |
| output_dir, | |
| ) | |
| print(f"Rendered video saved to: {video_path}") | |
| if __name__ == "__main__": | |
| tyro.cli(main) | |