Spaces:
Runtime error
Runtime error
| # ref: | |
| # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py | |
| # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py | |
| import os | |
| import math | |
| from typing import Optional, List, Type, Set, Literal | |
| import torch | |
| import torch.nn as nn | |
| from diffusers import UNet2DConditionModel | |
| from safetensors.torch import save_file | |
| UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [ | |
| # "Transformer2DModel", # ใฉใใใใใฃใกใฎๆนใใใ๏ผ # attn1, 2 | |
| "Attention" | |
| ] | |
| UNET_TARGET_REPLACE_MODULE_CONV = [ | |
| "ResnetBlock2D", | |
| "Downsample2D", | |
| "Upsample2D", | |
| "DownBlock2D", | |
| "UpBlock2D", | |
| ] # locon, 3clier | |
| LORA_PREFIX_UNET = "lora_unet" | |
| DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER | |
| TRAINING_METHODS = Literal[ | |
| "noxattn", # train all layers except x-attns and time_embed layers | |
| "innoxattn", # train all layers except self attention layers | |
| "selfattn", # ESD-u, train only self attention layers | |
| "xattn", # ESD-x, train only x attention layers | |
| "full", # train all layers | |
| "xattn-strict", # q and k values | |
| "noxattn-hspace", | |
| "noxattn-hspace-last", | |
| # "xlayer", | |
| # "outxattn", | |
| # "outsattn", | |
| # "inxattn", | |
| # "inmidsattn", | |
| # "selflayer", | |
| ] | |
| class LoRAModule(nn.Module): | |
| """ | |
| replaces forward method of the original Linear, instead of replacing the original Linear module. | |
| """ | |
| def __init__( | |
| self, | |
| lora_name, | |
| org_module: nn.Module, | |
| multiplier=1.0, | |
| lora_dim=4, | |
| alpha=1, | |
| ): | |
| """if alpha == 0 or None, alpha is rank (no scaling).""" | |
| super().__init__() | |
| self.lora_name = lora_name | |
| self.lora_dim = lora_dim | |
| if "Linear" in org_module.__class__.__name__: | |
| in_dim = org_module.in_features | |
| out_dim = org_module.out_features | |
| self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) | |
| self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) | |
| elif "Conv" in org_module.__class__.__name__: # ไธๅฟ | |
| in_dim = org_module.in_channels | |
| out_dim = org_module.out_channels | |
| self.lora_dim = min(self.lora_dim, in_dim, out_dim) | |
| if self.lora_dim != lora_dim: | |
| print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") | |
| kernel_size = org_module.kernel_size | |
| stride = org_module.stride | |
| padding = org_module.padding | |
| self.lora_down = nn.Conv2d( | |
| in_dim, self.lora_dim, kernel_size, stride, padding, bias=False | |
| ) | |
| self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) | |
| if type(alpha) == torch.Tensor: | |
| alpha = alpha.detach().numpy() | |
| alpha = lora_dim if alpha is None or alpha == 0 else alpha | |
| self.scale = alpha / self.lora_dim | |
| self.register_buffer("alpha", torch.tensor(alpha)) # ๅฎๆฐใจใใฆๆฑใใ | |
| # same as microsoft's | |
| nn.init.kaiming_uniform_(self.lora_down.weight, a=1) | |
| nn.init.zeros_(self.lora_up.weight) | |
| self.multiplier = multiplier | |
| self.org_module = org_module # remove in applying | |
| def apply_to(self): | |
| self.org_forward = self.org_module.forward | |
| self.org_module.forward = self.forward | |
| del self.org_module | |
| def forward(self, x): | |
| return ( | |
| self.org_forward(x) | |
| + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale | |
| ) | |
| class LoRANetwork(nn.Module): | |
| def __init__( | |
| self, | |
| unet: UNet2DConditionModel, | |
| rank: int = 4, | |
| multiplier: float = 1.0, | |
| alpha: float = 1.0, | |
| train_method: TRAINING_METHODS = "full", | |
| ) -> None: | |
| super().__init__() | |
| self.lora_scale = 1 | |
| self.multiplier = multiplier | |
| self.lora_dim = rank | |
| self.alpha = alpha | |
| # LoRAใฎใฟ | |
| self.module = LoRAModule | |
| # unetใฎloraใไฝใ | |
| self.unet_loras = self.create_modules( | |
| LORA_PREFIX_UNET, | |
| unet, | |
| DEFAULT_TARGET_REPLACE, | |
| self.lora_dim, | |
| self.multiplier, | |
| train_method=train_method, | |
| ) | |
| print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") | |
| # assertion ๅๅใฎ่ขซใใใชใใ็ขบ่ชใใฆใใใใใ | |
| lora_names = set() | |
| for lora in self.unet_loras: | |
| assert ( | |
| lora.lora_name not in lora_names | |
| ), f"duplicated lora name: {lora.lora_name}. {lora_names}" | |
| lora_names.add(lora.lora_name) | |
| # ้ฉ็จใใ | |
| for lora in self.unet_loras: | |
| lora.apply_to() | |
| self.add_module( | |
| lora.lora_name, | |
| lora, | |
| ) | |
| del unet | |
| torch.cuda.empty_cache() | |
| def create_modules( | |
| self, | |
| prefix: str, | |
| root_module: nn.Module, | |
| target_replace_modules: List[str], | |
| rank: int, | |
| multiplier: float, | |
| train_method: TRAINING_METHODS, | |
| ) -> list: | |
| loras = [] | |
| names = [] | |
| for name, module in root_module.named_modules(): | |
| if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention ใจ Time Embed ไปฅๅคๅญฆ็ฟ | |
| if "attn2" in name or "time_embed" in name: | |
| continue | |
| elif train_method == "innoxattn": # Cross Attention ไปฅๅคๅญฆ็ฟ | |
| if "attn2" in name: | |
| continue | |
| elif train_method == "selfattn": # Self Attention ใฎใฟๅญฆ็ฟ | |
| if "attn1" not in name: | |
| continue | |
| elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention ใฎใฟๅญฆ็ฟ | |
| if "attn2" not in name: | |
| continue | |
| elif train_method == "full": # ๅ จ้จๅญฆ็ฟ | |
| pass | |
| else: | |
| raise NotImplementedError( | |
| f"train_method: {train_method} is not implemented." | |
| ) | |
| if module.__class__.__name__ in target_replace_modules: | |
| for child_name, child_module in module.named_modules(): | |
| if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]: | |
| if train_method == 'xattn-strict': | |
| if 'out' in child_name: | |
| continue | |
| if train_method == 'noxattn-hspace': | |
| if 'mid_block' not in name: | |
| continue | |
| if train_method == 'noxattn-hspace-last': | |
| if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name: | |
| continue | |
| lora_name = prefix + "." + name + "." + child_name | |
| lora_name = lora_name.replace(".", "_") | |
| # print(f"{lora_name}") | |
| lora = self.module( | |
| lora_name, child_module, multiplier, rank, self.alpha | |
| ) | |
| # print(name, child_name) | |
| # print(child_module.weight.shape) | |
| if lora_name not in names: | |
| loras.append(lora) | |
| names.append(lora_name) | |
| # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}') | |
| return loras | |
| def prepare_optimizer_params(self): | |
| all_params = [] | |
| if self.unet_loras: # ๅฎ่ณชใใใใใชใ | |
| params = [] | |
| [params.extend(lora.parameters()) for lora in self.unet_loras] | |
| param_data = {"params": params} | |
| all_params.append(param_data) | |
| return all_params | |
| def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): | |
| state_dict = self.state_dict() | |
| if dtype is not None: | |
| for key in list(state_dict.keys()): | |
| v = state_dict[key] | |
| v = v.detach().clone().to("cpu").to(dtype) | |
| state_dict[key] = v | |
| # for key in list(state_dict.keys()): | |
| # if not key.startswith("lora"): | |
| # # loraไปฅๅค้คๅค | |
| # del state_dict[key] | |
| if os.path.splitext(file)[1] == ".safetensors": | |
| save_file(state_dict, file, metadata) | |
| else: | |
| torch.save(state_dict, file) | |
| def set_lora_slider(self, scale): | |
| self.lora_scale = scale | |
| def __enter__(self): | |
| for lora in self.unet_loras: | |
| lora.multiplier = 1.0 * self.lora_scale | |
| def __exit__(self, exc_type, exc_value, tb): | |
| for lora in self.unet_loras: | |
| lora.multiplier = 0 | |