Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import Tensor | |
| import torch.nn.functional as F | |
| from typing import Optional, Union, Tuple | |
| from utils import normalize | |
| from model import freq_exp, gen_nn_map | |
| from src.ddpm_step import deterministic_ddpm_step | |
| from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput | |
| # Kernel sizes for the DIFT correction at successive time-ranges | |
| DIFT_KERNELS: Tuple[int, int, int, int] = (12, 7, 5, 3) | |
| def _get_kernel_for_timestep(timestep: int) -> Tuple[int, int]: | |
| if timestep >= 799: | |
| return DIFT_KERNELS[0], 1 | |
| if timestep >= 599: | |
| return DIFT_KERNELS[1], 1 | |
| if timestep >= 299: | |
| return DIFT_KERNELS[2], 1 | |
| return DIFT_KERNELS[3], 1 | |
| def step_save_latents( | |
| self, | |
| model_output: torch.FloatTensor, | |
| timestep: int, | |
| sample: torch.FloatTensor, | |
| return_dict: bool = True, | |
| noise_pred_uncond: Optional[torch.FloatTensor] = None, | |
| **kwargs, | |
| ): | |
| timestep_index = self._timesteps.index(timestep) | |
| next_timestep_index = timestep_index + 1 | |
| u_hat_t, beta_coef = deterministic_ddpm_step( | |
| model_output=model_output, | |
| timestep=timestep, | |
| sample=sample, | |
| scheduler=self, | |
| ) | |
| x_t_minus_1 = self.x_ts[next_timestep_index] | |
| self.x_ts_c_predicted.append(u_hat_t) | |
| z_t = x_t_minus_1 - u_hat_t | |
| self.latents.append(z_t) | |
| z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs) | |
| x_t_minus_1_predicted = u_hat_t + z_t | |
| if not return_dict: | |
| return (x_t_minus_1_predicted,) | |
| return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None) | |
| def step_use_latents( | |
| self, | |
| model_output: torch.FloatTensor, | |
| timestep: int, | |
| sample: torch.FloatTensor, | |
| return_dict: bool = True, | |
| noise_pred_uncond: Optional[torch.FloatTensor] = None, | |
| **kwargs, | |
| ): | |
| timestep_index = self._timesteps.index(timestep) | |
| next_timestep_index = timestep_index + 1 | |
| z_t = self.latents[next_timestep_index] | |
| _, normalize_coefficient = normalize( | |
| z_t, | |
| timestep_index, | |
| self._config.max_norm_zs, | |
| ) | |
| x_t_hat_c_hat, beta_coef = deterministic_ddpm_step( | |
| model_output=model_output, | |
| timestep=timestep, | |
| sample=sample, | |
| scheduler=self, | |
| ) | |
| x_t_minus_1_exact = self.x_ts[next_timestep_index] | |
| x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat) | |
| x_t_c_predicted: torch.Tensor = self.x_ts_c_predicted[next_timestep_index] | |
| x_t_c = x_t_c_predicted[0].expand_as(x_t_hat_c_hat) | |
| mask: Optional[Tensor] = kwargs.get("mask", None) | |
| if mask is not None and timestep > 300: | |
| mask = mask.to(x_t_hat_c_hat.device) | |
| movement_intensifier = kwargs.get("movement_intensifier", 0.0) | |
| if timestep > 900 and movement_intensifier > 0.0: | |
| latent_mask_h, *_ = freq_exp( | |
| x_t_hat_c_hat[1:], | |
| "auto_mask", | |
| None, | |
| mask.unsqueeze(0), | |
| movement_intensifier | |
| ) | |
| x_t_hat_c_hat[1:] = latent_mask_h | |
| x_t_hat_c_hat[-1] = x_t_hat_c_hat[-1] * mask + (1-mask) * x_t_c[-1] | |
| edit_prompts_num = model_output.size(0) // 2 | |
| x_t_hat_c_indices = ( | |
| 0, | |
| edit_prompts_num, | |
| ) | |
| edit_images_indices = ( | |
| edit_prompts_num, | |
| (model_output.size(0)), | |
| ) | |
| x_t_hat_c = torch.zeros_like(x_t_hat_c_hat) | |
| x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[ | |
| x_t_hat_c_indices[0] : x_t_hat_c_indices[1] | |
| ] | |
| w1 = kwargs.get("w1", 1.9) | |
| cross_prompt_term = x_t_hat_c_hat - x_t_hat_c | |
| cross_trajectory_term = x_t_hat_c - normalize_coefficient * x_t_c | |
| x_t_minus_1_hat_ = ( | |
| normalize_coefficient * x_t_minus_1_exact | |
| + cross_trajectory_term | |
| + w1 * cross_prompt_term | |
| ) | |
| x_t_minus_1_hat_[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1_hat_[ | |
| edit_images_indices[0] : edit_images_indices[1] | |
| ] | |
| dift_timestep = kwargs.get("dift_timestep", 700) | |
| if timestep < dift_timestep and kwargs.get("apply_dift_correction", False): | |
| z_t = torch.cat([z_t]*x_t_hat_c_hat.shape[0], dim=0) | |
| dift_features: Optional[Tensor] = kwargs.get("dift_features", None) | |
| dift_s, _, dift_t = dift_features.chunk(3) | |
| resized_src_features = F.interpolate(dift_s[0].unsqueeze(0), size=z_t.shape[-1], mode='bilinear', align_corners=False).squeeze(0) | |
| resized_tgt_features = F.interpolate(dift_t[0].unsqueeze(0), size=z_t.shape[-1], mode='bilinear', align_corners=False).squeeze(0) | |
| kernel_size, stride = _get_kernel_for_timestep(timestep) | |
| torch.cuda.empty_cache() | |
| updated_z_t = gen_nn_map(z_t[1], resized_src_features, resized_tgt_features, | |
| kernel_size=kernel_size, stride=stride, | |
| device=z_t.device, timestep=timestep) | |
| alpha = 1.0 | |
| z_t[1] = alpha * updated_z_t + (1 - alpha) * z_t[1] | |
| x_t_minus_1_hat = x_t_hat_c_hat + z_t * normalize_coefficient | |
| else: | |
| x_t_minus_1_hat = x_t_minus_1_hat_ | |
| if not return_dict: | |
| return (x_t_minus_1_hat,) | |
| return DDIMSchedulerOutput( | |
| prev_sample=x_t_minus_1_hat, | |
| pred_original_sample=None, | |
| ) | |
| def get_ddpm_inversion_scheduler( | |
| scheduler, | |
| config, | |
| timesteps, | |
| latents, | |
| x_ts, | |
| **kwargs, | |
| ): | |
| def step( | |
| model_output: torch.FloatTensor, | |
| timestep: int, | |
| sample: torch.FloatTensor, | |
| eta: float = 0.0, | |
| use_clipped_model_output: bool = False, | |
| generator=None, | |
| variance_noise: Optional[torch.FloatTensor] = None, | |
| noise_pred_uncond: Optional[torch.FloatTensor] = None, | |
| dift_features: Optional[torch.FloatTensor] = None, | |
| return_dict: bool = True, | |
| ): | |
| # predict and save x_t_c | |
| res_inv = step_save_latents( | |
| scheduler, | |
| model_output[:1, :, :, :], | |
| timestep, | |
| sample[:1, :, :, :], | |
| return_dict, | |
| noise_pred_uncond[:1, :, :, :], | |
| **kwargs, | |
| ) | |
| res_inf = step_use_latents( | |
| scheduler, | |
| model_output[1:, :, :, :], | |
| timestep, | |
| sample[1:, :, :, :], | |
| return_dict, | |
| noise_pred_uncond[1:, :, :, :], | |
| dift_features=dift_features, | |
| **kwargs, | |
| ) | |
| res = (torch.cat((res_inv[0], res_inf[0]), dim=0),) | |
| return res | |
| scheduler._timesteps = timesteps | |
| scheduler._config = config | |
| scheduler.latents = latents | |
| scheduler.x_ts = x_ts | |
| scheduler.x_ts_c_predicted = [None] | |
| scheduler.step = step | |
| return scheduler | |