File size: 18,504 Bytes
820ba22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
#!/usr/bin/env python3
"""
Convert SimpleTuner LoRA weights to diffusers-compatible format for AuraFlow.

This script converts LoRA weights saved by SimpleTuner into a format that can be
directly loaded by diffusers' load_lora_weights() method.

Usage:
    python convert_simpletuner_lora.py <input_lora.safetensors> <output_lora.safetensors>

Example:
    python convert_simpletuner_lora.py input_lora.safetensors diffusers_compatible_lora.safetensors
"""

import argparse
import sys
from pathlib import Path
from typing import Dict

import safetensors.torch
import torch


def detect_lora_format(state_dict: Dict[str, torch.Tensor]) -> str:
    """
    Detect the format of the LoRA state dict.
    
    Returns:
        "peft" if already in PEFT/diffusers format
        "mixed" if mixed format (some lora_A/B, some lora.down/up)
        "simpletuner_transformer" if in SimpleTuner format with transformer prefix
        "simpletuner_auraflow" if in SimpleTuner AuraFlow format
        "kohya" if in Kohya format
        "unknown" otherwise
    """
    keys = list(state_dict.keys())
    
    # Check the actual weight naming convention (lora_A/lora_B vs lora_down/lora_up)
    has_lora_a_b = any((".lora_A." in k or ".lora_B." in k) for k in keys)
    has_lora_down_up = any((".lora_down." in k or ".lora_up." in k) for k in keys)
    has_lora_dot_down_up = any((".lora.down." in k or ".lora.up." in k) for k in keys)
    
    # Check prefixes
    has_transformer_prefix = any(k.startswith("transformer.") for k in keys)
    has_lora_transformer_prefix = any(k.startswith("lora_transformer_") for k in keys)
    has_lora_unet_prefix = any(k.startswith("lora_unet_") for k in keys)
    
    # Mixed format: has both lora_A/B AND lora.down/up (SimpleTuner hybrid)
    if has_transformer_prefix and has_lora_a_b and (has_lora_down_up or has_lora_dot_down_up):
        return "mixed"
    
    # Pure PEFT format: transformer.* with ONLY lora_A/lora_B
    if has_transformer_prefix and has_lora_a_b and not has_lora_down_up and not has_lora_dot_down_up:
        return "peft"
    
    # SimpleTuner with transformer prefix but old naming: transformer.* with lora_down/lora_up
    if has_transformer_prefix and (has_lora_down_up or has_lora_dot_down_up):
        return "simpletuner_transformer"
    
    # SimpleTuner AuraFlow format: lora_transformer_* with lora_down/lora_up
    if has_lora_transformer_prefix and has_lora_down_up:
        return "simpletuner_auraflow"
    
    # Traditional Kohya format: lora_unet_* with lora_down/lora_up
    if has_lora_unet_prefix and has_lora_down_up:
        return "kohya"
    
    return "unknown"


def convert_mixed_lora_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """
    Convert mixed LoRA format to pure PEFT format.
    
    SimpleTuner sometimes saves a hybrid format where some layers use lora_A/lora_B
    and others use .lora.down./.lora.up. This converts all to lora_A/lora_B.
    """
    new_state_dict = {}
    converted_count = 0
    kept_count = 0
    skipped_count = 0
    renames = []
    
    # Get all keys
    all_keys = sorted(state_dict.keys())
    
    print("\nProcessing keys:")
    print("-" * 80)
    
    for key in all_keys:
        # Already in correct format (lora_A or lora_B)
        if ".lora_A." in key or ".lora_B." in key:
            new_state_dict[key] = state_dict[key]
            kept_count += 1
        
        # Needs conversion: .lora.down. -> .lora_A.
        elif ".lora.down.weight" in key:
            new_key = key.replace(".lora.down.weight", ".lora_A.weight")
            new_state_dict[new_key] = state_dict[key]
            renames.append((key, new_key))
            converted_count += 1
        
        # Needs conversion: .lora.up. -> .lora_B.
        elif ".lora.up.weight" in key:
            new_key = key.replace(".lora.up.weight", ".lora_B.weight")
            new_state_dict[new_key] = state_dict[key]
            renames.append((key, new_key))
            converted_count += 1
        
        # Skip alpha keys (not used in PEFT format)
        elif ".alpha" in key:
            skipped_count += 1
            continue
        
        # Other keys (shouldn't happen, but keep them just in case)
        else:
            new_state_dict[key] = state_dict[key]
            print(f"⚠ Warning: Unexpected key format: {key}")
    
    print(f"\nSummary:")
    print(f"  βœ“ Kept {kept_count} keys already in correct format (lora_A/lora_B)")
    print(f"  βœ“ Converted {converted_count} keys from .lora.down/.lora.up to lora_A/lora_B")
    print(f"  βœ“ Skipped {skipped_count} alpha keys")
    
    if renames:
        print(f"\nRenames applied ({len(renames)} conversions):")
        print("-" * 80)
        for old_key, new_key in renames:
            # Show the difference more clearly
            if ".lora.down.weight" in old_key:
                layer = old_key.replace(".lora.down.weight", "")
                print(f"  {layer}")
                print(f"    .lora.down.weight  β†’  .lora_A.weight")
            elif ".lora.up.weight" in old_key:
                layer = old_key.replace(".lora.up.weight", "")
                print(f"  {layer}")
                print(f"    .lora.up.weight    β†’  .lora_B.weight")
    
    return new_state_dict


def convert_simpletuner_transformer_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """
    Convert SimpleTuner transformer format (already has transformer. prefix but uses lora_down/lora_up)
    to diffusers PEFT format (transformer. prefix with lora_A/lora_B).
    
    This is a simpler conversion since the key structure is already correct.
    """
    new_state_dict = {}
    renames = []
    
    # Get all unique LoRA layer base names (without .lora_down/.lora_up/.alpha suffix)
    all_keys = list(state_dict.keys())
    base_keys = set()
    
    for key in all_keys:
        if ".lora_down.weight" in key:
            base_key = key.replace(".lora_down.weight", "")
            base_keys.add(base_key)
    
    print(f"\nFound {len(base_keys)} LoRA layers to convert")
    print("-" * 80)
    
    # Convert each layer
    for base_key in sorted(base_keys):
        down_key = f"{base_key}.lora_down.weight"
        up_key = f"{base_key}.lora_up.weight"
        alpha_key = f"{base_key}.alpha"
        
        if down_key not in state_dict or up_key not in state_dict:
            print(f"⚠ Warning: Missing weights for {base_key}")
            continue
        
        down_weight = state_dict.pop(down_key)
        up_weight = state_dict.pop(up_key)
        
        # Handle alpha scaling
        has_alpha = False
        if alpha_key in state_dict:
            alpha = state_dict.pop(alpha_key)
            lora_rank = down_weight.shape[0]
            scale = alpha / lora_rank
            
            # Calculate scale_down and scale_up to preserve the scale value
            scale_down = scale
            scale_up = 1.0
            while scale_down * 2 < scale_up:
                scale_down *= 2
                scale_up /= 2
            
            down_weight = down_weight * scale_down
            up_weight = up_weight * scale_up
            has_alpha = True
        
        # Store in PEFT format (lora_A = down, lora_B = up)
        new_down_key = f"{base_key}.lora_A.weight"
        new_up_key = f"{base_key}.lora_B.weight"
        
        new_state_dict[new_down_key] = down_weight
        new_state_dict[new_up_key] = up_weight
        
        renames.append((down_key, new_down_key, has_alpha))
        renames.append((up_key, new_up_key, has_alpha))
    
    # Check for any remaining keys
    remaining = [k for k in state_dict.keys() if not k.startswith("text_encoder")]
    if remaining:
        print(f"⚠ Warning: {len(remaining)} keys were not converted: {remaining[:5]}")
    
    print(f"\nRenames applied ({len(renames)} conversions):")
    print("-" * 80)
    
    # Group by layer
    current_layer = None
    for old_key, new_key, has_alpha in renames:
        layer = old_key.replace(".lora_down.weight", "").replace(".lora_up.weight", "")
        
        if layer != current_layer:
            alpha_str = " (alpha scaled)" if has_alpha else ""
            print(f"\n  {layer}{alpha_str}")
            current_layer = layer
        
        if ".lora_down.weight" in old_key:
            print(f"    .lora_down.weight  β†’  .lora_A.weight")
        elif ".lora_up.weight" in old_key:
            print(f"    .lora_up.weight    β†’  .lora_B.weight")
    
    return new_state_dict


def convert_simpletuner_auraflow_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """
    Convert SimpleTuner AuraFlow LoRA format to diffusers PEFT format.
    
    SimpleTuner typically saves LoRAs in a format similar to Kohya's sd-scripts,
    but for transformer-based models like AuraFlow, the keys may differ.
    """
    new_state_dict = {}
    
    def _convert(original_key, diffusers_key, state_dict, new_state_dict):
        """Helper to convert a single LoRA layer."""
        down_key = f"{original_key}.lora_down.weight"
        if down_key not in state_dict:
            return False
            
        down_weight = state_dict.pop(down_key)
        lora_rank = down_weight.shape[0]
        
        up_weight_key = f"{original_key}.lora_up.weight"
        up_weight = state_dict.pop(up_weight_key)
        
        # Handle alpha scaling
        alpha_key = f"{original_key}.alpha"
        if alpha_key in state_dict:
            alpha = state_dict.pop(alpha_key)
            scale = alpha / lora_rank
            
            # Calculate scale_down and scale_up to preserve the scale value
            scale_down = scale
            scale_up = 1.0
            while scale_down * 2 < scale_up:
                scale_down *= 2
                scale_up /= 2
            
            down_weight = down_weight * scale_down
            up_weight = up_weight * scale_up
        
        # Store in PEFT format (lora_A = down, lora_B = up)
        diffusers_down_key = f"{diffusers_key}.lora_A.weight"
        new_state_dict[diffusers_down_key] = down_weight
        new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
        
        return True
    
    # Get all unique LoRA layer names
    all_unique_keys = {
        k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
        for k in state_dict
        if ".lora_down.weight" in k or ".lora_up.weight" in k or ".alpha" in k
    }
    
    # Process transformer blocks
    for original_key in sorted(all_unique_keys):
        if original_key.startswith("lora_transformer_single_transformer_blocks_"):
            # Single transformer blocks
            parts = original_key.split("lora_transformer_single_transformer_blocks_")[-1].split("_")
            block_idx = int(parts[0])
            diffusers_key = f"single_transformer_blocks.{block_idx}"
            
            # Map the rest of the key
            remaining = "_".join(parts[1:])
            if "attn_to_q" in remaining:
                diffusers_key += ".attn.to_q"
            elif "attn_to_k" in remaining:
                diffusers_key += ".attn.to_k"
            elif "attn_to_v" in remaining:
                diffusers_key += ".attn.to_v"
            elif "proj_out" in remaining:
                diffusers_key += ".proj_out"
            elif "proj_mlp" in remaining:
                diffusers_key += ".proj_mlp"
            elif "norm_linear" in remaining:
                diffusers_key += ".norm.linear"
            else:
                print(f"Warning: Unhandled single block key pattern: {original_key}")
                continue
                
        elif original_key.startswith("lora_transformer_transformer_blocks_"):
            # Double transformer blocks
            parts = original_key.split("lora_transformer_transformer_blocks_")[-1].split("_")
            block_idx = int(parts[0])
            diffusers_key = f"transformer_blocks.{block_idx}"
            
            # Map the rest of the key
            remaining = "_".join(parts[1:])
            if "attn_to_out_0" in remaining:
                diffusers_key += ".attn.to_out.0"
            elif "attn_to_add_out" in remaining:
                diffusers_key += ".attn.to_add_out"
            elif "attn_to_q" in remaining:
                diffusers_key += ".attn.to_q"
            elif "attn_to_k" in remaining:
                diffusers_key += ".attn.to_k"
            elif "attn_to_v" in remaining:
                diffusers_key += ".attn.to_v"
            elif "attn_add_q_proj" in remaining:
                diffusers_key += ".attn.add_q_proj"
            elif "attn_add_k_proj" in remaining:
                diffusers_key += ".attn.add_k_proj"
            elif "attn_add_v_proj" in remaining:
                diffusers_key += ".attn.add_v_proj"
            elif "ff_net_0_proj" in remaining:
                diffusers_key += ".ff.net.0.proj"
            elif "ff_net_2" in remaining:
                diffusers_key += ".ff.net.2"
            elif "ff_context_net_0_proj" in remaining:
                diffusers_key += ".ff_context.net.0.proj"
            elif "ff_context_net_2" in remaining:
                diffusers_key += ".ff_context.net.2"
            elif "norm1_linear" in remaining:
                diffusers_key += ".norm1.linear"
            elif "norm1_context_linear" in remaining:
                diffusers_key += ".norm1_context.linear"
            else:
                print(f"Warning: Unhandled double block key pattern: {original_key}")
                continue
        
        elif original_key.startswith("lora_te1_") or original_key.startswith("lora_te_"):
            # Text encoder keys - handle separately
            print(f"Found text encoder key: {original_key}")
            continue
        
        else:
            print(f"Warning: Unknown key pattern: {original_key}")
            continue
        
        # Perform the conversion
        _convert(original_key, diffusers_key, state_dict, new_state_dict)
    
    # Add "transformer." prefix to all keys
    transformer_state_dict = {
        f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.")
    }
    
    # Check for remaining unconverted keys
    if len(state_dict) > 0:
        remaining_keys = [k for k in state_dict.keys() if not k.startswith("lora_te")]
        if remaining_keys:
            print(f"Warning: Some keys were not converted: {remaining_keys[:10]}")
    
    return transformer_state_dict


def convert_lora(input_path: str, output_path: str) -> None:
    """
    Main conversion function.
    
    Args:
        input_path: Path to input LoRA safetensors file
        output_path: Path to output diffusers-compatible safetensors file
    """
    print(f"Loading LoRA from: {input_path}")
    state_dict = safetensors.torch.load_file(input_path)
    
    print(f"Detecting LoRA format...")
    format_type = detect_lora_format(state_dict)
    print(f"Detected format: {format_type}")
    
    if format_type == "peft":
        print("LoRA is already in diffusers-compatible PEFT format!")
        print("No conversion needed. Copying file...")
        import shutil
        shutil.copy(input_path, output_path)
        return
    
    elif format_type == "mixed":
        print("Converting MIXED format LoRA to pure PEFT format...")
        print("(Some layers use lora_A/B, others use .lora.down/.lora.up)")
        converted_state_dict = convert_mixed_lora_to_diffusers(state_dict.copy())
    
    elif format_type == "simpletuner_transformer":
        print("Converting SimpleTuner transformer format to diffusers...")
        print("(has transformer. prefix but uses lora_down/lora_up naming)")
        converted_state_dict = convert_simpletuner_transformer_to_diffusers(state_dict.copy())
    
    elif format_type == "simpletuner_auraflow":
        print("Converting SimpleTuner AuraFlow format to diffusers...")
        converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy())
    
    elif format_type == "kohya":
        print("Note: Detected Kohya format. This converter is optimized for AuraFlow.")
        print("For other models, diffusers has built-in conversion.")
        converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy())
    
    else:
        print("Error: Unknown LoRA format!")
        print("Sample keys from the state dict:")
        for i, key in enumerate(list(state_dict.keys())[:20]):
            print(f"  {key}")
        sys.exit(1)
    
    print(f"Saving converted LoRA to: {output_path}")
    safetensors.torch.save_file(converted_state_dict, output_path)
    
    print("\nConversion complete!")
    print(f"Original keys: {len(state_dict)}")
    print(f"Converted keys: {len(converted_state_dict)}")

def main():
    parser = argparse.ArgumentParser(
        description="Convert SimpleTuner LoRA to diffusers-compatible format",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Convert a SimpleTuner LoRA for AuraFlow
  python convert_simpletuner_lora.py my_lora.safetensors diffusers_lora.safetensors
  
  # Check format without converting
  python convert_simpletuner_lora.py my_lora.safetensors /tmp/test.safetensors
        """
    )
    
    parser.add_argument(
        "input",
        type=str,
        help="Input LoRA file (SimpleTuner format)"
    )
    
    parser.add_argument(
        "output",
        type=str,
        help="Output LoRA file (diffusers-compatible format)"
    )
    
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Only detect format, don't convert"
    )
    
    args = parser.parse_args()
    
    # Validate input file exists
    if not Path(args.input).exists():
        print(f"Error: Input file not found: {args.input}")
        sys.exit(1)
    
    if args.dry_run:
        print(f"Loading LoRA from: {args.input}")
        state_dict = safetensors.torch.load_file(args.input)
        format_type = detect_lora_format(state_dict)
        print(f"Detected format: {format_type}")
        print(f"\nSample keys ({min(10, len(state_dict))} of {len(state_dict)}):")
        for key in list(state_dict.keys())[:10]:
            print(f"  {key}")
        return
    
    convert_lora(args.input, args.output)


if __name__ == "__main__":
    main()