Spaces:
Running
Running
| import os | |
| import math | |
| from contextlib import contextmanager | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import re | |
| import pytorch_lightning as pl | |
| import torch | |
| from omegaconf import ListConfig, OmegaConf | |
| from safetensors.torch import load_file as load_safetensors | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from einops import rearrange | |
| from diffusers.models.attention_processor import IPAdapterAttnProcessor2_0 | |
| from ..modules import UNCONDITIONAL_CONFIG | |
| from ..modules.autoencoding.temporal_ae import VideoDecoder | |
| from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER | |
| from ..modules.ema import LitEma | |
| from ..util import ( | |
| default, | |
| disabled_train, | |
| get_obj_from_str, | |
| instantiate_from_config, | |
| log_txt_as_img, | |
| ) | |
| class DiffusionEngine(pl.LightningModule): | |
| def __init__( | |
| self, | |
| network_config, | |
| denoiser_config, | |
| first_stage_config, | |
| conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, | |
| sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, | |
| optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, | |
| scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, | |
| loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, | |
| network_wrapper: Union[None, str, Dict, ListConfig, OmegaConf] = None, | |
| ckpt_path: Union[None, str] = None, | |
| remove_keys_from_weights: Union[None, List, Tuple] = None, | |
| pattern_to_remove: Union[None, str] = None, | |
| remove_keys_from_unet_weights: Union[None, List, Tuple] = None, | |
| use_ema: bool = False, | |
| ema_decay_rate: float = 0.9999, | |
| scale_factor: float = 1.0, | |
| disable_first_stage_autocast=False, | |
| input_key: str = "jpg", | |
| log_keys: Union[List, None] = None, | |
| no_log_keys: Union[List, None] = None, | |
| no_cond_log: bool = False, | |
| compile_model: bool = False, | |
| en_and_decode_n_samples_a_time: Optional[int] = None, | |
| only_train_ipadapter: Optional[bool] = False, | |
| to_unfreeze: Optional[List[str]] = [], | |
| to_freeze: Optional[List[str]] = [], | |
| separate_unet_ckpt: Optional[str] = None, | |
| use_thunder: Optional[bool] = False, | |
| is_dubbing: Optional[bool] = False, | |
| bad_model_path: Optional[str] = None, | |
| bad_model_config: Optional[Dict] = None, | |
| ): | |
| super().__init__() | |
| # self.automatic_optimization = False | |
| self.log_keys = log_keys | |
| self.no_log_keys = no_log_keys | |
| self.input_key = input_key | |
| self.is_dubbing = is_dubbing | |
| self.optimizer_config = default( | |
| optimizer_config, {"target": "torch.optim.AdamW"} | |
| ) | |
| self.model = self.initialize_network( | |
| network_config, network_wrapper, compile_model=compile_model | |
| ) | |
| self.denoiser = instantiate_from_config(denoiser_config) | |
| self.sampler = ( | |
| instantiate_from_config(sampler_config) | |
| if sampler_config is not None | |
| else None | |
| ) | |
| self.is_guided = True | |
| if ( | |
| self.sampler | |
| and "IdentityGuider" in sampler_config["params"]["guider_config"]["target"] | |
| ): | |
| self.is_guided = False | |
| if self.sampler is not None: | |
| config_guider = sampler_config["params"]["guider_config"] | |
| sampler_config["params"]["guider_config"] = None | |
| self.sampler_no_guidance = instantiate_from_config(sampler_config) | |
| sampler_config["params"]["guider_config"] = config_guider | |
| self.conditioner = instantiate_from_config( | |
| default(conditioner_config, UNCONDITIONAL_CONFIG) | |
| ) | |
| self.scheduler_config = scheduler_config | |
| self._init_first_stage(first_stage_config) | |
| self.loss_fn = ( | |
| instantiate_from_config(loss_fn_config) | |
| if loss_fn_config is not None | |
| else None | |
| ) | |
| self.use_ema = use_ema | |
| if self.use_ema: | |
| self.model_ema = LitEma(self.model, decay=ema_decay_rate) | |
| print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") | |
| self.scale_factor = scale_factor | |
| self.disable_first_stage_autocast = disable_first_stage_autocast | |
| self.no_cond_log = no_cond_log | |
| if ckpt_path is not None: | |
| self.init_from_ckpt( | |
| ckpt_path, | |
| remove_keys_from_weights=remove_keys_from_weights, | |
| pattern_to_remove=pattern_to_remove, | |
| ) | |
| if separate_unet_ckpt is not None: | |
| sd = torch.load(separate_unet_ckpt, weights_only=False)["state_dict"] | |
| if remove_keys_from_unet_weights is not None: | |
| for k in list(sd.keys()): | |
| for remove_key in remove_keys_from_unet_weights: | |
| if remove_key in k: | |
| del sd[k] | |
| self.model.diffusion_model.load_state_dict(sd, strict=False) | |
| self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time | |
| print( | |
| "Using", | |
| self.en_and_decode_n_samples_a_time, | |
| "samples at a time for encoding and decoding", | |
| ) | |
| if to_freeze: | |
| for name, p in self.model.diffusion_model.named_parameters(): | |
| for layer in to_freeze: | |
| if layer[0] == "!": | |
| if layer[1:] not in name: | |
| # print("Freezing", name) | |
| p.requires_grad = False | |
| else: | |
| if layer in name: | |
| # print("Freezing", name) | |
| p.requires_grad = False | |
| # if "time_" in name: | |
| # print("Freezing", name) | |
| # p.requires_grad = False | |
| if only_train_ipadapter: | |
| # Freeze the model | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| # Unfreeze the adapter projection layer | |
| for p in self.model.diffusion_model.encoder_hid_proj.parameters(): | |
| p.requires_grad = True | |
| # Unfreeze the cross-attention layer | |
| for att_layer in self.model.diffusion_model.attn_processors.values(): | |
| if isinstance(att_layer, IPAdapterAttnProcessor2_0): | |
| for p in att_layer.parameters(): | |
| p.requires_grad = True | |
| # for name, p in self.named_parameters(): | |
| # if p.requires_grad: | |
| # print(name) | |
| if to_unfreeze: | |
| for name in to_unfreeze: | |
| for p in getattr(self.model.diffusion_model, name).parameters(): | |
| p.requires_grad = True | |
| if use_thunder: | |
| import thunder | |
| self.model.diffusion_model = thunder.jit(self.model.diffusion_model) | |
| if "Karras" in denoiser_config.target: | |
| assert bad_model_path is not None, ( | |
| "bad_model_path must be provided for KarrasGuidanceDenoiser" | |
| ) | |
| karras_config = default(bad_model_config, network_config) | |
| bad_model = self.initialize_network( | |
| karras_config, network_wrapper, compile_model=compile_model | |
| ) | |
| state_dict = self.load_bad_model_weights(bad_model_path) | |
| bad_model.load_state_dict(state_dict) | |
| self.denoiser.set_bad_network(bad_model) | |
| def load_bad_model_weights(self, path: str) -> None: | |
| print(f"Restoring bad model from {path}") | |
| state_dict = torch.load(path, map_location="cpu", weights_only=False) | |
| new_dict = {} | |
| for k, v in state_dict["module"].items(): | |
| if "learned_mask" in k: | |
| new_dict[k.replace("_forward_module.", "").replace("model.", "")] = v | |
| if "diffusion_model" in k: | |
| new_dict["diffusion_model" + k.split("diffusion_model")[1]] = v | |
| return new_dict | |
| def initialize_network(self, network_config, network_wrapper, compile_model=False): | |
| model = instantiate_from_config(network_config) | |
| if isinstance(network_wrapper, str) or network_wrapper is None: | |
| model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( | |
| model, compile_model=compile_model | |
| ) | |
| else: | |
| target = network_wrapper["target"] | |
| params = network_wrapper.get("params", dict()) | |
| model = get_obj_from_str(target)( | |
| model, compile_model=compile_model, **params | |
| ) | |
| return model | |
| def init_from_ckpt( | |
| self, | |
| path: str, | |
| remove_keys_from_weights: Optional[Union[List, Tuple]] = None, | |
| pattern_to_remove: str = None, | |
| ) -> None: | |
| print(f"Restoring from {path}") | |
| if path.endswith("ckpt"): | |
| sd = torch.load(path, map_location="cpu", weights_only=False)["state_dict"] | |
| elif path.endswith("pt"): | |
| sd = torch.load(path, map_location="cpu", weights_only=False)["module"] | |
| # Remove leading _forward_module from keys | |
| sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()} | |
| elif path.endswith("bin"): | |
| sd = torch.load(path, map_location="cpu", weights_only=False) | |
| # Remove leading _forward_module from keys | |
| sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()} | |
| elif path.endswith("safetensors"): | |
| sd = load_safetensors(path) | |
| else: | |
| raise NotImplementedError | |
| print(f"Loaded state dict from {path} with {len(sd)} keys") | |
| # if remove_keys_from_weights is not None: | |
| # for k in list(sd.keys()): | |
| # for remove_key in remove_keys_from_weights: | |
| # if remove_key in k: | |
| # del sd[k] | |
| if pattern_to_remove is not None or remove_keys_from_weights is not None: | |
| sd = self.remove_mismatched_keys( | |
| sd, pattern_to_remove, remove_keys_from_weights | |
| ) | |
| missing, unexpected = self.load_state_dict(sd, strict=False) | |
| print( | |
| f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" | |
| ) | |
| if len(missing) > 0: | |
| print(f"Missing Keys: {missing}") | |
| if len(unexpected) > 0: | |
| print(f"Unexpected Keys: {unexpected}") | |
| def remove_mismatched_keys(self, state_dict, pattern=None, additional_keys=None): | |
| """Remove keys from the state dictionary based on a pattern and a list of additional specific keys.""" | |
| # Find keys that match the pattern | |
| if pattern is not None: | |
| mismatched_keys = [key for key in state_dict if re.search(pattern, key)] | |
| else: | |
| mismatched_keys = [] | |
| print(f"Removing {len(mismatched_keys)} keys based on pattern {pattern}") | |
| print(mismatched_keys) | |
| # Add specific keys to be removed | |
| if additional_keys: | |
| mismatched_keys.extend( | |
| [key for key in additional_keys if key in state_dict] | |
| ) | |
| # Remove all identified keys | |
| for key in mismatched_keys: | |
| if key in state_dict: | |
| del state_dict[key] | |
| return state_dict | |
| def _init_first_stage(self, config): | |
| model = instantiate_from_config(config).eval() | |
| model.train = disabled_train | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| self.first_stage_model = model | |
| if self.input_key == "latents": | |
| # Remove encoder to save memory | |
| self.first_stage_model.encoder = None | |
| torch.cuda.empty_cache() | |
| def get_input(self, batch): | |
| # assuming unified data format, dataloader returns a dict. | |
| # image tensors should be scaled to -1 ... 1 and in bchw format | |
| return batch[self.input_key] | |
| def decode_first_stage(self, z): | |
| is_video = False | |
| if len(z.shape) == 5: | |
| is_video = True | |
| T = z.shape[2] | |
| z = rearrange(z, "b c t h w -> (b t) c h w") | |
| z = 1.0 / self.scale_factor * z | |
| n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) | |
| n_rounds = math.ceil(z.shape[0] / n_samples) | |
| all_out = [] | |
| with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): | |
| for n in range(n_rounds): | |
| if isinstance(self.first_stage_model.decoder, VideoDecoder): | |
| kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} | |
| else: | |
| kwargs = {} | |
| out = self.first_stage_model.decode( | |
| z[n * n_samples : (n + 1) * n_samples], **kwargs | |
| ) | |
| all_out.append(out) | |
| out = torch.cat(all_out, dim=0) | |
| if is_video: | |
| out = rearrange(out, "(b t) c h w -> b c t h w", t=T) | |
| torch.cuda.empty_cache() | |
| return out | |
| def encode_first_stage(self, x): | |
| is_video = False | |
| if len(x.shape) == 5: | |
| is_video = True | |
| T = x.shape[2] | |
| x = rearrange(x, "b c t h w -> (b t) c h w") | |
| n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) | |
| n_rounds = math.ceil(x.shape[0] / n_samples) | |
| all_out = [] | |
| with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): | |
| for n in range(n_rounds): | |
| out = self.first_stage_model.encode( | |
| x[n * n_samples : (n + 1) * n_samples] | |
| ) | |
| all_out.append(out) | |
| z = torch.cat(all_out, dim=0) | |
| z = self.scale_factor * z | |
| if is_video: | |
| z = rearrange(z, "(b t) c h w -> b c t h w", t=T) | |
| return z | |
| def forward(self, x, batch): | |
| loss_dict = self.loss_fn( | |
| self.model, | |
| self.denoiser, | |
| self.conditioner, | |
| x, | |
| batch, | |
| self.first_stage_model, | |
| ) | |
| # loss_mean = loss.mean() | |
| for k in loss_dict: | |
| loss_dict[k] = loss_dict[k].mean() | |
| # loss_dict = {"loss": loss_mean} | |
| return loss_dict["loss"], loss_dict | |
| def shared_step(self, batch: Dict) -> Any: | |
| x = self.get_input(batch) | |
| if self.input_key != "latents": | |
| x = self.encode_first_stage(x) | |
| batch["global_step"] = self.global_step | |
| loss, loss_dict = self(x, batch) | |
| return loss, loss_dict | |
| def training_step(self, batch, batch_idx): | |
| loss, loss_dict = self.shared_step(batch) | |
| # debugging_message = "Training step" | |
| # print(f"RANK - {self.trainer.global_rank}: {debugging_message}") | |
| self.log_dict( | |
| loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False | |
| ) | |
| self.log( | |
| "global_step", | |
| self.global_step, | |
| prog_bar=True, | |
| logger=True, | |
| on_step=True, | |
| on_epoch=False, | |
| ) | |
| # debugging_message = "Training step - log" | |
| # print(f"RANK - {self.trainer.global_rank}: {debugging_message}") | |
| if self.scheduler_config is not None: | |
| lr = self.optimizers().param_groups[0]["lr"] | |
| self.log( | |
| "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False | |
| ) | |
| # # to prevent other processes from moving forward until all processes are in sync | |
| # self.trainer.strategy.barrier() | |
| return loss | |
| # def validation_step(self, batch, batch_idx): | |
| # # loss, loss_dict = self.shared_step(batch) | |
| # # self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False) | |
| # self.log( | |
| # "global_step", | |
| # self.global_step, | |
| # prog_bar=True, | |
| # logger=True, | |
| # on_step=True, | |
| # on_epoch=False, | |
| # ) | |
| # return 0 | |
| # def on_train_epoch_start(self, *args, **kwargs): | |
| # print(f"RANK - {self.trainer.global_rank}: on_train_epoch_start") | |
| def on_train_start(self, *args, **kwargs): | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = str(self.trainer.global_rank) | |
| # torch.cuda.set_device(self.trainer.global_rank) | |
| # torch.cuda.empty_cache() | |
| if self.sampler is None or self.loss_fn is None: | |
| raise ValueError("Sampler and loss function need to be set for training.") | |
| # def on_before_batch_transfer(self, batch, dataloader_idx): | |
| # print(f"RANK - {self.trainer.global_rank}: on_before_batch_transfer - {dataloader_idx}") | |
| # return batch | |
| # def on_after_batch_transfer(self, batch, dataloader_idx): | |
| # print(f"RANK - {self.trainer.global_rank}: on_after_batch_transfer - {dataloader_idx}") | |
| # return batch | |
| def on_train_batch_end(self, *args, **kwargs): | |
| # print(f"RANK - {self.trainer.global_rank}: on_train_batch_end") | |
| if self.use_ema: | |
| self.model_ema(self.model) | |
| def ema_scope(self, context=None): | |
| if self.use_ema: | |
| self.model_ema.store(self.model.parameters()) | |
| self.model_ema.copy_to(self.model) | |
| if context is not None: | |
| print(f"{context}: Switched to EMA weights") | |
| try: | |
| yield None | |
| finally: | |
| if self.use_ema: | |
| self.model_ema.restore(self.model.parameters()) | |
| if context is not None: | |
| print(f"{context}: Restored training weights") | |
| def instantiate_optimizer_from_config(self, params, lr, cfg): | |
| return get_obj_from_str(cfg["target"])( | |
| params, lr=lr, **cfg.get("params", dict()) | |
| ) | |
| def configure_optimizers(self): | |
| lr = self.learning_rate | |
| params = list(self.model.parameters()) | |
| for embedder in self.conditioner.embedders: | |
| if embedder.is_trainable: | |
| params = params + list(embedder.parameters()) | |
| opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) | |
| if self.scheduler_config is not None: | |
| scheduler = instantiate_from_config(self.scheduler_config) | |
| print("Setting up LambdaLR scheduler...") | |
| scheduler = [ | |
| { | |
| "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), | |
| "interval": "step", | |
| "frequency": 1, | |
| } | |
| ] | |
| return [opt], scheduler | |
| return opt | |
| def sample( | |
| self, | |
| cond: Dict, | |
| uc: Union[Dict, None] = None, | |
| batch_size: int = 16, | |
| shape: Union[None, Tuple, List] = None, | |
| **kwargs, | |
| ): | |
| randn = torch.randn(batch_size, *shape).to(self.device) | |
| denoiser = lambda input, sigma, c: self.denoiser( | |
| self.model, input, sigma, c, **kwargs | |
| ) | |
| samples = self.sampler(denoiser, randn, cond, uc=uc) | |
| return samples | |
| def sample_no_guider( | |
| self, | |
| cond: Dict, | |
| uc: Union[Dict, None] = None, | |
| batch_size: int = 16, | |
| shape: Union[None, Tuple, List] = None, | |
| **kwargs, | |
| ): | |
| randn = torch.randn(batch_size, *shape).to(self.device) | |
| denoiser = lambda input, sigma, c: self.denoiser( | |
| self.model, input, sigma, c, **kwargs | |
| ) | |
| samples = self.sampler_no_guidance(denoiser, randn, cond, uc=uc) | |
| return samples | |
| def log_conditionings(self, batch: Dict, n: int) -> Dict: | |
| """ | |
| Defines heuristics to log different conditionings. | |
| These can be lists of strings (text-to-image), tensors, ints, ... | |
| """ | |
| image_h, image_w = batch[self.input_key].shape[-2:] | |
| log = dict() | |
| for embedder in self.conditioner.embedders: | |
| if ( | |
| (self.log_keys is None) or (embedder.input_key in self.log_keys) | |
| ) and not self.no_cond_log: | |
| if embedder.input_key in self.no_log_keys: | |
| continue | |
| x = batch[embedder.input_key][:n] | |
| if isinstance(x, torch.Tensor): | |
| if x.dim() == 1: | |
| # class-conditional, convert integer to string | |
| x = [str(x[i].item()) for i in range(x.shape[0])] | |
| xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) | |
| elif x.dim() == 2: | |
| # size and crop cond and the like | |
| x = [ | |
| "x".join([str(xx) for xx in x[i].tolist()]) | |
| for i in range(x.shape[0]) | |
| ] | |
| xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) | |
| elif x.dim() == 4: # already an image | |
| xc = x | |
| elif x.dim() == 5: | |
| xc = torch.cat([x[:, :, i] for i in range(x.shape[2])], dim=-1) | |
| else: | |
| print(x.shape, embedder.input_key) | |
| raise NotImplementedError() | |
| elif isinstance(x, (List, ListConfig)): | |
| if isinstance(x[0], str): | |
| # strings | |
| xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) | |
| else: | |
| raise NotImplementedError() | |
| else: | |
| raise NotImplementedError() | |
| log[embedder.input_key] = xc | |
| return log | |
| def log_images( | |
| self, | |
| batch: Dict, | |
| N: int = 8, | |
| sample: bool = True, | |
| ucg_keys: List[str] = None, | |
| **kwargs, | |
| ) -> Dict: | |
| conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] | |
| if ucg_keys: | |
| assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( | |
| "Each defined ucg key for sampling must be in the provided conditioner input keys," | |
| f"but we have {ucg_keys} vs. {conditioner_input_keys}" | |
| ) | |
| else: | |
| ucg_keys = conditioner_input_keys | |
| log = dict() | |
| x = self.get_input(batch) | |
| c, uc = self.conditioner.get_unconditional_conditioning( | |
| batch, | |
| force_uc_zero_embeddings=ucg_keys | |
| if len(self.conditioner.embedders) > 0 | |
| else [], | |
| ) | |
| sampling_kwargs = {} | |
| N = min(x.shape[0], N) | |
| x = x.to(self.device)[:N] | |
| if self.input_key != "latents": | |
| log["inputs"] = x | |
| z = self.encode_first_stage(x) | |
| else: | |
| z = x | |
| log["reconstructions"] = self.decode_first_stage(z) | |
| log.update(self.log_conditionings(batch, N)) | |
| for k in c: | |
| if isinstance(c[k], torch.Tensor): | |
| c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) | |
| if sample: | |
| with self.ema_scope("Plotting"): | |
| samples = self.sample( | |
| c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs | |
| ) | |
| samples = self.decode_first_stage(samples) | |
| log["samples"] = samples | |
| with self.ema_scope("Plotting"): | |
| samples = self.sample_no_guider( | |
| c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs | |
| ) | |
| samples = self.decode_first_stage(samples) | |
| log["samples_no_guidance"] = samples | |
| return log | |
| def log_videos( | |
| self, | |
| batch: Dict, | |
| N: int = 8, | |
| sample: bool = True, | |
| ucg_keys: List[str] = None, | |
| **kwargs, | |
| ) -> Dict: | |
| # conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] | |
| # if ucg_keys: | |
| # assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( | |
| # "Each defined ucg key for sampling must be in the provided conditioner input keys," | |
| # f"but we have {ucg_keys} vs. {conditioner_input_keys}" | |
| # ) | |
| # else: | |
| # ucg_keys = conditioner_input_keys | |
| log = dict() | |
| batch_uc = {} | |
| x = self.get_input(batch) | |
| num_frames = x.shape[2] # assuming bcthw format | |
| for key in batch.keys(): | |
| if key not in batch_uc and isinstance(batch[key], torch.Tensor): | |
| batch_uc[key] = torch.clone(batch[key]) | |
| c, uc = self.conditioner.get_unconditional_conditioning( | |
| batch, | |
| batch_uc=batch_uc, | |
| force_uc_zero_embeddings=ucg_keys | |
| if ucg_keys is not None | |
| else [ | |
| "cond_frames", | |
| "cond_frames_without_noise", | |
| ], | |
| ) | |
| # for k in ["crossattn", "concat"]: | |
| # uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) | |
| # uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) | |
| # c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) | |
| # c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) | |
| sampling_kwargs = {} | |
| N = min(x.shape[0], N) | |
| x = x.to(self.device)[:N] | |
| if self.input_key != "latents": | |
| log["inputs"] = x | |
| z = self.encode_first_stage(x) | |
| else: | |
| z = x | |
| log["reconstructions"] = self.decode_first_stage(z) | |
| log.update(self.log_conditionings(batch, N)) | |
| if c.get("masks", None) is not None: | |
| # Create a mask reconstruction | |
| masks = 1 - c["masks"] | |
| t = masks.shape[2] | |
| masks = rearrange(masks, "b c t h w -> (b t) c h w") | |
| target_size = ( | |
| log["reconstructions"].shape[-2], | |
| log["reconstructions"].shape[-1], | |
| ) | |
| masks = torch.nn.functional.interpolate( | |
| masks, size=target_size, mode="nearest" | |
| ) | |
| masks = rearrange(masks, "(b t) c h w -> b c t h w", t=t) | |
| log["mask_reconstructions"] = log["reconstructions"] * masks | |
| for k in c: | |
| if isinstance(c[k], torch.Tensor): | |
| c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) | |
| elif isinstance(c[k], list): | |
| for i in range(len(c[k])): | |
| c[k][i], uc[k][i] = map( | |
| lambda y: y[k][i][:N].to(self.device), (c, uc) | |
| ) | |
| if sample: | |
| n = 2 if self.is_guided else 1 | |
| # if num_frames == 1: | |
| # sampling_kwargs["image_only_indicator"] = torch.ones(n, num_frames).to(self.device) | |
| # else: | |
| sampling_kwargs["image_only_indicator"] = torch.zeros(n, num_frames).to( | |
| self.device | |
| ) | |
| sampling_kwargs["num_video_frames"] = batch["num_video_frames"] | |
| with self.ema_scope("Plotting"): | |
| samples = self.sample( | |
| c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs | |
| ) | |
| samples = self.decode_first_stage(samples) | |
| if self.is_dubbing: | |
| samples[:, :, :, : samples.shape[-2] // 2] = log["reconstructions"][ | |
| :, :, :, : samples.shape[-2] // 2 | |
| ] | |
| log["samples"] = samples | |
| # Without guidance | |
| # if num_frames == 1: | |
| # sampling_kwargs["image_only_indicator"] = torch.ones(1, num_frames).to(self.device) | |
| # else: | |
| sampling_kwargs["image_only_indicator"] = torch.zeros(1, num_frames).to( | |
| self.device | |
| ) | |
| sampling_kwargs["num_video_frames"] = batch["num_video_frames"] | |
| with self.ema_scope("Plotting"): | |
| samples = self.sample_no_guider( | |
| c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs | |
| ) | |
| samples = self.decode_first_stage(samples) | |
| if self.is_dubbing: | |
| samples[:, :, :, : samples.shape[-2] // 2] = log["reconstructions"][ | |
| :, :, :, : samples.shape[-2] // 2 | |
| ] | |
| log["samples_no_guidance"] = samples | |
| torch.cuda.empty_cache() | |
| return log | |