Spaces:
Running
Running
| import glob | |
| from os import path | |
| from paths import get_file_name, FastStableDiffusionPaths | |
| from pathlib import Path | |
| # A basic class to keep track of the currently loaded LoRAs and | |
| # their weights; the diffusers function \c get_active_adapters() | |
| # returns a list of adapter names but not their weights so we need | |
| # a way to keep track of the current LoRA weights to set whenever | |
| # a new LoRA is loaded | |
| class _lora_info: | |
| def __init__( | |
| self, | |
| path: str, | |
| weight: float, | |
| ): | |
| self.path = path | |
| self.adapter_name = get_file_name(path) | |
| self.weight = weight | |
| def __del__(self): | |
| self.path = None | |
| self.adapter_name = None | |
| _loaded_loras = [] | |
| _current_pipeline = None | |
| # This function loads a LoRA from the LoRA path setting, so it's | |
| # possible to load multiple LoRAs by calling this function more than | |
| # once with a different LoRA path setting; note that if you plan to | |
| # load multiple LoRAs and dynamically change their weights, you | |
| # might want to set the LoRA fuse option to False | |
| def load_lora_weight( | |
| pipeline, | |
| lcm_diffusion_setting, | |
| ): | |
| if not lcm_diffusion_setting.lora.path: | |
| raise Exception("Empty lora model path") | |
| if not path.exists(lcm_diffusion_setting.lora.path): | |
| raise Exception("Lora model path is invalid") | |
| # If the pipeline has been rebuilt since the last call, remove all | |
| # references to previously loaded LoRAs and store the new pipeline | |
| global _loaded_loras | |
| global _current_pipeline | |
| if pipeline != _current_pipeline: | |
| for lora in _loaded_loras: | |
| del lora | |
| del _loaded_loras | |
| _loaded_loras = [] | |
| _current_pipeline = pipeline | |
| current_lora = _lora_info( | |
| lcm_diffusion_setting.lora.path, | |
| lcm_diffusion_setting.lora.weight, | |
| ) | |
| _loaded_loras.append(current_lora) | |
| if lcm_diffusion_setting.lora.enabled: | |
| print(f"LoRA adapter name : {current_lora.adapter_name}") | |
| pipeline.load_lora_weights( | |
| FastStableDiffusionPaths.get_lora_models_path(), | |
| weight_name=Path(lcm_diffusion_setting.lora.path).name, | |
| local_files_only=True, | |
| adapter_name=current_lora.adapter_name, | |
| ) | |
| update_lora_weights( | |
| pipeline, | |
| lcm_diffusion_setting, | |
| ) | |
| if lcm_diffusion_setting.lora.fuse: | |
| pipeline.fuse_lora() | |
| def get_lora_models(root_dir: str): | |
| lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True) | |
| lora_models_map = {} | |
| for file_path in lora_models: | |
| lora_name = get_file_name(file_path) | |
| if lora_name is not None: | |
| lora_models_map[lora_name] = file_path | |
| return lora_models_map | |
| # This function returns a list of (adapter_name, weight) tuples for the | |
| # currently loaded LoRAs | |
| def get_active_lora_weights(): | |
| active_loras = [] | |
| for lora_info in _loaded_loras: | |
| active_loras.append( | |
| ( | |
| lora_info.adapter_name, | |
| lora_info.weight, | |
| ) | |
| ) | |
| return active_loras | |
| # This function receives a pipeline, an lcm_diffusion_setting object and | |
| # an optional list of updated (adapter_name, weight) tuples | |
| def update_lora_weights( | |
| pipeline, | |
| lcm_diffusion_setting, | |
| lora_weights=None, | |
| ): | |
| global _loaded_loras | |
| global _current_pipeline | |
| if pipeline != _current_pipeline: | |
| print("Wrong pipeline when trying to update LoRA weights") | |
| return | |
| if lora_weights: | |
| for idx, lora in enumerate(lora_weights): | |
| if _loaded_loras[idx].adapter_name != lora[0]: | |
| print("Wrong adapter name in LoRA enumeration!") | |
| continue | |
| _loaded_loras[idx].weight = lora[1] | |
| adapter_names = [] | |
| adapter_weights = [] | |
| if lcm_diffusion_setting.use_lcm_lora: | |
| adapter_names.append("lcm") | |
| adapter_weights.append(1.0) | |
| for lora in _loaded_loras: | |
| adapter_names.append(lora.adapter_name) | |
| adapter_weights.append(lora.weight) | |
| pipeline.set_adapters( | |
| adapter_names, | |
| adapter_weights=adapter_weights, | |
| ) | |
| adapter_weights = zip(adapter_names, adapter_weights) | |
| print(f"Adapters: {list(adapter_weights)}") | |