Spaces:
Runtime error
Runtime error
| import torch | |
| from diffusers.loaders.lora import LoraLoaderMixin | |
| from typing import Dict, Union | |
| import numpy as np | |
| import imageio | |
| def load_lora_weights(unet, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name = None, **kwargs): | |
| # if a dict is passed, copy it instead of modifying it inplace | |
| if isinstance(pretrained_model_name_or_path_or_dict, dict): | |
| pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() | |
| # First, ensure that the checkpoint is a compatible one and can be successfully loaded. | |
| state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | |
| # remove prefix if not removed when saved | |
| state_dict = {name.replace('base_model.model.', ''): param for name, param in state_dict.items()} | |
| is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) | |
| if not is_correct_format: | |
| raise ValueError("Invalid LoRA checkpoint.") | |
| low_cpu_mem_usage = True | |
| LoraLoaderMixin.load_lora_into_unet( | |
| state_dict, | |
| network_alphas=network_alphas, | |
| unet = unet, | |
| low_cpu_mem_usage=low_cpu_mem_usage, | |
| adapter_name=adapter_name, | |
| ) | |
| def save_video(frames, save_path, fps, quality=9): | |
| writer = imageio.get_writer(save_path, fps=fps, quality=quality) | |
| for frame in frames: | |
| frame = np.array(frame) | |
| writer.append_data(frame) | |
| writer.close() |