HAL1993 commited on
Commit
16d49c4
·
verified ·
1 Parent(s): 99b1f00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -133
app.py CHANGED
@@ -1,16 +1,15 @@
1
  import os
2
- import math # Added to fix NameError
3
- import gradio as gr
4
- import numpy as np
5
  import random
 
 
6
  import torch
7
- import spaces
8
  from PIL import Image
9
- from diffusers import FlowMatchEulerDiscreteScheduler
10
- from optimization import optimize_pipeline_
11
- from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
12
- from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
13
- from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
 
14
  import requests # For translation API
15
 
16
  # --- Translation Function ---
@@ -37,114 +36,133 @@ def translate_albanian_to_english(text):
37
  raise gr.Error("Translation failed. Please try again.")
38
  raise gr.Error("Translation failed. Please try again.")
39
 
40
- # --- Model Loading ---
41
- dtype = torch.bfloat16
42
- device = "cuda" if torch.cuda.is_available() else "cpu"
43
 
44
- # Scheduler configuration for Lightning
45
- scheduler_config = {
46
- "base_image_seq_len": 256,
47
- "base_shift": math.log(3),
48
- "invert_sigmas": False,
49
- "max_image_seq_len": 8192,
50
- "max_shift": math.log(3),
51
- "num_train_timesteps": 1000,
52
- "shift": 1.0,
53
- "shift_terminal": None,
54
- "stochastic_sampling": False,
55
- "time_shift_type": "exponential",
56
- "use_beta_sigmas": False,
57
- "use_dynamic_shifting": True,
58
- "use_exponential_sigmas": False,
59
- "use_karras_sigmas": False,
60
- }
61
 
62
- # Initialize scheduler with Lightning config
63
- scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
 
 
64
 
65
- # Load the model pipeline
66
- pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509",
67
- scheduler=scheduler,
68
- torch_dtype=dtype).to(device)
69
- pipe.load_lora_weights(
70
- "lightx2v/Qwen-Image-Lightning",
71
- weight_name="Qwen-Image-Lightning-4steps-V2.0.safetensors"
72
- )
73
- pipe.fuse_lora()
74
 
75
- # Apply the same optimizations from the first version
76
- pipe.transformer.__class__ = QwenImageTransformer2DModel
77
- pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
78
 
79
- # --- Ahead-of-time compilation ---
80
- optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))], prompt="prompt")
81
 
82
- # --- UI Constants and Helpers ---
83
- MAX_SEED = np.iinfo(np.int32).max
84
- QUALITY_PROMPT = ", high quality, detailed, vibrant, professional lighting"
 
 
 
 
 
 
 
85
 
86
- # --- Main Inference Function ---
87
- @spaces.GPU(duration=40)
88
- def infer(
89
- images,
90
- prompt,
91
- progress=gr.Progress(track_tqdm=True),
92
- ):
93
- """
94
- Generates an image using the local Qwen-Image diffusers pipeline.
95
- """
96
- negative_prompt = "" # Empty as in original
97
- seed = random.randint(0, MAX_SEED) # Default: randomize_seed=True
98
- true_guidance_scale = 1.0 # Default
99
- num_inference_steps = 4 # Default
100
- height = None # Default
101
- width = None # Default
102
- num_images_per_prompt = 1 # Default
 
 
 
 
 
 
 
103
 
104
- # Translate prompt from Albanian to English
105
- prompt_final = translate_albanian_to_english(prompt.strip()) + QUALITY_PROMPT
 
 
 
 
 
 
 
 
 
106
 
107
- # Set up the generator for reproducibility
108
- generator = torch.Generator(device=device).manual_seed(seed)
109
-
110
- # Load input images into PIL Images
111
- pil_images = []
112
- if images is not None:
113
- for item in images:
114
- try:
115
- if isinstance(item[0], Image.Image):
116
- pil_images.append(item[0].convert("RGB"))
117
- elif isinstance(item[0], str):
118
- pil_images.append(Image.open(item[0]).convert("RGB"))
119
- elif hasattr(item, "name"):
120
- pil_images.append(Image.open(item.name).convert("RGB"))
121
- except Exception:
122
- continue
123
 
124
- if height == 256 and width == 256:
125
- height, width = None, None
126
- print(f"Calling pipeline with prompt: '{prompt_final}'")
127
- print(f"Negative Prompt: '{negative_prompt}'")
128
- print(f"Seed: {seed}, Steps: {num_inference_steps}, Guidance: {true_guidance_scale}, Size: {width}x{height}")
 
 
 
 
 
129
 
130
- # Generate the image
131
- image = pipe(
132
- image=pil_images if len(pil_images) > 0 else None,
133
- prompt=prompt_final,
134
- height=height,
135
- width=width,
136
- negative_prompt=negative_prompt,
137
- num_inference_steps=num_inference_steps,
138
- generator=generator,
139
- true_cfg_scale=true_guidance_scale,
140
- num_images_per_prompt=num_images_per_prompt,
141
- ).images
142
 
143
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  # --- Gradio User Interface ---
146
  def create_demo():
147
- with gr.Blocks(css="", title="Qwen Image Editor") as demo:
148
  gr.HTML("""
149
  <style>
150
  @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;600;700&display=swap');
@@ -202,12 +220,12 @@ def create_demo():
202
  margin: 0.75rem 0;
203
  width: 100%;
204
  }
205
- .gr-gallery {
206
  width: 100%;
207
  border: 1px solid #FFFFFF;
208
  border-radius: 4px;
209
  }
210
- input, textarea {
211
  background: #000000;
212
  color: #FFFFFF;
213
  border: 1px solid #FFFFFF;
@@ -216,7 +234,7 @@ def create_demo():
216
  width: 100%;
217
  box-sizing: border-box;
218
  }
219
- input:hover, textarea:hover {
220
  box-shadow: 0 0 8px rgba(255, 255, 255, 0.3);
221
  transition: box-shadow 0.3s;
222
  }
@@ -279,44 +297,44 @@ def create_demo():
279
  """)
280
 
281
  with gr.Row(elem_id="general_items"):
282
- gr.Markdown("# Qwen Image Editor")
283
- gr.Markdown("Edit your images with precise instructions", elem_id="subtitle")
284
  with gr.Column(elem_id="input_column"):
285
- input_images = gr.Gallery(
286
- label="Input Images",
287
- show_label=True,
288
- type="pil",
289
- interactive=True,
290
- elem_classes=["gradio-component", "gr-gallery"]
291
- )
292
- result = gr.Gallery(
293
- label="Result",
294
- show_label=True,
295
- type="pil",
296
- elem_classes=["gradio-component", "gr-gallery"]
297
  )
298
- prompt = gr.Textbox(
299
- label="Prompt",
300
- placeholder="Describe the edit instruction",
301
- lines=3,
302
- elem_classes="gradio-component"
303
  )
304
- run_button = gr.Button(
305
- "Edit!",
306
  variant="primary",
307
  elem_classes="gradio-component"
308
  )
 
 
 
 
 
 
 
309
 
310
  gr.on(
311
- triggers=[run_button.click, prompt.submit],
312
- fn=infer,
313
- inputs=[input_images, prompt],
314
- outputs=[result],
315
  )
316
 
317
- return demo
318
 
319
  if __name__ == "__main__":
320
  print(f"Gradio version: {gr.__version__}")
321
- demo = create_demo()
322
- demo.queue().launch(share=True)
 
1
  import os
 
 
 
2
  import random
3
+ import sys
4
+ from typing import Sequence, Mapping, Any, Union
5
  import torch
 
6
  from PIL import Image
7
+ from huggingface_hub import hf_hub_download
8
+ import spaces
9
+ import subprocess
10
+ import gradio
11
+ import gradio_client
12
+ import gradio as gr
13
  import requests # For translation API
14
 
15
  # --- Translation Function ---
 
36
  raise gr.Error("Translation failed. Please try again.")
37
  raise gr.Error("Translation failed. Please try again.")
38
 
39
+ # --- Monkey-patch for gradio_client ---
40
+ import gradio_client.utils as _gc_utils
 
41
 
42
+ _orig_js2pt = _gc_utils._json_schema_to_python_type
43
+ _orig_get_type = _gc_utils.get_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def _safe_json_schema_to_python_type(schema, defs=None):
46
+ if isinstance(schema, bool):
47
+ return "Any"
48
+ return _orig_js2pt(schema, defs)
49
 
50
+ def _safe_get_type(schema):
51
+ if isinstance(schema, bool):
52
+ return "Any"
53
+ return _orig_get_type(schema)
 
 
 
 
 
54
 
55
+ _gc_utils._json_schema_to_python_type = _safe_json_schema_to_python_type
56
+ _gc_utils.get_type = _safe_get_type
 
57
 
58
+ print("gradio version:", gradio.__version__)
59
+ print("gradio_client version:", gradio_client.__version__)
60
 
61
+ # --- Model Downloads ---
62
+ hf_hub_download(repo_id="ezioruan/inswapper_128.onnx", filename="inswapper_128.onnx", local_dir="models/insightface")
63
+ hf_hub_download(repo_id="martintomov/comfy", filename="facerestore_models/GPEN-BFR-512.onnx", local_dir="models/facerestore_models")
64
+ hf_hub_download(repo_id="darkeril/collection", filename="detection_Resnet50_Final.pth", local_dir="models/facedetection")
65
+ hf_hub_download(repo_id="gmk123/GFPGAN", filename="parsing_parsenet.pth", local_dir="models/facedetection")
66
+ hf_hub_download(repo_id="MonsterMMORPG/tools", filename="1k3d68.onnx", local_dir="models/insightface/models/buffalo_l")
67
+ hf_hub_download(repo_id="MonsterMMORPG/tools", filename="2d106det.onnx", local_dir="models/insightface/models/buffalo_l")
68
+ hf_hub_download(repo_id="maze/faceX", filename="det_10g.onnx", local_dir="models/insightface/models/buffalo_l")
69
+ hf_hub_download(repo_id="typhoon01/aux_models", filename="genderage.onnx", local_dir="models/insightface/models/buffalo_l")
70
+ hf_hub_download(repo_id="maze/faceX", filename="w600k_r50.onnx", local_dir="models/insightface/models/buffalo_l")
71
 
72
+ # --- Utility Functions ---
73
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
74
+ try:
75
+ return obj[index]
76
+ except KeyError:
77
+ return obj["result"][index]
78
+
79
+ def find_path(name: str, path: str = None) -> str:
80
+ if path is None:
81
+ path = os.getcwd()
82
+ if name in os.listdir(path):
83
+ path_name = os.path.join(path, name)
84
+ print(f"{name} found: {path_name}")
85
+ return path_name
86
+ parent_directory = os.path.dirname(path)
87
+ if parent_directory == path:
88
+ return None
89
+ return find_path(name, parent_directory)
90
+
91
+ def add_comfyui_directory_to_sys_path() -> None:
92
+ comfyui_path = find_path("ComfyUI")
93
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
94
+ sys.path.append(comfyui_path)
95
+ print(f"'{comfyui_path}' added to sys.path")
96
 
97
+ def add_extra_model_paths() -> None:
98
+ try:
99
+ from main import load_extra_path_config
100
+ except ImportError:
101
+ print("Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead.")
102
+ from utils.extra_config import load_extra_path_config
103
+ extra_model_paths = find_path("extra_model_paths.yaml")
104
+ if extra_model_paths is not None:
105
+ load_extra_path_config(extra_model_paths)
106
+ else:
107
+ print("Could not find the extra_model_paths config file.")
108
 
109
+ add_comfyui_directory_to_sys_path()
110
+ add_extra_model_paths()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ def import_custom_nodes() -> None:
113
+ import asyncio
114
+ import execution
115
+ from nodes import init_extra_nodes
116
+ import server
117
+ loop = asyncio.new_event_loop()
118
+ asyncio.set_event_loop(loop)
119
+ server_instance = server.PromptServer(loop)
120
+ execution.PromptQueue(server_instance)
121
+ init_extra_nodes()
122
 
123
+ import_custom_nodes()
124
+ from nodes import NODE_CLASS_MAPPINGS
 
 
 
 
 
 
 
 
 
 
125
 
126
+ # --- Main Inference Function ---
127
+ @spaces.GPU(duration=20)
128
+ def generate_image(source_image, target_image):
129
+ """
130
+ Performs face swapping between source and target images using ReActorFaceSwap.
131
+ """
132
+ restore_strength = 1.0 # Default
133
+ target_index = 0 # Default
134
+
135
+ with torch.inference_mode():
136
+ loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
137
+ loadimage_1 = loadimage.load_image(image=target_image)
138
+ loadimage_3 = loadimage.load_image(image=source_image)
139
+ reactorfaceswap = NODE_CLASS_MAPPINGS["ReActorFaceSwap"]()
140
+ saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
141
+ reactorfaceswap_2 = reactorfaceswap.execute(
142
+ enabled=True,
143
+ swap_model="inswapper_128.onnx",
144
+ facedetection="retinaface_resnet50",
145
+ face_restore_model="GPEN-BFR-512.onnx",
146
+ face_restore_visibility=restore_strength,
147
+ codeformer_weight=0.5,
148
+ detect_gender_input="no",
149
+ detect_gender_source="no",
150
+ input_faces_index=str(target_index),
151
+ source_faces_index="0",
152
+ console_log_level=1,
153
+ input_image=get_value_at_index(loadimage_1, 0),
154
+ source_image=get_value_at_index(loadimage_3, 0),
155
+ )
156
+ saveimage_4 = saveimage.save_images(
157
+ filename_prefix="ComfyUI",
158
+ images=get_value_at_index(reactorfaceswap_2, 0),
159
+ )
160
+ saved_path = f"output/{saveimage_4['ui']['images'][0]['filename']}"
161
+ return Image.open(saved_path)
162
 
163
  # --- Gradio User Interface ---
164
  def create_demo():
165
+ with gr.Blocks(css="", title="Face Swapper") as app:
166
  gr.HTML("""
167
  <style>
168
  @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;600;700&display=swap');
 
220
  margin: 0.75rem 0;
221
  width: 100%;
222
  }
223
+ .gr-image {
224
  width: 100%;
225
  border: 1px solid #FFFFFF;
226
  border-radius: 4px;
227
  }
228
+ input {
229
  background: #000000;
230
  color: #FFFFFF;
231
  border: 1px solid #FFFFFF;
 
234
  width: 100%;
235
  box-sizing: border-box;
236
  }
237
+ input:hover {
238
  box-shadow: 0 0 8px rgba(255, 255, 255, 0.3);
239
  transition: box-shadow 0.3s;
240
  }
 
297
  """)
298
 
299
  with gr.Row(elem_id="general_items"):
300
+ gr.Markdown("# Face Swapper")
301
+ gr.Markdown("Swap faces in photos with advanced AI technology", elem_id="subtitle")
302
  with gr.Column(elem_id="input_column"):
303
+ source_image = gr.Image(
304
+ label="Source Image",
305
+ type="filepath",
306
+ sources=["upload", "clipboard"],
307
+ elem_classes=["gradio-component", "gr-image"]
 
 
 
 
 
 
 
308
  )
309
+ target_image = gr.Image(
310
+ label="Target Image",
311
+ type="filepath",
312
+ sources=["upload", "clipboard"],
313
+ elem_classes=["gradio-component", "gr-image"]
314
  )
315
+ generate_btn = gr.Button(
316
+ "Generate",
317
  variant="primary",
318
  elem_classes="gradio-component"
319
  )
320
+ output_image = gr.Image(
321
+ label="Generated Image",
322
+ show_download_button=True,
323
+ show_share_button=False,
324
+ interactive=False,
325
+ elem_classes=["gradio-component", "gr-image"]
326
+ )
327
 
328
  gr.on(
329
+ triggers=[generate_btn.click],
330
+ fn=generate_image,
331
+ inputs=[source_image, target_image],
332
+ outputs=[output_image]
333
  )
334
 
335
+ return app
336
 
337
  if __name__ == "__main__":
338
  print(f"Gradio version: {gr.__version__}")
339
+ app = create_demo()
340
+ app.queue().launch(share=True)