Spaces:
Runtime error
Runtime error
| import inspect | |
| from typing import Union, Optional, List | |
| import torch | |
| import numpy as np | |
| from tqdm.auto import tqdm | |
| from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import ( | |
| _resize_with_antialiasing, | |
| StableVideoDiffusionPipelineOutput, | |
| StableVideoDiffusionPipeline, | |
| ) | |
| from diffusers.utils.torch_utils import is_compiled_module, randn_tensor | |
| from einops import rearrange | |
| class ChronoDepthPipeline(StableVideoDiffusionPipeline): | |
| def encode_images(self, | |
| images: torch.Tensor, | |
| decode_chunk_size=5, | |
| ): | |
| video_length = images.shape[1] | |
| images = rearrange(images, "b f c h w -> (b f) c h w") | |
| latents = [] | |
| for i in range(0, images.shape[0], decode_chunk_size): | |
| latents_chunk = self.vae.encode(images[i : i + decode_chunk_size]).latent_dist.sample() | |
| latents.append(latents_chunk) | |
| latents = torch.cat(latents, dim=0) | |
| latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length) | |
| latents = latents * self.vae.config.scaling_factor | |
| return latents | |
| def _encode_image(self, images, device, discard=True, chunk_size=14): | |
| ''' | |
| set image to zero tensor discards the image embeddings if discard is True | |
| ''' | |
| dtype = next(self.image_encoder.parameters()).dtype | |
| images = _resize_with_antialiasing(images, (224, 224)) | |
| images = (images + 1.0) / 2.0 | |
| if discard: | |
| images = torch.zeros_like(images) | |
| image_embeddings = [] | |
| for i in range(0, images.shape[0], chunk_size): | |
| tmp = self.feature_extractor( | |
| images=images[i : i + chunk_size], | |
| do_normalize=True, | |
| do_center_crop=False, | |
| do_resize=False, | |
| do_rescale=False, | |
| return_tensors="pt", | |
| ).pixel_values | |
| tmp = tmp.to(device=device, dtype=dtype) | |
| image_embeddings.append(self.image_encoder(tmp).image_embeds) | |
| image_embeddings = torch.cat(image_embeddings, dim=0) | |
| image_embeddings = image_embeddings.unsqueeze(1) # [t, 1, 1024] | |
| return image_embeddings | |
| def decode_depth(self, depth_latent: torch.Tensor, decode_chunk_size=5) -> torch.Tensor: | |
| num_frames = depth_latent.shape[1] | |
| depth_latent = rearrange(depth_latent, "b f c h w -> (b f) c h w") | |
| depth_latent = depth_latent / self.vae.config.scaling_factor | |
| forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward | |
| accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) | |
| depth_frames = [] | |
| for i in range(0, depth_latent.shape[0], decode_chunk_size): | |
| num_frames_in = depth_latent[i : i + decode_chunk_size].shape[0] | |
| decode_kwargs = {} | |
| if accepts_num_frames: | |
| # we only pass num_frames_in if it's expected | |
| decode_kwargs["num_frames"] = num_frames_in | |
| depth_frame = self.vae.decode(depth_latent[i : i + decode_chunk_size], **decode_kwargs).sample | |
| depth_frames.append(depth_frame) | |
| depth_frames = torch.cat(depth_frames, dim=0) | |
| depth_frames = depth_frames.reshape(-1, num_frames, *depth_frames.shape[1:]) | |
| depth_mean = depth_frames.mean(dim=2, keepdim=True) | |
| return depth_mean | |
| def check_inputs(images, height, width): | |
| if ( | |
| not isinstance(images, torch.Tensor) | |
| and not isinstance(images, np.ndarray) | |
| ): | |
| raise ValueError( | |
| "`images` has to be of type `torch.Tensor` or `numpy.ndarray` but is" | |
| f" {type(images)}" | |
| ) | |
| if height % 64 != 0 or width % 64 != 0: | |
| raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | |
| def __call__( | |
| self, | |
| input_images: Union[np.ndarray, torch.FloatTensor], | |
| height: int = 576, | |
| width: int = 768, | |
| num_inference_steps: int = 10, | |
| fps: int = 7, | |
| motion_bucket_id: int = 127, | |
| noise_aug_strength: float = 0.02, | |
| decode_chunk_size: Optional[int] = None, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| show_progress_bar: bool = True, | |
| latents: Optional[torch.Tensor] = None, | |
| infer_mode: str = 'ours', | |
| sigma_epsilon: float = -4, | |
| ): | |
| """ | |
| Args: | |
| input_images: shape [T, H, W, 3] if np.ndarray or [T, 3, H, W] if torch.FloatTensor, range [0, 1] | |
| height: int, height of the input image | |
| width: int, width of the input image | |
| num_inference_steps: int, number of inference steps | |
| fps: int, frames per second | |
| motion_bucket_id: int, motion bucket id | |
| noise_aug_strength: float, noise augmentation strength | |
| decode_chunk_size: int, decode chunk size | |
| generator: torch.Generator or List[torch.Generator], random number generator | |
| show_progress_bar: bool, show progress bar | |
| """ | |
| assert height >= 0 and width >=0 | |
| assert num_inference_steps >=1 | |
| decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8 | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs(input_images, height, width) | |
| # 2. Define call parameters | |
| batch_size = 1 # only support batch size 1 for now | |
| device = self._execution_device | |
| # 3. Encode input image | |
| if isinstance(input_images, np.ndarray): | |
| input_images = torch.from_numpy(input_images.transpose(0, 3, 1, 2)) | |
| else: | |
| assert isinstance(input_images, torch.Tensor) | |
| input_images = input_images.to(device=device) | |
| input_images = input_images * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w] | |
| discard_clip_features = True | |
| image_embeddings = self._encode_image(input_images, device, | |
| discard=discard_clip_features, | |
| chunk_size=decode_chunk_size | |
| ) | |
| # NOTE: Stable Diffusion Video was conditioned on fps - 1, which | |
| # is why it is reduced here. | |
| # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 | |
| fps = fps - 1 | |
| # 4. Encode input image using VAE | |
| noise = randn_tensor(input_images.shape, generator=generator, device=device, dtype=input_images.dtype) | |
| input_images = input_images + noise_aug_strength * noise | |
| rgb_batch = input_images.unsqueeze(0) | |
| added_time_ids = self._get_add_time_ids( | |
| fps, | |
| motion_bucket_id, | |
| noise_aug_strength, | |
| image_embeddings.dtype, | |
| batch_size, | |
| 1, # do not modify this! | |
| False, # do not modify this! | |
| ) | |
| added_time_ids = added_time_ids.to(device) | |
| if infer_mode == 'ours': | |
| depth_pred_raw = self.single_infer_ours( | |
| rgb_batch, | |
| image_embeddings, | |
| added_time_ids, | |
| num_inference_steps, | |
| show_progress_bar, | |
| generator, | |
| decode_chunk_size=decode_chunk_size, | |
| latents=latents, | |
| sigma_epsilon=sigma_epsilon, | |
| ) | |
| elif infer_mode == 'replacement': | |
| depth_pred_raw = self.single_infer_replacement( | |
| rgb_batch, | |
| image_embeddings, | |
| added_time_ids, | |
| num_inference_steps, | |
| show_progress_bar, | |
| generator, | |
| decode_chunk_size=decode_chunk_size, | |
| latents=latents, | |
| ) | |
| elif infer_mode == 'naive': | |
| depth_pred_raw = self.single_infer_naive_sliding_window( | |
| rgb_batch, | |
| image_embeddings, | |
| added_time_ids, | |
| num_inference_steps, | |
| show_progress_bar, | |
| generator, | |
| decode_chunk_size=decode_chunk_size, | |
| latents=latents, | |
| ) | |
| depth_frames = depth_pred_raw.cpu().numpy().astype(np.float32) | |
| self.maybe_free_model_hooks() | |
| return StableVideoDiffusionPipelineOutput( | |
| frames = depth_frames, | |
| ) | |
| def single_infer_ours(self, | |
| input_rgb: torch.Tensor, | |
| image_embeddings: torch.Tensor, | |
| added_time_ids: torch.Tensor, | |
| num_inference_steps: int, | |
| show_pbar: bool, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]], | |
| decode_chunk_size=1, | |
| latents: Optional[torch.Tensor] = None, | |
| sigma_epsilon: float = -4, | |
| ): | |
| device = input_rgb.device | |
| H, W = input_rgb.shape[-2:] | |
| needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float32) | |
| rgb_latent = self.encode_images(input_rgb) | |
| rgb_latent = rgb_latent.to(image_embeddings.dtype) | |
| torch.cuda.empty_cache() | |
| # cast back to fp16 if needed | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float16) | |
| # Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| batch_size, n_frames, _, _, _ = rgb_latent.shape | |
| num_channels_latents = self.unet.config.in_channels | |
| curr_frame = 0 | |
| depth_latent = torch.tensor([], dtype=image_embeddings.dtype, device=device) | |
| pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") | |
| # first chunk | |
| horizon = min(n_frames-curr_frame, self.n_tokens) | |
| start_frame = 0 | |
| chunk = self.prepare_latents( | |
| batch_size, | |
| horizon, | |
| num_channels_latents, | |
| H, | |
| W, | |
| image_embeddings.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| depth_latent = torch.cat([depth_latent, chunk], 1) | |
| if show_pbar: | |
| iterable = tqdm( | |
| enumerate(timesteps), | |
| total=len(timesteps), | |
| leave=False, | |
| desc=" " * 4 + "Diffusion denoising first chunk", | |
| ) | |
| else: | |
| iterable = enumerate(timesteps) | |
| for i, t in iterable: | |
| curr_timesteps = torch.tensor([t]*horizon).to(device) | |
| depth_latent = self.scheduler.scale_model_input(depth_latent, t) | |
| noise_pred = self.unet( | |
| torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2), | |
| curr_timesteps[start_frame:], | |
| image_embeddings[start_frame:curr_frame+horizon], | |
| added_time_ids=added_time_ids | |
| )[0] | |
| depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample | |
| self.scheduler._step_index = None | |
| curr_frame += horizon | |
| pbar.update(horizon) | |
| while curr_frame < n_frames: | |
| if self.chunk_size > 0: | |
| horizon = min(n_frames - curr_frame, self.chunk_size) | |
| else: | |
| horizon = min(n_frames - curr_frame, self.n_tokens) | |
| assert horizon <= self.n_tokens, "horizon exceeds the number of tokens." | |
| chunk = self.prepare_latents( | |
| batch_size, | |
| horizon, | |
| num_channels_latents, | |
| H, | |
| W, | |
| image_embeddings.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| depth_latent = torch.cat([depth_latent, chunk], 1) | |
| start_frame = max(0, curr_frame + horizon - self.n_tokens) | |
| pbar.set_postfix( | |
| { | |
| "start": start_frame, | |
| "end": curr_frame + horizon, | |
| } | |
| ) | |
| if show_pbar: | |
| iterable = tqdm( | |
| enumerate(timesteps), | |
| total=len(timesteps), | |
| leave=False, | |
| desc=" " * 4 + "Diffusion denoising ", | |
| ) | |
| else: | |
| iterable = enumerate(timesteps) | |
| for i, t in iterable: | |
| t_horizon = torch.tensor([t]*horizon).to(device) | |
| # t_context = timesteps[-1] * torch.ones((curr_frame,), dtype=t.dtype).to(device) | |
| t_context = sigma_epsilon * torch.ones((curr_frame,), dtype=t.dtype).to(device) | |
| curr_timesteps = torch.concatenate((t_context, t_horizon), 0) | |
| depth_latent[:, curr_frame:] = self.scheduler.scale_model_input(depth_latent[:, curr_frame:], t) | |
| noise_pred = self.unet( | |
| torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2), | |
| curr_timesteps[start_frame:], | |
| image_embeddings[start_frame:curr_frame+horizon], | |
| added_time_ids=added_time_ids | |
| )[0] | |
| depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample | |
| self.scheduler._step_index = None | |
| curr_frame += horizon | |
| pbar.update(horizon) | |
| torch.cuda.empty_cache() | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float16) | |
| depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size) | |
| # clip prediction | |
| depth = torch.clip(depth, -1.0, 1.0) | |
| # shift to [0, 1] | |
| depth = (depth + 1.0) / 2.0 | |
| return depth.squeeze(0) | |
| def single_infer_replacement(self, | |
| input_rgb: torch.Tensor, | |
| image_embeddings: torch.Tensor, | |
| added_time_ids: torch.Tensor, | |
| num_inference_steps: int, | |
| show_pbar: bool, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]], | |
| decode_chunk_size=1, | |
| latents: Optional[torch.Tensor] = None, | |
| ): | |
| device = input_rgb.device | |
| H, W = input_rgb.shape[-2:] | |
| needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float32) | |
| rgb_latent = self.encode_images(input_rgb) | |
| rgb_latent = rgb_latent.to(image_embeddings.dtype) | |
| torch.cuda.empty_cache() | |
| # cast back to fp16 if needed | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float16) | |
| # Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| batch_size, n_frames, _, _, _ = rgb_latent.shape | |
| num_channels_latents = self.unet.config.in_channels | |
| curr_frame = 0 | |
| depth_latent = torch.tensor([], dtype=image_embeddings.dtype, device=device) | |
| pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") | |
| # first chunk | |
| horizon = min(n_frames-curr_frame, self.n_tokens) | |
| start_frame = 0 | |
| chunk = self.prepare_latents( | |
| batch_size, | |
| horizon, | |
| num_channels_latents, | |
| H, | |
| W, | |
| image_embeddings.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| depth_latent = torch.cat([depth_latent, chunk], 1) | |
| if show_pbar: | |
| iterable = tqdm( | |
| enumerate(timesteps), | |
| total=len(timesteps), | |
| leave=False, | |
| desc=" " * 4 + "Diffusion denoising first chunk", | |
| ) | |
| else: | |
| iterable = enumerate(timesteps) | |
| for i, t in iterable: | |
| curr_timesteps = torch.tensor([t]*horizon).to(device) | |
| depth_latent = self.scheduler.scale_model_input(depth_latent, t) | |
| noise_pred = self.unet( | |
| torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2), | |
| curr_timesteps[start_frame:], | |
| image_embeddings[start_frame:curr_frame+horizon], | |
| added_time_ids=added_time_ids | |
| )[0] | |
| depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample | |
| self.scheduler._step_index = None | |
| curr_frame += horizon | |
| pbar.update(horizon) | |
| while curr_frame < n_frames: | |
| if self.chunk_size > 0: | |
| horizon = min(n_frames - curr_frame, self.chunk_size) | |
| else: | |
| horizon = min(n_frames - curr_frame, self.n_tokens) | |
| assert horizon <= self.n_tokens, "horizon exceeds the number of tokens." | |
| chunk = self.prepare_latents( | |
| batch_size, | |
| horizon, | |
| num_channels_latents, | |
| H, | |
| W, | |
| image_embeddings.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| depth_latent = torch.cat([depth_latent, chunk], 1) | |
| start_frame = max(0, curr_frame + horizon - self.n_tokens) | |
| depth_pred_last_latent = depth_latent[:, start_frame:curr_frame].clone() | |
| pbar.set_postfix( | |
| { | |
| "start": start_frame, | |
| "end": curr_frame + horizon, | |
| } | |
| ) | |
| if show_pbar: | |
| iterable = tqdm( | |
| enumerate(timesteps), | |
| total=len(timesteps), | |
| leave=False, | |
| desc=" " * 4 + "Diffusion denoising ", | |
| ) | |
| else: | |
| iterable = enumerate(timesteps) | |
| for i, t in iterable: | |
| curr_timesteps = torch.tensor([t]*(curr_frame+horizon-start_frame)).to(device) | |
| epsilon = randn_tensor( | |
| depth_pred_last_latent.shape, | |
| generator=generator, | |
| device=device, | |
| dtype=image_embeddings.dtype | |
| ) | |
| depth_latent[:, start_frame:curr_frame] = depth_pred_last_latent + epsilon * self.scheduler.sigmas[i] | |
| depth_latent[:, start_frame:] = self.scheduler.scale_model_input(depth_latent[:, start_frame:], t) | |
| noise_pred = self.unet( | |
| torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2), | |
| curr_timesteps, | |
| image_embeddings[start_frame:curr_frame+horizon], | |
| added_time_ids=added_time_ids | |
| )[0] | |
| depth_latent[:, start_frame:] = self.scheduler.step(noise_pred, t, depth_latent[:, start_frame:]).prev_sample | |
| depth_latent[:, start_frame:curr_frame] = depth_pred_last_latent | |
| self.scheduler._step_index = None | |
| curr_frame += horizon | |
| pbar.update(horizon) | |
| torch.cuda.empty_cache() | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float16) | |
| depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size) | |
| # clip prediction | |
| depth = torch.clip(depth, -1.0, 1.0) | |
| # shift to [0, 1] | |
| depth = (depth + 1.0) / 2.0 | |
| return depth.squeeze(0) | |
| def single_infer_naive_sliding_window(self, | |
| input_rgb: torch.Tensor, | |
| image_embeddings: torch.Tensor, | |
| added_time_ids: torch.Tensor, | |
| num_inference_steps: int, | |
| show_pbar: bool, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]], | |
| decode_chunk_size=1, | |
| latents: Optional[torch.Tensor] = None, | |
| ): | |
| device = input_rgb.device | |
| H, W = input_rgb.shape[-2:] | |
| needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float32) | |
| rgb_latent = self.encode_images(input_rgb) | |
| rgb_latent = rgb_latent.to(image_embeddings.dtype) | |
| torch.cuda.empty_cache() | |
| # cast back to fp16 if needed | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float16) | |
| # Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| batch_size, n_frames, _, _, _ = rgb_latent.shape | |
| num_channels_latents = self.unet.config.in_channels | |
| curr_frame = 0 | |
| depth_latent = torch.tensor([], dtype=image_embeddings.dtype, device=device) | |
| pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") | |
| # first chunk | |
| horizon = min(n_frames-curr_frame, self.n_tokens) | |
| start_frame = 0 | |
| chunk = self.prepare_latents( | |
| batch_size, | |
| horizon, | |
| num_channels_latents, | |
| H, | |
| W, | |
| image_embeddings.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| depth_latent = torch.cat([depth_latent, chunk], 1) | |
| if show_pbar: | |
| iterable = tqdm( | |
| enumerate(timesteps), | |
| total=len(timesteps), | |
| leave=False, | |
| desc=" " * 4 + "Diffusion denoising first chunk", | |
| ) | |
| else: | |
| iterable = enumerate(timesteps) | |
| for i, t in iterable: | |
| curr_timesteps = torch.tensor([t]*horizon).to(device) | |
| depth_latent = self.scheduler.scale_model_input(depth_latent, t) | |
| noise_pred = self.unet( | |
| torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2), | |
| curr_timesteps[start_frame:], | |
| image_embeddings[start_frame:curr_frame+horizon], | |
| added_time_ids=added_time_ids | |
| )[0] | |
| depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample | |
| self.scheduler._step_index = None | |
| curr_frame += horizon | |
| pbar.update(horizon) | |
| while curr_frame < n_frames: | |
| if self.chunk_size > 0: | |
| horizon = min(n_frames - curr_frame, self.chunk_size) | |
| else: | |
| horizon = min(n_frames - curr_frame, self.n_tokens) | |
| assert horizon <= self.n_tokens, "horizon exceeds the number of tokens." | |
| start_frame = max(0, curr_frame + horizon - self.n_tokens) | |
| chunk = self.prepare_latents( | |
| batch_size, | |
| curr_frame+horizon-start_frame, | |
| num_channels_latents, | |
| H, | |
| W, | |
| image_embeddings.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| pbar.set_postfix( | |
| { | |
| "start": start_frame, | |
| "end": curr_frame + horizon, | |
| } | |
| ) | |
| if show_pbar: | |
| iterable = tqdm( | |
| enumerate(timesteps), | |
| total=len(timesteps), | |
| leave=False, | |
| desc=" " * 4 + "Diffusion denoising ", | |
| ) | |
| else: | |
| iterable = enumerate(timesteps) | |
| for i, t in iterable: | |
| curr_timesteps = torch.tensor([t]*(curr_frame+horizon-start_frame)).to(device) | |
| chunk = self.scheduler.scale_model_input(chunk, t) | |
| noise_pred = self.unet( | |
| torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], chunk], dim=2), | |
| curr_timesteps, | |
| image_embeddings[start_frame:curr_frame+horizon], | |
| added_time_ids=added_time_ids | |
| )[0] | |
| chunk = self.scheduler.step(noise_pred, t, chunk).prev_sample | |
| depth_latent = torch.cat([depth_latent, chunk[:, -horizon:]], 1) | |
| self.scheduler._step_index = None | |
| curr_frame += horizon | |
| pbar.update(horizon) | |
| torch.cuda.empty_cache() | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float16) | |
| depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size) | |
| # clip prediction | |
| depth = torch.clip(depth, -1.0, 1.0) | |
| # shift to [0, 1] | |
| depth = (depth + 1.0) / 2.0 | |
| return depth.squeeze(0) |