Spaces:
Runtime error
Runtime error
| import collections | |
| import json | |
| import math | |
| import os | |
| import re | |
| import threading | |
| from typing import List, Literal, Optional, Tuple, Union | |
| import gradio as gr | |
| from colorama import Fore, Style, init | |
| init(autoreset=True) | |
| import imageio.v3 as iio | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| from einops import repeat | |
| from PIL import Image | |
| from tqdm.auto import tqdm | |
| from seva.geometry import get_camera_dist, get_plucker_coordinates, to_hom_pose | |
| from seva.sampling import ( | |
| EulerEDMSampler, | |
| MultiviewCFG, | |
| MultiviewTemporalCFG, | |
| VanillaCFG, | |
| ) | |
| from seva.utils import seed_everything | |
| try: | |
| # Check if version string contains 'dev' or 'nightly' | |
| version = torch.__version__ | |
| IS_TORCH_NIGHTLY = "dev" in version | |
| if IS_TORCH_NIGHTLY: | |
| torch._dynamo.config.cache_size_limit = 128 # type: ignore[assignment] | |
| torch._dynamo.config.accumulated_cache_size_limit = 1024 # type: ignore[assignment] | |
| torch._dynamo.config.force_parameter_static_shapes = False # type: ignore[assignment] | |
| except Exception: | |
| IS_TORCH_NIGHTLY = False | |
| def pad_indices( | |
| input_indices: List[int], | |
| test_indices: List[int], | |
| T: int, | |
| padding_mode: Literal["first", "last", "none"] = "last", | |
| ): | |
| assert padding_mode in ["last", "none"], "`first` padding is not supported yet." | |
| if padding_mode == "last": | |
| padded_indices = [ | |
| i for i in range(T) if i not in (input_indices + test_indices) | |
| ] | |
| else: | |
| padded_indices = [] | |
| input_selects = list(range(len(input_indices))) | |
| test_selects = list(range(len(test_indices))) | |
| if max(input_indices) > max(test_indices): | |
| # last elem from input | |
| input_selects += [input_selects[-1]] * len(padded_indices) | |
| input_indices = input_indices + padded_indices | |
| sorted_inds = np.argsort(input_indices) | |
| input_indices = [input_indices[ind] for ind in sorted_inds] | |
| input_selects = [input_selects[ind] for ind in sorted_inds] | |
| else: | |
| # last elem from test | |
| test_selects += [test_selects[-1]] * len(padded_indices) | |
| test_indices = test_indices + padded_indices | |
| sorted_inds = np.argsort(test_indices) | |
| test_indices = [test_indices[ind] for ind in sorted_inds] | |
| test_selects = [test_selects[ind] for ind in sorted_inds] | |
| if padding_mode == "last": | |
| input_maps = np.array([-1] * T) | |
| test_maps = np.array([-1] * T) | |
| else: | |
| input_maps = np.array([-1] * (len(input_indices) + len(test_indices))) | |
| test_maps = np.array([-1] * (len(input_indices) + len(test_indices))) | |
| input_maps[input_indices] = input_selects | |
| test_maps[test_indices] = test_selects | |
| return input_indices, test_indices, input_maps, test_maps | |
| def assemble( | |
| input, | |
| test, | |
| input_maps, | |
| test_maps, | |
| ): | |
| T = len(input_maps) | |
| assembled = torch.zeros_like(test[-1:]).repeat_interleave(T, dim=0) | |
| assembled[input_maps != -1] = input[input_maps[input_maps != -1]] | |
| assembled[test_maps != -1] = test[test_maps[test_maps != -1]] | |
| assert np.logical_xor(input_maps != -1, test_maps != -1).all() | |
| return assembled | |
| def get_resizing_factor( | |
| target_shape: Tuple[int, int], # H, W | |
| current_shape: Tuple[int, int], # H, W | |
| cover_target: bool = True, | |
| # If True, the output shape will fully cover the target shape. | |
| # If No, the target shape will fully cover the output shape. | |
| ) -> float: | |
| r_bound = target_shape[1] / target_shape[0] | |
| aspect_r = current_shape[1] / current_shape[0] | |
| if r_bound >= 1.0: | |
| if cover_target: | |
| if aspect_r >= r_bound: | |
| factor = min(target_shape) / min(current_shape) | |
| elif aspect_r < 1.0: | |
| factor = max(target_shape) / min(current_shape) | |
| else: | |
| factor = max(target_shape) / max(current_shape) | |
| else: | |
| if aspect_r >= r_bound: | |
| factor = max(target_shape) / max(current_shape) | |
| elif aspect_r < 1.0: | |
| factor = min(target_shape) / max(current_shape) | |
| else: | |
| factor = min(target_shape) / min(current_shape) | |
| else: | |
| if cover_target: | |
| if aspect_r <= r_bound: | |
| factor = min(target_shape) / min(current_shape) | |
| elif aspect_r > 1.0: | |
| factor = max(target_shape) / min(current_shape) | |
| else: | |
| factor = max(target_shape) / max(current_shape) | |
| else: | |
| if aspect_r <= r_bound: | |
| factor = max(target_shape) / max(current_shape) | |
| elif aspect_r > 1.0: | |
| factor = min(target_shape) / max(current_shape) | |
| else: | |
| factor = min(target_shape) / min(current_shape) | |
| return factor | |
| def get_unique_embedder_keys_from_conditioner(conditioner): | |
| keys = [x.input_key for x in conditioner.embedders if x.input_key is not None] | |
| keys = [item for sublist in keys for item in sublist] # Flatten list | |
| return set(keys) | |
| def get_wh_with_fixed_shortest_side(w, h, size): | |
| # size is smaller or equal to zero, we return original w h | |
| if size is None or size <= 0: | |
| return w, h | |
| if w < h: | |
| new_w = size | |
| new_h = int(size * h / w) | |
| else: | |
| new_h = size | |
| new_w = int(size * w / h) | |
| return new_w, new_h | |
| def load_img_and_K( | |
| image_path_or_size: Union[str, torch.Size], | |
| size: Optional[Union[int, Tuple[int, int]]], | |
| scale: float = 1.0, | |
| center: Tuple[float, float] = (0.5, 0.5), | |
| K: torch.Tensor | None = None, | |
| size_stride: int = 1, | |
| center_crop: bool = False, | |
| image_as_tensor: bool = True, | |
| context_rgb: np.ndarray | None = None, | |
| device: str = "cuda", | |
| ): | |
| if isinstance(image_path_or_size, torch.Size): | |
| image = Image.new("RGBA", image_path_or_size[::-1]) | |
| else: | |
| image = Image.open(image_path_or_size).convert("RGBA") | |
| w, h = image.size | |
| if size is None: | |
| size = (w, h) | |
| image = np.array(image).astype(np.float32) / 255 | |
| if image.shape[-1] == 4: | |
| rgb, alpha = image[:, :, :3], image[:, :, 3:] | |
| if context_rgb is not None: | |
| image = rgb * alpha + context_rgb * (1 - alpha) | |
| else: | |
| image = rgb * alpha + (1 - alpha) | |
| image = image.transpose(2, 0, 1) | |
| image = torch.from_numpy(image).to(dtype=torch.float32) | |
| image = image.unsqueeze(0) | |
| if isinstance(size, (tuple, list)): | |
| # => if size is a tuple or list, we first rescale to fully cover the `size` | |
| # area and then crop the `size` area from the rescale image | |
| W, H = size | |
| else: | |
| # => if size is int, we rescale the image to fit the shortest side to size | |
| # => if size is None, no rescaling is applied | |
| W, H = get_wh_with_fixed_shortest_side(w, h, size) | |
| W, H = ( | |
| math.floor(W / size_stride + 0.5) * size_stride, | |
| math.floor(H / size_stride + 0.5) * size_stride, | |
| ) | |
| rfs = get_resizing_factor((math.floor(H * scale), math.floor(W * scale)), (h, w)) | |
| resize_size = rh, rw = [int(np.ceil(rfs * s)) for s in (h, w)] | |
| image = torch.nn.functional.interpolate( | |
| image, resize_size, mode="area", antialias=False | |
| ) | |
| if scale < 1.0: | |
| pw = math.ceil((W - resize_size[1]) * 0.5) | |
| ph = math.ceil((H - resize_size[0]) * 0.5) | |
| image = F.pad(image, (pw, pw, ph, ph), "constant", 1.0) | |
| cy_center = int(center[1] * image.shape[-2]) | |
| cx_center = int(center[0] * image.shape[-1]) | |
| if center_crop: | |
| side = min(H, W) | |
| ct = max(0, cy_center - side // 2) | |
| cl = max(0, cx_center - side // 2) | |
| ct = min(ct, image.shape[-2] - side) | |
| cl = min(cl, image.shape[-1] - side) | |
| image = TF.crop(image, top=ct, left=cl, height=side, width=side) | |
| else: | |
| ct = max(0, cy_center - H // 2) | |
| cl = max(0, cx_center - W // 2) | |
| ct = min(ct, image.shape[-2] - H) | |
| cl = min(cl, image.shape[-1] - W) | |
| image = TF.crop(image, top=ct, left=cl, height=H, width=W) | |
| if K is not None: | |
| K = K.clone() | |
| if torch.all(K[:2, -1] >= 0) and torch.all(K[:2, -1] <= 1): | |
| K[:2] *= K.new_tensor([rw, rh])[:, None] # normalized K | |
| else: | |
| K[:2] *= K.new_tensor([rw / w, rh / h])[:, None] # unnormalized K | |
| K[:2, 2] -= K.new_tensor([cl, ct]) | |
| if image_as_tensor: | |
| # tensor of shape (1, 3, H, W) with values ranging from (-1, 1) | |
| image = image.to(device) * 2.0 - 1.0 | |
| else: | |
| # PIL Image with values ranging from (0, 255) | |
| image = image.permute(0, 2, 3, 1).numpy()[0] | |
| image = Image.fromarray((image * 255).astype(np.uint8)) | |
| return image, K | |
| def transform_img_and_K( | |
| image: torch.Tensor, | |
| size: Union[int, Tuple[int, int]], | |
| scale: float = 1.0, | |
| center: Tuple[float, float] = (0.5, 0.5), | |
| K: torch.Tensor | None = None, | |
| size_stride: int = 1, | |
| mode: str = "crop", | |
| ): | |
| assert mode in [ | |
| "crop", | |
| "pad", | |
| "stretch", | |
| ], f"mode should be one of ['crop', 'pad', 'stretch'], got {mode}" | |
| h, w = image.shape[-2:] | |
| if isinstance(size, (tuple, list)): | |
| # => if size is a tuple or list, we first rescale to fully cover the `size` | |
| # area and then crop the `size` area from the rescale image | |
| W, H = size | |
| else: | |
| # => if size is int, we rescale the image to fit the shortest side to size | |
| # => if size is None, no rescaling is applied | |
| W, H = get_wh_with_fixed_shortest_side(w, h, size) | |
| W, H = ( | |
| math.floor(W / size_stride + 0.5) * size_stride, | |
| math.floor(H / size_stride + 0.5) * size_stride, | |
| ) | |
| if mode == "stretch": | |
| rh, rw = H, W | |
| else: | |
| rfs = get_resizing_factor( | |
| (H, W), | |
| (h, w), | |
| cover_target=mode != "pad", | |
| ) | |
| (rh, rw) = [int(np.ceil(rfs * s)) for s in (h, w)] | |
| rh, rw = int(rh / scale), int(rw / scale) | |
| image = torch.nn.functional.interpolate( | |
| image, (rh, rw), mode="area", antialias=False | |
| ) | |
| cy_center = int(center[1] * image.shape[-2]) | |
| cx_center = int(center[0] * image.shape[-1]) | |
| if mode != "pad": | |
| ct = max(0, cy_center - H // 2) | |
| cl = max(0, cx_center - W // 2) | |
| ct = min(ct, image.shape[-2] - H) | |
| cl = min(cl, image.shape[-1] - W) | |
| image = TF.crop(image, top=ct, left=cl, height=H, width=W) | |
| pl, pt = 0, 0 | |
| else: | |
| pt = max(0, H // 2 - cy_center) | |
| pl = max(0, W // 2 - cx_center) | |
| pb = max(0, H - pt - image.shape[-2]) | |
| pr = max(0, W - pl - image.shape[-1]) | |
| image = TF.pad( | |
| image, | |
| [pl, pt, pr, pb], | |
| ) | |
| cl, ct = 0, 0 | |
| if K is not None: | |
| K = K.clone() | |
| # K[:, :2, 2] += K.new_tensor([pl, pt]) | |
| if torch.all(K[:, :2, -1] >= 0) and torch.all(K[:, :2, -1] <= 1): | |
| K[:, :2] *= K.new_tensor([rw, rh])[None, :, None] # normalized K | |
| else: | |
| K[:, :2] *= K.new_tensor([rw / w, rh / h])[None, :, None] # unnormalized K | |
| K[:, :2, 2] += K.new_tensor([pl - cl, pt - ct]) | |
| return image, K | |
| lowvram_mode = False | |
| def set_lowvram_mode(mode): | |
| global lowvram_mode | |
| lowvram_mode = mode | |
| def load_model(model, device: str = "cuda"): | |
| model.to(device) | |
| def unload_model(model): | |
| global lowvram_mode | |
| if lowvram_mode: | |
| model.cpu() | |
| torch.cuda.empty_cache() | |
| def infer_prior_stats( | |
| T, | |
| num_input_frames, | |
| num_total_frames, | |
| version_dict, | |
| ): | |
| options = version_dict["options"] | |
| chunk_strategy = options.get("chunk_strategy", "nearest") | |
| T_first_pass = T[0] if isinstance(T, (list, tuple)) else T | |
| T_second_pass = T[1] if isinstance(T, (list, tuple)) else T | |
| # get traj_prior_c2ws for 2-pass sampling | |
| if chunk_strategy.startswith("interp"): | |
| # Start and end have alreay taken up two slots | |
| # +1 means we need X + 1 prior frames to bound X times forwards for all test frames | |
| # Tuning up `num_prior_frames_ratio` is helpful when you observe sudden jump in the | |
| # generated frames due to insufficient prior frames. This option is effective for | |
| # complicated trajectory and when `interp` strategy is used (usually semi-dense-view | |
| # regime). Recommended range is [1.0 (default), 1.5]. | |
| if num_input_frames >= options.get("num_input_semi_dense", 9): | |
| num_prior_frames = ( | |
| math.ceil( | |
| num_total_frames | |
| / (T_second_pass - 2) | |
| * options.get("num_prior_frames_ratio", 1.0) | |
| ) | |
| + 1 | |
| ) | |
| if num_prior_frames + num_input_frames < T_first_pass: | |
| num_prior_frames = T_first_pass - num_input_frames | |
| num_prior_frames = max( | |
| num_prior_frames, | |
| options.get("num_prior_frames", 0), | |
| ) | |
| T_first_pass = num_prior_frames + num_input_frames | |
| if "gt" in chunk_strategy: | |
| T_second_pass = T_second_pass + num_input_frames | |
| # Dynamically update context window length. | |
| version_dict["T"] = [T_first_pass, T_second_pass] | |
| else: | |
| num_prior_frames = ( | |
| math.ceil( | |
| num_total_frames | |
| / ( | |
| T_second_pass | |
| - 2 | |
| - (num_input_frames if "gt" in chunk_strategy else 0) | |
| ) | |
| * options.get("num_prior_frames_ratio", 1.0) | |
| ) | |
| + 1 | |
| ) | |
| if num_prior_frames + num_input_frames < T_first_pass: | |
| num_prior_frames = T_first_pass - num_input_frames | |
| num_prior_frames = max( | |
| num_prior_frames, | |
| options.get("num_prior_frames", 0), | |
| ) | |
| else: | |
| num_prior_frames = max( | |
| T_first_pass - num_input_frames, | |
| options.get("num_prior_frames", 0), | |
| ) | |
| if num_input_frames >= options.get("num_input_semi_dense", 9): | |
| T_first_pass = num_prior_frames + num_input_frames | |
| # Dynamically update context window length. | |
| version_dict["T"] = [T_first_pass, T_second_pass] | |
| return num_prior_frames | |
| def infer_prior_inds( | |
| c2ws, | |
| num_prior_frames, | |
| input_frame_indices, | |
| options, | |
| ): | |
| chunk_strategy = options.get("chunk_strategy", "nearest") | |
| if chunk_strategy.startswith("interp"): | |
| prior_frame_indices = np.array( | |
| [i for i in range(c2ws.shape[0]) if i not in input_frame_indices] | |
| ) | |
| prior_frame_indices = prior_frame_indices[ | |
| np.ceil( | |
| np.linspace( | |
| 0, prior_frame_indices.shape[0] - 1, num_prior_frames, endpoint=True | |
| ) | |
| ).astype(int) | |
| ] # having a ceil here is actually safer for corner case | |
| else: | |
| prior_frame_indices = [] | |
| while len(prior_frame_indices) < num_prior_frames: | |
| closest_distance = np.abs( | |
| np.arange(c2ws.shape[0])[None] | |
| - np.concatenate( | |
| [np.array(input_frame_indices), np.array(prior_frame_indices)] | |
| )[:, None] | |
| ).min(0) | |
| prior_frame_indices.append(np.argsort(closest_distance)[-1]) | |
| return np.sort(prior_frame_indices) | |
| def compute_relative_inds( | |
| source_inds, | |
| target_inds, | |
| ): | |
| assert len(source_inds) > 2 | |
| # compute relative indices of target_inds within source_inds | |
| relative_inds = [] | |
| for ind in target_inds: | |
| if ind in source_inds: | |
| relative_ind = int(np.where(source_inds == ind)[0][0]) | |
| elif ind < source_inds[0]: | |
| # extrapolate | |
| relative_ind = -((source_inds[0] - ind) / (source_inds[1] - source_inds[0])) | |
| elif ind > source_inds[-1]: | |
| # extrapolate | |
| relative_ind = len(source_inds) + ( | |
| (ind - source_inds[-1]) / (source_inds[-1] - source_inds[-2]) | |
| ) | |
| else: | |
| # interpolate | |
| lower_inds = source_inds[source_inds < ind] | |
| upper_inds = source_inds[source_inds > ind] | |
| if len(lower_inds) > 0 and len(upper_inds) > 0: | |
| lower_ind = lower_inds[-1] | |
| upper_ind = upper_inds[0] | |
| relative_lower_ind = int(np.where(source_inds == lower_ind)[0][0]) | |
| relative_upper_ind = int(np.where(source_inds == upper_ind)[0][0]) | |
| relative_ind = relative_lower_ind + (ind - lower_ind) / ( | |
| upper_ind - lower_ind | |
| ) * (relative_upper_ind - relative_lower_ind) | |
| else: | |
| # Out of range | |
| relative_inds.append(float("nan")) # Or some other placeholder | |
| relative_inds.append(relative_ind) | |
| return relative_inds | |
| def find_nearest_source_inds( | |
| source_c2ws, | |
| target_c2ws, | |
| nearest_num=1, | |
| mode="translation", | |
| ): | |
| dists = get_camera_dist(source_c2ws, target_c2ws, mode=mode).cpu().numpy() | |
| sorted_inds = np.argsort(dists, axis=0).T | |
| return sorted_inds[:, :nearest_num] | |
| def chunk_input_and_test( | |
| T, | |
| input_c2ws, | |
| test_c2ws, | |
| input_ords, # orders | |
| test_ords, # orders | |
| options, | |
| task: str = "img2img", | |
| chunk_strategy: str = "gt", | |
| gt_input_inds: list = [], | |
| ): | |
| M, N = input_c2ws.shape[0], test_c2ws.shape[0] | |
| chunks = [] | |
| if chunk_strategy.startswith("gt"): | |
| assert len(gt_input_inds) < T, ( | |
| f"Number of gt input frames {len(gt_input_inds)} should be " | |
| f"less than {T} when `gt` chunking strategy is used." | |
| ) | |
| assert ( | |
| list(range(M)) == gt_input_inds | |
| ), "All input_c2ws should be gt when `gt` chunking strategy is used." | |
| # LEGACY CHUNKING STRATEGY | |
| # num_test_per_chunk = T - len(gt_input_inds) | |
| # test_inds_per_chunk = [i for i in range(T) if i not in gt_input_inds] | |
| # for i in range(0, test_c2ws.shape[0], num_test_per_chunk): | |
| # chunk = ["NULL"] * T | |
| # for j, k in enumerate(gt_input_inds): | |
| # chunk[k] = f"!{j:03d}" | |
| # for j, k in enumerate( | |
| # test_inds_per_chunk[: test_c2ws[i : i + num_test_per_chunk].shape[0]] | |
| # ): | |
| # chunk[k] = f">{i + j:03d}" | |
| # chunks.append(chunk) | |
| num_test_seen = 0 | |
| while num_test_seen < N: | |
| chunk = [f"!{i:03d}" for i in gt_input_inds] | |
| if chunk_strategy != "gt" and num_test_seen > 0: | |
| pseudo_num_ratio = options.get("pseudo_num_ratio", 0.33) | |
| if (N - num_test_seen) >= math.floor( | |
| (T - len(gt_input_inds)) * pseudo_num_ratio | |
| ): | |
| pseudo_num = math.ceil((T - len(gt_input_inds)) * pseudo_num_ratio) | |
| else: | |
| pseudo_num = (T - len(gt_input_inds)) - (N - num_test_seen) | |
| pseudo_num = min(pseudo_num, options.get("pseudo_num_max", 10000)) | |
| if "ltr" in chunk_strategy: | |
| chunk.extend( | |
| [ | |
| f"!{i + len(gt_input_inds):03d}" | |
| for i in range(num_test_seen - pseudo_num, num_test_seen) | |
| ] | |
| ) | |
| elif "nearest" in chunk_strategy: | |
| source_inds = np.concatenate( | |
| [ | |
| find_nearest_source_inds( | |
| test_c2ws[:num_test_seen], | |
| test_c2ws[num_test_seen:], | |
| nearest_num=1, # pseudo_num, | |
| mode="rotation", | |
| ), | |
| find_nearest_source_inds( | |
| test_c2ws[:num_test_seen], | |
| test_c2ws[num_test_seen:], | |
| nearest_num=1, # pseudo_num, | |
| mode="translation", | |
| ), | |
| ], | |
| axis=1, | |
| ) | |
| ####### [HACK ALERT] keep running until pseudo num is stablized ######## | |
| temp_pseudo_num = pseudo_num | |
| while True: | |
| nearest_source_inds = np.concatenate( | |
| [ | |
| np.sort( | |
| [ | |
| ind | |
| for (ind, _) in collections.Counter( | |
| [ | |
| item | |
| for item in source_inds[ | |
| : T | |
| - len(gt_input_inds) | |
| - temp_pseudo_num | |
| ] | |
| .flatten() | |
| .tolist() | |
| if item | |
| != ( | |
| num_test_seen - 1 | |
| ) # exclude the last one here | |
| ] | |
| ).most_common(pseudo_num - 1) | |
| ], | |
| ).astype(int), | |
| [num_test_seen - 1], # always keep the last one | |
| ] | |
| ) | |
| if len(nearest_source_inds) >= temp_pseudo_num: | |
| break # stablized | |
| else: | |
| temp_pseudo_num = len(nearest_source_inds) | |
| pseudo_num = len(nearest_source_inds) | |
| ######################################################################## | |
| chunk.extend( | |
| [f"!{i + len(gt_input_inds):03d}" for i in nearest_source_inds] | |
| ) | |
| else: | |
| raise NotImplementedError( | |
| f"Chunking strategy {chunk_strategy} for the first pass is not implemented." | |
| ) | |
| chunk.extend( | |
| [ | |
| f">{i:03d}" | |
| for i in range( | |
| num_test_seen, | |
| min(num_test_seen + T - len(gt_input_inds) - pseudo_num, N), | |
| ) | |
| ] | |
| ) | |
| else: | |
| chunk.extend( | |
| [ | |
| f">{i:03d}" | |
| for i in range( | |
| num_test_seen, | |
| min(num_test_seen + T - len(gt_input_inds), N), | |
| ) | |
| ] | |
| ) | |
| num_test_seen += sum([1 for c in chunk if c.startswith(">")]) | |
| if len(chunk) < T: | |
| chunk.extend(["NULL"] * (T - len(chunk))) | |
| chunks.append(chunk) | |
| elif chunk_strategy.startswith("nearest"): | |
| input_imgs = np.array([f"!{i:03d}" for i in range(M)]) | |
| test_imgs = np.array([f">{i:03d}" for i in range(N)]) | |
| match = re.match(r"^nearest-(\d+)$", chunk_strategy) | |
| if match: | |
| nearest_num = int(match.group(1)) | |
| assert ( | |
| nearest_num < T | |
| ), f"Nearest number of {nearest_num} should be less than {T}." | |
| source_inds = find_nearest_source_inds( | |
| input_c2ws, | |
| test_c2ws, | |
| nearest_num=nearest_num, | |
| mode="translation", # during the second pass, consider translation only is enough | |
| ) | |
| for i in range(0, N, T - nearest_num): | |
| nearest_source_inds = np.sort( | |
| [ | |
| ind | |
| for (ind, _) in collections.Counter( | |
| source_inds[i : i + T - nearest_num].flatten().tolist() | |
| ).most_common(nearest_num) | |
| ] | |
| ) | |
| chunk = ( | |
| input_imgs[nearest_source_inds].tolist() | |
| + test_imgs[i : i + T - nearest_num].tolist() | |
| ) | |
| chunks.append(chunk + ["NULL"] * (T - len(chunk))) | |
| else: | |
| # do not always condition on gt cond frames | |
| if "gt" not in chunk_strategy: | |
| gt_input_inds = [] | |
| source_inds = find_nearest_source_inds( | |
| input_c2ws, | |
| test_c2ws, | |
| nearest_num=1, | |
| mode="translation", # during the second pass, consider translation only is enough | |
| )[:, 0] | |
| test_inds_per_input = {} | |
| for test_idx, input_idx in enumerate(source_inds): | |
| if input_idx not in test_inds_per_input: | |
| test_inds_per_input[input_idx] = [] | |
| test_inds_per_input[input_idx].append(test_idx) | |
| num_test_seen = 0 | |
| chunk = input_imgs[gt_input_inds].tolist() | |
| candidate_input_inds = sorted(list(test_inds_per_input.keys())) | |
| while num_test_seen < N: | |
| input_idx = candidate_input_inds[0] | |
| test_inds = test_inds_per_input[input_idx] | |
| input_is_cond = input_idx in gt_input_inds | |
| prefix_inds = [] if input_is_cond else [input_idx] | |
| if len(chunk) == T - len(prefix_inds) or not candidate_input_inds: | |
| if chunk: | |
| chunk += ["NULL"] * (T - len(chunk)) | |
| chunks.append(chunk) | |
| chunk = input_imgs[gt_input_inds].tolist() | |
| if num_test_seen >= N: | |
| break | |
| continue | |
| candidate_chunk = ( | |
| input_imgs[prefix_inds].tolist() + test_imgs[test_inds].tolist() | |
| ) | |
| space_left = T - len(chunk) | |
| if len(candidate_chunk) <= space_left: | |
| chunk.extend(candidate_chunk) | |
| num_test_seen += len(test_inds) | |
| candidate_input_inds.pop(0) | |
| else: | |
| chunk.extend(candidate_chunk[:space_left]) | |
| num_input_idx = 0 if input_is_cond else 1 | |
| num_test_seen += space_left - num_input_idx | |
| test_inds_per_input[input_idx] = test_inds[ | |
| space_left - num_input_idx : | |
| ] | |
| if len(chunk) == T: | |
| chunks.append(chunk) | |
| chunk = input_imgs[gt_input_inds].tolist() | |
| if chunk and chunk != input_imgs[gt_input_inds].tolist(): | |
| chunks.append(chunk + ["NULL"] * (T - len(chunk))) | |
| elif chunk_strategy.startswith("interp"): | |
| # `interp` chunk requires ordering info | |
| assert input_ords is not None and test_ords is not None, ( | |
| "When using `interp` chunking strategy, ordering of input " | |
| "and test frames should be provided." | |
| ) | |
| # if chunk_strategy is `interp*`` and task is `img2trajvid*`, we will not | |
| # use input views since their order info within target views is unknown | |
| if "img2trajvid" in task: | |
| assert ( | |
| list(range(len(gt_input_inds))) == gt_input_inds | |
| ), "`img2trajvid` task should put `gt_input_inds` in start." | |
| input_c2ws = input_c2ws[ | |
| [ind for ind in range(M) if ind not in gt_input_inds] | |
| ] | |
| input_ords = [ | |
| input_ords[ind] for ind in range(M) if ind not in gt_input_inds | |
| ] | |
| M = input_c2ws.shape[0] | |
| input_ords = [0] + input_ords # this is a hack accounting for test views | |
| # before the first input view | |
| input_ords[-1] += 0.01 # this is a hack ensuring last test stop is included | |
| # in the last forward when input_ords[-1] == test_ords[-1] | |
| input_ords = np.array(input_ords)[:, None] | |
| input_ords_ = np.concatenate([input_ords[1:], np.full((1, 1), np.inf)]) | |
| test_ords = np.array(test_ords)[None] | |
| in_stop_ranges = np.logical_and( | |
| np.repeat(input_ords, N, axis=1) <= np.repeat(test_ords, M + 1, axis=0), | |
| np.repeat(input_ords_, N, axis=1) > np.repeat(test_ords, M + 1, axis=0), | |
| ) # (M, N) | |
| assert (in_stop_ranges.sum(1) <= T - 2).all(), ( | |
| "More input frames need to be sampled during the first pass to ensure " | |
| f"#test frames during each forard in the second pass will not exceed {T - 2}." | |
| ) | |
| if input_ords[1, 0] <= test_ords[0, 0]: | |
| assert not in_stop_ranges[0].any() | |
| if input_ords[-1, 0] >= test_ords[0, -1]: | |
| assert not in_stop_ranges[-1].any() | |
| gt_chunk = ( | |
| [f"!{i:03d}" for i in gt_input_inds] if "gt" in chunk_strategy else [] | |
| ) | |
| chunk = gt_chunk + [] | |
| # any test views before the first input views | |
| if in_stop_ranges[0].any(): | |
| for j, in_range in enumerate(in_stop_ranges[0]): | |
| if in_range: | |
| chunk.append(f">{j:03d}") | |
| in_stop_ranges = in_stop_ranges[1:] | |
| i = 0 | |
| base_i = len(gt_input_inds) if "img2trajvid" in task else 0 | |
| chunk.append(f"!{i + base_i:03d}") | |
| while i < len(in_stop_ranges): | |
| in_stop_range = in_stop_ranges[i] | |
| if not in_stop_range.any(): | |
| i += 1 | |
| continue | |
| input_left = i + 1 < M | |
| space_left = T - len(chunk) | |
| if sum(in_stop_range) + input_left <= space_left: | |
| for j, in_range in enumerate(in_stop_range): | |
| if in_range: | |
| chunk.append(f">{j:03d}") | |
| i += 1 | |
| if input_left: | |
| chunk.append(f"!{i + base_i:03d}") | |
| else: | |
| chunk += ["NULL"] * space_left | |
| chunks.append(chunk) | |
| chunk = gt_chunk + [f"!{i + base_i:03d}"] | |
| if len(chunk) > 1: | |
| chunk += ["NULL"] * (T - len(chunk)) | |
| chunks.append(chunk) | |
| else: | |
| raise NotImplementedError | |
| ( | |
| input_inds_per_chunk, | |
| input_sels_per_chunk, | |
| test_inds_per_chunk, | |
| test_sels_per_chunk, | |
| ) = ( | |
| [], | |
| [], | |
| [], | |
| [], | |
| ) | |
| for chunk in chunks: | |
| input_inds = [ | |
| int(img.removeprefix("!")) for img in chunk if img.startswith("!") | |
| ] | |
| input_sels = [chunk.index(img) for img in chunk if img.startswith("!")] | |
| test_inds = [int(img.removeprefix(">")) for img in chunk if img.startswith(">")] | |
| test_sels = [chunk.index(img) for img in chunk if img.startswith(">")] | |
| input_inds_per_chunk.append(input_inds) | |
| input_sels_per_chunk.append(input_sels) | |
| test_inds_per_chunk.append(test_inds) | |
| test_sels_per_chunk.append(test_sels) | |
| if options.get("sampler_verbose", True): | |
| def colorize(item): | |
| if item.startswith("!"): | |
| return f"{Fore.RED}{item}{Style.RESET_ALL}" # Red for items starting with '!' | |
| elif item.startswith(">"): | |
| return f"{Fore.GREEN}{item}{Style.RESET_ALL}" # Green for items starting with '>' | |
| return item # Default color if neither '!' nor '>' | |
| print("\nchunks:") | |
| for chunk in chunks: | |
| print(", ".join(colorize(item) for item in chunk)) | |
| return ( | |
| chunks, | |
| input_inds_per_chunk, # ordering of input in raw sequence | |
| input_sels_per_chunk, # ordering of input in one-forward sequence of length T | |
| test_inds_per_chunk, # ordering of test in raw sequence | |
| test_sels_per_chunk, # oredering of test in one-forward sequence of length T | |
| ) | |
| def is_k_in_dict(d, k): | |
| return any(map(lambda x: x.startswith(k), d.keys())) | |
| def get_k_from_dict(d, k): | |
| media_d = {} | |
| for key, value in d.items(): | |
| if key == k: | |
| return value | |
| if key.startswith(k): | |
| media = key.split("/")[-1] | |
| if media == "raw": | |
| return value | |
| media_d[media] = value | |
| if len(media_d) == 0: | |
| return torch.tensor([]) | |
| assert ( | |
| len(media_d) == 1 | |
| ), f"multiple media found in {d} for key {k}: {media_d.keys()}" | |
| return media_d[media] | |
| def update_kv_for_dict(d, k, v): | |
| for key in d.keys(): | |
| if key.startswith(k): | |
| d[key] = v | |
| return d | |
| def extend_dict(ds, d): | |
| for key in d.keys(): | |
| if key in ds: | |
| ds[key] = torch.cat([ds[key], d[key]], 0) | |
| else: | |
| ds[key] = d[key] | |
| return ds | |
| def replace_or_include_input_for_dict( | |
| samples, | |
| test_indices, | |
| imgs, | |
| c2w, | |
| K, | |
| ): | |
| samples_new = {} | |
| for sample, value in samples.items(): | |
| if "rgb" in sample: | |
| imgs[test_indices] = ( | |
| value[test_indices] if value.shape[0] == imgs.shape[0] else value | |
| ).to(device=imgs.device, dtype=imgs.dtype) | |
| samples_new[sample] = imgs | |
| elif "c2w" in sample: | |
| c2w[test_indices] = ( | |
| value[test_indices] if value.shape[0] == c2w.shape[0] else value | |
| ).to(device=c2w.device, dtype=c2w.dtype) | |
| samples_new[sample] = c2w | |
| elif "intrinsics" in sample: | |
| K[test_indices] = ( | |
| value[test_indices] if value.shape[0] == K.shape[0] else value | |
| ).to(device=K.device, dtype=K.dtype) | |
| samples_new[sample] = K | |
| else: | |
| samples_new[sample] = value | |
| return samples_new | |
| def decode_output( | |
| samples, | |
| T, | |
| indices=None, | |
| ): | |
| # decode model output into dict if it is not | |
| if isinstance(samples, dict): | |
| # model with postprocessor and outputs dict | |
| for sample, value in samples.items(): | |
| if isinstance(value, torch.Tensor): | |
| value = value.detach().cpu() | |
| elif isinstance(value, np.ndarray): | |
| value = torch.from_numpy(value) | |
| else: | |
| value = torch.tensor(value) | |
| if indices is not None and value.shape[0] == T: | |
| value = value[indices] | |
| samples[sample] = value | |
| else: | |
| # model without postprocessor and outputs tensor (rgb) | |
| samples = samples.detach().cpu() | |
| if indices is not None and samples.shape[0] == T: | |
| samples = samples[indices] | |
| samples = {"samples-rgb/image": samples} | |
| return samples | |
| def save_output( | |
| samples, | |
| save_path, | |
| video_save_fps=2, | |
| ): | |
| os.makedirs(save_path, exist_ok=True) | |
| for sample in samples: | |
| media_type = "video" | |
| if "/" in sample: | |
| sample_, media_type = sample.split("/") | |
| else: | |
| sample_ = sample | |
| value = samples[sample] | |
| if isinstance(value, torch.Tensor): | |
| value = value.detach().cpu() | |
| elif isinstance(value, np.ndarray): | |
| value = torch.from_numpy(value) | |
| else: | |
| value = torch.tensor(value) | |
| if media_type == "image": | |
| value = (value.permute(0, 2, 3, 1) + 1) / 2.0 | |
| value = (value * 255).clamp(0, 255).to(torch.uint8) | |
| iio.imwrite( | |
| os.path.join(save_path, f"{sample_}.mp4") | |
| if sample_ | |
| else f"{save_path}.mp4", | |
| value, | |
| fps=video_save_fps, | |
| macro_block_size=1, | |
| ffmpeg_log_level="error", | |
| ) | |
| os.makedirs(os.path.join(save_path, sample_), exist_ok=True) | |
| for i, s in enumerate(value): | |
| iio.imwrite( | |
| os.path.join(save_path, sample_, f"{i:03d}.png"), | |
| s, | |
| ) | |
| elif media_type == "video": | |
| value = (value.permute(0, 2, 3, 1) + 1) / 2.0 | |
| value = (value * 255).clamp(0, 255).to(torch.uint8) | |
| iio.imwrite( | |
| os.path.join(save_path, f"{sample_}.mp4"), | |
| value, | |
| fps=video_save_fps, | |
| macro_block_size=1, | |
| ffmpeg_log_level="error", | |
| ) | |
| elif media_type == "raw": | |
| torch.save( | |
| value, | |
| os.path.join(save_path, f"{sample_}.pt"), | |
| ) | |
| else: | |
| pass | |
| def create_transforms_simple(save_path, img_paths, img_whs, c2ws, Ks): | |
| import os.path as osp | |
| out_frames = [] | |
| for img_path, img_wh, c2w, K in zip(img_paths, img_whs, c2ws, Ks): | |
| out_frame = { | |
| "fl_x": K[0][0].item(), | |
| "fl_y": K[1][1].item(), | |
| "cx": K[0][2].item(), | |
| "cy": K[1][2].item(), | |
| "w": img_wh[0].item(), | |
| "h": img_wh[1].item(), | |
| "file_path": f"./{osp.relpath(img_path, start=save_path)}" | |
| if img_path is not None | |
| else None, | |
| "transform_matrix": c2w.tolist(), | |
| } | |
| out_frames.append(out_frame) | |
| out = { | |
| # "camera_model": "PINHOLE", | |
| "orientation_override": "none", | |
| "frames": out_frames, | |
| } | |
| with open(osp.join(save_path, "transforms.json"), "w") as of: | |
| json.dump(out, of, indent=5) | |
| class GradioTrackedSampler(EulerEDMSampler): | |
| """ | |
| A thin wrapper around the EulerEDMSampler that allows tracking progress and | |
| aborting sampling for gradio demo. | |
| """ | |
| def __init__(self, abort_event: threading.Event, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.abort_event = abort_event | |
| def __call__( # type: ignore | |
| self, | |
| denoiser, | |
| x: torch.Tensor, | |
| scale: float | torch.Tensor, | |
| cond: dict, | |
| uc: dict | None = None, | |
| num_steps: int | None = None, | |
| verbose: bool = True, | |
| global_pbar: gr.Progress | None = None, | |
| **guider_kwargs, | |
| ) -> torch.Tensor | None: | |
| uc = cond if uc is None else uc | |
| x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( | |
| x, | |
| cond, | |
| uc, | |
| num_steps, | |
| ) | |
| for i in self.get_sigma_gen(num_sigmas, verbose=verbose): | |
| gamma = ( | |
| min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) | |
| if self.s_tmin <= sigmas[i] <= self.s_tmax | |
| else 0.0 | |
| ) | |
| x = self.sampler_step( | |
| s_in * sigmas[i], | |
| s_in * sigmas[i + 1], | |
| denoiser, | |
| x, | |
| scale, | |
| cond, | |
| uc, | |
| gamma, | |
| **guider_kwargs, | |
| ) | |
| # Allow tracking progress in gradio demo. | |
| if global_pbar is not None: | |
| global_pbar.update() | |
| # Allow aborting sampling in gradio demo. | |
| if self.abort_event.is_set(): | |
| return None | |
| return x | |
| def create_samplers( | |
| guider_types: int | list[int], | |
| discretization, | |
| num_frames: list[int] | None, | |
| num_steps: int, | |
| cfg_min: float = 1.0, | |
| device: str | torch.device = "cuda", | |
| abort_event: threading.Event | None = None, | |
| ): | |
| guider_mapping = { | |
| 0: VanillaCFG, | |
| 1: MultiviewCFG, | |
| 2: MultiviewTemporalCFG, | |
| } | |
| samplers = [] | |
| if not isinstance(guider_types, (list, tuple)): | |
| guider_types = [guider_types] | |
| for i, guider_type in enumerate(guider_types): | |
| if guider_type not in guider_mapping: | |
| raise ValueError( | |
| f"Invalid guider type {guider_type}. Must be one of {list(guider_mapping.keys())}" | |
| ) | |
| guider_cls = guider_mapping[guider_type] | |
| guider_args = () | |
| if guider_type > 0: | |
| guider_args += (cfg_min,) | |
| if guider_type == 2: | |
| assert num_frames is not None | |
| guider_args = (num_frames[i], cfg_min) | |
| guider = guider_cls(*guider_args) | |
| if abort_event is not None: | |
| sampler = GradioTrackedSampler( | |
| abort_event, | |
| discretization=discretization, | |
| guider=guider, | |
| num_steps=num_steps, | |
| s_churn=0.0, | |
| s_tmin=0.0, | |
| s_tmax=999.0, | |
| s_noise=1.0, | |
| verbose=True, | |
| device=device, | |
| ) | |
| else: | |
| sampler = EulerEDMSampler( | |
| discretization=discretization, | |
| guider=guider, | |
| num_steps=num_steps, | |
| s_churn=0.0, | |
| s_tmin=0.0, | |
| s_tmax=999.0, | |
| s_noise=1.0, | |
| verbose=True, | |
| device=device, | |
| ) | |
| samplers.append(sampler) | |
| return samplers | |
| def get_value_dict( | |
| curr_imgs, | |
| curr_imgs_clip, | |
| curr_input_frame_indices, | |
| curr_c2ws, | |
| curr_Ks, | |
| curr_input_camera_indices, | |
| all_c2ws, | |
| camera_scale=2.0, | |
| ): | |
| assert sorted(curr_input_camera_indices) == sorted( | |
| range(len(curr_input_camera_indices)) | |
| ) | |
| H, W, T, F = curr_imgs.shape[-2], curr_imgs.shape[-1], len(curr_imgs), 8 | |
| value_dict = {} | |
| value_dict["cond_frames_without_noise"] = curr_imgs_clip[curr_input_frame_indices] | |
| value_dict["cond_frames"] = curr_imgs + 0.0 * torch.randn_like(curr_imgs) | |
| value_dict["cond_frames_mask"] = torch.zeros(T, dtype=torch.bool) | |
| value_dict["cond_frames_mask"][curr_input_frame_indices] = True | |
| value_dict["cond_aug"] = 0.0 | |
| c2w = to_hom_pose(curr_c2ws.float()) | |
| w2c = torch.linalg.inv(c2w) | |
| # camera centering | |
| ref_c2ws = all_c2ws | |
| camera_dist_2med = torch.norm( | |
| ref_c2ws[:, :3, 3] - ref_c2ws[:, :3, 3].median(0, keepdim=True).values, | |
| dim=-1, | |
| ) | |
| valid_mask = camera_dist_2med <= torch.clamp( | |
| torch.quantile(camera_dist_2med, 0.97) * 10, | |
| max=1e6, | |
| ) | |
| c2w[:, :3, 3] -= ref_c2ws[valid_mask, :3, 3].mean(0, keepdim=True) | |
| w2c = torch.linalg.inv(c2w) | |
| # camera normalization | |
| camera_dists = c2w[:, :3, 3].clone() | |
| translation_scaling_factor = ( | |
| camera_scale | |
| if torch.isclose( | |
| torch.norm(camera_dists[0]), | |
| torch.zeros(1), | |
| atol=1e-5, | |
| ).any() | |
| else (camera_scale / torch.norm(camera_dists[0])) | |
| ) | |
| w2c[:, :3, 3] *= translation_scaling_factor | |
| c2w[:, :3, 3] *= translation_scaling_factor | |
| value_dict["plucker_coordinate"], _ = get_plucker_coordinates( | |
| extrinsics_src=w2c[0], | |
| extrinsics=w2c, | |
| intrinsics=curr_Ks.float().clone(), | |
| mode="plucker", | |
| rel_zero_translation=True, | |
| target_size=(H // F, W // F), | |
| return_grid_cam=True, | |
| ) | |
| value_dict["c2w"] = c2w | |
| value_dict["K"] = curr_Ks | |
| value_dict["camera_mask"] = torch.zeros(T, dtype=torch.bool) | |
| value_dict["camera_mask"][curr_input_camera_indices] = True | |
| return value_dict | |
| def do_sample( | |
| model, | |
| ae, | |
| conditioner, | |
| denoiser, | |
| sampler, | |
| value_dict, | |
| H, | |
| W, | |
| C, | |
| F, | |
| T, | |
| cfg, | |
| encoding_t=1, | |
| decoding_t=1, | |
| verbose=True, | |
| global_pbar=None, | |
| **_, | |
| ): | |
| imgs = value_dict["cond_frames"].to("cuda") | |
| input_masks = value_dict["cond_frames_mask"].to("cuda") | |
| pluckers = value_dict["plucker_coordinate"].to("cuda") | |
| num_samples = [1, T] | |
| with torch.inference_mode(), torch.autocast("cuda"): | |
| load_model(ae) | |
| load_model(conditioner) | |
| latents = torch.nn.functional.pad( | |
| ae.encode(imgs[input_masks], encoding_t), (0, 0, 0, 0, 0, 1), value=1.0 | |
| ) | |
| c_crossattn = repeat(conditioner(imgs[input_masks]).mean(0), "d -> n 1 d", n=T) | |
| uc_crossattn = torch.zeros_like(c_crossattn) | |
| c_replace = latents.new_zeros(T, *latents.shape[1:]) | |
| c_replace[input_masks] = latents | |
| uc_replace = torch.zeros_like(c_replace) | |
| c_concat = torch.cat( | |
| [ | |
| repeat( | |
| input_masks, | |
| "n -> n 1 h w", | |
| h=pluckers.shape[2], | |
| w=pluckers.shape[3], | |
| ), | |
| pluckers, | |
| ], | |
| 1, | |
| ) | |
| uc_concat = torch.cat( | |
| [pluckers.new_zeros(T, 1, *pluckers.shape[-2:]), pluckers], 1 | |
| ) | |
| c_dense_vector = pluckers | |
| uc_dense_vector = c_dense_vector | |
| # TODO(hangg): concat and dense are problematic. | |
| c = { | |
| "crossattn": c_crossattn, | |
| "replace": c_replace, | |
| "concat": c_concat, | |
| "dense_vector": c_dense_vector, | |
| } | |
| uc = { | |
| "crossattn": uc_crossattn, | |
| "replace": uc_replace, | |
| "concat": uc_concat, | |
| "dense_vector": uc_dense_vector, | |
| } | |
| unload_model(ae) | |
| unload_model(conditioner) | |
| additional_model_inputs = {"num_frames": T} | |
| additional_sampler_inputs = { | |
| "c2w": value_dict["c2w"].to("cuda"), | |
| "K": value_dict["K"].to("cuda"), | |
| "input_frame_mask": value_dict["cond_frames_mask"].to("cuda"), | |
| } | |
| if global_pbar is not None: | |
| additional_sampler_inputs["global_pbar"] = global_pbar | |
| shape = (math.prod(num_samples), C, H // F, W // F) | |
| randn = torch.randn(shape).to("cuda") | |
| load_model(model) | |
| samples_z = sampler( | |
| lambda input, sigma, c: denoiser( | |
| model, | |
| input, | |
| sigma, | |
| c, | |
| **additional_model_inputs, | |
| ), | |
| randn, | |
| scale=cfg, | |
| cond=c, | |
| uc=uc, | |
| verbose=verbose, | |
| **additional_sampler_inputs, | |
| ) | |
| if samples_z is None: | |
| return | |
| unload_model(model) | |
| load_model(ae) | |
| samples = ae.decode(samples_z, decoding_t) | |
| unload_model(ae) | |
| return samples | |
| def run_one_scene( | |
| task, | |
| version_dict, | |
| model, | |
| ae, | |
| conditioner, | |
| denoiser, | |
| image_cond, | |
| camera_cond, | |
| save_path, | |
| use_traj_prior, | |
| traj_prior_Ks, | |
| traj_prior_c2ws, | |
| seed=23, | |
| gradio=False, | |
| abort_event=None, | |
| first_pass_pbar=None, | |
| second_pass_pbar=None, | |
| ): | |
| H, W, T, C, F, options = ( | |
| version_dict["H"], | |
| version_dict["W"], | |
| version_dict["T"], | |
| version_dict["C"], | |
| version_dict["f"], | |
| version_dict["options"], | |
| ) | |
| if isinstance(image_cond, str): | |
| image_cond = {"img": [image_cond]} | |
| imgs_clip, imgs, img_size = [], [], None | |
| for i, (img, K) in enumerate(zip(image_cond["img"], camera_cond["K"])): | |
| if isinstance(img, str) or img is None: | |
| img, K = load_img_and_K(img or img_size, None, K=K, device="cpu") # type: ignore | |
| img_size = img.shape[-2:] | |
| if options.get("L_short", -1) == -1: | |
| img, K = transform_img_and_K( | |
| img, | |
| (W, H), | |
| K=K[None], | |
| mode=( | |
| options.get("transform_input", "crop") | |
| if i in image_cond["input_indices"] | |
| else options.get("transform_target", "crop") | |
| ), | |
| scale=( | |
| 1.0 | |
| if i in image_cond["input_indices"] | |
| else options.get("transform_scale", 1.0) | |
| ), | |
| ) | |
| else: | |
| downsample = 3 | |
| assert options["L_short"] % F * 2**downsample == 0, ( | |
| "Short side of the image should be divisible by " | |
| f"F*2**{downsample}={F * 2**downsample}." | |
| ) | |
| img, K = transform_img_and_K( | |
| img, | |
| options["L_short"], | |
| K=K[None], | |
| size_stride=F * 2**downsample, | |
| mode=( | |
| options.get("transform_input", "crop") | |
| if i in image_cond["input_indices"] | |
| else options.get("transform_target", "crop") | |
| ), | |
| scale=( | |
| 1.0 | |
| if i in image_cond["input_indices"] | |
| else options.get("transform_scale", 1.0) | |
| ), | |
| ) | |
| version_dict["W"] = W = img.shape[-1] | |
| version_dict["H"] = H = img.shape[-2] | |
| K = K[0] | |
| K[0] /= W | |
| K[1] /= H | |
| camera_cond["K"][i] = K | |
| img_clip = img | |
| elif isinstance(img, np.ndarray): | |
| img_size = torch.Size(img.shape[:2]) | |
| img = torch.as_tensor(img).permute(2, 0, 1) | |
| img = img.unsqueeze(0) | |
| img = img / 255.0 * 2.0 - 1.0 | |
| if not gradio: | |
| img, K = transform_img_and_K(img, (W, H), K=K[None]) | |
| assert K is not None | |
| K = K[0] | |
| K[0] /= W | |
| K[1] /= H | |
| camera_cond["K"][i] = K | |
| img_clip = img | |
| else: | |
| assert ( | |
| False | |
| ), f"Variable `img` got {type(img)} type which is not supported!!!" | |
| imgs_clip.append(img_clip) | |
| imgs.append(img) | |
| imgs_clip = torch.cat(imgs_clip, dim=0) | |
| imgs = torch.cat(imgs, dim=0) | |
| if traj_prior_Ks is not None: | |
| assert img_size is not None | |
| for i, prior_k in enumerate(traj_prior_Ks): | |
| img, prior_k = load_img_and_K(img_size, None, K=prior_k, device="cpu") # type: ignore | |
| img, prior_k = transform_img_and_K( | |
| img, | |
| (W, H), | |
| K=prior_k[None], | |
| mode=options.get( | |
| "transform_target", "crop" | |
| ), # mode for prior is always same as target | |
| scale=options.get( | |
| "transform_scale", 1.0 | |
| ), # scale for prior is always same as target | |
| ) | |
| prior_k = prior_k[0] | |
| prior_k[0] /= W | |
| prior_k[1] /= H | |
| traj_prior_Ks[i] = prior_k | |
| options["num_frames"] = T | |
| discretization = denoiser.discretization | |
| torch.cuda.empty_cache() | |
| seed_everything(seed) | |
| # Get Data | |
| input_indices = image_cond["input_indices"] | |
| input_imgs = imgs[input_indices] | |
| input_imgs_clip = imgs_clip[input_indices] | |
| input_c2ws = camera_cond["c2w"][input_indices] | |
| input_Ks = camera_cond["K"][input_indices] | |
| test_indices = [i for i in range(len(imgs)) if i not in input_indices] | |
| test_imgs = imgs[test_indices] | |
| test_imgs_clip = imgs_clip[test_indices] | |
| test_c2ws = camera_cond["c2w"][test_indices] | |
| test_Ks = camera_cond["K"][test_indices] | |
| if options.get("save_input", True): | |
| save_output( | |
| {"/image": input_imgs}, | |
| save_path=os.path.join(save_path, "input"), | |
| video_save_fps=2, | |
| ) | |
| if not use_traj_prior: | |
| chunk_strategy = options.get("chunk_strategy", "gt") | |
| ( | |
| _, | |
| input_inds_per_chunk, | |
| input_sels_per_chunk, | |
| test_inds_per_chunk, | |
| test_sels_per_chunk, | |
| ) = chunk_input_and_test( | |
| T, | |
| input_c2ws, | |
| test_c2ws, | |
| input_indices, | |
| test_indices, | |
| options=options, | |
| task=task, | |
| chunk_strategy=chunk_strategy, | |
| gt_input_inds=list(range(input_c2ws.shape[0])), | |
| ) | |
| print( | |
| f"One pass - chunking with `{chunk_strategy}` strategy: total " | |
| f"{len(input_inds_per_chunk)} forward(s) ..." | |
| ) | |
| all_samples = {} | |
| all_test_inds = [] | |
| for i, ( | |
| chunk_input_inds, | |
| chunk_input_sels, | |
| chunk_test_inds, | |
| chunk_test_sels, | |
| ) in tqdm( | |
| enumerate( | |
| zip( | |
| input_inds_per_chunk, | |
| input_sels_per_chunk, | |
| test_inds_per_chunk, | |
| test_sels_per_chunk, | |
| ) | |
| ), | |
| total=len(input_inds_per_chunk), | |
| leave=False, | |
| ): | |
| ( | |
| curr_input_sels, | |
| curr_test_sels, | |
| curr_input_maps, | |
| curr_test_maps, | |
| ) = pad_indices( | |
| chunk_input_sels, | |
| chunk_test_sels, | |
| T=T, | |
| padding_mode=options.get("t_padding_mode", "last"), | |
| ) | |
| curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [ | |
| assemble( | |
| input=x[chunk_input_inds], | |
| test=y[chunk_test_inds], | |
| input_maps=curr_input_maps, | |
| test_maps=curr_test_maps, | |
| ) | |
| for x, y in zip( | |
| [ | |
| torch.cat( | |
| [ | |
| input_imgs, | |
| get_k_from_dict(all_samples, "samples-rgb").to( | |
| input_imgs.device | |
| ), | |
| ], | |
| dim=0, | |
| ), | |
| torch.cat( | |
| [ | |
| input_imgs_clip, | |
| get_k_from_dict(all_samples, "samples-rgb").to( | |
| input_imgs.device | |
| ), | |
| ], | |
| dim=0, | |
| ), | |
| torch.cat([input_c2ws, test_c2ws[all_test_inds]], dim=0), | |
| torch.cat([input_Ks, test_Ks[all_test_inds]], dim=0), | |
| ], # procedually append generated prior views to the input views | |
| [test_imgs, test_imgs_clip, test_c2ws, test_Ks], | |
| ) | |
| ] | |
| value_dict = get_value_dict( | |
| curr_imgs.to("cuda"), | |
| curr_imgs_clip.to("cuda"), | |
| curr_input_sels | |
| + [ | |
| sel | |
| for (ind, sel) in zip( | |
| np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]], | |
| curr_test_sels, | |
| ) | |
| if test_indices[ind] in image_cond["input_indices"] | |
| ], | |
| curr_c2ws, | |
| curr_Ks, | |
| curr_input_sels | |
| + [ | |
| sel | |
| for (ind, sel) in zip( | |
| np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]], | |
| curr_test_sels, | |
| ) | |
| if test_indices[ind] in camera_cond["input_indices"] | |
| ], | |
| all_c2ws=camera_cond["c2w"], | |
| ) | |
| samplers = create_samplers( | |
| options["guider_types"], | |
| discretization, | |
| [len(curr_imgs)], | |
| options["num_steps"], | |
| options["cfg_min"], | |
| abort_event=abort_event, | |
| ) | |
| assert len(samplers) == 1 | |
| samples = do_sample( | |
| model, | |
| ae, | |
| conditioner, | |
| denoiser, | |
| samplers[0], | |
| value_dict, | |
| H, | |
| W, | |
| C, | |
| F, | |
| T=len(curr_imgs), | |
| cfg=( | |
| options["cfg"][0] | |
| if isinstance(options["cfg"], (list, tuple)) | |
| else options["cfg"] | |
| ), | |
| **{k: options[k] for k in options if k not in ["cfg", "T"]}, | |
| ) | |
| samples = decode_output( | |
| samples, len(curr_imgs), chunk_test_sels | |
| ) # decode into dict | |
| if options.get("save_first_pass", False): | |
| save_output( | |
| replace_or_include_input_for_dict( | |
| samples, | |
| chunk_test_sels, | |
| curr_imgs, | |
| curr_c2ws, | |
| curr_Ks, | |
| ), | |
| save_path=os.path.join(save_path, "first-pass", f"forward_{i}"), | |
| video_save_fps=2, | |
| ) | |
| extend_dict(all_samples, samples) | |
| all_test_inds.extend(chunk_test_inds) | |
| else: | |
| assert traj_prior_c2ws is not None, ( | |
| "`traj_prior_c2ws` should be set when using 2-pass sampling. One " | |
| "potential reason is that the amount of input frames is larger than " | |
| "T. Set `num_prior_frames` manually to overwrite the infered stats." | |
| ) | |
| traj_prior_c2ws = torch.as_tensor( | |
| traj_prior_c2ws, | |
| device=input_c2ws.device, | |
| dtype=input_c2ws.dtype, | |
| ) | |
| if traj_prior_Ks is None: | |
| traj_prior_Ks = test_Ks[:1].repeat_interleave( | |
| traj_prior_c2ws.shape[0], dim=0 | |
| ) | |
| traj_prior_imgs = imgs.new_zeros(traj_prior_c2ws.shape[0], *imgs.shape[1:]) | |
| traj_prior_imgs_clip = imgs_clip.new_zeros( | |
| traj_prior_c2ws.shape[0], *imgs_clip.shape[1:] | |
| ) | |
| # ---------------------------------- first pass ---------------------------------- | |
| T_first_pass = T[0] if isinstance(T, (list, tuple)) else T | |
| T_second_pass = T[1] if isinstance(T, (list, tuple)) else T | |
| chunk_strategy_first_pass = options.get( | |
| "chunk_strategy_first_pass", "gt-nearest" | |
| ) | |
| ( | |
| _, | |
| input_inds_per_chunk, | |
| input_sels_per_chunk, | |
| prior_inds_per_chunk, | |
| prior_sels_per_chunk, | |
| ) = chunk_input_and_test( | |
| T_first_pass, | |
| input_c2ws, | |
| traj_prior_c2ws, | |
| input_indices, | |
| image_cond["prior_indices"], | |
| options=options, | |
| task=task, | |
| chunk_strategy=chunk_strategy_first_pass, | |
| gt_input_inds=list(range(input_c2ws.shape[0])), | |
| ) | |
| print( | |
| f"Two passes (first) - chunking with `{chunk_strategy_first_pass}` strategy: total " | |
| f"{len(input_inds_per_chunk)} forward(s) ..." | |
| ) | |
| all_samples = {} | |
| all_prior_inds = [] | |
| for i, ( | |
| chunk_input_inds, | |
| chunk_input_sels, | |
| chunk_prior_inds, | |
| chunk_prior_sels, | |
| ) in tqdm( | |
| enumerate( | |
| zip( | |
| input_inds_per_chunk, | |
| input_sels_per_chunk, | |
| prior_inds_per_chunk, | |
| prior_sels_per_chunk, | |
| ) | |
| ), | |
| total=len(input_inds_per_chunk), | |
| leave=False, | |
| ): | |
| ( | |
| curr_input_sels, | |
| curr_prior_sels, | |
| curr_input_maps, | |
| curr_prior_maps, | |
| ) = pad_indices( | |
| chunk_input_sels, | |
| chunk_prior_sels, | |
| T=T_first_pass, | |
| padding_mode=options.get("t_padding_mode", "last"), | |
| ) | |
| curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [ | |
| assemble( | |
| input=x[chunk_input_inds], | |
| test=y[chunk_prior_inds], | |
| input_maps=curr_input_maps, | |
| test_maps=curr_prior_maps, | |
| ) | |
| for x, y in zip( | |
| [ | |
| torch.cat( | |
| [ | |
| input_imgs, | |
| get_k_from_dict(all_samples, "samples-rgb").to( | |
| input_imgs.device | |
| ), | |
| ], | |
| dim=0, | |
| ), | |
| torch.cat( | |
| [ | |
| input_imgs_clip, | |
| get_k_from_dict(all_samples, "samples-rgb").to( | |
| input_imgs.device | |
| ), | |
| ], | |
| dim=0, | |
| ), | |
| torch.cat([input_c2ws, traj_prior_c2ws[all_prior_inds]], dim=0), | |
| torch.cat([input_Ks, traj_prior_Ks[all_prior_inds]], dim=0), | |
| ], # procedually append generated prior views to the input views | |
| [ | |
| traj_prior_imgs, | |
| traj_prior_imgs_clip, | |
| traj_prior_c2ws, | |
| traj_prior_Ks, | |
| ], | |
| ) | |
| ] | |
| value_dict = get_value_dict( | |
| curr_imgs.to("cuda"), | |
| curr_imgs_clip.to("cuda"), | |
| curr_input_sels, | |
| curr_c2ws, | |
| curr_Ks, | |
| list(range(T_first_pass)), | |
| all_c2ws=camera_cond["c2w"], # traj_prior_c2ws, | |
| ) | |
| samplers = create_samplers( | |
| options["guider_types"], | |
| discretization, | |
| [T_first_pass, T_second_pass], | |
| options["num_steps"], | |
| options["cfg_min"], | |
| abort_event=abort_event, | |
| ) | |
| samples = do_sample( | |
| model, | |
| ae, | |
| conditioner, | |
| denoiser, | |
| ( | |
| samplers[1] | |
| if len(samplers) > 1 | |
| and options.get("ltr_first_pass", False) | |
| and chunk_strategy_first_pass != "gt" | |
| and i > 0 | |
| else samplers[0] | |
| ), | |
| value_dict, | |
| H, | |
| W, | |
| C, | |
| F, | |
| cfg=( | |
| options["cfg"][0] | |
| if isinstance(options["cfg"], (list, tuple)) | |
| else options["cfg"] | |
| ), | |
| T=T_first_pass, | |
| global_pbar=first_pass_pbar, | |
| **{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]}, | |
| ) | |
| if samples is None: | |
| return | |
| samples = decode_output( | |
| samples, T_first_pass, chunk_prior_sels | |
| ) # decode into dict | |
| extend_dict(all_samples, samples) | |
| all_prior_inds.extend(chunk_prior_inds) | |
| if options.get("save_first_pass", True): | |
| save_output( | |
| all_samples, | |
| save_path=os.path.join(save_path, "first-pass"), | |
| video_save_fps=5, | |
| ) | |
| video_path_0 = os.path.join(save_path, "first-pass", "samples-rgb.mp4") | |
| yield video_path_0 | |
| # ---------------------------------- second pass ---------------------------------- | |
| prior_indices = image_cond["prior_indices"] | |
| assert ( | |
| prior_indices is not None | |
| ), "`prior_frame_indices` needs to be set if using 2-pass sampling." | |
| prior_argsort = np.argsort(input_indices + prior_indices).tolist() | |
| prior_indices = np.array(input_indices + prior_indices)[prior_argsort].tolist() | |
| gt_input_inds = [prior_argsort.index(i) for i in range(input_c2ws.shape[0])] | |
| traj_prior_imgs = torch.cat( | |
| [input_imgs, get_k_from_dict(all_samples, "samples-rgb")], dim=0 | |
| )[prior_argsort] | |
| traj_prior_imgs_clip = torch.cat( | |
| [ | |
| input_imgs_clip, | |
| get_k_from_dict(all_samples, "samples-rgb"), | |
| ], | |
| dim=0, | |
| )[prior_argsort] | |
| traj_prior_c2ws = torch.cat([input_c2ws, traj_prior_c2ws], dim=0)[prior_argsort] | |
| traj_prior_Ks = torch.cat([input_Ks, traj_prior_Ks], dim=0)[prior_argsort] | |
| update_kv_for_dict(all_samples, "samples-rgb", traj_prior_imgs) | |
| update_kv_for_dict(all_samples, "samples-c2ws", traj_prior_c2ws) | |
| update_kv_for_dict(all_samples, "samples-intrinsics", traj_prior_Ks) | |
| chunk_strategy = options.get("chunk_strategy", "nearest") | |
| ( | |
| _, | |
| prior_inds_per_chunk, | |
| prior_sels_per_chunk, | |
| test_inds_per_chunk, | |
| test_sels_per_chunk, | |
| ) = chunk_input_and_test( | |
| T_second_pass, | |
| traj_prior_c2ws, | |
| test_c2ws, | |
| prior_indices, | |
| test_indices, | |
| options=options, | |
| task=task, | |
| chunk_strategy=chunk_strategy, | |
| gt_input_inds=gt_input_inds, | |
| ) | |
| print( | |
| f"Two passes (second) - chunking with `{chunk_strategy}` strategy: total " | |
| f"{len(prior_inds_per_chunk)} forward(s) ..." | |
| ) | |
| all_samples = {} | |
| all_test_inds = [] | |
| for i, ( | |
| chunk_prior_inds, | |
| chunk_prior_sels, | |
| chunk_test_inds, | |
| chunk_test_sels, | |
| ) in tqdm( | |
| enumerate( | |
| zip( | |
| prior_inds_per_chunk, | |
| prior_sels_per_chunk, | |
| test_inds_per_chunk, | |
| test_sels_per_chunk, | |
| ) | |
| ), | |
| total=len(prior_inds_per_chunk), | |
| leave=False, | |
| ): | |
| ( | |
| curr_prior_sels, | |
| curr_test_sels, | |
| curr_prior_maps, | |
| curr_test_maps, | |
| ) = pad_indices( | |
| chunk_prior_sels, | |
| chunk_test_sels, | |
| T=T_second_pass, | |
| padding_mode="last", | |
| ) | |
| curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [ | |
| assemble( | |
| input=x[chunk_prior_inds], | |
| test=y[chunk_test_inds], | |
| input_maps=curr_prior_maps, | |
| test_maps=curr_test_maps, | |
| ) | |
| for x, y in zip( | |
| [ | |
| traj_prior_imgs, | |
| traj_prior_imgs_clip, | |
| traj_prior_c2ws, | |
| traj_prior_Ks, | |
| ], | |
| [test_imgs, test_imgs_clip, test_c2ws, test_Ks], | |
| ) | |
| ] | |
| value_dict = get_value_dict( | |
| curr_imgs.to("cuda"), | |
| curr_imgs_clip.to("cuda"), | |
| curr_prior_sels, | |
| curr_c2ws, | |
| curr_Ks, | |
| list(range(T_second_pass)), | |
| all_c2ws=camera_cond["c2w"], # test_c2ws, | |
| ) | |
| samples = do_sample( | |
| model, | |
| ae, | |
| conditioner, | |
| denoiser, | |
| samplers[1] if len(samplers) > 1 else samplers[0], | |
| value_dict, | |
| H, | |
| W, | |
| C, | |
| F, | |
| T=T_second_pass, | |
| cfg=( | |
| options["cfg"][1] | |
| if isinstance(options["cfg"], (list, tuple)) | |
| and len(options["cfg"]) > 1 | |
| else options["cfg"] | |
| ), | |
| global_pbar=second_pass_pbar, | |
| **{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]}, | |
| ) | |
| if samples is None: | |
| return | |
| samples = decode_output( | |
| samples, T_second_pass, chunk_test_sels | |
| ) # decode into dict | |
| if options.get("save_second_pass", False): | |
| save_output( | |
| replace_or_include_input_for_dict( | |
| samples, | |
| chunk_test_sels, | |
| curr_imgs, | |
| curr_c2ws, | |
| curr_Ks, | |
| ), | |
| save_path=os.path.join(save_path, "second-pass", f"forward_{i}"), | |
| video_save_fps=2, | |
| ) | |
| extend_dict(all_samples, samples) | |
| all_test_inds.extend(chunk_test_inds) | |
| all_samples = { | |
| key: value[np.argsort(all_test_inds)] for key, value in all_samples.items() | |
| } | |
| save_output( | |
| replace_or_include_input_for_dict( | |
| all_samples, | |
| test_indices, | |
| imgs.clone(), | |
| camera_cond["c2w"].clone(), | |
| camera_cond["K"].clone(), | |
| ) | |
| if options.get("replace_or_include_input", False) | |
| else all_samples, | |
| save_path=save_path, | |
| video_save_fps=options.get("video_save_fps", 2), | |
| ) | |
| video_path_1 = os.path.join(save_path, "samples-rgb.mp4") | |
| yield video_path_1 | |