Spaces:
Runtime error
Runtime error
| from typing import Union | |
| from torchvision.transforms import ToTensor | |
| from torchvision.utils import save_image | |
| from tqdm import tqdm | |
| import torch | |
| from torch.optim.adam import Adam | |
| import torch.nn.functional as nnf | |
| import numpy as np | |
| from PIL import Image | |
| def load_512(image_path, left=0, right=0, top=0, bottom=0): | |
| if type(image_path) is str: | |
| image = np.array(Image.open(image_path))[:, :, :3] | |
| else: | |
| image = image_path | |
| h, w, c = image.shape | |
| left = min(left, w-1) | |
| right = min(right, w - left - 1) | |
| top = min(top, h - left - 1) | |
| bottom = min(bottom, h - top - 1) | |
| image = image[top:h-bottom, left:w-right] | |
| h, w, c = image.shape | |
| if h < w: | |
| offset = (w - h) // 2 | |
| image = image[:, offset:offset + h] | |
| elif w < h: | |
| offset = (h - w) // 2 | |
| image = image[offset:offset + w] | |
| image = np.array(Image.fromarray(image).resize((512, 512))) | |
| return image | |
| def invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path): | |
| print("Start null text inversion") | |
| null_inversion = NullInversion(ldm_stable, ldm_stable_config) | |
| (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(args.real_image_path, prompts[0], offsets=(0,0,0,0), verbose=True) | |
| save_image(ToTensor()(image_gt), f"{exp_path}/real_image.jpg") | |
| save_image(ToTensor()(image_enc), f"{exp_path}/image_enc.jpg") | |
| print("End null text inversion") | |
| return x_t, uncond_embeddings | |
| class NullInversion: | |
| def __init__(self, model, model_config): | |
| self.model = model | |
| self.model_config = model_config | |
| self.tokenizer = self.model.tokenizer | |
| self.model.scheduler.set_timesteps(self.model_config["num_diffusion_steps"]) | |
| self.prompt = None | |
| self.context = None | |
| def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): | |
| prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps | |
| alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | |
| alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod | |
| beta_prod_t = 1 - alpha_prod_t | |
| pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 | |
| pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output | |
| prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction | |
| return prev_sample | |
| def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): | |
| timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep | |
| alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod | |
| alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] | |
| beta_prod_t = 1 - alpha_prod_t | |
| next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 | |
| next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output | |
| next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction | |
| return next_sample | |
| def get_noise_pred_single(self, latents, t, context): | |
| noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"] | |
| return noise_pred | |
| def get_noise_pred(self, latents, t, is_forward=True, context=None): | |
| latents_input = torch.cat([latents] * 2) | |
| if context is None: | |
| context = self.context | |
| guidance_scale = 1 if is_forward else self.model_config["guidance_scale"] | |
| noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"] | |
| noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) | |
| if is_forward: | |
| latents = self.next_step(noise_pred, t, latents) | |
| else: | |
| latents = self.prev_step(noise_pred, t, latents) | |
| return latents | |
| def latent2image(self, latents, return_type='np'): | |
| latents = 1 / 0.18215 * latents.detach() | |
| image = self.model.vae.decode(latents)['sample'] | |
| if return_type == 'np': | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).numpy()[0] | |
| image = (image * 255).astype(np.uint8) | |
| return image | |
| def image2latent(self, image): | |
| with torch.no_grad(): | |
| if type(image) is Image: | |
| image = np.array(image) | |
| if type(image) is torch.Tensor and image.dim() == 4: | |
| latents = image | |
| else: | |
| image = torch.from_numpy(image).float() / 127.5 - 1 | |
| image = image.permute(2, 0, 1).unsqueeze(0).to(self.model.device) | |
| latents = self.model.vae.encode(image)['latent_dist'].mean | |
| latents = latents * 0.18215 | |
| return latents | |
| def init_prompt(self, prompt: str): | |
| uncond_input = self.model.tokenizer( | |
| [""], padding="max_length", max_length=self.model.tokenizer.model_max_length, | |
| return_tensors="pt" | |
| ) | |
| uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] | |
| text_input = self.model.tokenizer( | |
| [prompt], | |
| padding="max_length", | |
| max_length=self.model.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0] | |
| self.context = torch.cat([uncond_embeddings, text_embeddings]) | |
| self.prompt = prompt | |
| def ddim_loop(self, latent): | |
| uncond_embeddings, cond_embeddings = self.context.chunk(2) | |
| all_latent = [latent] | |
| latent = latent.clone().detach() | |
| for i in tqdm(range(self.model_config["num_diffusion_steps"])): | |
| t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1] | |
| noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings) | |
| latent = self.next_step(noise_pred, t, latent) | |
| all_latent.append(latent) | |
| return all_latent | |
| def scheduler(self): | |
| return self.model.scheduler | |
| def ddim_inversion(self, image): | |
| latent = self.image2latent(image) | |
| image_rec = self.latent2image(latent) | |
| ddim_latents = self.ddim_loop(latent) | |
| return image_rec, ddim_latents | |
| def null_optimization(self, latents, num_inner_steps, epsilon): | |
| uncond_embeddings, cond_embeddings = self.context.chunk(2) | |
| uncond_embeddings_list = [] | |
| latent_cur = latents[-1] | |
| with tqdm(total=num_inner_steps * (self.model_config["num_diffusion_steps"])) as bar: | |
| for i in range(self.model_config["num_diffusion_steps"]): | |
| uncond_embeddings = uncond_embeddings.clone().detach() | |
| uncond_embeddings.requires_grad = True | |
| optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.)) | |
| latent_prev = latents[len(latents) - i - 2] | |
| t = self.model.scheduler.timesteps[i] | |
| with torch.no_grad(): | |
| noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings) | |
| for j in range(num_inner_steps): | |
| noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings) | |
| noise_pred = noise_pred_uncond + self.model_config["guidance_scale"] * (noise_pred_cond - noise_pred_uncond) | |
| latents_prev_rec = self.prev_step(noise_pred, t, latent_cur) | |
| loss = nnf.mse_loss(latents_prev_rec, latent_prev) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| loss_item = loss.item() | |
| bar.update() | |
| if loss_item < epsilon + i * 2e-5: | |
| break | |
| bar.update(num_inner_steps - j - 1) | |
| uncond_embeddings_list.append(uncond_embeddings[:1].detach()) | |
| with torch.no_grad(): | |
| context = torch.cat([uncond_embeddings, cond_embeddings]) | |
| latent_cur = self.get_noise_pred(latent_cur, t, False, context) | |
| # bar.close() | |
| return uncond_embeddings_list | |
| def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False): | |
| self.init_prompt(prompt) | |
| image_gt = load_512(image_path, *offsets) | |
| if verbose: | |
| print("DDIM inversion...") | |
| image_rec, ddim_latents = self.ddim_inversion(image_gt) | |
| if verbose: | |
| print("Null-text optimization...") | |
| uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon) | |
| return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings | |