Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update utils.py
Browse files
    	
        utils.py
    CHANGED
    
    | @@ -19,6 +19,7 @@ import subprocess | |
| 19 | 
             
            subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
         | 
| 20 | 
             
            from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
         | 
| 21 |  | 
|  | |
| 22 | 
             
            from src.flux.generate import generate, seed_everything
         | 
| 23 |  | 
| 24 | 
             
            try:
         | 
| @@ -29,78 +30,21 @@ except ImportError: | |
| 29 |  | 
| 30 | 
             
            import re
         | 
| 31 |  | 
| 32 | 
            -
            # Global variables
         | 
| 33 | 
             
            pipe = None
         | 
| 34 | 
             
            model_dict = {}
         | 
| 35 | 
            -
            _MODEL_INITIALIZED = False
         | 
| 36 | 
            -
            _ADAPTERS_LOADED = False
         | 
| 37 |  | 
| 38 | 
             
            def init_flux_pipeline():
         | 
| 39 | 
            -
                 | 
| 40 | 
            -
                global pipe, _MODEL_INITIALIZED
         | 
| 41 | 
            -
                
         | 
| 42 | 
             
                if pipe is None:
         | 
| 43 | 
            -
                    print("Initializing Flux pipeline...")
         | 
| 44 | 
             
                    token = os.getenv("HF_TOKEN")
         | 
| 45 | 
             
                    if not token:
         | 
| 46 | 
             
                        raise ValueError("HF_TOKEN environment variable not set.")
         | 
| 47 | 
            -
                        
         | 
| 48 | 
             
                    pipe = FluxPipeline.from_pretrained(
         | 
| 49 | 
             
                        "black-forest-labs/FLUX.1-schnell",
         | 
| 50 | 
             
                        use_auth_token=token,
         | 
| 51 | 
             
                        torch_dtype=torch.bfloat16
         | 
| 52 | 
             
                    )
         | 
| 53 | 
             
                    pipe = pipe.to("cuda")
         | 
| 54 | 
            -
                    _MODEL_INITIALIZED = True
         | 
| 55 | 
            -
                    print("Flux pipeline initialized successfully.")
         | 
| 56 | 
            -
                
         | 
| 57 | 
            -
                return pipe
         | 
| 58 | 
            -
             | 
| 59 | 
            -
            def load_all_lora_adapters():
         | 
| 60 | 
            -
                """Load all LoRA adapters, ensuring it runs only once"""
         | 
| 61 | 
            -
                global pipe, _ADAPTERS_LOADED
         | 
| 62 | 
            -
                
         | 
| 63 | 
            -
                # Ensure model is initialized
         | 
| 64 | 
            -
                init_flux_pipeline()
         | 
| 65 | 
            -
                
         | 
| 66 | 
            -
                if not _ADAPTERS_LOADED:
         | 
| 67 | 
            -
                    print("Loading all LoRA adapters...")
         | 
| 68 | 
            -
                    
         | 
| 69 | 
            -
                    LORA_ADAPTERS = {
         | 
| 70 | 
            -
                        "add": "weights/add.safetensors",
         | 
| 71 | 
            -
                        "remove": "weights/remove.safetensors",
         | 
| 72 | 
            -
                        "action": "weights/action.safetensors",
         | 
| 73 | 
            -
                        "expression": "weights/expression.safetensors",
         | 
| 74 | 
            -
                        "addition": "weights/addition.safetensors",
         | 
| 75 | 
            -
                        "material": "weights/material.safetensors",
         | 
| 76 | 
            -
                        "color": "weights/color.safetensors",
         | 
| 77 | 
            -
                        "bg": "weights/bg.safetensors",
         | 
| 78 | 
            -
                        "appearance": "weights/appearance.safetensors",
         | 
| 79 | 
            -
                        "fusion": "weights/fusion.safetensors",
         | 
| 80 | 
            -
                        "overall": "weights/overall.safetensors",
         | 
| 81 | 
            -
                    }
         | 
| 82 | 
            -
                    
         | 
| 83 | 
            -
                    for adapter_name, weight_path in LORA_ADAPTERS.items():
         | 
| 84 | 
            -
                        try:
         | 
| 85 | 
            -
                            pipe.load_lora_weights(
         | 
| 86 | 
            -
                                "Cicici1109/IEAP",
         | 
| 87 | 
            -
                                weight_name=weight_path,
         | 
| 88 | 
            -
                                adapter_name=adapter_name,
         | 
| 89 | 
            -
                            )
         | 
| 90 | 
            -
                            print(f"✅ Successfully loaded adapter: {adapter_name}")
         | 
| 91 | 
            -
                        except Exception as e:
         | 
| 92 | 
            -
                            print(f"❌ Failed to load adapter {adapter_name}: {e}")
         | 
| 93 | 
            -
                    
         | 
| 94 | 
            -
                    loaded_adapters = list(pipe.lora_adapters.keys())
         | 
| 95 | 
            -
                    print(f"Loaded adapters: {loaded_adapters}")
         | 
| 96 | 
            -
                    
         | 
| 97 | 
            -
                    if loaded_adapters:
         | 
| 98 | 
            -
                        pipe.set_adapters(loaded_adapters[0])
         | 
| 99 | 
            -
                        print(f"Default adapter set to: {loaded_adapters[0]}")
         | 
| 100 | 
            -
                    
         | 
| 101 | 
            -
                    _ADAPTERS_LOADED = True
         | 
| 102 | 
            -
                
         | 
| 103 | 
            -
                return pipe
         | 
| 104 |  | 
| 105 | 
             
            def get_model(model_path):
         | 
| 106 | 
             
                global model_dict
         | 
| @@ -221,55 +165,57 @@ def extract_last_bbox(result): | |
| 221 |  | 
| 222 | 
             
            @spaces.GPU
         | 
| 223 | 
             
            def infer_with_DiT(task, image, instruction, category):
         | 
| 224 | 
            -
                 | 
| 225 | 
            -
             | 
| 226 | 
            -
                
         | 
| 227 | 
             
                if task == 'RoI Inpainting':
         | 
| 228 | 
             
                    if category == 'Add' or category == 'Replace':
         | 
| 229 | 
            -
                         | 
| 230 | 
             
                        added = extract_object_with_gpt(instruction)
         | 
| 231 | 
             
                        instruction_dit = f"add {added} on the black region"
         | 
| 232 | 
             
                    elif category == 'Remove' or category == 'Action Change':
         | 
| 233 | 
            -
                         | 
| 234 | 
             
                        instruction_dit = f"Fill the hole of the image"
         | 
|  | |
| 235 | 
             
                    condition = Condition("scene", image, position_delta=(0, 0))
         | 
| 236 | 
            -
                
         | 
| 237 | 
             
                elif task == 'RoI Editing':
         | 
| 238 | 
             
                    image = Image.open(image).convert('RGB').resize((512, 512))
         | 
| 239 | 
             
                    condition = Condition("scene", image, position_delta=(0, -32))
         | 
| 240 | 
             
                    instruction_dit = instruction
         | 
| 241 | 
            -
                    
         | 
| 242 | 
             
                    if category == 'Action Change':
         | 
| 243 | 
            -
                         | 
| 244 | 
             
                    elif category == 'Expression Change':
         | 
| 245 | 
            -
                         | 
| 246 | 
             
                    elif category == 'Add':
         | 
| 247 | 
            -
                         | 
| 248 | 
             
                    elif category == 'Material Change':
         | 
| 249 | 
            -
                         | 
| 250 | 
             
                    elif category == 'Color Change':
         | 
| 251 | 
            -
                         | 
| 252 | 
             
                    elif category == 'Background Change':
         | 
| 253 | 
            -
                         | 
| 254 | 
             
                    elif category == 'Appearance Change':
         | 
| 255 | 
            -
                         | 
| 256 | 
            -
             | 
| 257 | 
             
                elif task == 'RoI Compositioning':
         | 
| 258 | 
            -
                     | 
| 259 | 
             
                    condition = Condition("scene", image, position_delta=(0, 0))
         | 
| 260 | 
             
                    instruction_dit = "inpaint the black-bordered region so that the object's edges blend smoothly with the background"
         | 
| 261 |  | 
| 262 | 
             
                elif task == 'Global Transformation':
         | 
| 263 | 
             
                    image = Image.open(image).convert('RGB').resize((512, 512))
         | 
| 264 | 
             
                    instruction_dit = instruction
         | 
| 265 | 
            -
                     | 
|  | |
| 266 | 
             
                    condition = Condition("scene", image, position_delta=(0, -32))
         | 
| 267 | 
             
                else:
         | 
| 268 | 
             
                    raise ValueError(f"Invalid task: '{task}'")
         | 
| 269 |  | 
| 270 | 
            -
                 | 
| 271 | 
            -
                 | 
| 272 | 
            -
             | 
|  | |
|  | |
|  | |
| 273 |  | 
| 274 | 
             
                result_img = generate(
         | 
| 275 | 
             
                    pipe,
         | 
| @@ -646,4 +592,7 @@ def layout_change(bbox, instruction): | |
| 646 | 
             
                result = response.choices[0].message.content.strip()
         | 
| 647 |  | 
| 648 | 
             
                bbox = extract_last_bbox(result)
         | 
| 649 | 
            -
                return bbox
         | 
|  | |
|  | |
|  | 
|  | |
| 19 | 
             
            subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
         | 
| 20 | 
             
            from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
         | 
| 21 |  | 
| 22 | 
            +
             | 
| 23 | 
             
            from src.flux.generate import generate, seed_everything
         | 
| 24 |  | 
| 25 | 
             
            try:
         | 
|  | |
| 30 |  | 
| 31 | 
             
            import re
         | 
| 32 |  | 
|  | |
| 33 | 
             
            pipe = None
         | 
| 34 | 
             
            model_dict = {}
         | 
|  | |
|  | |
| 35 |  | 
| 36 | 
             
            def init_flux_pipeline():
         | 
| 37 | 
            +
                global pipe
         | 
|  | |
|  | |
| 38 | 
             
                if pipe is None:
         | 
|  | |
| 39 | 
             
                    token = os.getenv("HF_TOKEN")
         | 
| 40 | 
             
                    if not token:
         | 
| 41 | 
             
                        raise ValueError("HF_TOKEN environment variable not set.")
         | 
|  | |
| 42 | 
             
                    pipe = FluxPipeline.from_pretrained(
         | 
| 43 | 
             
                        "black-forest-labs/FLUX.1-schnell",
         | 
| 44 | 
             
                        use_auth_token=token,
         | 
| 45 | 
             
                        torch_dtype=torch.bfloat16
         | 
| 46 | 
             
                    )
         | 
| 47 | 
             
                    pipe = pipe.to("cuda")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 48 |  | 
| 49 | 
             
            def get_model(model_path):
         | 
| 50 | 
             
                global model_dict
         | 
|  | |
| 165 |  | 
| 166 | 
             
            @spaces.GPU
         | 
| 167 | 
             
            def infer_with_DiT(task, image, instruction, category):
         | 
| 168 | 
            +
                init_flux_pipeline()
         | 
| 169 | 
            +
             | 
|  | |
| 170 | 
             
                if task == 'RoI Inpainting':
         | 
| 171 | 
             
                    if category == 'Add' or category == 'Replace':
         | 
| 172 | 
            +
                        lora_path = "weights/add.safetensors"
         | 
| 173 | 
             
                        added = extract_object_with_gpt(instruction)
         | 
| 174 | 
             
                        instruction_dit = f"add {added} on the black region"
         | 
| 175 | 
             
                    elif category == 'Remove' or category == 'Action Change':
         | 
| 176 | 
            +
                        lora_path = "weights/remove.safetensors"
         | 
| 177 | 
             
                        instruction_dit = f"Fill the hole of the image"
         | 
| 178 | 
            +
                   
         | 
| 179 | 
             
                    condition = Condition("scene", image, position_delta=(0, 0))
         | 
|  | |
| 180 | 
             
                elif task == 'RoI Editing':
         | 
| 181 | 
             
                    image = Image.open(image).convert('RGB').resize((512, 512))
         | 
| 182 | 
             
                    condition = Condition("scene", image, position_delta=(0, -32))
         | 
| 183 | 
             
                    instruction_dit = instruction
         | 
|  | |
| 184 | 
             
                    if category == 'Action Change':
         | 
| 185 | 
            +
                        lora_path = "weights/action.safetensors"
         | 
| 186 | 
             
                    elif category == 'Expression Change':
         | 
| 187 | 
            +
                        lora_path = "weights/expression.safetensors"
         | 
| 188 | 
             
                    elif category == 'Add':
         | 
| 189 | 
            +
                        lora_path = "weights/addition.safetensors"
         | 
| 190 | 
             
                    elif category == 'Material Change':
         | 
| 191 | 
            +
                        lora_path = "weights/material.safetensors"
         | 
| 192 | 
             
                    elif category == 'Color Change':
         | 
| 193 | 
            +
                        lora_path = "weights/color.safetensors"
         | 
| 194 | 
             
                    elif category == 'Background Change':
         | 
| 195 | 
            +
                        lora_path = "weights/bg.safetensors"
         | 
| 196 | 
             
                    elif category == 'Appearance Change':
         | 
| 197 | 
            +
                        lora_path = "weights/appearance.safetensors"
         | 
| 198 | 
            +
                    
         | 
| 199 | 
             
                elif task == 'RoI Compositioning':
         | 
| 200 | 
            +
                    lora_path = "weights/fusion.safetensors"
         | 
| 201 | 
             
                    condition = Condition("scene", image, position_delta=(0, 0))
         | 
| 202 | 
             
                    instruction_dit = "inpaint the black-bordered region so that the object's edges blend smoothly with the background"
         | 
| 203 |  | 
| 204 | 
             
                elif task == 'Global Transformation':
         | 
| 205 | 
             
                    image = Image.open(image).convert('RGB').resize((512, 512))
         | 
| 206 | 
             
                    instruction_dit = instruction
         | 
| 207 | 
            +
                    lora_path = "weights/overall.safetensors"
         | 
| 208 | 
            +
             | 
| 209 | 
             
                    condition = Condition("scene", image, position_delta=(0, -32))
         | 
| 210 | 
             
                else:
         | 
| 211 | 
             
                    raise ValueError(f"Invalid task: '{task}'")
         | 
| 212 |  | 
| 213 | 
            +
                pipe.unload_lora_weights()
         | 
| 214 | 
            +
                pipe.load_lora_weights(
         | 
| 215 | 
            +
                    "Cicici1109/IEAP",
         | 
| 216 | 
            +
                    weight_name=lora_path,
         | 
| 217 | 
            +
                    adapter_name="scene",
         | 
| 218 | 
            +
                )
         | 
| 219 |  | 
| 220 | 
             
                result_img = generate(
         | 
| 221 | 
             
                    pipe,
         | 
|  | |
| 592 | 
             
                result = response.choices[0].message.content.strip()
         | 
| 593 |  | 
| 594 | 
             
                bbox = extract_last_bbox(result)
         | 
| 595 | 
            +
                return bbox
         | 
| 596 | 
            +
             | 
| 597 | 
            +
            if __name__ == "__main__":
         | 
| 598 | 
            +
                init_flux_pipeline()
         | 
