Spaces:
Runtime error
Runtime error
| from omegaconf import OmegaConf | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| import os | |
| from tqdm import tqdm | |
| from einops import rearrange | |
| import numpy as np | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| from ldm.models.diffusion.ddim import DDIMSampler | |
| from ldm.util import instantiate_from_config | |
| import random | |
| import glob | |
| import re | |
| import shutil | |
| import pdb | |
| import argparse | |
| from convertModels import savemodelDiffusers | |
| # Util Functions | |
| def load_model_from_config(config, ckpt, device="cpu", verbose=False): | |
| """Loads a model from config and a ckpt | |
| if config is a path will use omegaconf to load | |
| """ | |
| if isinstance(config, (str, Path)): | |
| config = OmegaConf.load(config) | |
| pl_sd = torch.load(ckpt, map_location="cpu") | |
| global_step = pl_sd["global_step"] | |
| sd = pl_sd["state_dict"] | |
| model = instantiate_from_config(config.model) | |
| m, u = model.load_state_dict(sd, strict=False) | |
| model.to(device) | |
| model.eval() | |
| model.cond_stage_model.device = device | |
| return model | |
| def sample_model(model, sampler, c, h, w, ddim_steps, scale, ddim_eta, start_code=None, n_samples=1,t_start=-1,log_every_t=None,till_T=None,verbose=True): | |
| """Sample the model""" | |
| uc = None | |
| if scale != 1.0: | |
| uc = model.get_learned_conditioning(n_samples * [""]) | |
| log_t = 100 | |
| if log_every_t is not None: | |
| log_t = log_every_t | |
| shape = [4, h // 8, w // 8] | |
| samples_ddim, inters = sampler.sample(S=ddim_steps, | |
| conditioning=c, | |
| batch_size=n_samples, | |
| shape=shape, | |
| verbose=False, | |
| x_T=start_code, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=uc, | |
| eta=ddim_eta, | |
| verbose_iter = verbose, | |
| t_start=t_start, | |
| log_every_t = log_t, | |
| till_T = till_T | |
| ) | |
| if log_every_t is not None: | |
| return samples_ddim, inters | |
| return samples_ddim | |
| def load_img(path, target_size=512): | |
| """Load an image, resize and output -1..1""" | |
| image = Image.open(path).convert("RGB") | |
| tform = transforms.Compose([ | |
| transforms.Resize(target_size), | |
| transforms.CenterCrop(target_size), | |
| transforms.ToTensor(), | |
| ]) | |
| image = tform(image) | |
| return 2.*image - 1. | |
| def moving_average(a, n=3) : | |
| ret = np.cumsum(a, dtype=float) | |
| ret[n:] = ret[n:] - ret[:-n] | |
| return ret[n - 1:] / n | |
| def plot_loss(losses, path,word, n=100): | |
| v = moving_average(losses, n) | |
| plt.plot(v, label=f'{word}_loss') | |
| plt.legend(loc="upper left") | |
| plt.title('Average loss in trainings', fontsize=20) | |
| plt.xlabel('Data point', fontsize=16) | |
| plt.ylabel('Loss value', fontsize=16) | |
| plt.savefig(path) | |
| ##################### ESD Functions | |
| def get_models(config_path, ckpt_path, devices): | |
| model_orig = load_model_from_config(config_path, ckpt_path, devices[1]) | |
| sampler_orig = DDIMSampler(model_orig) | |
| model = load_model_from_config(config_path, ckpt_path, devices[0]) | |
| sampler = DDIMSampler(model) | |
| return model_orig, sampler_orig, model, sampler | |
| def train_esd(prompt, train_method, start_guidance, negative_guidance, iterations, lr, config_path, ckpt_path, diffusers_config_path, devices, seperator=None, image_size=512, ddim_steps=50): | |
| ''' | |
| Function to train diffusion models to erase concepts from model weights | |
| Parameters | |
| ---------- | |
| prompt : str | |
| The concept to erase from diffusion model (Eg: "Van Gogh"). | |
| train_method : str | |
| The parameters to train for erasure (ESD-x, ESD-u, full, selfattn). | |
| start_guidance : float | |
| Guidance to generate images for training. | |
| negative_guidance : float | |
| Guidance to erase the concepts from diffusion model. | |
| iterations : int | |
| Number of iterations to train. | |
| lr : float | |
| learning rate for fine tuning. | |
| config_path : str | |
| config path for compvis diffusion format. | |
| ckpt_path : str | |
| checkpoint path for pre-trained compvis diffusion weights. | |
| diffusers_config_path : str | |
| Config path for diffusers unet in json format. | |
| devices : str | |
| 2 devices used to load the models (Eg: '0,1' will load in cuda:0 and cuda:1). | |
| seperator : str, optional | |
| If the prompt has commas can use this to seperate the prompt for individual simulataneous erasures. The default is None. | |
| image_size : int, optional | |
| Image size for generated images. The default is 512. | |
| ddim_steps : int, optional | |
| Number of diffusion time steps. The default is 50. | |
| Returns | |
| ------- | |
| None | |
| ''' | |
| # PROMPT CLEANING | |
| word_print = prompt.replace(' ','') | |
| if prompt == 'allartist': | |
| prompt = "Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, Alena Aenami, Tyler Edlin, Kilian Eng" | |
| if prompt == 'i2p': | |
| prompt = "hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood" | |
| if prompt == "artifact": | |
| prompt = "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy" | |
| if seperator is not None: | |
| words = prompt.split(seperator) | |
| words = [word.strip() for word in words] | |
| else: | |
| words = [prompt] | |
| print(words) | |
| ddim_eta = 0 | |
| # MODEL TRAINING SETUP | |
| model_orig, sampler_orig, model, sampler = get_models(config_path, ckpt_path, devices) | |
| # choose parameters to train based on train_method | |
| parameters = [] | |
| for name, param in model.model.diffusion_model.named_parameters(): | |
| # train all layers except x-attns and time_embed layers | |
| if train_method == 'noxattn': | |
| if name.startswith('out.') or 'attn2' in name or 'time_embed' in name: | |
| pass | |
| else: | |
| print(name) | |
| parameters.append(param) | |
| # train only self attention layers | |
| if train_method == 'selfattn': | |
| if 'attn1' in name: | |
| print(name) | |
| parameters.append(param) | |
| # train only x attention layers | |
| if train_method == 'xattn': | |
| if 'attn2' in name: | |
| print(name) | |
| parameters.append(param) | |
| # train all layers | |
| if train_method == 'full': | |
| print(name) | |
| parameters.append(param) | |
| # train all layers except time embed layers | |
| if train_method == 'notime': | |
| if not (name.startswith('out.') or 'time_embed' in name): | |
| print(name) | |
| parameters.append(param) | |
| if train_method == 'xlayer': | |
| if 'attn2' in name: | |
| if 'output_blocks.6.' in name or 'output_blocks.8.' in name: | |
| print(name) | |
| parameters.append(param) | |
| if train_method == 'selflayer': | |
| if 'attn1' in name: | |
| if 'input_blocks.4.' in name or 'input_blocks.7.' in name: | |
| print(name) | |
| parameters.append(param) | |
| # set model to train | |
| model.train() | |
| # create a lambda function for cleaner use of sampling code (only denoising till time step t) | |
| quick_sample_till_t = lambda x, s, code, t: sample_model(model, sampler, | |
| x, image_size, image_size, ddim_steps, s, ddim_eta, | |
| start_code=code, till_T=t, verbose=False) | |
| losses = [] | |
| opt = torch.optim.Adam(parameters, lr=lr) | |
| criteria = torch.nn.MSELoss() | |
| history = [] | |
| name = f'compvis-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{negative_guidance}-iter_{iterations}-lr_{lr}' | |
| # TRAINING CODE | |
| pbar = tqdm(range(iterations)) | |
| for i in pbar: | |
| word = random.sample(words,1)[0] | |
| # get text embeddings for unconditional and conditional prompts | |
| emb_0 = model.get_learned_conditioning(['']) | |
| emb_p = model.get_learned_conditioning([word]) | |
| emb_n = model.get_learned_conditioning([f'{word}']) | |
| opt.zero_grad() | |
| t_enc = torch.randint(ddim_steps, (1,), device=devices[0]) | |
| # time step from 1000 to 0 (0 being good) | |
| og_num = round((int(t_enc)/ddim_steps)*1000) | |
| og_num_lim = round((int(t_enc+1)/ddim_steps)*1000) | |
| t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=devices[0]) | |
| start_code = torch.randn((1, 4, 64, 64)).to(devices[0]) | |
| with torch.no_grad(): | |
| # generate an image with the concept from ESD model | |
| z = quick_sample_till_t(emb_p.to(devices[0]), start_guidance, start_code, int(t_enc)) # emb_p seems to work better instead of emb_0 | |
| # get conditional and unconditional scores from frozen model at time step t and image z | |
| e_0 = model_orig.apply_model(z.to(devices[1]), t_enc_ddpm.to(devices[1]), emb_0.to(devices[1])) | |
| e_p = model_orig.apply_model(z.to(devices[1]), t_enc_ddpm.to(devices[1]), emb_p.to(devices[1])) | |
| # breakpoint() | |
| # get conditional score from ESD model | |
| e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_n.to(devices[0])) | |
| e_0.requires_grad = False | |
| e_p.requires_grad = False | |
| # reconstruction loss for ESD objective from frozen model and conditional score of ESD model | |
| loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) #loss = criteria(e_n, e_0) works the best try 5000 epochs | |
| # update weights to erase the concept | |
| loss.backward() | |
| losses.append(loss.item()) | |
| pbar.set_postfix({"loss": loss.item()}) | |
| history.append(loss.item()) | |
| opt.step() | |
| # # save checkpoint and loss curve | |
| # if (i+1) % 500 == 0 and i+1 != iterations and i+1>= 500: | |
| # save_model(model, name, i-1, save_compvis=True, save_diffusers=False) | |
| # if i % 100 == 0: | |
| # save_history(losses, name, word_print) | |
| model.eval() | |
| # save_model(model, name, None, save_compvis=True, save_diffusers=True, compvis_config_file=config_path, diffusers_config_file=diffusers_config_path) | |
| # save_history(losses, name, word_print) | |
| return model_orig, model | |
| def save_model(model, name, num, compvis_config_file=None, diffusers_config_file=None, device='cpu', save_compvis=True, save_diffusers=True): | |
| # SAVE MODEL | |
| # PATH = f'{FOLDER}/{model_type}-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{neg_guidance}-iter_{i+1}-lr_{lr}-startmodel_{start_model}-numacc_{numacc}.pt' | |
| folder_path = f'models/{name}' | |
| os.makedirs(folder_path, exist_ok=True) | |
| if num is not None: | |
| path = f'{folder_path}/{name}-epoch_{num}.pt' | |
| else: | |
| path = f'{folder_path}/{name}.pt' | |
| if save_compvis: | |
| torch.save(model.state_dict(), path) | |
| if save_diffusers: | |
| print('Saving Model in Diffusers Format') | |
| savemodelDiffusers(name, compvis_config_file, diffusers_config_file, device=device ) | |
| def save_history(losses, name, word_print): | |
| folder_path = f'models/{name}' | |
| os.makedirs(folder_path, exist_ok=True) | |
| with open(f'{folder_path}/loss.txt', 'w') as f: | |
| f.writelines([str(i) for i in losses]) | |
| plot_loss(losses,f'{folder_path}/loss.png' , word_print, n=3) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser( | |
| prog = 'TrainESD', | |
| description = 'Finetuning stable diffusion model to erase concepts using ESD method') | |
| parser.add_argument('--prompt', help='prompt corresponding to concept to erase', type=str, required=True) | |
| parser.add_argument('--train_method', help='method of training', type=str, required=True) | |
| parser.add_argument('--start_guidance', help='guidance of start image used to train', type=float, required=False, default=3) | |
| parser.add_argument('--negative_guidance', help='guidance of negative training used to train', type=float, required=False, default=1) | |
| parser.add_argument('--iterations', help='iterations used to train', type=int, required=False, default=1000) | |
| parser.add_argument('--lr', help='learning rate used to train', type=int, required=False, default=1e-5) | |
| parser.add_argument('--config_path', help='config path for stable diffusion v1-4 inference', type=str, required=False, default='configs/stable-diffusion/v1-inference.yaml') | |
| parser.add_argument('--ckpt_path', help='ckpt path for stable diffusion v1-4', type=str, required=False, default='models/ldm/stable-diffusion-v1/sd-v1-4-full-ema.ckpt') | |
| parser.add_argument('--diffusers_config_path', help='diffusers unet config json path', type=str, required=False, default='diffusers_unet_config.json') | |
| parser.add_argument('--devices', help='cuda devices to train on', type=str, required=False, default='0,0') | |
| parser.add_argument('--seperator', help='separator if you want to train bunch of words separately', type=str, required=False, default=None) | |
| parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512) | |
| parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50) | |
| args = parser.parse_args() | |
| prompt = args.prompt | |
| train_method = args.train_method | |
| start_guidance = args.start_guidance | |
| negative_guidance = args.negative_guidance | |
| iterations = args.iterations | |
| lr = args.lr | |
| config_path = args.config_path | |
| ckpt_path = args.ckpt_path | |
| diffusers_config_path = args.diffusers_config_path | |
| devices = [f'cuda:{int(d.strip())}' for d in args.devices.split(',')] | |
| seperator = args.seperator | |
| image_size = args.image_size | |
| ddim_steps = args.ddim_steps | |
| train_esd(prompt=prompt, train_method=train_method, start_guidance=start_guidance, negative_guidance=negative_guidance, iterations=iterations, lr=lr, config_path=config_path, ckpt_path=ckpt_path, diffusers_config_path=diffusers_config_path, devices=devices, seperator=seperator, image_size=image_size, ddim_steps=ddim_steps) | |