Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -9,6 +9,16 @@ try: 
     | 
|
| 9 | 
         
             
            except ImportError:
         
     | 
| 10 | 
         
             
                print("Flash attention not available - using standard attention (this is fine)")
         
     | 
| 11 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 12 | 
         
             
            import spaces
         
     | 
| 13 | 
         
             
            import argparse
         
     | 
| 14 | 
         
             
            import os
         
     | 
| 
         @@ -122,14 +132,34 @@ pipe = FluxPipeline.from_pretrained( 
     | 
|
| 122 | 
         
             
                torch_dtype=torch.bfloat16
         
     | 
| 123 | 
         
             
            )
         
     | 
| 124 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 125 | 
         
             
            print("Loading LoRA weights...")
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                 
     | 
| 
         | 
|
| 128 | 
         
             
                    "ByteDance/Hyper-SD", 
         
     | 
| 129 | 
         
             
                    "Hyper-FLUX.1-dev-8steps-lora.safetensors"
         
     | 
| 130 | 
         
             
                )
         
     | 
| 131 | 
         
            -
            )
         
     | 
| 132 | 
         
            -
            pipe.fuse_lora(lora_scale=0.125)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 133 | 
         
             
            pipe.to(device="cuda", dtype=torch.bfloat16)
         
     | 
| 134 | 
         | 
| 135 | 
         
             
            # Safety checker initialization
         
     | 
| 
         @@ -387,12 +417,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: 
     | 
|
| 387 | 
         
             
                                )
         
     | 
| 388 | 
         | 
| 389 | 
         
             
                            with gr.Row():
         
     | 
| 
         | 
|
| 
         | 
|
| 390 | 
         
             
                                steps = gr.Slider(
         
     | 
| 391 | 
         
            -
                                    label="Inference Steps",
         
     | 
| 392 | 
         
             
                                    minimum=6,
         
     | 
| 393 | 
         
            -
                                    maximum= 
     | 
| 394 | 
         
             
                                    step=1,
         
     | 
| 395 | 
         
            -
                                    value= 
     | 
| 396 | 
         
             
                                )
         
     | 
| 397 | 
         
             
                                scales = gr.Slider(
         
     | 
| 398 | 
         
             
                                    label="Guidance Scale",
         
     | 
| 
         | 
|
| 9 | 
         
             
            except ImportError:
         
     | 
| 10 | 
         
             
                print("Flash attention not available - using standard attention (this is fine)")
         
     | 
| 11 | 
         | 
| 12 | 
         
            +
            # Ensure PEFT is available for LoRA
         
     | 
| 13 | 
         
            +
            try:
         
     | 
| 14 | 
         
            +
                import peft
         
     | 
| 15 | 
         
            +
                print("PEFT library is available")
         
     | 
| 16 | 
         
            +
            except ImportError:
         
     | 
| 17 | 
         
            +
                print("Installing PEFT for LoRA support...")
         
     | 
| 18 | 
         
            +
                subprocess.run([sys.executable, "-m", "pip", "install", "peft>=0.7.0"], check=True)
         
     | 
| 19 | 
         
            +
                import peft
         
     | 
| 20 | 
         
            +
                print("PEFT installed successfully")
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
             
            import spaces
         
     | 
| 23 | 
         
             
            import argparse
         
     | 
| 24 | 
         
             
            import os
         
     | 
| 
         | 
|
| 132 | 
         
             
                torch_dtype=torch.bfloat16
         
     | 
| 133 | 
         
             
            )
         
     | 
| 134 | 
         | 
| 135 | 
         
            +
            # Track whether LoRA was loaded successfully
         
     | 
| 136 | 
         
            +
            LORA_LOADED = False
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
             
            print("Loading LoRA weights...")
         
     | 
| 139 | 
         
            +
            try:
         
     | 
| 140 | 
         
            +
                # Method 1: Try loading with file path
         
     | 
| 141 | 
         
            +
                lora_path = hf_hub_download(
         
     | 
| 142 | 
         
             
                    "ByteDance/Hyper-SD", 
         
     | 
| 143 | 
         
             
                    "Hyper-FLUX.1-dev-8steps-lora.safetensors"
         
     | 
| 144 | 
         
             
                )
         
     | 
| 145 | 
         
            +
                pipe.load_lora_weights(lora_path)
         
     | 
| 146 | 
         
            +
                pipe.fuse_lora(lora_scale=0.125)
         
     | 
| 147 | 
         
            +
                LORA_LOADED = True
         
     | 
| 148 | 
         
            +
                print("LoRA weights loaded and fused successfully (Method 1)")
         
     | 
| 149 | 
         
            +
            except Exception as e1:
         
     | 
| 150 | 
         
            +
                print(f"Method 1 failed: {e1}")
         
     | 
| 151 | 
         
            +
                try:
         
     | 
| 152 | 
         
            +
                    # Method 2: Try loading directly from repo
         
     | 
| 153 | 
         
            +
                    pipe.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
         
     | 
| 154 | 
         
            +
                    pipe.fuse_lora(lora_scale=0.125)
         
     | 
| 155 | 
         
            +
                    LORA_LOADED = True
         
     | 
| 156 | 
         
            +
                    print("LoRA weights loaded and fused successfully (Method 2)")
         
     | 
| 157 | 
         
            +
                except Exception as e2:
         
     | 
| 158 | 
         
            +
                    print(f"Method 2 failed: {e2}")
         
     | 
| 159 | 
         
            +
                    print("WARNING: Could not load LoRA weights. Continuing without LoRA.")
         
     | 
| 160 | 
         
            +
                    print("The model will still work but may require more inference steps for good quality.")
         
     | 
| 161 | 
         
            +
                    print("Recommended: Use 20-30 inference steps instead of 8.")
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
             
            pipe.to(device="cuda", dtype=torch.bfloat16)
         
     | 
| 164 | 
         | 
| 165 | 
         
             
            # Safety checker initialization
         
     | 
| 
         | 
|
| 417 | 
         
             
                                )
         
     | 
| 418 | 
         | 
| 419 | 
         
             
                            with gr.Row():
         
     | 
| 420 | 
         
            +
                                # Adjust default steps based on whether LoRA is loaded
         
     | 
| 421 | 
         
            +
                                default_steps = 8 if LORA_LOADED else 20
         
     | 
| 422 | 
         
             
                                steps = gr.Slider(
         
     | 
| 423 | 
         
            +
                                    label="Inference Steps" + (" (LoRA Enabled)" if LORA_LOADED else " (No LoRA - More Steps Recommended)"),
         
     | 
| 424 | 
         
             
                                    minimum=6,
         
     | 
| 425 | 
         
            +
                                    maximum=50,
         
     | 
| 426 | 
         
             
                                    step=1,
         
     | 
| 427 | 
         
            +
                                    value=default_steps
         
     | 
| 428 | 
         
             
                                )
         
     | 
| 429 | 
         
             
                                scales = gr.Slider(
         
     | 
| 430 | 
         
             
                                    label="Guidance Scale",
         
     |