Spaces:
Runtime error
Runtime error
| from StableDiffuser import StableDiffuser | |
| from finetuning import FineTunedModel | |
| import torch | |
| from tqdm import tqdm | |
| def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path): | |
| nsteps = 50 | |
| diffuser = StableDiffuser(scheduler='DDIM').to('cuda') | |
| diffuser.train() | |
| finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules) | |
| optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr) | |
| criteria = torch.nn.MSELoss() | |
| pbar = tqdm(range(iterations)) | |
| with torch.no_grad(): | |
| neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1) | |
| positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1) | |
| del diffuser.vae | |
| del diffuser.text_encoder | |
| del diffuser.tokenizer | |
| torch.cuda.empty_cache() | |
| for i in pbar: | |
| with torch.no_grad(): | |
| diffuser.set_scheduler_timesteps(nsteps) | |
| optimizer.zero_grad() | |
| iteration = torch.randint(1, nsteps - 1, (1,)).item() | |
| latents = diffuser.get_initial_latents(1, 512, 1) | |
| with finetuner: | |
| latents_steps, _ = diffuser.diffusion( | |
| latents, | |
| positive_text_embeddings, | |
| start_iteration=0, | |
| end_iteration=iteration, | |
| guidance_scale=3, | |
| show_progress=False | |
| ) | |
| diffuser.set_scheduler_timesteps(1000) | |
| iteration = int(iteration / nsteps * 1000) | |
| positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1) | |
| neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1) | |
| with finetuner: | |
| negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1) | |
| positive_latents.requires_grad = False | |
| neutral_latents.requires_grad = False | |
| loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs | |
| loss.backward() | |
| optimizer.step() | |
| torch.save(finetuner.state_dict(), save_path) | |
| del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents | |
| torch.cuda.empty_cache() | |
| if __name__ == '__main__': | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--prompt', required=True) | |
| parser.add_argument('--modules', required=True) | |
| parser.add_argument('--freeze_modules', nargs='+', required=True) | |
| parser.add_argument('--save_path', required=True) | |
| parser.add_argument('--iterations', type=int, required=True) | |
| parser.add_argument('--lr', type=float, required=True) | |
| parser.add_argument('--negative_guidance', type=float, required=True) | |
| train(**vars(parser.parse_args())) |