|
|
|
|
|
""" |
|
|
Convert SimpleTuner LoRA weights to diffusers-compatible format for AuraFlow. |
|
|
|
|
|
This script converts LoRA weights saved by SimpleTuner into a format that can be |
|
|
directly loaded by diffusers' load_lora_weights() method. |
|
|
|
|
|
Usage: |
|
|
python convert_simpletuner_lora.py <input_lora.safetensors> <output_lora.safetensors> |
|
|
|
|
|
Example: |
|
|
python convert_simpletuner_lora.py input_lora.safetensors diffusers_compatible_lora.safetensors |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Dict |
|
|
|
|
|
import safetensors.torch |
|
|
import torch |
|
|
|
|
|
|
|
|
def detect_lora_format(state_dict: Dict[str, torch.Tensor]) -> str: |
|
|
""" |
|
|
Detect the format of the LoRA state dict. |
|
|
|
|
|
Returns: |
|
|
"peft" if already in PEFT/diffusers format |
|
|
"mixed" if mixed format (some lora_A/B, some lora.down/up) |
|
|
"simpletuner_transformer" if in SimpleTuner format with transformer prefix |
|
|
"simpletuner_auraflow" if in SimpleTuner AuraFlow format |
|
|
"kohya" if in Kohya format |
|
|
"unknown" otherwise |
|
|
""" |
|
|
keys = list(state_dict.keys()) |
|
|
|
|
|
|
|
|
has_lora_a_b = any((".lora_A." in k or ".lora_B." in k) for k in keys) |
|
|
has_lora_down_up = any((".lora_down." in k or ".lora_up." in k) for k in keys) |
|
|
has_lora_dot_down_up = any((".lora.down." in k or ".lora.up." in k) for k in keys) |
|
|
|
|
|
|
|
|
has_transformer_prefix = any(k.startswith("transformer.") for k in keys) |
|
|
has_lora_transformer_prefix = any(k.startswith("lora_transformer_") for k in keys) |
|
|
has_lora_unet_prefix = any(k.startswith("lora_unet_") for k in keys) |
|
|
|
|
|
|
|
|
if has_transformer_prefix and has_lora_a_b and (has_lora_down_up or has_lora_dot_down_up): |
|
|
return "mixed" |
|
|
|
|
|
|
|
|
if has_transformer_prefix and has_lora_a_b and not has_lora_down_up and not has_lora_dot_down_up: |
|
|
return "peft" |
|
|
|
|
|
|
|
|
if has_transformer_prefix and (has_lora_down_up or has_lora_dot_down_up): |
|
|
return "simpletuner_transformer" |
|
|
|
|
|
|
|
|
if has_lora_transformer_prefix and has_lora_down_up: |
|
|
return "simpletuner_auraflow" |
|
|
|
|
|
|
|
|
if has_lora_unet_prefix and has_lora_down_up: |
|
|
return "kohya" |
|
|
|
|
|
return "unknown" |
|
|
|
|
|
|
|
|
def convert_mixed_lora_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Convert mixed LoRA format to pure PEFT format. |
|
|
|
|
|
SimpleTuner sometimes saves a hybrid format where some layers use lora_A/lora_B |
|
|
and others use .lora.down./.lora.up. This converts all to lora_A/lora_B. |
|
|
""" |
|
|
new_state_dict = {} |
|
|
converted_count = 0 |
|
|
kept_count = 0 |
|
|
skipped_count = 0 |
|
|
renames = [] |
|
|
|
|
|
|
|
|
all_keys = sorted(state_dict.keys()) |
|
|
|
|
|
print("\nProcessing keys:") |
|
|
print("-" * 80) |
|
|
|
|
|
for key in all_keys: |
|
|
|
|
|
if ".lora_A." in key or ".lora_B." in key: |
|
|
new_state_dict[key] = state_dict[key] |
|
|
kept_count += 1 |
|
|
|
|
|
|
|
|
elif ".lora.down.weight" in key: |
|
|
new_key = key.replace(".lora.down.weight", ".lora_A.weight") |
|
|
new_state_dict[new_key] = state_dict[key] |
|
|
renames.append((key, new_key)) |
|
|
converted_count += 1 |
|
|
|
|
|
|
|
|
elif ".lora.up.weight" in key: |
|
|
new_key = key.replace(".lora.up.weight", ".lora_B.weight") |
|
|
new_state_dict[new_key] = state_dict[key] |
|
|
renames.append((key, new_key)) |
|
|
converted_count += 1 |
|
|
|
|
|
|
|
|
elif ".alpha" in key: |
|
|
skipped_count += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
else: |
|
|
new_state_dict[key] = state_dict[key] |
|
|
print(f"⚠ Warning: Unexpected key format: {key}") |
|
|
|
|
|
print(f"\nSummary:") |
|
|
print(f" ✓ Kept {kept_count} keys already in correct format (lora_A/lora_B)") |
|
|
print(f" ✓ Converted {converted_count} keys from .lora.down/.lora.up to lora_A/lora_B") |
|
|
print(f" ✓ Skipped {skipped_count} alpha keys") |
|
|
|
|
|
if renames: |
|
|
print(f"\nRenames applied ({len(renames)} conversions):") |
|
|
print("-" * 80) |
|
|
for old_key, new_key in renames: |
|
|
|
|
|
if ".lora.down.weight" in old_key: |
|
|
layer = old_key.replace(".lora.down.weight", "") |
|
|
print(f" {layer}") |
|
|
print(f" .lora.down.weight → .lora_A.weight") |
|
|
elif ".lora.up.weight" in old_key: |
|
|
layer = old_key.replace(".lora.up.weight", "") |
|
|
print(f" {layer}") |
|
|
print(f" .lora.up.weight → .lora_B.weight") |
|
|
|
|
|
return new_state_dict |
|
|
|
|
|
|
|
|
def convert_simpletuner_transformer_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Convert SimpleTuner transformer format (already has transformer. prefix but uses lora_down/lora_up) |
|
|
to diffusers PEFT format (transformer. prefix with lora_A/lora_B). |
|
|
|
|
|
This is a simpler conversion since the key structure is already correct. |
|
|
""" |
|
|
new_state_dict = {} |
|
|
renames = [] |
|
|
|
|
|
|
|
|
all_keys = list(state_dict.keys()) |
|
|
base_keys = set() |
|
|
|
|
|
for key in all_keys: |
|
|
if ".lora_down.weight" in key: |
|
|
base_key = key.replace(".lora_down.weight", "") |
|
|
base_keys.add(base_key) |
|
|
|
|
|
print(f"\nFound {len(base_keys)} LoRA layers to convert") |
|
|
print("-" * 80) |
|
|
|
|
|
|
|
|
for base_key in sorted(base_keys): |
|
|
down_key = f"{base_key}.lora_down.weight" |
|
|
up_key = f"{base_key}.lora_up.weight" |
|
|
alpha_key = f"{base_key}.alpha" |
|
|
|
|
|
if down_key not in state_dict or up_key not in state_dict: |
|
|
print(f"⚠ Warning: Missing weights for {base_key}") |
|
|
continue |
|
|
|
|
|
down_weight = state_dict.pop(down_key) |
|
|
up_weight = state_dict.pop(up_key) |
|
|
|
|
|
|
|
|
has_alpha = False |
|
|
if alpha_key in state_dict: |
|
|
alpha = state_dict.pop(alpha_key) |
|
|
lora_rank = down_weight.shape[0] |
|
|
scale = alpha / lora_rank |
|
|
|
|
|
|
|
|
scale_down = scale |
|
|
scale_up = 1.0 |
|
|
while scale_down * 2 < scale_up: |
|
|
scale_down *= 2 |
|
|
scale_up /= 2 |
|
|
|
|
|
down_weight = down_weight * scale_down |
|
|
up_weight = up_weight * scale_up |
|
|
has_alpha = True |
|
|
|
|
|
|
|
|
new_down_key = f"{base_key}.lora_A.weight" |
|
|
new_up_key = f"{base_key}.lora_B.weight" |
|
|
|
|
|
new_state_dict[new_down_key] = down_weight |
|
|
new_state_dict[new_up_key] = up_weight |
|
|
|
|
|
renames.append((down_key, new_down_key, has_alpha)) |
|
|
renames.append((up_key, new_up_key, has_alpha)) |
|
|
|
|
|
|
|
|
remaining = [k for k in state_dict.keys() if not k.startswith("text_encoder")] |
|
|
if remaining: |
|
|
print(f"⚠ Warning: {len(remaining)} keys were not converted: {remaining[:5]}") |
|
|
|
|
|
print(f"\nRenames applied ({len(renames)} conversions):") |
|
|
print("-" * 80) |
|
|
|
|
|
|
|
|
current_layer = None |
|
|
for old_key, new_key, has_alpha in renames: |
|
|
layer = old_key.replace(".lora_down.weight", "").replace(".lora_up.weight", "") |
|
|
|
|
|
if layer != current_layer: |
|
|
alpha_str = " (alpha scaled)" if has_alpha else "" |
|
|
print(f"\n {layer}{alpha_str}") |
|
|
current_layer = layer |
|
|
|
|
|
if ".lora_down.weight" in old_key: |
|
|
print(f" .lora_down.weight → .lora_A.weight") |
|
|
elif ".lora_up.weight" in old_key: |
|
|
print(f" .lora_up.weight → .lora_B.weight") |
|
|
|
|
|
return new_state_dict |
|
|
|
|
|
|
|
|
def convert_simpletuner_auraflow_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Convert SimpleTuner AuraFlow LoRA format to diffusers PEFT format. |
|
|
|
|
|
SimpleTuner typically saves LoRAs in a format similar to Kohya's sd-scripts, |
|
|
but for transformer-based models like AuraFlow, the keys may differ. |
|
|
""" |
|
|
new_state_dict = {} |
|
|
|
|
|
def _convert(original_key, diffusers_key, state_dict, new_state_dict): |
|
|
"""Helper to convert a single LoRA layer.""" |
|
|
down_key = f"{original_key}.lora_down.weight" |
|
|
if down_key not in state_dict: |
|
|
return False |
|
|
|
|
|
down_weight = state_dict.pop(down_key) |
|
|
lora_rank = down_weight.shape[0] |
|
|
|
|
|
up_weight_key = f"{original_key}.lora_up.weight" |
|
|
up_weight = state_dict.pop(up_weight_key) |
|
|
|
|
|
|
|
|
alpha_key = f"{original_key}.alpha" |
|
|
if alpha_key in state_dict: |
|
|
alpha = state_dict.pop(alpha_key) |
|
|
scale = alpha / lora_rank |
|
|
|
|
|
|
|
|
scale_down = scale |
|
|
scale_up = 1.0 |
|
|
while scale_down * 2 < scale_up: |
|
|
scale_down *= 2 |
|
|
scale_up /= 2 |
|
|
|
|
|
down_weight = down_weight * scale_down |
|
|
up_weight = up_weight * scale_up |
|
|
|
|
|
|
|
|
diffusers_down_key = f"{diffusers_key}.lora_A.weight" |
|
|
new_state_dict[diffusers_down_key] = down_weight |
|
|
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
all_unique_keys = { |
|
|
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") |
|
|
for k in state_dict |
|
|
if ".lora_down.weight" in k or ".lora_up.weight" in k or ".alpha" in k |
|
|
} |
|
|
|
|
|
|
|
|
for original_key in sorted(all_unique_keys): |
|
|
if original_key.startswith("lora_transformer_single_transformer_blocks_"): |
|
|
|
|
|
parts = original_key.split("lora_transformer_single_transformer_blocks_")[-1].split("_") |
|
|
block_idx = int(parts[0]) |
|
|
diffusers_key = f"single_transformer_blocks.{block_idx}" |
|
|
|
|
|
|
|
|
remaining = "_".join(parts[1:]) |
|
|
if "attn_to_q" in remaining: |
|
|
diffusers_key += ".attn.to_q" |
|
|
elif "attn_to_k" in remaining: |
|
|
diffusers_key += ".attn.to_k" |
|
|
elif "attn_to_v" in remaining: |
|
|
diffusers_key += ".attn.to_v" |
|
|
elif "proj_out" in remaining: |
|
|
diffusers_key += ".proj_out" |
|
|
elif "proj_mlp" in remaining: |
|
|
diffusers_key += ".proj_mlp" |
|
|
elif "norm_linear" in remaining: |
|
|
diffusers_key += ".norm.linear" |
|
|
else: |
|
|
print(f"Warning: Unhandled single block key pattern: {original_key}") |
|
|
continue |
|
|
|
|
|
elif original_key.startswith("lora_transformer_transformer_blocks_"): |
|
|
|
|
|
parts = original_key.split("lora_transformer_transformer_blocks_")[-1].split("_") |
|
|
block_idx = int(parts[0]) |
|
|
diffusers_key = f"transformer_blocks.{block_idx}" |
|
|
|
|
|
|
|
|
remaining = "_".join(parts[1:]) |
|
|
if "attn_to_out_0" in remaining: |
|
|
diffusers_key += ".attn.to_out.0" |
|
|
elif "attn_to_add_out" in remaining: |
|
|
diffusers_key += ".attn.to_add_out" |
|
|
elif "attn_to_q" in remaining: |
|
|
diffusers_key += ".attn.to_q" |
|
|
elif "attn_to_k" in remaining: |
|
|
diffusers_key += ".attn.to_k" |
|
|
elif "attn_to_v" in remaining: |
|
|
diffusers_key += ".attn.to_v" |
|
|
elif "attn_add_q_proj" in remaining: |
|
|
diffusers_key += ".attn.add_q_proj" |
|
|
elif "attn_add_k_proj" in remaining: |
|
|
diffusers_key += ".attn.add_k_proj" |
|
|
elif "attn_add_v_proj" in remaining: |
|
|
diffusers_key += ".attn.add_v_proj" |
|
|
elif "ff_net_0_proj" in remaining: |
|
|
diffusers_key += ".ff.net.0.proj" |
|
|
elif "ff_net_2" in remaining: |
|
|
diffusers_key += ".ff.net.2" |
|
|
elif "ff_context_net_0_proj" in remaining: |
|
|
diffusers_key += ".ff_context.net.0.proj" |
|
|
elif "ff_context_net_2" in remaining: |
|
|
diffusers_key += ".ff_context.net.2" |
|
|
elif "norm1_linear" in remaining: |
|
|
diffusers_key += ".norm1.linear" |
|
|
elif "norm1_context_linear" in remaining: |
|
|
diffusers_key += ".norm1_context.linear" |
|
|
else: |
|
|
print(f"Warning: Unhandled double block key pattern: {original_key}") |
|
|
continue |
|
|
|
|
|
elif original_key.startswith("lora_te1_") or original_key.startswith("lora_te_"): |
|
|
|
|
|
print(f"Found text encoder key: {original_key}") |
|
|
continue |
|
|
|
|
|
else: |
|
|
print(f"Warning: Unknown key pattern: {original_key}") |
|
|
continue |
|
|
|
|
|
|
|
|
_convert(original_key, diffusers_key, state_dict, new_state_dict) |
|
|
|
|
|
|
|
|
transformer_state_dict = { |
|
|
f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.") |
|
|
} |
|
|
|
|
|
|
|
|
if len(state_dict) > 0: |
|
|
remaining_keys = [k for k in state_dict.keys() if not k.startswith("lora_te")] |
|
|
if remaining_keys: |
|
|
print(f"Warning: Some keys were not converted: {remaining_keys[:10]}") |
|
|
|
|
|
return transformer_state_dict |
|
|
|
|
|
|
|
|
def convert_lora(input_path: str, output_path: str) -> None: |
|
|
""" |
|
|
Main conversion function. |
|
|
|
|
|
Args: |
|
|
input_path: Path to input LoRA safetensors file |
|
|
output_path: Path to output diffusers-compatible safetensors file |
|
|
""" |
|
|
print(f"Loading LoRA from: {input_path}") |
|
|
state_dict = safetensors.torch.load_file(input_path) |
|
|
|
|
|
print(f"Detecting LoRA format...") |
|
|
format_type = detect_lora_format(state_dict) |
|
|
print(f"Detected format: {format_type}") |
|
|
|
|
|
if format_type == "peft": |
|
|
print("LoRA is already in diffusers-compatible PEFT format!") |
|
|
print("No conversion needed. Copying file...") |
|
|
import shutil |
|
|
shutil.copy(input_path, output_path) |
|
|
return |
|
|
|
|
|
elif format_type == "mixed": |
|
|
print("Converting MIXED format LoRA to pure PEFT format...") |
|
|
print("(Some layers use lora_A/B, others use .lora.down/.lora.up)") |
|
|
converted_state_dict = convert_mixed_lora_to_diffusers(state_dict.copy()) |
|
|
|
|
|
elif format_type == "simpletuner_transformer": |
|
|
print("Converting SimpleTuner transformer format to diffusers...") |
|
|
print("(has transformer. prefix but uses lora_down/lora_up naming)") |
|
|
converted_state_dict = convert_simpletuner_transformer_to_diffusers(state_dict.copy()) |
|
|
|
|
|
elif format_type == "simpletuner_auraflow": |
|
|
print("Converting SimpleTuner AuraFlow format to diffusers...") |
|
|
converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy()) |
|
|
|
|
|
elif format_type == "kohya": |
|
|
print("Note: Detected Kohya format. This converter is optimized for AuraFlow.") |
|
|
print("For other models, diffusers has built-in conversion.") |
|
|
converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy()) |
|
|
|
|
|
else: |
|
|
print("Error: Unknown LoRA format!") |
|
|
print("Sample keys from the state dict:") |
|
|
for i, key in enumerate(list(state_dict.keys())[:20]): |
|
|
print(f" {key}") |
|
|
sys.exit(1) |
|
|
|
|
|
print(f"Saving converted LoRA to: {output_path}") |
|
|
safetensors.torch.save_file(converted_state_dict, output_path) |
|
|
|
|
|
print("\nConversion complete!") |
|
|
print(f"Original keys: {len(state_dict)}") |
|
|
print(f"Converted keys: {len(converted_state_dict)}") |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Convert SimpleTuner LoRA to diffusers-compatible format", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
# Convert a SimpleTuner LoRA for AuraFlow |
|
|
python convert_simpletuner_lora.py my_lora.safetensors diffusers_lora.safetensors |
|
|
|
|
|
# Check format without converting |
|
|
python convert_simpletuner_lora.py my_lora.safetensors /tmp/test.safetensors |
|
|
""" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"input", |
|
|
type=str, |
|
|
help="Input LoRA file (SimpleTuner format)" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"output", |
|
|
type=str, |
|
|
help="Output LoRA file (diffusers-compatible format)" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--dry-run", |
|
|
action="store_true", |
|
|
help="Only detect format, don't convert" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not Path(args.input).exists(): |
|
|
print(f"Error: Input file not found: {args.input}") |
|
|
sys.exit(1) |
|
|
|
|
|
if args.dry_run: |
|
|
print(f"Loading LoRA from: {args.input}") |
|
|
state_dict = safetensors.torch.load_file(args.input) |
|
|
format_type = detect_lora_format(state_dict) |
|
|
print(f"Detected format: {format_type}") |
|
|
print(f"\nSample keys ({min(10, len(state_dict))} of {len(state_dict)}):") |
|
|
for key in list(state_dict.keys())[:10]: |
|
|
print(f" {key}") |
|
|
return |
|
|
|
|
|
convert_lora(args.input, args.output) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|