Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer | |
| from omegaconf import OmegaConf | |
| import math | |
| import imageio | |
| from PIL import Image | |
| import torchvision | |
| import torch.nn.functional as F | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import time | |
| import datetime | |
| import torch | |
| import sys | |
| import os | |
| from torchvision import datasets | |
| import pickle | |
| # StableDiffusion P2P implementation originally from https://github.com/bloc97/CrossAttentionControl | |
| use_half_prec = True | |
| if use_half_prec: | |
| from my_half_diffusers import AutoencoderKL, UNet2DConditionModel | |
| from my_half_diffusers.schedulers.scheduling_utils import SchedulerOutput | |
| from my_half_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler | |
| else: | |
| from my_diffusers import AutoencoderKL, UNet2DConditionModel | |
| from my_diffusers.schedulers.scheduling_utils import SchedulerOutput | |
| from my_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler | |
| torch_dtype = torch.float16 if use_half_prec else torch.float64 | |
| np_dtype = np.float16 if use_half_prec else np.float64 | |
| import random | |
| from tqdm.auto import tqdm | |
| from torch import autocast | |
| from difflib import SequenceMatcher | |
| # Build our CLIP model | |
| model_path_clip = "openai/clip-vit-large-patch14" | |
| clip_tokenizer = CLIPTokenizer.from_pretrained(model_path_clip) | |
| clip_model = CLIPModel.from_pretrained(model_path_clip, torch_dtype=torch_dtype) | |
| clip = clip_model.text_model | |
| # Getting our HF Auth token | |
| auth_token = os.environ.get('auth_token') | |
| if auth_token is None: | |
| with open('hf_auth', 'r') as f: | |
| auth_token = f.readlines()[0].strip() | |
| model_path_diffusion = "CompVis/stable-diffusion-v1-4" | |
| # Build our SD model | |
| unet = UNet2DConditionModel.from_pretrained(model_path_diffusion, subfolder="unet", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype) | |
| vae = AutoencoderKL.from_pretrained(model_path_diffusion, subfolder="vae", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype) | |
| # Push to devices w/ double precision | |
| device = 'cuda' | |
| if use_half_prec: | |
| unet.to(device) | |
| vae.to(device) | |
| clip.to(device) | |
| else: | |
| unet.double().to(device) | |
| vae.double().to(device) | |
| clip.double().to(device) | |
| print("Loaded all models") | |
| from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | |
| from transformers import AutoFeatureExtractor | |
| # load safety model | |
| safety_model_id = "CompVis/stable-diffusion-safety-checker" | |
| safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) | |
| safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) | |
| def load_replacement(x): | |
| try: | |
| hwc = x.shape | |
| y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) | |
| y = (np.array(y)/255.0).astype(x.dtype) | |
| assert y.shape == x.shape | |
| return y | |
| except Exception: | |
| return x | |
| def check_safety(x_image): | |
| safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") | |
| x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) | |
| assert x_checked_image.shape[0] == len(has_nsfw_concept) | |
| for i in range(len(has_nsfw_concept)): | |
| if has_nsfw_concept[i]: | |
| # x_checked_image[i] = load_replacement(x_checked_image[i]) | |
| x_checked_image[i] *= 0 # load_replacement(x_checked_image[i]) | |
| return x_checked_image, has_nsfw_concept | |
| def EDICT_editing(im_path, | |
| base_prompt, | |
| edit_prompt, | |
| use_p2p=False, | |
| steps=50, | |
| mix_weight=0.93, | |
| init_image_strength=0.8, | |
| guidance_scale=3, | |
| run_baseline=False, | |
| width=512, height=512): | |
| """ | |
| Main call of our research, performs editing with either EDICT or DDIM | |
| Args: | |
| im_path: path to image to run on | |
| base_prompt: conditional prompt to deterministically noise with | |
| edit_prompt: desired text conditoining | |
| steps: ddim steps | |
| mix_weight: Weight of mixing layers. | |
| Higher means more consistent generations but divergence in inversion | |
| Lower means opposite | |
| This is fairly tuned and can get good results | |
| init_image_strength: Editing strength. Higher = more dramatic edit. | |
| Typically [0.6, 0.9] is good range. | |
| Definitely tunable per-image/maybe best results are at a different value | |
| guidance_scale: classifier-free guidance scale | |
| 3 I've found is the best for both our method and basic DDIM inversion | |
| Higher can result in more distorted results | |
| run_baseline: | |
| VERY IMPORTANT | |
| True is EDICT, False is DDIM | |
| Output: | |
| PAIR of Images (tuple) | |
| If run_baseline=True then [0] will be edit and [1] will be original | |
| If run_baseline=False then they will be two nearly identical edited versions | |
| """ | |
| # Resize/center crop to 512x512 (Can do higher res. if desired) | |
| if isinstance(im_path, str): | |
| orig_im = load_im_into_format_from_path(im_path) | |
| elif Image.isImageType(im_path): | |
| width, height = im_path.size | |
| # add max dim for sake of memory | |
| max_dim = max(width, height) | |
| if max_dim > 1024: | |
| factor = 1024 / max_dim | |
| width *= factor | |
| height *= factor | |
| width = int(width) | |
| height = int(height) | |
| im_path = im_path.resize((width, height)) | |
| min_dim = min(width, height) | |
| if min_dim < 512: | |
| factor = 512 / min_dim | |
| width *= factor | |
| height *= factor | |
| width = int(width) | |
| height = int(height) | |
| im_path = im_path.resize((width, height)) | |
| width = width - (width%64) | |
| height = height - (height%64) | |
| orig_im = im_path # general_crop(im_path, width, height) | |
| else: | |
| orig_im = im_path | |
| # compute latent pair (second one will be original latent if run_baseline=True) | |
| latents = coupled_stablediffusion(base_prompt, | |
| reverse=True, | |
| init_image=orig_im, | |
| init_image_strength=init_image_strength, | |
| steps=steps, | |
| mix_weight=mix_weight, | |
| guidance_scale=guidance_scale, | |
| run_baseline=run_baseline, | |
| width=width, height=height) | |
| # Denoise intermediate state with new conditioning | |
| gen = coupled_stablediffusion(edit_prompt if (not use_p2p) else base_prompt, | |
| None if (not use_p2p) else edit_prompt, | |
| fixed_starting_latent=latents, | |
| init_image_strength=init_image_strength, | |
| steps=steps, | |
| mix_weight=mix_weight, | |
| guidance_scale=guidance_scale, | |
| run_baseline=run_baseline, | |
| width=width, height=height) | |
| return gen | |
| def img2img_editing(im_path, | |
| edit_prompt, | |
| steps=50, | |
| init_image_strength=0.7, | |
| guidance_scale=3): | |
| """ | |
| Basic SDEdit/img2img, given an image add some noise and denoise with prompt | |
| """ | |
| orig_im = load_im_into_format_from_path(im_path) | |
| return baseline_stablediffusion(edit_prompt, | |
| init_image_strength=init_image_strength, | |
| steps=steps, | |
| init_image=orig_im, | |
| guidance_scale=guidance_scale) | |
| def center_crop(im): | |
| width, height = im.size # Get dimensions | |
| min_dim = min(width, height) | |
| left = (width - min_dim)/2 | |
| top = (height - min_dim)/2 | |
| right = (width + min_dim)/2 | |
| bottom = (height + min_dim)/2 | |
| # Crop the center of the image | |
| im = im.crop((left, top, right, bottom)) | |
| return im | |
| def general_crop(im, target_w, target_h): | |
| width, height = im.size # Get dimensions | |
| min_dim = min(width, height) | |
| left = target_w / 2 # (width - min_dim)/2 | |
| top = target_h / 2 # (height - min_dim)/2 | |
| right = width - (target_w / 2) # (width + min_dim)/2 | |
| bottom = height - (target_h / 2) # (height + min_dim)/2 | |
| # Crop the center of the image | |
| im = im.crop((left, top, right, bottom)) | |
| return im | |
| def load_im_into_format_from_path(im_path): | |
| return center_crop(Image.open(im_path)).resize((512,512)) | |
| #### P2P STUFF #### | |
| def init_attention_weights(weight_tuples): | |
| tokens_length = clip_tokenizer.model_max_length | |
| weights = torch.ones(tokens_length) | |
| for i, w in weight_tuples: | |
| if i < tokens_length and i >= 0: | |
| weights[i] = w | |
| for name, module in unet.named_modules(): | |
| module_name = type(module).__name__ | |
| if module_name == "CrossAttention" and "attn2" in name: | |
| module.last_attn_slice_weights = weights.to(device) | |
| if module_name == "CrossAttention" and "attn1" in name: | |
| module.last_attn_slice_weights = None | |
| def init_attention_edit(tokens, tokens_edit): | |
| tokens_length = clip_tokenizer.model_max_length | |
| mask = torch.zeros(tokens_length) | |
| indices_target = torch.arange(tokens_length, dtype=torch.long) | |
| indices = torch.zeros(tokens_length, dtype=torch.long) | |
| tokens = tokens.input_ids.numpy()[0] | |
| tokens_edit = tokens_edit.input_ids.numpy()[0] | |
| for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes(): | |
| if b0 < tokens_length: | |
| if name == "equal" or (name == "replace" and a1-a0 == b1-b0): | |
| mask[b0:b1] = 1 | |
| indices[b0:b1] = indices_target[a0:a1] | |
| for name, module in unet.named_modules(): | |
| module_name = type(module).__name__ | |
| if module_name == "CrossAttention" and "attn2" in name: | |
| module.last_attn_slice_mask = mask.to(device) | |
| module.last_attn_slice_indices = indices.to(device) | |
| if module_name == "CrossAttention" and "attn1" in name: | |
| module.last_attn_slice_mask = None | |
| module.last_attn_slice_indices = None | |
| def init_attention_func(): | |
| def new_attention(self, query, key, value, sequence_length, dim): | |
| batch_size_attention = query.shape[0] | |
| hidden_states = torch.zeros( | |
| (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype | |
| ) | |
| slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] | |
| for i in range(hidden_states.shape[0] // slice_size): | |
| start_idx = i * slice_size | |
| end_idx = (i + 1) * slice_size | |
| attn_slice = ( | |
| torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale | |
| ) | |
| attn_slice = attn_slice.softmax(dim=-1) | |
| if self.use_last_attn_slice: | |
| if self.last_attn_slice_mask is not None: | |
| new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) | |
| attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask | |
| else: | |
| attn_slice = self.last_attn_slice | |
| self.use_last_attn_slice = False | |
| if self.save_last_attn_slice: | |
| self.last_attn_slice = attn_slice | |
| self.save_last_attn_slice = False | |
| if self.use_last_attn_weights and self.last_attn_slice_weights is not None: | |
| attn_slice = attn_slice * self.last_attn_slice_weights | |
| self.use_last_attn_weights = False | |
| attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) | |
| hidden_states[start_idx:end_idx] = attn_slice | |
| # reshape hidden_states | |
| hidden_states = self.reshape_batch_dim_to_heads(hidden_states) | |
| return hidden_states | |
| for name, module in unet.named_modules(): | |
| module_name = type(module).__name__ | |
| if module_name == "CrossAttention": | |
| module.last_attn_slice = None | |
| module.use_last_attn_slice = False | |
| module.use_last_attn_weights = False | |
| module.save_last_attn_slice = False | |
| module._attention = new_attention.__get__(module, type(module)) | |
| def use_last_tokens_attention(use=True): | |
| for name, module in unet.named_modules(): | |
| module_name = type(module).__name__ | |
| if module_name == "CrossAttention" and "attn2" in name: | |
| module.use_last_attn_slice = use | |
| def use_last_tokens_attention_weights(use=True): | |
| for name, module in unet.named_modules(): | |
| module_name = type(module).__name__ | |
| if module_name == "CrossAttention" and "attn2" in name: | |
| module.use_last_attn_weights = use | |
| def use_last_self_attention(use=True): | |
| for name, module in unet.named_modules(): | |
| module_name = type(module).__name__ | |
| if module_name == "CrossAttention" and "attn1" in name: | |
| module.use_last_attn_slice = use | |
| def save_last_tokens_attention(save=True): | |
| for name, module in unet.named_modules(): | |
| module_name = type(module).__name__ | |
| if module_name == "CrossAttention" and "attn2" in name: | |
| module.save_last_attn_slice = save | |
| def save_last_self_attention(save=True): | |
| for name, module in unet.named_modules(): | |
| module_name = type(module).__name__ | |
| if module_name == "CrossAttention" and "attn1" in name: | |
| module.save_last_attn_slice = save | |
| #################################### | |
| ##### BASELINE ALGORITHM, ONLY USED NOW FOR SDEDIT ####3 | |
| def baseline_stablediffusion(prompt="", | |
| prompt_edit=None, | |
| null_prompt='', | |
| prompt_edit_token_weights=[], | |
| prompt_edit_tokens_start=0.0, | |
| prompt_edit_tokens_end=1.0, | |
| prompt_edit_spatial_start=0.0, | |
| prompt_edit_spatial_end=1.0, | |
| clip_start=0.0, | |
| clip_end=1.0, | |
| guidance_scale=7, | |
| steps=50, | |
| seed=1, | |
| width=512, height=512, | |
| init_image=None, init_image_strength=0.5, | |
| fixed_starting_latent = None, | |
| prev_image= None, | |
| grid=None, | |
| clip_guidance=None, | |
| clip_guidance_scale=1, | |
| num_cutouts=4, | |
| cut_power=1, | |
| scheduler_str='lms', | |
| return_latent=False, | |
| one_pass=False, | |
| normalize_noise_pred=False): | |
| width = width - width % 64 | |
| height = height - height % 64 | |
| #If seed is None, randomly select seed from 0 to 2^32-1 | |
| if seed is None: seed = random.randrange(2**32 - 1) | |
| generator = torch.cuda.manual_seed(seed) | |
| #Set inference timesteps to scheduler | |
| scheduler_dict = {'ddim':DDIMScheduler, | |
| 'lms':LMSDiscreteScheduler, | |
| 'pndm':PNDMScheduler, | |
| 'ddpm':DDPMScheduler} | |
| scheduler_call = scheduler_dict[scheduler_str] | |
| if scheduler_str == 'ddim': | |
| scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| clip_sample=False, set_alpha_to_one=False) | |
| else: | |
| scheduler = scheduler_call(beta_schedule="scaled_linear", | |
| num_train_timesteps=1000) | |
| scheduler.set_timesteps(steps) | |
| if prev_image is not None: | |
| prev_scheduler = LMSDiscreteScheduler(beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| num_train_timesteps=1000) | |
| prev_scheduler.set_timesteps(steps) | |
| #Preprocess image if it exists (img2img) | |
| if init_image is not None: | |
| init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS) | |
| init_image = np.array(init_image).astype(np_dtype) / 255.0 * 2.0 - 1.0 | |
| init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2)) | |
| #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel | |
| if init_image.shape[1] > 3: | |
| init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:]) | |
| #Move image to GPU | |
| init_image = init_image.to(device) | |
| #Encode image | |
| with autocast(device): | |
| init_latent = vae.encode(init_image).latent_dist.sample(generator=generator) * 0.18215 | |
| t_start = steps - int(steps * init_image_strength) | |
| else: | |
| init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device) | |
| t_start = 0 | |
| #Generate random normal noise | |
| if fixed_starting_latent is None: | |
| noise = torch.randn(init_latent.shape, generator=generator, device=device, dtype=unet.dtype) | |
| if scheduler_str == 'ddim': | |
| if init_image is not None: | |
| raise notImplementedError | |
| latent = scheduler.add_noise(init_latent, noise, | |
| 1000 - int(1000 * init_image_strength)).to(device) | |
| else: | |
| latent = noise | |
| else: | |
| latent = scheduler.add_noise(init_latent, noise, | |
| t_start).to(device) | |
| else: | |
| latent = fixed_starting_latent | |
| t_start = steps - int(steps * init_image_strength) | |
| if prev_image is not None: | |
| #Resize and prev_image for numpy b h w c -> torch b c h w | |
| prev_image = prev_image.resize((width, height), resample=Image.Resampling.LANCZOS) | |
| prev_image = np.array(prev_image).astype(np_dtype) / 255.0 * 2.0 - 1.0 | |
| prev_image = torch.from_numpy(prev_image[np.newaxis, ...].transpose(0, 3, 1, 2)) | |
| #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel | |
| if prev_image.shape[1] > 3: | |
| prev_image = prev_image[:, :3] * prev_image[:, 3:] + (1 - prev_image[:, 3:]) | |
| #Move image to GPU | |
| prev_image = prev_image.to(device) | |
| #Encode image | |
| with autocast(device): | |
| prev_init_latent = vae.encode(prev_image).latent_dist.sample(generator=generator) * 0.18215 | |
| t_start = steps - int(steps * init_image_strength) | |
| prev_latent = prev_scheduler.add_noise(prev_init_latent, noise, t_start).to(device) | |
| else: | |
| prev_latent = None | |
| #Process clip | |
| with autocast(device): | |
| tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) | |
| embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state | |
| tokens_conditional = clip_tokenizer(prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) | |
| embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state | |
| #Process prompt editing | |
| assert not ((prompt_edit is not None) and (prev_image is not None)) | |
| if prompt_edit is not None: | |
| tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) | |
| embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state | |
| init_attention_edit(tokens_conditional, tokens_conditional_edit) | |
| elif prev_image is not None: | |
| init_attention_edit(tokens_conditional, tokens_conditional) | |
| init_attention_func() | |
| init_attention_weights(prompt_edit_token_weights) | |
| timesteps = scheduler.timesteps[t_start:] | |
| # print(timesteps) | |
| assert isinstance(guidance_scale, int) | |
| num_cycles = 1 # guidance_scale + 1 | |
| last_noise_preds = None | |
| for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): | |
| t_index = t_start + i | |
| latent_model_input = latent | |
| if scheduler_str=='lms': | |
| sigma = scheduler.sigmas[t_index] # last is first and first is last | |
| latent_model_input = (latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype) | |
| else: | |
| assert scheduler_str in ['ddim', 'pndm', 'ddpm'] | |
| #Predict the unconditional noise residual | |
| if len(t.shape) == 0: | |
| t = t[None].to(unet.device) | |
| noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=embedding_unconditional, | |
| ).sample | |
| if prev_latent is not None: | |
| prev_latent_model_input = prev_latent | |
| prev_latent_model_input = (prev_latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype) | |
| prev_noise_pred_uncond = unet(prev_latent_model_input, t, | |
| encoder_hidden_states=embedding_unconditional, | |
| ).sample | |
| # noise_pred_uncond = unet(latent_model_input, t, | |
| # encoder_hidden_states=embedding_unconditional)['sample'] | |
| #Prepare the Cross-Attention layers | |
| if prompt_edit is not None or prev_latent is not None: | |
| save_last_tokens_attention() | |
| save_last_self_attention() | |
| else: | |
| #Use weights on non-edited prompt when edit is None | |
| use_last_tokens_attention_weights() | |
| #Predict the conditional noise residual and save the cross-attention layer activations | |
| if prev_latent is not None: | |
| raise NotImplementedError # I totally lost track of what this is | |
| prev_noise_pred_cond = unet(prev_latent_model_input, t, encoder_hidden_states=embedding_conditional, | |
| ).sample | |
| else: | |
| noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional, | |
| ).sample | |
| #Edit the Cross-Attention layer activations | |
| t_scale = t / scheduler.num_train_timesteps | |
| if prompt_edit is not None or prev_latent is not None: | |
| if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end: | |
| use_last_tokens_attention() | |
| if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end: | |
| use_last_self_attention() | |
| #Use weights on edited prompt | |
| use_last_tokens_attention_weights() | |
| #Predict the edited conditional noise residual using the cross-attention masks | |
| if prompt_edit is not None: | |
| noise_pred_cond = unet(latent_model_input, t, | |
| encoder_hidden_states=embedding_conditional_edit).sample | |
| #Perform guidance | |
| # if i%(num_cycles)==0: # cycle_i+1==num_cycles: | |
| """ | |
| if cycle_i+1==num_cycles: | |
| noise_pred = noise_pred_uncond | |
| else: | |
| noise_pred = noise_pred_cond - noise_pred_uncond | |
| """ | |
| if last_noise_preds is not None: | |
| # print( (last_noise_preds[0]*noise_pred_uncond).sum(), (last_noise_preds[1]*noise_pred_cond).sum()) | |
| # print(F.cosine_similarity(last_noise_preds[0].flatten(), noise_pred_uncond.flatten(), dim=0), | |
| # F.cosine_similarity(last_noise_preds[1].flatten(), noise_pred_cond.flatten(), dim=0)) | |
| last_grad= last_noise_preds[1] - last_noise_preds[0] | |
| new_grad = noise_pred_cond - noise_pred_uncond | |
| # print( F.cosine_similarity(last_grad.flatten(), new_grad.flatten(), dim=0)) | |
| last_noise_preds = (noise_pred_uncond, noise_pred_cond) | |
| use_cond_guidance = True | |
| if use_cond_guidance: | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| else: | |
| noise_pred = noise_pred_uncond | |
| if clip_guidance is not None and t_scale >= clip_start and t_scale <= clip_end: | |
| noise_pred, latent = new_cond_fn(latent, t, t_index, | |
| embedding_conditional, noise_pred,clip_guidance, | |
| clip_guidance_scale, | |
| num_cutouts, | |
| scheduler, unet,use_cutouts=True, | |
| cut_power=cut_power) | |
| if normalize_noise_pred: | |
| noise_pred = noise_pred * noise_pred_uncond.norm() / noise_pred.norm() | |
| if scheduler_str == 'ddim': | |
| latent = forward_step(scheduler, noise_pred, | |
| t, | |
| latent).prev_sample | |
| else: | |
| latent = scheduler.step(noise_pred, | |
| t_index, | |
| latent).prev_sample | |
| if prev_latent is not None: | |
| prev_noise_pred = prev_noise_pred_uncond + guidance_scale * (prev_noise_pred_cond - prev_noise_pred_uncond) | |
| prev_latent = prev_scheduler.step(prev_noise_pred, t_index, prev_latent).prev_sample | |
| if one_pass: break | |
| #scale and decode the image latents with vae | |
| if return_latent: return latent | |
| latent = latent / 0.18215 | |
| image = vae.decode(latent.to(vae.dtype)).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).numpy() | |
| image, _ = check_safety(image) | |
| image = (image[0] * 255).round().astype("uint8") | |
| return Image.fromarray(image) | |
| #################################### | |
| #### HELPER FUNCTIONS FOR OUR METHOD ##### | |
| def get_alpha_and_beta(t, scheduler): | |
| # want to run this for both current and previous timnestep | |
| if t.dtype==torch.long: | |
| alpha = scheduler.alphas_cumprod[t] | |
| return alpha, 1-alpha | |
| if t<0: | |
| return scheduler.final_alpha_cumprod, 1 - scheduler.final_alpha_cumprod | |
| low = t.floor().long() | |
| high = t.ceil().long() | |
| rem = t - low | |
| low_alpha = scheduler.alphas_cumprod[low] | |
| high_alpha = scheduler.alphas_cumprod[high] | |
| interpolated_alpha = low_alpha * rem + high_alpha * (1-rem) | |
| interpolated_beta = 1 - interpolated_alpha | |
| return interpolated_alpha, interpolated_beta | |
| # A DDIM forward step function | |
| def forward_step( | |
| self, | |
| model_output, | |
| timestep: int, | |
| sample, | |
| eta: float = 0.0, | |
| use_clipped_model_output: bool = False, | |
| generator=None, | |
| return_dict: bool = True, | |
| use_double=False, | |
| ) : | |
| if self.num_inference_steps is None: | |
| raise ValueError( | |
| "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | |
| ) | |
| prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps | |
| if timestep > self.timesteps.max(): | |
| raise NotImplementedError("Need to double check what the overflow is") | |
| alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self) | |
| alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self) | |
| alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5) | |
| first_term = (1./alpha_quotient) * sample | |
| second_term = (1./alpha_quotient) * (beta_prod_t ** 0.5) * model_output | |
| third_term = ((1 - alpha_prod_t_prev)**0.5) * model_output | |
| return first_term - second_term + third_term | |
| # A DDIM reverse step function, the inverse of above | |
| def reverse_step( | |
| self, | |
| model_output, | |
| timestep: int, | |
| sample, | |
| eta: float = 0.0, | |
| use_clipped_model_output: bool = False, | |
| generator=None, | |
| return_dict: bool = True, | |
| use_double=False, | |
| ) : | |
| if self.num_inference_steps is None: | |
| raise ValueError( | |
| "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | |
| ) | |
| prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps | |
| if timestep > self.timesteps.max(): | |
| raise NotImplementedError | |
| else: | |
| alpha_prod_t = self.alphas_cumprod[timestep] | |
| alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self) | |
| alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self) | |
| alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5) | |
| first_term = alpha_quotient * sample | |
| second_term = ((beta_prod_t)**0.5) * model_output | |
| third_term = alpha_quotient * ((1 - alpha_prod_t_prev)**0.5) * model_output | |
| return first_term + second_term - third_term | |
| def latent_to_image(latent): | |
| image = vae.decode(latent.to(vae.dtype)/0.18215).sample | |
| image = prep_image_for_return(image) | |
| return image | |
| def prep_image_for_return(image): | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).numpy() | |
| image = (image[0] * 255).round().astype("uint8") | |
| image = Image.fromarray(image) | |
| return image | |
| ############################# | |
| ##### MAIN EDICT FUNCTION ####### | |
| # Use EDICT_editing to perform calls | |
| def coupled_stablediffusion(prompt="", | |
| prompt_edit=None, | |
| null_prompt='', | |
| prompt_edit_token_weights=[], | |
| prompt_edit_tokens_start=0.0, | |
| prompt_edit_tokens_end=1.0, | |
| prompt_edit_spatial_start=0.0, | |
| prompt_edit_spatial_end=1.0, | |
| guidance_scale=7.0, steps=50, | |
| seed=1, width=512, height=512, | |
| init_image=None, init_image_strength=1.0, | |
| run_baseline=False, | |
| use_lms=False, | |
| leapfrog_steps=True, | |
| reverse=False, | |
| return_latents=False, | |
| fixed_starting_latent=None, | |
| beta_schedule='scaled_linear', | |
| mix_weight=0.93): | |
| #If seed is None, randomly select seed from 0 to 2^32-1 | |
| if seed is None: seed = random.randrange(2**32 - 1) | |
| generator = torch.cuda.manual_seed(seed) | |
| def image_to_latent(im): | |
| if isinstance(im, torch.Tensor): | |
| # assume it's the latent | |
| # used to avoid clipping new generation before inversion | |
| init_latent = im.to(device) | |
| else: | |
| #Resize and transpose for numpy b h w c -> torch b c h w | |
| im = im.resize((width, height), resample=Image.Resampling.LANCZOS) | |
| im = np.array(im).astype(np_dtype) / 255.0 * 2.0 - 1.0 | |
| # check if black and white | |
| if len(im.shape) < 3: | |
| im = np.stack([im for _ in range(3)], axis=2) # putting at end b/c channels | |
| im = torch.from_numpy(im[np.newaxis, ...].transpose(0, 3, 1, 2)) | |
| #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel | |
| if im.shape[1] > 3: | |
| im = im[:, :3] * im[:, 3:] + (1 - im[:, 3:]) | |
| #Move image to GPU | |
| im = im.to(device) | |
| #Encode image | |
| if use_half_prec: | |
| init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215 | |
| else: | |
| with autocast(device): | |
| init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215 | |
| return init_latent | |
| assert not use_lms, "Can't invert LMS the same as DDIM" | |
| if run_baseline: leapfrog_steps=False | |
| #Change size to multiple of 64 to prevent size mismatches inside model | |
| width = width - width % 64 | |
| height = height - height % 64 | |
| #Preprocess image if it exists (img2img) | |
| if init_image is not None: | |
| assert reverse # want to be performing deterministic noising | |
| # can take either pair (output of generative process) or single image | |
| if isinstance(init_image, list): | |
| if isinstance(init_image[0], torch.Tensor): | |
| init_latent = [t.clone() for t in init_image] | |
| else: | |
| init_latent = [image_to_latent(im) for im in init_image] | |
| else: | |
| init_latent = image_to_latent(init_image) | |
| # this is t_start for forward, t_end for reverse | |
| t_limit = steps - int(steps * init_image_strength) | |
| else: | |
| assert not reverse, 'Need image to reverse from' | |
| init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device) | |
| t_limit = 0 | |
| if reverse: | |
| latent = init_latent | |
| else: | |
| #Generate random normal noise | |
| noise = torch.randn(init_latent.shape, | |
| generator=generator, | |
| device=device, | |
| dtype=torch_dtype) | |
| if fixed_starting_latent is None: | |
| latent = noise | |
| else: | |
| if isinstance(fixed_starting_latent, list): | |
| latent = [l.clone() for l in fixed_starting_latent] | |
| else: | |
| latent = fixed_starting_latent.clone() | |
| t_limit = steps - int(steps * init_image_strength) | |
| if isinstance(latent, list): # initializing from pair of images | |
| latent_pair = latent | |
| else: # initializing from noise | |
| latent_pair = [latent.clone(), latent.clone()] | |
| if steps==0: | |
| if init_image is not None: | |
| return image_to_latent(init_image) | |
| else: | |
| image = vae.decode(latent.to(vae.dtype) / 0.18215).sample | |
| return prep_image_for_return(image) | |
| #Set inference timesteps to scheduler | |
| schedulers = [] | |
| for i in range(2): | |
| # num_raw_timesteps = max(1000, steps) | |
| scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, | |
| beta_schedule=beta_schedule, | |
| num_train_timesteps=1000, | |
| clip_sample=False, | |
| set_alpha_to_one=False) | |
| scheduler.set_timesteps(steps) | |
| schedulers.append(scheduler) | |
| with autocast(device): | |
| # CLIP Text Embeddings | |
| tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length", | |
| max_length=clip_tokenizer.model_max_length, | |
| truncation=True, return_tensors="pt", | |
| return_overflowing_tokens=True) | |
| embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state | |
| tokens_conditional = clip_tokenizer(prompt, padding="max_length", | |
| max_length=clip_tokenizer.model_max_length, | |
| truncation=True, return_tensors="pt", | |
| return_overflowing_tokens=True) | |
| embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state | |
| #Process prompt editing (if running Prompt-to-Prompt) | |
| if prompt_edit is not None: | |
| tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length", | |
| max_length=clip_tokenizer.model_max_length, | |
| truncation=True, return_tensors="pt", | |
| return_overflowing_tokens=True) | |
| embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state | |
| init_attention_edit(tokens_conditional, tokens_conditional_edit) | |
| init_attention_func() | |
| init_attention_weights(prompt_edit_token_weights) | |
| timesteps = schedulers[0].timesteps[t_limit:] | |
| if reverse: timesteps = timesteps.flip(0) | |
| for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): | |
| t_scale = t / schedulers[0].num_train_timesteps | |
| if (reverse) and (not run_baseline): | |
| # Reverse mixing layer | |
| new_latents = [l.clone() for l in latent_pair] | |
| new_latents[1] = (new_latents[1].clone() - (1-mix_weight)*new_latents[0].clone()) / mix_weight | |
| new_latents[0] = (new_latents[0].clone() - (1-mix_weight)*new_latents[1].clone()) / mix_weight | |
| latent_pair = new_latents | |
| # alternate EDICT steps | |
| for latent_i in range(2): | |
| if run_baseline and latent_i==1: continue # just have one sequence for baseline | |
| # this modifies latent_pair[i] while using | |
| # latent_pair[(i+1)%2] | |
| if reverse and (not run_baseline): | |
| if leapfrog_steps: | |
| # what i would be from going other way | |
| orig_i = len(timesteps) - (i+1) | |
| offset = (orig_i+1) % 2 | |
| latent_i = (latent_i + offset) % 2 | |
| else: | |
| # Do 1 then 0 | |
| latent_i = (latent_i+1)%2 | |
| else: | |
| if leapfrog_steps: | |
| offset = i%2 | |
| latent_i = (latent_i + offset) % 2 | |
| latent_j = ((latent_i+1) % 2) if not run_baseline else latent_i | |
| latent_model_input = latent_pair[latent_j] | |
| latent_base = latent_pair[latent_i] | |
| #Predict the unconditional noise residual | |
| noise_pred_uncond = unet(latent_model_input, t, | |
| encoder_hidden_states=embedding_unconditional).sample | |
| #Prepare the Cross-Attention layers | |
| if prompt_edit is not None: | |
| save_last_tokens_attention() | |
| save_last_self_attention() | |
| else: | |
| #Use weights on non-edited prompt when edit is None | |
| use_last_tokens_attention_weights() | |
| #Predict the conditional noise residual and save the cross-attention layer activations | |
| noise_pred_cond = unet(latent_model_input, t, | |
| encoder_hidden_states=embedding_conditional).sample | |
| #Edit the Cross-Attention layer activations | |
| if prompt_edit is not None: | |
| t_scale = t / schedulers[0].num_train_timesteps | |
| if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end: | |
| use_last_tokens_attention() | |
| if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end: | |
| use_last_self_attention() | |
| #Use weights on edited prompt | |
| use_last_tokens_attention_weights() | |
| #Predict the edited conditional noise residual using the cross-attention masks | |
| noise_pred_cond = unet(latent_model_input, | |
| t, | |
| encoder_hidden_states=embedding_conditional_edit).sample | |
| #Perform guidance | |
| grad = (noise_pred_cond - noise_pred_uncond) | |
| noise_pred = noise_pred_uncond + guidance_scale * grad | |
| step_call = reverse_step if reverse else forward_step | |
| new_latent = step_call(schedulers[latent_i], | |
| noise_pred, | |
| t, | |
| latent_base)# .prev_sample | |
| new_latent = new_latent.to(latent_base.dtype) | |
| latent_pair[latent_i] = new_latent | |
| if (not reverse) and (not run_baseline): | |
| # Mixing layer (contraction) during generative process | |
| new_latents = [l.clone() for l in latent_pair] | |
| new_latents[0] = (mix_weight*new_latents[0] + (1-mix_weight)*new_latents[1]).clone() | |
| new_latents[1] = ((1-mix_weight)*new_latents[0] + (mix_weight)*new_latents[1]).clone() | |
| latent_pair = new_latents | |
| #scale and decode the image latents with vae, can return latents instead of images | |
| if reverse or return_latents: | |
| results = [latent_pair] | |
| return results if len(results)>1 else results[0] | |
| # decode latents to iamges | |
| images = [] | |
| for latent_i in range(2): | |
| latent = latent_pair[latent_i] / 0.18215 | |
| image = vae.decode(latent.to(vae.dtype)).sample | |
| images.append(image) | |
| # Return images | |
| return_arr = [] | |
| for image in images: | |
| image = prep_image_for_return(image) | |
| return_arr.append(image) | |
| results = [return_arr] | |
| return results if len(results)>1 else results[0] | |