pony-v7-base / lora /convert_simpletuner_lora.py
astralite-heart's picture
Create convert_simpletuner_lora.py
820ba22 verified
#!/usr/bin/env python3
"""
Convert SimpleTuner LoRA weights to diffusers-compatible format for AuraFlow.
This script converts LoRA weights saved by SimpleTuner into a format that can be
directly loaded by diffusers' load_lora_weights() method.
Usage:
python convert_simpletuner_lora.py <input_lora.safetensors> <output_lora.safetensors>
Example:
python convert_simpletuner_lora.py input_lora.safetensors diffusers_compatible_lora.safetensors
"""
import argparse
import sys
from pathlib import Path
from typing import Dict
import safetensors.torch
import torch
def detect_lora_format(state_dict: Dict[str, torch.Tensor]) -> str:
"""
Detect the format of the LoRA state dict.
Returns:
"peft" if already in PEFT/diffusers format
"mixed" if mixed format (some lora_A/B, some lora.down/up)
"simpletuner_transformer" if in SimpleTuner format with transformer prefix
"simpletuner_auraflow" if in SimpleTuner AuraFlow format
"kohya" if in Kohya format
"unknown" otherwise
"""
keys = list(state_dict.keys())
# Check the actual weight naming convention (lora_A/lora_B vs lora_down/lora_up)
has_lora_a_b = any((".lora_A." in k or ".lora_B." in k) for k in keys)
has_lora_down_up = any((".lora_down." in k or ".lora_up." in k) for k in keys)
has_lora_dot_down_up = any((".lora.down." in k or ".lora.up." in k) for k in keys)
# Check prefixes
has_transformer_prefix = any(k.startswith("transformer.") for k in keys)
has_lora_transformer_prefix = any(k.startswith("lora_transformer_") for k in keys)
has_lora_unet_prefix = any(k.startswith("lora_unet_") for k in keys)
# Mixed format: has both lora_A/B AND lora.down/up (SimpleTuner hybrid)
if has_transformer_prefix and has_lora_a_b and (has_lora_down_up or has_lora_dot_down_up):
return "mixed"
# Pure PEFT format: transformer.* with ONLY lora_A/lora_B
if has_transformer_prefix and has_lora_a_b and not has_lora_down_up and not has_lora_dot_down_up:
return "peft"
# SimpleTuner with transformer prefix but old naming: transformer.* with lora_down/lora_up
if has_transformer_prefix and (has_lora_down_up or has_lora_dot_down_up):
return "simpletuner_transformer"
# SimpleTuner AuraFlow format: lora_transformer_* with lora_down/lora_up
if has_lora_transformer_prefix and has_lora_down_up:
return "simpletuner_auraflow"
# Traditional Kohya format: lora_unet_* with lora_down/lora_up
if has_lora_unet_prefix and has_lora_down_up:
return "kohya"
return "unknown"
def convert_mixed_lora_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert mixed LoRA format to pure PEFT format.
SimpleTuner sometimes saves a hybrid format where some layers use lora_A/lora_B
and others use .lora.down./.lora.up. This converts all to lora_A/lora_B.
"""
new_state_dict = {}
converted_count = 0
kept_count = 0
skipped_count = 0
renames = []
# Get all keys
all_keys = sorted(state_dict.keys())
print("\nProcessing keys:")
print("-" * 80)
for key in all_keys:
# Already in correct format (lora_A or lora_B)
if ".lora_A." in key or ".lora_B." in key:
new_state_dict[key] = state_dict[key]
kept_count += 1
# Needs conversion: .lora.down. -> .lora_A.
elif ".lora.down.weight" in key:
new_key = key.replace(".lora.down.weight", ".lora_A.weight")
new_state_dict[new_key] = state_dict[key]
renames.append((key, new_key))
converted_count += 1
# Needs conversion: .lora.up. -> .lora_B.
elif ".lora.up.weight" in key:
new_key = key.replace(".lora.up.weight", ".lora_B.weight")
new_state_dict[new_key] = state_dict[key]
renames.append((key, new_key))
converted_count += 1
# Skip alpha keys (not used in PEFT format)
elif ".alpha" in key:
skipped_count += 1
continue
# Other keys (shouldn't happen, but keep them just in case)
else:
new_state_dict[key] = state_dict[key]
print(f"⚠ Warning: Unexpected key format: {key}")
print(f"\nSummary:")
print(f" ✓ Kept {kept_count} keys already in correct format (lora_A/lora_B)")
print(f" ✓ Converted {converted_count} keys from .lora.down/.lora.up to lora_A/lora_B")
print(f" ✓ Skipped {skipped_count} alpha keys")
if renames:
print(f"\nRenames applied ({len(renames)} conversions):")
print("-" * 80)
for old_key, new_key in renames:
# Show the difference more clearly
if ".lora.down.weight" in old_key:
layer = old_key.replace(".lora.down.weight", "")
print(f" {layer}")
print(f" .lora.down.weight → .lora_A.weight")
elif ".lora.up.weight" in old_key:
layer = old_key.replace(".lora.up.weight", "")
print(f" {layer}")
print(f" .lora.up.weight → .lora_B.weight")
return new_state_dict
def convert_simpletuner_transformer_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert SimpleTuner transformer format (already has transformer. prefix but uses lora_down/lora_up)
to diffusers PEFT format (transformer. prefix with lora_A/lora_B).
This is a simpler conversion since the key structure is already correct.
"""
new_state_dict = {}
renames = []
# Get all unique LoRA layer base names (without .lora_down/.lora_up/.alpha suffix)
all_keys = list(state_dict.keys())
base_keys = set()
for key in all_keys:
if ".lora_down.weight" in key:
base_key = key.replace(".lora_down.weight", "")
base_keys.add(base_key)
print(f"\nFound {len(base_keys)} LoRA layers to convert")
print("-" * 80)
# Convert each layer
for base_key in sorted(base_keys):
down_key = f"{base_key}.lora_down.weight"
up_key = f"{base_key}.lora_up.weight"
alpha_key = f"{base_key}.alpha"
if down_key not in state_dict or up_key not in state_dict:
print(f"⚠ Warning: Missing weights for {base_key}")
continue
down_weight = state_dict.pop(down_key)
up_weight = state_dict.pop(up_key)
# Handle alpha scaling
has_alpha = False
if alpha_key in state_dict:
alpha = state_dict.pop(alpha_key)
lora_rank = down_weight.shape[0]
scale = alpha / lora_rank
# Calculate scale_down and scale_up to preserve the scale value
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
has_alpha = True
# Store in PEFT format (lora_A = down, lora_B = up)
new_down_key = f"{base_key}.lora_A.weight"
new_up_key = f"{base_key}.lora_B.weight"
new_state_dict[new_down_key] = down_weight
new_state_dict[new_up_key] = up_weight
renames.append((down_key, new_down_key, has_alpha))
renames.append((up_key, new_up_key, has_alpha))
# Check for any remaining keys
remaining = [k for k in state_dict.keys() if not k.startswith("text_encoder")]
if remaining:
print(f"⚠ Warning: {len(remaining)} keys were not converted: {remaining[:5]}")
print(f"\nRenames applied ({len(renames)} conversions):")
print("-" * 80)
# Group by layer
current_layer = None
for old_key, new_key, has_alpha in renames:
layer = old_key.replace(".lora_down.weight", "").replace(".lora_up.weight", "")
if layer != current_layer:
alpha_str = " (alpha scaled)" if has_alpha else ""
print(f"\n {layer}{alpha_str}")
current_layer = layer
if ".lora_down.weight" in old_key:
print(f" .lora_down.weight → .lora_A.weight")
elif ".lora_up.weight" in old_key:
print(f" .lora_up.weight → .lora_B.weight")
return new_state_dict
def convert_simpletuner_auraflow_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert SimpleTuner AuraFlow LoRA format to diffusers PEFT format.
SimpleTuner typically saves LoRAs in a format similar to Kohya's sd-scripts,
but for transformer-based models like AuraFlow, the keys may differ.
"""
new_state_dict = {}
def _convert(original_key, diffusers_key, state_dict, new_state_dict):
"""Helper to convert a single LoRA layer."""
down_key = f"{original_key}.lora_down.weight"
if down_key not in state_dict:
return False
down_weight = state_dict.pop(down_key)
lora_rank = down_weight.shape[0]
up_weight_key = f"{original_key}.lora_up.weight"
up_weight = state_dict.pop(up_weight_key)
# Handle alpha scaling
alpha_key = f"{original_key}.alpha"
if alpha_key in state_dict:
alpha = state_dict.pop(alpha_key)
scale = alpha / lora_rank
# Calculate scale_down and scale_up to preserve the scale value
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
# Store in PEFT format (lora_A = down, lora_B = up)
diffusers_down_key = f"{diffusers_key}.lora_A.weight"
new_state_dict[diffusers_down_key] = down_weight
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
return True
# Get all unique LoRA layer names
all_unique_keys = {
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
for k in state_dict
if ".lora_down.weight" in k or ".lora_up.weight" in k or ".alpha" in k
}
# Process transformer blocks
for original_key in sorted(all_unique_keys):
if original_key.startswith("lora_transformer_single_transformer_blocks_"):
# Single transformer blocks
parts = original_key.split("lora_transformer_single_transformer_blocks_")[-1].split("_")
block_idx = int(parts[0])
diffusers_key = f"single_transformer_blocks.{block_idx}"
# Map the rest of the key
remaining = "_".join(parts[1:])
if "attn_to_q" in remaining:
diffusers_key += ".attn.to_q"
elif "attn_to_k" in remaining:
diffusers_key += ".attn.to_k"
elif "attn_to_v" in remaining:
diffusers_key += ".attn.to_v"
elif "proj_out" in remaining:
diffusers_key += ".proj_out"
elif "proj_mlp" in remaining:
diffusers_key += ".proj_mlp"
elif "norm_linear" in remaining:
diffusers_key += ".norm.linear"
else:
print(f"Warning: Unhandled single block key pattern: {original_key}")
continue
elif original_key.startswith("lora_transformer_transformer_blocks_"):
# Double transformer blocks
parts = original_key.split("lora_transformer_transformer_blocks_")[-1].split("_")
block_idx = int(parts[0])
diffusers_key = f"transformer_blocks.{block_idx}"
# Map the rest of the key
remaining = "_".join(parts[1:])
if "attn_to_out_0" in remaining:
diffusers_key += ".attn.to_out.0"
elif "attn_to_add_out" in remaining:
diffusers_key += ".attn.to_add_out"
elif "attn_to_q" in remaining:
diffusers_key += ".attn.to_q"
elif "attn_to_k" in remaining:
diffusers_key += ".attn.to_k"
elif "attn_to_v" in remaining:
diffusers_key += ".attn.to_v"
elif "attn_add_q_proj" in remaining:
diffusers_key += ".attn.add_q_proj"
elif "attn_add_k_proj" in remaining:
diffusers_key += ".attn.add_k_proj"
elif "attn_add_v_proj" in remaining:
diffusers_key += ".attn.add_v_proj"
elif "ff_net_0_proj" in remaining:
diffusers_key += ".ff.net.0.proj"
elif "ff_net_2" in remaining:
diffusers_key += ".ff.net.2"
elif "ff_context_net_0_proj" in remaining:
diffusers_key += ".ff_context.net.0.proj"
elif "ff_context_net_2" in remaining:
diffusers_key += ".ff_context.net.2"
elif "norm1_linear" in remaining:
diffusers_key += ".norm1.linear"
elif "norm1_context_linear" in remaining:
diffusers_key += ".norm1_context.linear"
else:
print(f"Warning: Unhandled double block key pattern: {original_key}")
continue
elif original_key.startswith("lora_te1_") or original_key.startswith("lora_te_"):
# Text encoder keys - handle separately
print(f"Found text encoder key: {original_key}")
continue
else:
print(f"Warning: Unknown key pattern: {original_key}")
continue
# Perform the conversion
_convert(original_key, diffusers_key, state_dict, new_state_dict)
# Add "transformer." prefix to all keys
transformer_state_dict = {
f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.")
}
# Check for remaining unconverted keys
if len(state_dict) > 0:
remaining_keys = [k for k in state_dict.keys() if not k.startswith("lora_te")]
if remaining_keys:
print(f"Warning: Some keys were not converted: {remaining_keys[:10]}")
return transformer_state_dict
def convert_lora(input_path: str, output_path: str) -> None:
"""
Main conversion function.
Args:
input_path: Path to input LoRA safetensors file
output_path: Path to output diffusers-compatible safetensors file
"""
print(f"Loading LoRA from: {input_path}")
state_dict = safetensors.torch.load_file(input_path)
print(f"Detecting LoRA format...")
format_type = detect_lora_format(state_dict)
print(f"Detected format: {format_type}")
if format_type == "peft":
print("LoRA is already in diffusers-compatible PEFT format!")
print("No conversion needed. Copying file...")
import shutil
shutil.copy(input_path, output_path)
return
elif format_type == "mixed":
print("Converting MIXED format LoRA to pure PEFT format...")
print("(Some layers use lora_A/B, others use .lora.down/.lora.up)")
converted_state_dict = convert_mixed_lora_to_diffusers(state_dict.copy())
elif format_type == "simpletuner_transformer":
print("Converting SimpleTuner transformer format to diffusers...")
print("(has transformer. prefix but uses lora_down/lora_up naming)")
converted_state_dict = convert_simpletuner_transformer_to_diffusers(state_dict.copy())
elif format_type == "simpletuner_auraflow":
print("Converting SimpleTuner AuraFlow format to diffusers...")
converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy())
elif format_type == "kohya":
print("Note: Detected Kohya format. This converter is optimized for AuraFlow.")
print("For other models, diffusers has built-in conversion.")
converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy())
else:
print("Error: Unknown LoRA format!")
print("Sample keys from the state dict:")
for i, key in enumerate(list(state_dict.keys())[:20]):
print(f" {key}")
sys.exit(1)
print(f"Saving converted LoRA to: {output_path}")
safetensors.torch.save_file(converted_state_dict, output_path)
print("\nConversion complete!")
print(f"Original keys: {len(state_dict)}")
print(f"Converted keys: {len(converted_state_dict)}")
def main():
parser = argparse.ArgumentParser(
description="Convert SimpleTuner LoRA to diffusers-compatible format",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Convert a SimpleTuner LoRA for AuraFlow
python convert_simpletuner_lora.py my_lora.safetensors diffusers_lora.safetensors
# Check format without converting
python convert_simpletuner_lora.py my_lora.safetensors /tmp/test.safetensors
"""
)
parser.add_argument(
"input",
type=str,
help="Input LoRA file (SimpleTuner format)"
)
parser.add_argument(
"output",
type=str,
help="Output LoRA file (diffusers-compatible format)"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Only detect format, don't convert"
)
args = parser.parse_args()
# Validate input file exists
if not Path(args.input).exists():
print(f"Error: Input file not found: {args.input}")
sys.exit(1)
if args.dry_run:
print(f"Loading LoRA from: {args.input}")
state_dict = safetensors.torch.load_file(args.input)
format_type = detect_lora_format(state_dict)
print(f"Detected format: {format_type}")
print(f"\nSample keys ({min(10, len(state_dict))} of {len(state_dict)}):")
for key in list(state_dict.keys())[:10]:
print(f" {key}")
return
convert_lora(args.input, args.output)
if __name__ == "__main__":
main()