Spaces:
Runtime error
Runtime error
| import torch.nn.functional as nnf | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| from torch.optim.adam import Adam | |
| from PIL import Image | |
| from generation import load_512 | |
| from p2p import register_attention_control | |
| def null_optimization(solver, | |
| latents, | |
| guidance_scale, | |
| num_inner_steps, | |
| epsilon): | |
| uncond_embeddings, cond_embeddings = solver.context.chunk(2) | |
| uncond_embeddings_list = [] | |
| latent_cur = latents[-1] | |
| bar = tqdm(total=num_inner_steps * solver.n_steps) | |
| for i in range(solver.n_steps): | |
| uncond_embeddings = uncond_embeddings.clone().detach() | |
| uncond_embeddings.requires_grad = True | |
| optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.)) | |
| latent_prev = latents[len(latents) - i - 2] | |
| t = solver.model.scheduler.timesteps[i] | |
| with torch.no_grad(): | |
| noise_pred_cond = solver.get_noise_pred_single(latent_cur, t, cond_embeddings) | |
| for j in range(num_inner_steps): | |
| noise_pred_uncond = solver.get_noise_pred_single(latent_cur, t, uncond_embeddings) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| latents_prev_rec = solver.prev_step(noise_pred, t, latent_cur) | |
| loss = nnf.mse_loss(latents_prev_rec, latent_prev) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| loss_item = loss.item() | |
| bar.update() | |
| if loss_item < epsilon + i * 2e-5: | |
| break | |
| for j in range(j + 1, num_inner_steps): | |
| bar.update() | |
| uncond_embeddings_list.append(uncond_embeddings[:1].detach()) | |
| with torch.no_grad(): | |
| context = torch.cat([uncond_embeddings, cond_embeddings]) | |
| noise_pred = solver.get_noise_pred(solver.model, latent_cur, t, guidance_scale, context) | |
| latent_cur = solver.prev_step(noise_pred, t, latent_cur) | |
| bar.close() | |
| return uncond_embeddings_list | |
| def invert(solver, | |
| stop_step, | |
| is_cons_inversion=False, | |
| inv_guidance_scale=1, | |
| nti_guidance_scale=8, | |
| dynamic_guidance=False, | |
| tau1=0.4, | |
| tau2=0.6, | |
| w_embed_dim=0, | |
| image_path=None, | |
| prompt='', | |
| offsets=(0, 0, 0, 0), | |
| do_nti=False, | |
| do_npi=False, | |
| num_inner_steps=10, | |
| early_stop_epsilon=1e-5, | |
| seed=0, | |
| ): | |
| solver.init_prompt(prompt) | |
| uncond_embeddings, cond_embeddings = solver.context.chunk(2) | |
| register_attention_control(solver.model, None) | |
| if isinstance(image_path, list): | |
| image_gt = [load_512(path, *offsets) for path in image_path] | |
| elif isinstance(image_path, str): | |
| image_gt = load_512(image_path, *offsets) | |
| else: | |
| image_gt = np.array(Image.fromarray(image_path).resize((512, 512))) | |
| if is_cons_inversion: | |
| image_rec, ddim_latents = solver.cons_inversion(image_gt, | |
| w_embed_dim=w_embed_dim, | |
| guidance_scale=inv_guidance_scale, | |
| seed=seed,) | |
| else: | |
| image_rec, ddim_latents = solver.ddim_inversion(image_gt, | |
| n_steps=stop_step, | |
| guidance_scale=inv_guidance_scale, | |
| dynamic_guidance=dynamic_guidance, | |
| tau1=tau1, tau2=tau2, | |
| w_embed_dim=w_embed_dim) | |
| if do_nti: | |
| print("Null-text optimization...") | |
| uncond_embeddings = null_optimization(solver, | |
| ddim_latents, | |
| nti_guidance_scale, | |
| num_inner_steps, | |
| early_stop_epsilon) | |
| elif do_npi: | |
| uncond_embeddings = [cond_embeddings] * solver.n_steps | |
| else: | |
| uncond_embeddings = None | |
| return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings | |