| from .diffusion_utils import build_pipeline | |
| NAME_TO_MODEL = { | |
| "stable-diffusion-v1-4": | |
| { | |
| "model" : "CompVis/stable-diffusion-v1-4", | |
| "unet" : "CompVis/stable-diffusion-v1-4", | |
| "tokenizer" : "openai/clip-vit-large-patch14", | |
| "text_encoder" : "openai/clip-vit-large-patch14", | |
| }, | |
| "stable_diffusion_v2_1": | |
| { | |
| "model" : "stabilityai/stable-diffusion-2-1", | |
| "unet" : "stabilityai/stable-diffusion-2-1", | |
| "tokenizer" : "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", | |
| "text_encoder" : "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", | |
| } | |
| } | |
| def get_model(model_name): | |
| model = NAME_TO_MODEL.get(model_name) | |
| if model is None: | |
| raise ValueError(f"Model name {model_name} not found. Available models: {list(NAME_TO_MODEL.keys())}") | |
| vae, tokenizer, text_encoder, unet = build_pipeline(model["model"], model["tokenizer"], model["text_encoder"], model["unet"]) | |
| return vae, tokenizer, text_encoder, unet |