Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from tqdm import tqdm | |
| from seva.geometry import get_camera_dist | |
| def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: | |
| """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
| dims_to_append = target_dims - x.ndim | |
| if dims_to_append < 0: | |
| raise ValueError( | |
| f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" | |
| ) | |
| return x[(...,) + (None,) * dims_to_append] | |
| def append_zero(x: torch.Tensor) -> torch.Tensor: | |
| return torch.cat([x, x.new_zeros([1])]) | |
| def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor: | |
| return (x - denoised) / append_dims(sigma, x.ndim) | |
| def make_betas( | |
| num_timesteps: int, linear_start: float = 1e-4, linear_end: float = 2e-2 | |
| ) -> np.ndarray: | |
| betas = ( | |
| torch.linspace( | |
| linear_start**0.5, linear_end**0.5, num_timesteps, dtype=torch.float64 | |
| ) | |
| ** 2 | |
| ) | |
| return betas.numpy() | |
| def generate_roughly_equally_spaced_steps( | |
| num_substeps: int, max_step: int | |
| ) -> np.ndarray: | |
| return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] | |
| class EpsScaling(object): | |
| def __call__( | |
| self, sigma: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| c_skip = torch.ones_like(sigma, device=sigma.device) | |
| c_out = -sigma | |
| c_in = 1 / (sigma**2 + 1.0) ** 0.5 | |
| c_noise = sigma.clone() | |
| return c_skip, c_out, c_in, c_noise | |
| class DDPMDiscretization(object): | |
| def __init__( | |
| self, | |
| linear_start: float = 5e-06, | |
| linear_end: float = 0.012, | |
| num_timesteps: int = 1000, | |
| log_snr_shift: float | None = 2.4, | |
| ): | |
| self.num_timesteps = num_timesteps | |
| betas = make_betas( | |
| num_timesteps, | |
| linear_start=linear_start, | |
| linear_end=linear_end, | |
| ) | |
| self.log_snr_shift = log_snr_shift | |
| alphas = 1.0 - betas # first alpha here is on data side | |
| self.alphas_cumprod = np.cumprod(alphas, axis=0) | |
| def get_sigmas(self, n: int, device: str | torch.device = "cpu") -> torch.Tensor: | |
| if n < self.num_timesteps: | |
| timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) | |
| alphas_cumprod = self.alphas_cumprod[timesteps] | |
| elif n == self.num_timesteps: | |
| alphas_cumprod = self.alphas_cumprod | |
| else: | |
| raise ValueError(f"Expected n <= {self.num_timesteps}, but got n = {n}.") | |
| sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 | |
| if self.log_snr_shift is not None: | |
| sigmas = sigmas * np.exp(self.log_snr_shift) | |
| return torch.flip( | |
| torch.tensor(sigmas, dtype=torch.float32, device=device), (0,) | |
| ) | |
| def __call__( | |
| self, | |
| n: int, | |
| do_append_zero: bool = True, | |
| flip: bool = False, | |
| device: str | torch.device = "cpu", | |
| ) -> torch.Tensor: | |
| sigmas = self.get_sigmas(n, device=device) | |
| sigmas = append_zero(sigmas) if do_append_zero else sigmas | |
| return sigmas if not flip else torch.flip(sigmas, (0,)) | |
| class DiscreteDenoiser(object): | |
| sigmas: torch.Tensor | |
| def __init__( | |
| self, | |
| discretization: DDPMDiscretization, | |
| num_idx: int = 1000, | |
| device: str | torch.device = "cpu", | |
| ): | |
| self.scaling = EpsScaling() | |
| self.discretization = discretization | |
| self.num_idx = num_idx | |
| self.device = device | |
| self.register_sigmas() | |
| def register_sigmas(self): | |
| self.sigmas = self.discretization( | |
| self.num_idx, do_append_zero=False, flip=True, device=self.device | |
| ) | |
| def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: | |
| dists = sigma - self.sigmas[:, None] | |
| return dists.abs().argmin(dim=0).view(sigma.shape) | |
| def idx_to_sigma(self, idx: torch.Tensor | int) -> torch.Tensor: | |
| return self.sigmas[idx] | |
| def __call__( | |
| self, | |
| network: nn.Module, | |
| input: torch.Tensor, | |
| sigma: torch.Tensor, | |
| cond: dict, | |
| **additional_model_inputs, | |
| ) -> torch.Tensor: | |
| sigma = self.idx_to_sigma(self.sigma_to_idx(sigma)) | |
| sigma_shape = sigma.shape | |
| sigma = append_dims(sigma, input.ndim) | |
| c_skip, c_out, c_in, c_noise = self.scaling(sigma) | |
| c_noise = self.sigma_to_idx(c_noise.reshape(sigma_shape)) | |
| if "replace" in cond: | |
| x, mask = cond.pop("replace").split((input.shape[1], 1), dim=1) | |
| input = input * (1 - mask) + x * mask | |
| return ( | |
| network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out | |
| + input * c_skip | |
| ) | |
| class ConstantScaleRule(object): | |
| def __call__(self, scale: float | torch.Tensor) -> float | torch.Tensor: | |
| return scale | |
| class MultiviewScaleRule(object): | |
| def __init__(self, min_scale: float = 1.0): | |
| self.min_scale = min_scale | |
| def __call__( | |
| self, | |
| scale: float | torch.Tensor, | |
| c2w: torch.Tensor, | |
| K: torch.Tensor, | |
| input_frame_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| c2w_input = c2w[input_frame_mask] | |
| rotation_diff = get_camera_dist(c2w, c2w_input, mode="rotation").min(-1).values | |
| translation_diff = ( | |
| get_camera_dist(c2w, c2w_input, mode="translation").min(-1).values | |
| ) | |
| K_diff = ( | |
| ((K[:, None] - K[input_frame_mask][None]).flatten(-2) == 0).all(-1).any(-1) | |
| ) | |
| close_frame = (rotation_diff < 10.0) & (translation_diff < 1e-5) & K_diff | |
| if isinstance(scale, torch.Tensor): | |
| scale = scale.clone() | |
| scale[close_frame] = self.min_scale | |
| elif isinstance(scale, float): | |
| scale = torch.where(close_frame, self.min_scale, scale) | |
| else: | |
| raise ValueError(f"Invalid scale type {type(scale)}.") | |
| return scale | |
| class ConstantScaleSchedule(object): | |
| def __call__( | |
| self, sigma: float | torch.Tensor, scale: float | torch.Tensor | |
| ) -> float | torch.Tensor: | |
| if isinstance(sigma, float): | |
| return scale | |
| elif isinstance(sigma, torch.Tensor): | |
| if len(sigma.shape) == 1 and isinstance(scale, torch.Tensor): | |
| sigma = append_dims(sigma, scale.ndim) | |
| return scale * torch.ones_like(sigma) | |
| else: | |
| raise ValueError(f"Invalid sigma type {type(sigma)}.") | |
| class ConstantGuidance(object): | |
| def __call__( | |
| self, | |
| uncond: torch.Tensor, | |
| cond: torch.Tensor, | |
| scale: float | torch.Tensor, | |
| ) -> torch.Tensor: | |
| if isinstance(scale, torch.Tensor) and len(scale.shape) == 1: | |
| scale = append_dims(scale, cond.ndim) | |
| return uncond + scale * (cond - uncond) | |
| class VanillaCFG(object): | |
| def __init__(self): | |
| self.scale_rule = ConstantScaleRule() | |
| self.scale_schedule = ConstantScaleSchedule() | |
| self.guidance = ConstantGuidance() | |
| def __call__( | |
| self, x: torch.Tensor, sigma: float | torch.Tensor, scale: float | torch.Tensor | |
| ) -> torch.Tensor: | |
| x_u, x_c = x.chunk(2) | |
| scale = self.scale_rule(scale) | |
| scale_value = self.scale_schedule(sigma, scale) | |
| x_pred = self.guidance(x_u, x_c, scale_value) | |
| return x_pred | |
| def prepare_inputs( | |
| self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict | |
| ) -> tuple[torch.Tensor, torch.Tensor, dict]: | |
| c_out = dict() | |
| for k in c: | |
| if k in ["vector", "crossattn", "concat", "replace", "dense_vector"]: | |
| c_out[k] = torch.cat((uc[k], c[k]), 0) | |
| else: | |
| assert c[k] == uc[k] | |
| c_out[k] = c[k] | |
| return torch.cat([x] * 2), torch.cat([s] * 2), c_out | |
| class MultiviewCFG(VanillaCFG): | |
| def __init__(self, cfg_min: float = 1.0): | |
| self.scale_min = cfg_min | |
| self.scale_rule = MultiviewScaleRule(min_scale=cfg_min) | |
| self.scale_schedule = ConstantScaleSchedule() | |
| self.guidance = ConstantGuidance() | |
| def __call__( # type: ignore | |
| self, | |
| x: torch.Tensor, | |
| sigma: float | torch.Tensor, | |
| scale: float | torch.Tensor, | |
| c2w: torch.Tensor, | |
| K: torch.Tensor, | |
| input_frame_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| x_u, x_c = x.chunk(2) | |
| scale = self.scale_rule(scale, c2w, K, input_frame_mask) | |
| scale_value = self.scale_schedule(sigma, scale) | |
| x_pred = self.guidance(x_u, x_c, scale_value) | |
| return x_pred | |
| class MultiviewTemporalCFG(MultiviewCFG): | |
| def __init__(self, num_frames: int, cfg_min: float = 1.0): | |
| super().__init__(cfg_min=cfg_min) | |
| self.num_frames = num_frames | |
| distance_matrix = ( | |
| torch.arange(num_frames)[None] - torch.arange(num_frames)[:, None] | |
| ).abs() | |
| self.distance_matrix = distance_matrix | |
| def __call__( | |
| self, | |
| x: torch.Tensor, | |
| sigma: float | torch.Tensor, | |
| scale: float | torch.Tensor, | |
| c2w: torch.Tensor, | |
| K: torch.Tensor, | |
| input_frame_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| input_frame_mask = rearrange( | |
| input_frame_mask, "(b t) ... -> b t ...", t=self.num_frames | |
| ) | |
| min_distance = ( | |
| self.distance_matrix[None].to(x.device) | |
| + (~input_frame_mask[:, None]) * self.num_frames | |
| ).min(-1)[0] | |
| min_distance = min_distance / min_distance.max(-1, keepdim=True)[0].clamp(min=1) | |
| scale = min_distance * (scale - self.scale_min) + self.scale_min | |
| scale = rearrange(scale, "b t ... -> (b t) ...") | |
| scale = append_dims(scale, x.ndim) | |
| return super().__call__(x, sigma, scale, c2w, K, input_frame_mask.flatten(0, 1)) | |
| class EulerEDMSampler(object): | |
| def __init__( | |
| self, | |
| discretization: DDPMDiscretization, | |
| guider: VanillaCFG | MultiviewCFG | MultiviewTemporalCFG, | |
| num_steps: int | None = None, | |
| verbose: bool = False, | |
| device: str | torch.device = "cuda", | |
| s_churn=0.0, | |
| s_tmin=0.0, | |
| s_tmax=float("inf"), | |
| s_noise=1.0, | |
| ): | |
| self.num_steps = num_steps | |
| self.discretization = discretization | |
| self.guider = guider | |
| self.verbose = verbose | |
| self.device = device | |
| self.s_churn = s_churn | |
| self.s_tmin = s_tmin | |
| self.s_tmax = s_tmax | |
| self.s_noise = s_noise | |
| def prepare_sampling_loop( | |
| self, x: torch.Tensor, cond: dict, uc: dict, num_steps: int | None = None | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, dict, dict]: | |
| num_steps = num_steps or self.num_steps | |
| assert num_steps is not None, "num_steps must be specified" | |
| sigmas = self.discretization(num_steps, device=self.device) | |
| x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) | |
| num_sigmas = len(sigmas) | |
| s_in = x.new_ones([x.shape[0]]) | |
| return x, s_in, sigmas, num_sigmas, cond, uc | |
| def get_sigma_gen(self, num_sigmas: int, verbose: bool = True) -> range | tqdm: | |
| sigma_generator = range(num_sigmas - 1) | |
| if self.verbose and verbose: | |
| sigma_generator = tqdm( | |
| sigma_generator, | |
| total=num_sigmas - 1, | |
| desc="Sampling", | |
| leave=False, | |
| ) | |
| return sigma_generator | |
| def sampler_step( | |
| self, | |
| sigma: torch.Tensor, | |
| next_sigma: torch.Tensor, | |
| denoiser, | |
| x: torch.Tensor, | |
| scale: float | torch.Tensor, | |
| cond: dict, | |
| uc: dict, | |
| gamma: float = 0.0, | |
| **guider_kwargs, | |
| ) -> torch.Tensor: | |
| sigma_hat = sigma * (gamma + 1.0) + 1e-6 | |
| eps = torch.randn_like(x) * self.s_noise | |
| x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 | |
| denoised = denoiser(*self.guider.prepare_inputs(x, sigma_hat, cond, uc)) | |
| denoised = self.guider(denoised, sigma_hat, scale, **guider_kwargs) | |
| d = to_d(x, sigma_hat, denoised) | |
| dt = append_dims(next_sigma - sigma_hat, x.ndim) | |
| return x + dt * d | |
| def __call__( | |
| self, | |
| denoiser, | |
| x: torch.Tensor, | |
| scale: float | torch.Tensor, | |
| cond: dict, | |
| uc: dict | None = None, | |
| num_steps: int | None = None, | |
| verbose: bool = True, | |
| **guider_kwargs, | |
| ) -> torch.Tensor: | |
| uc = cond if uc is None else uc | |
| x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( | |
| x, | |
| cond, | |
| uc, | |
| num_steps, | |
| ) | |
| for i in self.get_sigma_gen(num_sigmas, verbose=verbose): | |
| gamma = ( | |
| min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) | |
| if self.s_tmin <= sigmas[i] <= self.s_tmax | |
| else 0.0 | |
| ) | |
| x = self.sampler_step( | |
| s_in * sigmas[i], | |
| s_in * sigmas[i + 1], | |
| denoiser, | |
| x, | |
| scale, | |
| cond, | |
| uc, | |
| gamma, | |
| **guider_kwargs, | |
| ) | |
| return x | |