Cicici1109 commited on
Commit
5f84ebe
·
verified ·
1 Parent(s): 1b448e0

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +65 -40
utils.py CHANGED
@@ -19,7 +19,6 @@ import subprocess
19
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
20
  from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
21
 
22
-
23
  from src.flux.generate import generate, seed_everything
24
 
25
  try:
@@ -30,54 +29,82 @@ except ImportError:
30
 
31
  import re
32
 
 
33
  pipe = None
34
  model_dict = {}
35
-
 
36
 
37
  def init_flux_pipeline():
38
- global pipe
 
 
39
  if pipe is None:
 
40
  token = os.getenv("HF_TOKEN")
41
  if not token:
42
  raise ValueError("HF_TOKEN environment variable not set.")
 
43
  pipe = FluxPipeline.from_pretrained(
44
  "black-forest-labs/FLUX.1-schnell",
45
  use_auth_token=token,
46
  torch_dtype=torch.bfloat16
47
  )
48
  pipe = pipe.to("cuda")
 
 
 
 
49
 
50
  def load_all_lora_adapters():
51
- global pipe
52
- if pipe is None:
53
- init_flux_pipeline()
54
-
55
- # Define all LoRA adapters to preload
56
- LORA_ADAPTERS = {
57
- "add": "weights/add.safetensors",
58
- "remove": "weights/remove.safetensors",
59
- "action": "weights/action.safetensors",
60
- "expression": "weights/expression.safetensors",
61
- "addition": "weights/addition.safetensors",
62
- "material": "weights/material.safetensors",
63
- "color": "weights/color.safetensors",
64
- "bg": "weights/bg.safetensors",
65
- "appearance": "weights/appearance.safetensors",
66
- "fusion": "weights/fusion.safetensors",
67
- "overall": "weights/overall.safetensors",
68
- }
69
-
70
- print(LORA_ADAPTERS)
71
-
72
- for adapter_name, weight_path in LORA_ADAPTERS.items():
73
- pipe.load_lora_weights(
74
- "Cicici1109/IEAP",
75
- weight_name=weight_path,
76
- adapter_name=adapter_name,
77
- )
78
- print(f"Loaded LoRA adapter: {adapter_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- pipe.set_adapters("scene") # Set default adapter
81
 
82
  def get_model(model_path):
83
  global model_dict
@@ -198,8 +225,9 @@ def extract_last_bbox(result):
198
 
199
  @spaces.GPU
200
  def infer_with_DiT(task, image, instruction, category):
201
- init_flux_pipeline()
202
-
 
203
  if task == 'RoI Inpainting':
204
  if category == 'Add' or category == 'Replace':
205
  adapter_name = "add"
@@ -243,7 +271,8 @@ def infer_with_DiT(task, image, instruction, category):
243
  else:
244
  raise ValueError(f"Invalid task: '{task}'")
245
 
246
- # Switch to the preloaded adapter
 
247
  pipe.set_adapters(adapter_name)
248
 
249
  result_img = generate(
@@ -621,8 +650,4 @@ def layout_change(bbox, instruction):
621
  result = response.choices[0].message.content.strip()
622
 
623
  bbox = extract_last_bbox(result)
624
- return bbox
625
-
626
- if __name__ == "__main__":
627
- init_flux_pipeline()
628
- load_all_lora_adapters() # Preload all LoRA adapters at startup
 
19
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
20
  from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
21
 
 
22
  from src.flux.generate import generate, seed_everything
23
 
24
  try:
 
29
 
30
  import re
31
 
32
+ # Global variables
33
  pipe = None
34
  model_dict = {}
35
+ _MODEL_INITIALIZED = False
36
+ _ADAPTERS_LOADED = False
37
 
38
  def init_flux_pipeline():
39
+ """Initialize Flux model, ensuring it runs only once"""
40
+ global pipe, _MODEL_INITIALIZED
41
+
42
  if pipe is None:
43
+ print("Initializing Flux pipeline...")
44
  token = os.getenv("HF_TOKEN")
45
  if not token:
46
  raise ValueError("HF_TOKEN environment variable not set.")
47
+
48
  pipe = FluxPipeline.from_pretrained(
49
  "black-forest-labs/FLUX.1-schnell",
50
  use_auth_token=token,
51
  torch_dtype=torch.bfloat16
52
  )
53
  pipe = pipe.to("cuda")
54
+ _MODEL_INITIALIZED = True
55
+ print("Flux pipeline initialized successfully.")
56
+
57
+ return pipe
58
 
59
  def load_all_lora_adapters():
60
+ """Load all LoRA adapters, ensuring it runs only once"""
61
+ global pipe, _ADAPTERS_LOADED
62
+
63
+ # Ensure model is initialized
64
+ init_flux_pipeline()
65
+
66
+ if not _ADAPTERS_LOADED:
67
+ print("Loading all LoRA adapters...")
68
+
69
+ LORA_ADAPTERS = {
70
+ "add": "weights/add.safetensors",
71
+ "remove": "weights/remove.safetensors",
72
+ "action": "weights/action.safetensors",
73
+ "expression": "weights/expression.safetensors",
74
+ "addition": "weights/addition.safetensors",
75
+ "material": "weights/material.safetensors",
76
+ "color": "weights/color.safetensors",
77
+ "bg": "weights/bg.safetensors",
78
+ "appearance": "weights/appearance.safetensors",
79
+ "fusion": "weights/fusion.safetensors",
80
+ "overall": "weights/overall.safetensors",
81
+ }
82
+
83
+ for adapter_name, weight_path in LORA_ADAPTERS.items():
84
+ if not os.path.exists(weight_path):
85
+ print(f"Warning: LoRA weight file not found: {weight_path}")
86
+ continue
87
+
88
+ try:
89
+ pipe.load_lora_weights(
90
+ "Cicici1109/IEAP",
91
+ weight_name=weight_path,
92
+ adapter_name=adapter_name,
93
+ )
94
+ print(f"✅ Successfully loaded adapter: {adapter_name}")
95
+ except Exception as e:
96
+ print(f"❌ Failed to load adapter {adapter_name}: {e}")
97
+
98
+ loaded_adapters = list(pipe.lora_adapters.keys())
99
+ print(f"Loaded adapters: {loaded_adapters}")
100
+
101
+ if loaded_adapters:
102
+ pipe.set_adapters(loaded_adapters[0])
103
+ print(f"Default adapter set to: {loaded_adapters[0]}")
104
+
105
+ _ADAPTERS_LOADED = True
106
 
107
+ return pipe
108
 
109
  def get_model(model_path):
110
  global model_dict
 
225
 
226
  @spaces.GPU
227
  def infer_with_DiT(task, image, instruction, category):
228
+ # Ensure model and adapters are initialized
229
+ load_all_lora_adapters()
230
+
231
  if task == 'RoI Inpainting':
232
  if category == 'Add' or category == 'Replace':
233
  adapter_name = "add"
 
271
  else:
272
  raise ValueError(f"Invalid task: '{task}'")
273
 
274
+ # Switch to the specified adapter
275
+ print(f"Switching to adapter: {adapter_name}")
276
  pipe.set_adapters(adapter_name)
277
 
278
  result_img = generate(
 
650
  result = response.choices[0].message.content.strip()
651
 
652
  bbox = extract_last_bbox(result)
653
+ return bbox