Spaces:
Running
on
Zero
Running
on
Zero
| from diffusers import AutoencoderKL, DDIMScheduler | |
| import torch | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from animatediff.models.unet import UNet3DConditionModel | |
| from omegaconf import OmegaConf | |
| from animatediff.pipelines.pipeline_animation import AnimationPipeline | |
| from animatediff.utils.util import load_weights | |
| from safetensors import safe_open | |
| from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint | |
| from faceadapter.face_adapter import FaceAdapterPlusForVideoLora | |
| model_style_type2base_model_path = { | |
| "realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors", | |
| "anime": "models/aingdiffusion/aingdiffusion_v170_ar.safetensors", | |
| "photorealistic": "models/sar/sar.safetensors" # LDM format. Needs to be converted. | |
| } | |
| def load_model(model_style_type="realistic", device="cuda"): | |
| inference_config = "inference-v2.yaml" | |
| sd_version = "models/animatediff/sd" | |
| id_ckpt = "models/animator.ckpt" | |
| image_encoder_path = "models/image_encoder" | |
| base_model_path = model_style_type2base_model_path[model_style_type] | |
| motion_module_path="models/v3_sd15_mm.ckpt" | |
| motion_lora_path = "models/v3_sd15_adapter.ckpt" | |
| inference_config = OmegaConf.load(inference_config) | |
| tokenizer = CLIPTokenizer.from_pretrained(sd_version, subfolder="tokenizer",torch_dtype=torch.float16, | |
| ) | |
| text_encoder = CLIPTextModel.from_pretrained(sd_version, subfolder="text_encoder",torch_dtype=torch.float16, | |
| ).to(device=device) | |
| vae = AutoencoderKL.from_pretrained(sd_version, subfolder="vae",torch_dtype=torch.float16, | |
| ).to(device=device) | |
| unet = UNet3DConditionModel.from_pretrained_2d(sd_version, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs) | |
| ).to(device=device) | |
| # unet.to(dtype=torch.float16) does not work on hf spaces. | |
| unet = unet.half() | |
| pipeline = AnimationPipeline( | |
| vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, | |
| controlnet=None, | |
| #beta_start=0.00085, beta_end=0.012, beta_schedule="linear",steps_offset=1 | |
| scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs) | |
| # scheduler=DPMSolverMultistepScheduler(**OmegaConf.to_container(inference_config.DPMSolver_scheduler_kwargs) | |
| # scheduler=EulerAncestralDiscreteScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs) | |
| # scheduler=EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear",steps_offset=1 | |
| ), | |
| torch_dtype=torch.float16, | |
| ).to(device=device) | |
| pipeline = load_weights( | |
| pipeline, | |
| # motion module | |
| motion_module_path = motion_module_path, | |
| motion_module_lora_configs = [], | |
| # domain adapter | |
| adapter_lora_path = motion_lora_path, | |
| adapter_lora_scale = 1, | |
| # image layers | |
| dreambooth_model_path = None, | |
| lora_model_path = "", | |
| lora_alpha = 0.8 | |
| ).to(device=device) | |
| if base_model_path != "": | |
| print(f"load dreambooth model from {base_model_path}") | |
| dreambooth_state_dict = {} | |
| with safe_open(base_model_path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| dreambooth_state_dict[key] = f.get_tensor(key) | |
| converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config) | |
| # print(vae) | |
| # vae ->to_q, to_k, to_v | |
| # print(converted_vae_checkpoint) | |
| convert_vae_keys = list(converted_vae_checkpoint.keys()) | |
| for key in convert_vae_keys: | |
| if "encoder.mid_block.attentions" in key or "decoder.mid_block.attentions" in key: | |
| new_key = None | |
| if "key" in key: | |
| new_key = key.replace("key","to_k") | |
| elif "query" in key: | |
| new_key = key.replace("query","to_q") | |
| elif "value" in key: | |
| new_key = key.replace("value","to_v") | |
| elif "proj_attn" in key: | |
| new_key = key.replace("proj_attn","to_out.0") | |
| if new_key: | |
| converted_vae_checkpoint[new_key] = converted_vae_checkpoint.pop(key) | |
| pipeline.vae.load_state_dict(converted_vae_checkpoint) | |
| converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config) | |
| m, u = pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) | |
| print(f"### custom unet missing keys: {len(m)}; \n### unexpected keys: {len(u)};") | |
| pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict, dtype=torch.float16).to(device=device) | |
| del dreambooth_state_dict | |
| pipeline = pipeline.to(torch.float16) | |
| id_animator = FaceAdapterPlusForVideoLora(pipeline, image_encoder_path, id_ckpt, num_tokens=16, | |
| device=torch.device(device), torch_type=torch.float16) | |
| return id_animator | |