Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from diffusers import AutoencoderKL, DDIMScheduler, DDIMInverseScheduler, DPMSolverMultistepScheduler | |
| from .unet_2d_condition import UNet2DConditionModel | |
| from easydict import EasyDict | |
| import numpy as np | |
| # For compatibility | |
| from utils.latents import get_unscaled_latents, get_scaled_latents, blend_latents | |
| from utils import torch_device | |
| def load_sd(key="runwayml/stable-diffusion-v1-5", use_fp16=False, load_inverse_scheduler=True): | |
| """ | |
| Keys: | |
| key = "CompVis/stable-diffusion-v1-4" | |
| key = "runwayml/stable-diffusion-v1-5" | |
| key = "stabilityai/stable-diffusion-2-1-base" | |
| Unpack with: | |
| ``` | |
| model_dict = load_sd(key=key, use_fp16=use_fp16) | |
| vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype | |
| ``` | |
| use_fp16: fp16 might have degraded performance | |
| """ | |
| # run final results in fp32 | |
| if use_fp16: | |
| dtype = torch.float16 | |
| revision = "fp16" | |
| else: | |
| dtype = torch.float | |
| revision = "main" | |
| vae = AutoencoderKL.from_pretrained(key, subfolder="vae", revision=revision, torch_dtype=dtype).to(torch_device) | |
| tokenizer = CLIPTokenizer.from_pretrained(key, subfolder="tokenizer", revision=revision, torch_dtype=dtype) | |
| text_encoder = CLIPTextModel.from_pretrained(key, subfolder="text_encoder", revision=revision, torch_dtype=dtype).to(torch_device) | |
| unet = UNet2DConditionModel.from_pretrained(key, subfolder="unet", revision=revision, torch_dtype=dtype).to(torch_device) | |
| dpm_scheduler = DPMSolverMultistepScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype) | |
| scheduler = DDIMScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype) | |
| model_dict = EasyDict(vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, dpm_scheduler=dpm_scheduler, dtype=dtype) | |
| if load_inverse_scheduler: | |
| inverse_scheduler = DDIMInverseScheduler.from_config(scheduler.config) | |
| model_dict.inverse_scheduler = inverse_scheduler | |
| return model_dict | |
| def encode_prompts(tokenizer, text_encoder, prompts, negative_prompt="", return_full_only=False, one_uncond_input_only=False): | |
| if negative_prompt == "": | |
| print("Note that negative_prompt is an empty string") | |
| text_input = tokenizer( | |
| prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt" | |
| ) | |
| max_length = text_input.input_ids.shape[-1] | |
| if one_uncond_input_only: | |
| num_uncond_input = 1 | |
| else: | |
| num_uncond_input = len(prompts) | |
| uncond_input = tokenizer([negative_prompt] * num_uncond_input, padding="max_length", max_length=max_length, return_tensors="pt") | |
| with torch.no_grad(): | |
| uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] | |
| cond_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] | |
| if one_uncond_input_only: | |
| return uncond_embeddings, cond_embeddings | |
| text_embeddings = torch.cat([uncond_embeddings, cond_embeddings]) | |
| if return_full_only: | |
| return text_embeddings | |
| return text_embeddings, uncond_embeddings, cond_embeddings | |
| def process_input_embeddings(input_embeddings): | |
| assert isinstance(input_embeddings, (tuple, list)) | |
| if len(input_embeddings) == 3: | |
| # input_embeddings: text_embeddings, uncond_embeddings, cond_embeddings | |
| # Assume `uncond_embeddings` is full (has batch size the same as cond_embeddings) | |
| _, uncond_embeddings, cond_embeddings = input_embeddings | |
| assert uncond_embeddings.shape[0] == cond_embeddings.shape[0], f"{uncond_embeddings.shape[0]} != {cond_embeddings.shape[0]}" | |
| return input_embeddings | |
| elif len(input_embeddings) == 2: | |
| # input_embeddings: uncond_embeddings, cond_embeddings | |
| # uncond_embeddings may have only one item | |
| uncond_embeddings, cond_embeddings = input_embeddings | |
| if uncond_embeddings.shape[0] == 1: | |
| uncond_embeddings = uncond_embeddings.expand(cond_embeddings.shape) | |
| # We follow the convention: negative (unconditional) prompt comes first | |
| text_embeddings = torch.cat((uncond_embeddings, cond_embeddings), dim=0) | |
| return text_embeddings, uncond_embeddings, cond_embeddings | |
| else: | |
| raise ValueError(f"input_embeddings length: {len(input_embeddings)}") | |