primerz commited on
Commit
c358674
ยท
verified ยท
1 Parent(s): 3e6f23b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -226
app.py CHANGED
@@ -2,25 +2,22 @@ import spaces # MUST be first, before any CUDA-related imports
2
  import gradio as gr
3
  import torch
4
  from diffusers import (
 
5
  ControlNetModel,
6
  AutoencoderKL,
7
- DPMSolverMultistepScheduler,
8
- LCMScheduler
9
  )
10
  from diffusers.models.attention_processor import AttnProcessor2_0
11
  from insightface.app import FaceAnalysis
12
  from PIL import Image
13
  import numpy as np
14
  import cv2
 
 
15
  from huggingface_hub import hf_hub_download
16
  import os
17
 
18
- # Import the custom img2img pipeline with InstantID
19
- from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
20
-
21
- # Import ZoeDetector for better depth maps
22
- from controlnet_aux import ZoeDetector
23
-
24
  # Configuration
25
  MODEL_REPO = "primerz/pixagram"
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -29,23 +26,61 @@ dtype = torch.float16 if device == "cuda" else torch.float32
29
  # LORA trigger word
30
  TRIGGER_WORD = "p1x3l4rt, pixel art"
31
 
 
 
 
32
  print(f"Using device: {device}")
33
  print(f"Loading models from: {MODEL_REPO}")
34
  print(f"LORA Trigger Word: {TRIGGER_WORD}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  class RetroArtConverter:
37
- def __init__(self, use_lcm=False):
38
  self.device = device
39
  self.dtype = dtype
40
- self.use_lcm = use_lcm
41
  self.models_loaded = {
42
  'custom_checkpoint': False,
43
  'lora': False,
44
- 'instantid': False
 
45
  }
46
 
47
  # Initialize face analysis for InstantID
48
- print("Loading face analysis model (antelopev2)...")
49
  try:
50
  self.face_app = FaceAnalysis(
51
  name='antelopev2',
@@ -60,7 +95,25 @@ class RetroArtConverter:
60
  self.face_app = None
61
  self.face_detection_enabled = False
62
 
63
- # Load ControlNet for InstantID
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  print("Loading InstantID ControlNet...")
65
  try:
66
  self.controlnet_instantid = ControlNetModel.from_pretrained(
@@ -76,82 +129,42 @@ class RetroArtConverter:
76
  self.controlnet_instantid = None
77
  self.instantid_enabled = False
78
 
79
- # Load ControlNet for Zoe depth
80
- print("Loading Zoe Depth ControlNet...")
81
- self.controlnet_depth = ControlNetModel.from_pretrained(
82
- "diffusers/controlnet-zoe-depth-sdxl-1.0",
83
- torch_dtype=self.dtype
84
- ).to(self.device)
85
-
86
- # Load Zoe depth detector (better than DPT)
87
- print("Loading Zoe depth detector...")
88
- try:
89
- self.zoe_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators")
90
- self.zoe_detector.to(self.device)
91
- print("โœ“ Zoe detector loaded successfully")
92
- except Exception as e:
93
- print(f"โš ๏ธ Could not load Zoe detector: {e}")
94
- self.zoe_detector = None
95
-
96
  # Determine which controlnets to use
97
  if self.instantid_enabled and self.controlnet_instantid is not None:
98
  controlnets = [self.controlnet_instantid, self.controlnet_depth]
99
- print(f"Initializing with multiple ControlNets: InstantID + Zoe Depth")
100
  else:
101
  controlnets = self.controlnet_depth
102
- print(f"Initializing with single ControlNet: Zoe Depth only")
103
-
104
- # Load VAE
105
- print("Loading VAE...")
106
- self.vae = AutoencoderKL.from_pretrained(
107
- "madebyollin/sdxl-vae-fp16-fix",
108
- torch_dtype=self.dtype
109
- ).to(self.device)
110
 
111
  # Load SDXL checkpoint from HuggingFace Hub
112
- print("Loading SDXL checkpoint (horizon) from HuggingFace Hub...")
113
  try:
114
  model_path = hf_hub_download(
115
  repo_id=MODEL_REPO,
116
  filename="horizon.safetensors",
117
  repo_type="model"
118
  )
119
- # Use the custom img2img pipeline for better results
120
- self.pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
121
  model_path,
122
  controlnet=controlnets,
123
- vae=self.vae,
124
  torch_dtype=self.dtype,
125
  use_safetensors=True
126
  ).to(self.device)
127
- print("โœ“ Custom checkpoint loaded successfully")
128
  self.models_loaded['custom_checkpoint'] = True
129
  except Exception as e:
130
  print(f"โš ๏ธ Could not load custom checkpoint: {e}")
131
  print("Using default SDXL base model")
132
- self.pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
133
  "stabilityai/stable-diffusion-xl-base-1.0",
134
  controlnet=controlnets,
135
- vae=self.vae,
136
  torch_dtype=self.dtype,
137
  use_safetensors=True
138
  ).to(self.device)
139
  self.models_loaded['custom_checkpoint'] = False
140
 
141
- # Load InstantID IP-Adapter
142
- if self.instantid_enabled:
143
- print("Loading InstantID IP-Adapter...")
144
- try:
145
- ip_adapter_path = hf_hub_download(
146
- repo_id="InstantX/InstantID",
147
- filename="ip-adapter.bin"
148
- )
149
- self.pipe.load_ip_adapter_instantid(ip_adapter_path)
150
- self.pipe.set_ip_adapter_scale(0.8)
151
- print("โœ“ InstantID IP-Adapter loaded successfully")
152
- except Exception as e:
153
- print(f"โš ๏ธ Could not load IP-Adapter: {e}")
154
-
155
  # Load LORA from HuggingFace Hub
156
  print("Loading LORA (retroart) from HuggingFace Hub...")
157
  try:
@@ -168,14 +181,14 @@ class RetroArtConverter:
168
  print(f"โš ๏ธ Could not load LORA: {e}")
169
  self.models_loaded['lora'] = False
170
 
171
- # Choose scheduler based on mode
172
- if use_lcm:
173
- print("Setting up LCM scheduler for fast generation...")
174
  self.pipe.scheduler = LCMScheduler.from_config(
175
  self.pipe.scheduler.config
176
  )
177
  else:
178
- print("Setting up DPMSolverMultistep scheduler with Karras sigmas for quality...")
179
  self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
180
  self.pipe.scheduler.config,
181
  use_karras_sigmas=True
@@ -192,6 +205,11 @@ class RetroArtConverter:
192
  except Exception as e:
193
  print(f"โš ๏ธ xformers not available: {e}")
194
 
 
 
 
 
 
195
  # Track controlnet configuration
196
  self.using_multiple_controlnets = isinstance(controlnets, list)
197
  print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
@@ -203,38 +221,36 @@ class RetroArtConverter:
203
  print("===================\n")
204
 
205
  print("โœ“ Model initialization complete!")
206
- if use_lcm:
207
- print("\n=== LCM CONFIGURATION ===")
208
- print("Scheduler: LCM")
209
- print("Recommended Steps: 8-12")
210
  print("Recommended CFG: 1.0-1.5")
211
- print("Recommended Strength: 0.6-0.8")
212
  else:
213
- print("\n=== QUALITY CONFIGURATION ===")
214
- print("Scheduler: DPMSolverMultistep + Karras")
215
- print("Recommended Steps: 25-40")
216
- print("Recommended CFG: 5.0-7.5")
217
- print("Recommended Strength: 0.4-0.7")
218
  print(f"LORA Trigger: '{TRIGGER_WORD}'")
219
- print("=========================\n")
220
 
221
  def get_depth_map(self, image):
222
- """Generate depth map from input image using Zoe"""
223
- if self.zoe_detector is not None:
224
- # Use Zoe detector for better depth maps
225
- depth_image = self.zoe_detector(image)
226
  return depth_image
227
  else:
228
- # Fallback to basic conversion
229
- img_array = np.array(image.convert('L'))
230
- depth_colored = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
231
  return Image.fromarray(depth_colored)
232
 
233
  def calculate_optimal_size(self, original_width, original_height):
234
  """Calculate optimal size from recommended resolutions"""
235
  aspect_ratio = original_width / original_height
236
 
237
- # Recommended resolutions for SDXL
238
  recommended_sizes = [
239
  (896, 1152), # Portrait
240
  (1152, 896), # Landscape
@@ -272,15 +288,14 @@ class RetroArtConverter:
272
  input_image,
273
  prompt="retro game character, vibrant colors, detailed",
274
  negative_prompt="blurry, low quality, ugly, distorted",
275
- num_inference_steps=25,
276
- guidance_scale=5.0,
277
- strength=0.6, # img2img strength
278
  controlnet_conditioning_scale=0.8,
279
  lora_scale=1.0,
280
- face_strength=0.85, # InstantID face strength
281
- depth_control_scale=0.8 # Zoe depth strength
282
  ):
283
- """Generate retro art using img2img pipeline with face keypoints"""
284
 
285
  # Add trigger word to prompt
286
  prompt = self.add_trigger_word(prompt)
@@ -291,6 +306,7 @@ class RetroArtConverter:
291
 
292
  print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}")
293
  print(f"Prompt: {prompt}")
 
294
 
295
  # Resize with high quality
296
  resized_image = input_image.resize((target_width, target_height), Image.LANCZOS)
@@ -303,33 +319,30 @@ class RetroArtConverter:
303
 
304
  # Handle face detection for InstantID
305
  using_multiple_controlnets = self.using_multiple_controlnets
306
- face_kps = None
307
  face_embeddings = None
308
  has_detected_faces = False
309
 
310
  if using_multiple_controlnets and self.face_app is not None:
311
  print("Detecting faces and extracting keypoints...")
312
- img_array = np.array(resized_image)
313
  faces = self.face_app.get(img_array)
314
 
315
  if len(faces) > 0:
316
  has_detected_faces = True
317
  print(f"Detected {len(faces)} face(s)")
318
 
319
- # Get the largest face
320
- face = sorted(faces,
321
- key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
322
 
323
  # Extract face embeddings
324
- face_embeddings = torch.from_numpy(face.normed_embedding).unsqueeze(0).to(
325
- self.device, dtype=self.dtype
326
- )
327
 
328
- # Draw keypoints (this shows age, gender, expression)
329
- face_kps = draw_kps(resized_image, face.kps)
330
- print(f"Face keypoints drawn (age/gender/expression preserved)")
331
- else:
332
- print("No faces detected in image")
333
 
334
  # Set LORA scale
335
  if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
@@ -343,52 +356,53 @@ class RetroArtConverter:
343
  pipe_kwargs = {
344
  "prompt": prompt,
345
  "negative_prompt": negative_prompt,
346
- "image": resized_image, # Original image for img2img
 
347
  "num_inference_steps": num_inference_steps,
348
  "guidance_scale": guidance_scale,
349
- "strength": strength, # img2img denoising strength
350
  "generator": torch.Generator(device=self.device).manual_seed(42)
351
  }
352
 
 
 
 
 
353
  # Configure ControlNet inputs
354
- if using_multiple_controlnets and has_detected_faces and face_kps is not None:
355
- print("Using InstantID + Zoe Depth ControlNets with face keypoints")
356
- control_images = [face_kps, depth_image]
357
- conditioning_scales = [face_strength, depth_control_scale]
 
358
 
359
  pipe_kwargs["control_image"] = control_images
360
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
361
-
362
- # Add face embeddings through IP-Adapter
363
- if face_embeddings is not None and hasattr(self.pipe, 'set_ip_adapter_scale'):
364
- pipe_kwargs["ip_adapter_image_embeds"] = [face_embeddings]
365
 
366
- elif using_multiple_controlnets:
367
- print("Multiple ControlNets available but no faces detected - using depth only")
368
- # Use depth for both to maintain structure
369
  control_images = [depth_image, depth_image]
370
- conditioning_scales = [0.0, depth_control_scale] # Disable InstantID
371
 
372
  pipe_kwargs["control_image"] = control_images
373
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
374
 
375
  else:
376
- print("Using Zoe Depth ControlNet only")
377
  pipe_kwargs["control_image"] = depth_image
378
- pipe_kwargs["controlnet_conditioning_scale"] = depth_control_scale
379
 
380
  # Generate
381
- mode = "LCM" if self.use_lcm else "Quality"
382
- print(f"Generating with {mode} mode: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
383
  result = self.pipe(**pipe_kwargs)
384
 
385
  return result.images[0]
386
 
 
387
  # Initialize converter
388
  print("Initializing RetroArt Converter...")
389
- print("Choose mode: LCM (fast) or Quality (better)")
390
- converter_lcm = RetroArtConverter(use_lcm=True)
391
- converter_quality = RetroArtConverter(use_lcm=False)
392
 
393
  @spaces.GPU
394
  def process_image(
@@ -397,31 +411,25 @@ def process_image(
397
  negative_prompt,
398
  steps,
399
  guidance_scale,
400
- strength,
401
  controlnet_scale,
402
  lora_scale,
403
- face_strength,
404
- depth_control_scale,
405
- use_lcm_mode
406
  ):
407
  if image is None:
408
  return None
409
 
410
  try:
411
- # Choose the right converter based on mode
412
- converter = converter_lcm if use_lcm_mode else converter_quality
413
-
414
  result = converter.generate_retro_art(
415
  input_image=image,
416
  prompt=prompt,
417
  negative_prompt=negative_prompt,
418
  num_inference_steps=int(steps),
419
  guidance_scale=guidance_scale,
420
- strength=strength,
421
  controlnet_conditioning_scale=controlnet_scale,
422
  lora_scale=lora_scale,
423
- face_strength=face_strength,
424
- depth_control_scale=depth_control_scale
425
  )
426
  return result
427
  except Exception as e:
@@ -430,27 +438,44 @@ def process_image(
430
  traceback.print_exc()
431
  raise gr.Error(f"Generation failed: {str(e)}")
432
 
 
433
  # Gradio UI
434
- with gr.Blocks(title="RetroArt Converter - Improved", theme=gr.themes.Soft()) as demo:
435
- gr.Markdown("""
436
- # ๐ŸŽฎ RetroArt Converter (Improved with True Img2Img)
437
 
438
- Convert images into retro pixel art style with **proper face detection** and **gender/age preservation**!
439
 
440
- **โœจ Key Improvements:**
441
- - ๐ŸŽฏ **True img2img pipeline** for better structure preservation
442
- - ๐Ÿ‘ค **draw_kps**: Detects and preserves age, gender, expression
443
- - ๐Ÿ—บ๏ธ **Zoe Depth**: Superior depth estimation
444
- - โšก **Dual Mode**: Fast LCM or Quality DPM++
445
- - ๐ŸŽจ Custom pixel art LORA with trigger: `p1x3l4rt, pixel art`
 
 
446
  """)
447
 
448
  # Model status
449
- status_text = "**๐Ÿ“ฆ Loaded Models (LCM Mode):**\n"
450
- status_text += f"- Custom Checkpoint: {'โœ“ Loaded' if converter_lcm.models_loaded['custom_checkpoint'] else 'โœ— Using SDXL base'}\n"
451
- status_text += f"- LORA (RetroArt): {'โœ“ Loaded' if converter_lcm.models_loaded['lora'] else 'โœ— Disabled'}\n"
452
- status_text += f"- InstantID: {'โœ“ Loaded' if converter_lcm.models_loaded['instantid'] else 'โœ— Disabled'}\n"
453
- gr.Markdown(status_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
  with gr.Row():
456
  with gr.Column():
@@ -469,44 +494,29 @@ with gr.Blocks(title="RetroArt Converter - Improved", theme=gr.themes.Soft()) as
469
  lines=2
470
  )
471
 
472
- use_lcm_mode = gr.Checkbox(
473
- label="Use LCM Mode (Fast)",
474
- value=True,
475
- info="Uncheck for Quality mode (slower but better)"
476
- )
477
-
478
- with gr.Accordion("โš™๏ธ Generation Settings", open=True):
479
  steps = gr.Slider(
480
  minimum=4,
481
  maximum=50,
482
- value=12,
483
  step=1,
484
- label="Inference Steps (12 for LCM, 25-40 for Quality)"
485
  )
486
 
487
  guidance_scale = gr.Slider(
488
  minimum=0.5,
489
  maximum=15.0,
490
- value=1.0,
491
  step=0.1,
492
- label="Guidance Scale (1.0-1.5 for LCM, 5-7.5 for Quality)"
493
  )
494
 
495
  strength = gr.Slider(
496
  minimum=0.3,
497
- maximum=1.0,
498
- value=0.7,
499
  step=0.05,
500
- label="Img2Img Strength (how much to change)"
501
- )
502
-
503
- with gr.Accordion("๐ŸŽจ Style Settings", open=True):
504
- lora_scale = gr.Slider(
505
- minimum=0.5,
506
- maximum=1.5,
507
- value=1.0,
508
- step=0.05,
509
- label="RetroArt LORA Scale"
510
  )
511
 
512
  controlnet_scale = gr.Slider(
@@ -514,26 +524,24 @@ with gr.Blocks(title="RetroArt Converter - Improved", theme=gr.themes.Soft()) as
514
  maximum=1.2,
515
  value=0.8,
516
  step=0.05,
517
- label="Overall ControlNet Scale"
518
  )
519
-
520
- with gr.Accordion("๐Ÿ‘ค Face & Depth Settings", open=False):
521
- face_strength = gr.Slider(
522
- minimum=0,
523
- maximum=2.0,
524
- value=0.85,
525
  step=0.05,
526
- label="Face Preservation (InstantID)",
527
- info="Higher = better face likeness"
528
  )
529
-
530
- depth_control_scale = gr.Slider(
 
531
  minimum=0,
532
- maximum=1.0,
533
  value=0.8,
534
- step=0.05,
535
- label="Zoe Depth Control Scale",
536
- info="Higher = more structure preservation"
537
  )
538
 
539
  generate_btn = gr.Button("๐ŸŽจ Generate Retro Art", variant="primary", size="lg")
@@ -541,60 +549,45 @@ with gr.Blocks(title="RetroArt Converter - Improved", theme=gr.themes.Soft()) as
541
  with gr.Column():
542
  output_image = gr.Image(label="Retro Art Output")
543
 
544
- gr.Markdown("""
545
  ### ๐Ÿ’ก Tips for Best Results:
546
 
547
- **Mode Selection:**
548
- - โœ… **LCM Mode**: 12 steps, CFG 1.0-1.5, Strength 0.6-0.8 (โšก fast!)
549
- - โœ… **Quality Mode**: 25-40 steps, CFG 5-7.5, Strength 0.4-0.7 (๐ŸŽจ better!)
 
 
 
 
 
 
 
550
 
551
- **Face Preservation:**
552
- - System automatically detects faces and draws keypoints
553
- - Preserves age, gender, and expression characteristics
554
- - Adjust "Face Preservation" slider for control
555
 
556
- **For Best Quality:**
557
- - Use high-resolution input images (min 512px)
558
- - For portraits: enable Quality mode + high face strength
559
- - For scenes: lower img2img strength for more creativity
560
- - Adjust depth control for structure vs creativity balance
561
 
562
- **Style Control:**
563
- - LORA trigger word auto-added for pixel art style
564
- - Increase LORA scale (1.2-1.5) for stronger retro effect
565
- - Try: "SNES style", "16-bit RPG", "Game Boy advance style"
566
  """)
567
 
568
- # Update defaults when switching modes
569
- def update_mode_defaults(use_lcm):
570
- if use_lcm:
571
- return (
572
- gr.update(value=12), # steps
573
- gr.update(value=1.0), # guidance_scale
574
- gr.update(value=0.7) # strength
575
- )
576
- else:
577
- return (
578
- gr.update(value=30), # steps
579
- gr.update(value=6.0), # guidance_scale
580
- gr.update(value=0.6) # strength
581
- )
582
-
583
- use_lcm_mode.change(
584
- fn=update_mode_defaults,
585
- inputs=[use_lcm_mode],
586
- outputs=[steps, guidance_scale, strength]
587
- )
588
-
589
  generate_btn.click(
590
  fn=process_image,
591
  inputs=[
592
- input_image, prompt, negative_prompt, steps, guidance_scale, strength,
593
- controlnet_scale, lora_scale, face_strength, depth_control_scale, use_lcm_mode
594
  ],
595
  outputs=[output_image]
596
  )
597
 
 
598
  if __name__ == "__main__":
599
  demo.queue(max_size=20)
600
  demo.launch(
 
2
  import gradio as gr
3
  import torch
4
  from diffusers import (
5
+ StableDiffusionXLControlNetImg2ImgPipeline, # Changed to img2img
6
  ControlNetModel,
7
  AutoencoderKL,
8
+ LCMScheduler,
9
+ DPMSolverMultistepScheduler
10
  )
11
  from diffusers.models.attention_processor import AttnProcessor2_0
12
  from insightface.app import FaceAnalysis
13
  from PIL import Image
14
  import numpy as np
15
  import cv2
16
+ import math
17
+ from controlnet_aux import ZoeDetector # Better depth detection
18
  from huggingface_hub import hf_hub_download
19
  import os
20
 
 
 
 
 
 
 
21
  # Configuration
22
  MODEL_REPO = "primerz/pixagram"
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
26
  # LORA trigger word
27
  TRIGGER_WORD = "p1x3l4rt, pixel art"
28
 
29
+ # Use LCM or DPM++ scheduler
30
+ USE_LCM = True # Set to False to use DPM++ 2M Karras
31
+
32
  print(f"Using device: {device}")
33
  print(f"Loading models from: {MODEL_REPO}")
34
  print(f"LORA Trigger Word: {TRIGGER_WORD}")
35
+ print(f"Scheduler: {'LCM' if USE_LCM else 'DPM++ 2M Karras'}")
36
+
37
+
38
+ def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
39
+ """Draw facial keypoints on image for InstantID ControlNet"""
40
+ stickwidth = 4
41
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
42
+ kps = np.array(kps)
43
+
44
+ w, h = image_pil.size
45
+ out_img = np.zeros([h, w, 3])
46
+
47
+ for i in range(len(limbSeq)):
48
+ index = limbSeq[i]
49
+ color = color_list[index[0]]
50
+
51
+ x = kps[index][:, 0]
52
+ y = kps[index][:, 1]
53
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
54
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
55
+ polygon = cv2.ellipse2Poly(
56
+ (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
57
+ )
58
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
59
+ out_img = (out_img * 0.6).astype(np.uint8)
60
+
61
+ for idx_kp, kp in enumerate(kps):
62
+ color = color_list[idx_kp]
63
+ x, y = kp
64
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
65
+
66
+ out_img_pil = Image.fromarray(out_img.astype(np.uint8))
67
+ return out_img_pil
68
+
69
 
70
  class RetroArtConverter:
71
+ def __init__(self):
72
  self.device = device
73
  self.dtype = dtype
74
+ self.use_lcm = USE_LCM
75
  self.models_loaded = {
76
  'custom_checkpoint': False,
77
  'lora': False,
78
+ 'instantid': False,
79
+ 'zoe_depth': False
80
  }
81
 
82
  # Initialize face analysis for InstantID
83
+ print("Loading face analysis model...")
84
  try:
85
  self.face_app = FaceAnalysis(
86
  name='antelopev2',
 
95
  self.face_app = None
96
  self.face_detection_enabled = False
97
 
98
+ # Load Zoe Depth detector (better than DPT)
99
+ print("Loading Zoe Depth detector...")
100
+ try:
101
+ self.zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
102
+ self.zoe_depth.to(self.device)
103
+ print("โœ“ Zoe Depth loaded successfully")
104
+ self.models_loaded['zoe_depth'] = True
105
+ except Exception as e:
106
+ print(f"โš ๏ธ Zoe Depth not available: {e}")
107
+ self.zoe_depth = None
108
+
109
+ # Load ControlNet for depth
110
+ print("Loading ControlNet Zoe Depth model...")
111
+ self.controlnet_depth = ControlNetModel.from_pretrained(
112
+ "diffusers/controlnet-zoe-depth-sdxl-1.0",
113
+ torch_dtype=self.dtype
114
+ ).to(self.device)
115
+
116
+ # Load InstantID ControlNet
117
  print("Loading InstantID ControlNet...")
118
  try:
119
  self.controlnet_instantid = ControlNetModel.from_pretrained(
 
129
  self.controlnet_instantid = None
130
  self.instantid_enabled = False
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  # Determine which controlnets to use
133
  if self.instantid_enabled and self.controlnet_instantid is not None:
134
  controlnets = [self.controlnet_instantid, self.controlnet_depth]
135
+ print(f"Initializing with multiple ControlNets: InstantID + Depth")
136
  else:
137
  controlnets = self.controlnet_depth
138
+ print(f"Initializing with single ControlNet: Depth only")
 
 
 
 
 
 
 
139
 
140
  # Load SDXL checkpoint from HuggingFace Hub
141
+ print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
142
  try:
143
  model_path = hf_hub_download(
144
  repo_id=MODEL_REPO,
145
  filename="horizon.safetensors",
146
  repo_type="model"
147
  )
148
+ # Use Img2Img pipeline
149
+ self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
150
  model_path,
151
  controlnet=controlnets,
 
152
  torch_dtype=self.dtype,
153
  use_safetensors=True
154
  ).to(self.device)
155
+ print("โœ“ Custom checkpoint loaded successfully (VAE bundled)")
156
  self.models_loaded['custom_checkpoint'] = True
157
  except Exception as e:
158
  print(f"โš ๏ธ Could not load custom checkpoint: {e}")
159
  print("Using default SDXL base model")
160
+ self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
161
  "stabilityai/stable-diffusion-xl-base-1.0",
162
  controlnet=controlnets,
 
163
  torch_dtype=self.dtype,
164
  use_safetensors=True
165
  ).to(self.device)
166
  self.models_loaded['custom_checkpoint'] = False
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # Load LORA from HuggingFace Hub
169
  print("Loading LORA (retroart) from HuggingFace Hub...")
170
  try:
 
181
  print(f"โš ๏ธ Could not load LORA: {e}")
182
  self.models_loaded['lora'] = False
183
 
184
+ # Setup scheduler based on USE_LCM flag
185
+ if self.use_lcm:
186
+ print("Setting up LCM scheduler...")
187
  self.pipe.scheduler = LCMScheduler.from_config(
188
  self.pipe.scheduler.config
189
  )
190
  else:
191
+ print("Setting up DPM++ 2M Karras scheduler...")
192
  self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
193
  self.pipe.scheduler.config,
194
  use_karras_sigmas=True
 
205
  except Exception as e:
206
  print(f"โš ๏ธ xformers not available: {e}")
207
 
208
+ # Set CLIP skip to 2
209
+ if hasattr(self.pipe, 'text_encoder'):
210
+ self.clip_skip = 2
211
+ print(f"โœ“ CLIP skip set to {self.clip_skip}")
212
+
213
  # Track controlnet configuration
214
  self.using_multiple_controlnets = isinstance(controlnets, list)
215
  print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
 
221
  print("===================\n")
222
 
223
  print("โœ“ Model initialization complete!")
224
+ print("\n=== CONFIGURATION ===")
225
+ print(f"Scheduler: {'LCM' if self.use_lcm else 'DPM++ 2M Karras'}")
226
+ if self.use_lcm:
227
+ print("Recommended Steps: 12")
228
  print("Recommended CFG: 1.0-1.5")
 
229
  else:
230
+ print("Recommended Steps: 30-50")
231
+ print("Recommended CFG: 7.0-8.0")
232
+ print("Recommended Resolution: 896x1152 or 832x1216")
233
+ print("CLIP Skip: 2")
 
234
  print(f"LORA Trigger: '{TRIGGER_WORD}'")
235
+ print("=====================\n")
236
 
237
  def get_depth_map(self, image):
238
+ """Generate depth map using Zoe Depth"""
239
+ if self.zoe_depth is not None:
240
+ # Use Zoe detector
241
+ depth_image = self.zoe_depth(image, detect_resolution=512, image_resolution=1024)
242
  return depth_image
243
  else:
244
+ # Fallback to simple grayscale
245
+ gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
246
+ depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
247
  return Image.fromarray(depth_colored)
248
 
249
  def calculate_optimal_size(self, original_width, original_height):
250
  """Calculate optimal size from recommended resolutions"""
251
  aspect_ratio = original_width / original_height
252
 
253
+ # Recommended resolutions for this model
254
  recommended_sizes = [
255
  (896, 1152), # Portrait
256
  (1152, 896), # Landscape
 
288
  input_image,
289
  prompt="retro game character, vibrant colors, detailed",
290
  negative_prompt="blurry, low quality, ugly, distorted",
291
+ num_inference_steps=12,
292
+ guidance_scale=1.0,
 
293
  controlnet_conditioning_scale=0.8,
294
  lora_scale=1.0,
295
+ identity_preservation=0.8,
296
+ strength=0.75 # img2img strength
297
  ):
298
+ """Generate retro art with img2img pipeline"""
299
 
300
  # Add trigger word to prompt
301
  prompt = self.add_trigger_word(prompt)
 
306
 
307
  print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}")
308
  print(f"Prompt: {prompt}")
309
+ print(f"Img2Img Strength: {strength}")
310
 
311
  # Resize with high quality
312
  resized_image = input_image.resize((target_width, target_height), Image.LANCZOS)
 
319
 
320
  # Handle face detection for InstantID
321
  using_multiple_controlnets = self.using_multiple_controlnets
322
+ face_kps_image = None
323
  face_embeddings = None
324
  has_detected_faces = False
325
 
326
  if using_multiple_controlnets and self.face_app is not None:
327
  print("Detecting faces and extracting keypoints...")
328
+ img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
329
  faces = self.face_app.get(img_array)
330
 
331
  if len(faces) > 0:
332
  has_detected_faces = True
333
  print(f"Detected {len(faces)} face(s)")
334
 
335
+ # Get largest face
336
+ face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
 
337
 
338
  # Extract face embeddings
339
+ face_embeddings = face.normed_embedding
 
 
340
 
341
+ # Draw keypoints
342
+ face_kps = face.kps
343
+ face_kps_image = draw_kps(resized_image, face_kps)
344
+
345
+ print(f"Face info: bbox={face.bbox}, age={face.age if hasattr(face, 'age') else 'N/A'}, gender={'M' if face.gender == 1 else 'F' if hasattr(face, 'gender') else 'N/A'}")
346
 
347
  # Set LORA scale
348
  if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
 
356
  pipe_kwargs = {
357
  "prompt": prompt,
358
  "negative_prompt": negative_prompt,
359
+ "image": resized_image, # img2img source
360
+ "strength": strength, # how much to transform
361
  "num_inference_steps": num_inference_steps,
362
  "guidance_scale": guidance_scale,
 
363
  "generator": torch.Generator(device=self.device).manual_seed(42)
364
  }
365
 
366
+ # Add CLIP skip
367
+ if hasattr(self.pipe, 'text_encoder'):
368
+ pipe_kwargs["clip_skip"] = 2
369
+
370
  # Configure ControlNet inputs
371
+ if using_multiple_controlnets and has_detected_faces and face_kps_image is not None:
372
+ print("Using InstantID (keypoints) + Depth ControlNets")
373
+ # Order: [InstantID, Depth]
374
+ control_images = [face_kps_image, depth_image]
375
+ conditioning_scales = [identity_preservation, controlnet_conditioning_scale]
376
 
377
  pipe_kwargs["control_image"] = control_images
378
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
 
 
 
 
379
 
380
+ elif using_multiple_controlnets and not has_detected_faces:
381
+ print("Multiple ControlNets available but no faces detected, using depth only")
382
+ # Use depth for both to avoid errors
383
  control_images = [depth_image, depth_image]
384
+ conditioning_scales = [0.0, controlnet_conditioning_scale]
385
 
386
  pipe_kwargs["control_image"] = control_images
387
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
388
 
389
  else:
390
+ print("Using Depth ControlNet only")
391
  pipe_kwargs["control_image"] = depth_image
392
+ pipe_kwargs["controlnet_conditioning_scale"] = controlnet_conditioning_scale
393
 
394
  # Generate
395
+ scheduler_name = "LCM" if self.use_lcm else "DPM++"
396
+ print(f"Generating with {scheduler_name}: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
397
  result = self.pipe(**pipe_kwargs)
398
 
399
  return result.images[0]
400
 
401
+
402
  # Initialize converter
403
  print("Initializing RetroArt Converter...")
404
+ converter = RetroArtConverter()
405
+
 
406
 
407
  @spaces.GPU
408
  def process_image(
 
411
  negative_prompt,
412
  steps,
413
  guidance_scale,
 
414
  controlnet_scale,
415
  lora_scale,
416
+ identity_preservation,
417
+ strength
 
418
  ):
419
  if image is None:
420
  return None
421
 
422
  try:
 
 
 
423
  result = converter.generate_retro_art(
424
  input_image=image,
425
  prompt=prompt,
426
  negative_prompt=negative_prompt,
427
  num_inference_steps=int(steps),
428
  guidance_scale=guidance_scale,
 
429
  controlnet_conditioning_scale=controlnet_scale,
430
  lora_scale=lora_scale,
431
+ identity_preservation=identity_preservation,
432
+ strength=strength
433
  )
434
  return result
435
  except Exception as e:
 
438
  traceback.print_exc()
439
  raise gr.Error(f"Generation failed: {str(e)}")
440
 
441
+
442
  # Gradio UI
443
+ with gr.Blocks(title="RetroArt Converter - Img2Img", theme=gr.themes.Soft()) as demo:
444
+ gr.Markdown(f"""
445
+ # ๐ŸŽฎ RetroArt Converter (Img2Img + InstantID)
446
 
447
+ Convert images into retro pixel art style using img2img with face preservation!
448
 
449
+ **โœจ Features:**
450
+ - ๐Ÿ–ผ๏ธ **True Img2Img**: Transforms your image while preserving structure
451
+ - ๐Ÿ‘ค **InstantID**: Facial keypoint detection with age/gender detection
452
+ - ๐ŸŽจ Custom pixel art LORA with trigger word: `{TRIGGER_WORD}`
453
+ - ๐Ÿ”๏ธ **Zoe Depth**: Better depth map quality
454
+ - โšก **{'LCM' if USE_LCM else 'DPM++ 2M Karras'}** scheduler
455
+ - ๐Ÿ“ Optimized resolutions: 896x1152 / 832x1216
456
+ - ๐ŸŽฏ CLIP Skip 2 for better style
457
  """)
458
 
459
  # Model status
460
+ if converter.models_loaded:
461
+ status_text = "**๐Ÿ“ฆ Loaded Models:**\n"
462
+ status_text += f"- Custom Checkpoint (Horizon): {'โœ“ Loaded' if converter.models_loaded['custom_checkpoint'] else 'โœ— Using SDXL base'}\n"
463
+ status_text += f"- LORA (RetroArt): {'โœ“ Loaded' if converter.models_loaded['lora'] else 'โœ— Disabled'}\n"
464
+ status_text += f"- InstantID: {'โœ“ Loaded' if converter.models_loaded['instantid'] else 'โœ— Disabled'}\n"
465
+ status_text += f"- Zoe Depth: {'โœ“ Loaded' if converter.models_loaded['zoe_depth'] else 'โœ— Fallback'}\n"
466
+ gr.Markdown(status_text)
467
+
468
+ scheduler_info = f"""
469
+ **โš™๏ธ Configuration:**
470
+ - Pipeline: **Img2Img** (better structure preservation)
471
+ - Scheduler: **{'LCM' if USE_LCM else 'DPM++ 2M Karras'}**
472
+ - Recommended Steps: **{12 if USE_LCM else '30-50'}**
473
+ - Recommended CFG: **{1.0 if USE_LCM else '7.0-8.0'}**
474
+ - CLIP Skip: **2**
475
+ - LORA Trigger: `{TRIGGER_WORD}` (auto-added)
476
+ - Face Detection: **Age & Gender detection enabled**
477
+ """
478
+ gr.Markdown(scheduler_info)
479
 
480
  with gr.Row():
481
  with gr.Column():
 
494
  lines=2
495
  )
496
 
497
+ with gr.Accordion(f"โšก {'LCM' if USE_LCM else 'DPM++'} Settings", open=True):
 
 
 
 
 
 
498
  steps = gr.Slider(
499
  minimum=4,
500
  maximum=50,
501
+ value=12 if USE_LCM else 30,
502
  step=1,
503
+ label=f"Inference Steps ({'LCM works with 12' if USE_LCM else 'DPM++ uses 30-50'})"
504
  )
505
 
506
  guidance_scale = gr.Slider(
507
  minimum=0.5,
508
  maximum=15.0,
509
+ value=1.0 if USE_LCM else 7.5,
510
  step=0.1,
511
+ label=f"Guidance Scale (CFG) - {'LCM uses 1.0-1.5' if USE_LCM else 'DPM++ uses 7-8'}"
512
  )
513
 
514
  strength = gr.Slider(
515
  minimum=0.3,
516
+ maximum=0.95,
517
+ value=0.75,
518
  step=0.05,
519
+ label="Img2Img Strength (how much to transform)"
 
 
 
 
 
 
 
 
 
520
  )
521
 
522
  controlnet_scale = gr.Slider(
 
524
  maximum=1.2,
525
  value=0.8,
526
  step=0.05,
527
+ label="Zoe Depth ControlNet Scale"
528
  )
529
+
530
+ lora_scale = gr.Slider(
531
+ minimum=0.5,
532
+ maximum=1.5,
533
+ value=1.0,
 
534
  step=0.05,
535
+ label="RetroArt LORA Scale"
 
536
  )
537
+
538
+ with gr.Accordion("๐Ÿ‘ค InstantID Settings (for portraits)", open=False):
539
+ identity_preservation = gr.Slider(
540
  minimum=0,
541
+ maximum=1.5,
542
  value=0.8,
543
+ step=0.1,
544
+ label="Identity/Keypoint Preservation"
 
545
  )
546
 
547
  generate_btn = gr.Button("๐ŸŽจ Generate Retro Art", variant="primary", size="lg")
 
549
  with gr.Column():
550
  output_image = gr.Image(label="Retro Art Output")
551
 
552
+ gr.Markdown(f"""
553
  ### ๐Ÿ’ก Tips for Best Results:
554
 
555
+ **For Img2Img:**
556
+ - โœ… **Strength 0.7-0.8**: Good balance of transformation and structure
557
+ - โœ… **Strength 0.5-0.6**: More faithful to original
558
+ - โœ… **Strength 0.8-0.9**: More creative/stylized
559
+
560
+ **For {'LCM' if USE_LCM else 'DPM++'}:**
561
+ - {'โœ… Use **12 steps** (optimized for speed)' if USE_LCM else 'โœ… Use **30-50 steps** (better quality)'}
562
+ - {'โœ… Keep CFG at **1.0-1.5**' if USE_LCM else 'โœ… Keep CFG at **7.0-8.0**'}
563
+ - โœ… LORA trigger word is **auto-added**
564
+ - โœ… Resolution auto-optimized to 896x1152 or 832x1216
565
 
566
+ **For Portraits:**
567
+ - The system detects **age and gender** automatically
568
+ - Facial **keypoints** are used for better face preservation
569
+ - Adjust Identity Preservation: lower = more stylized, higher = more realistic face
570
 
571
+ **For Quality:**
572
+ - Use high-resolution input images
573
+ - Be specific in prompts: "16-bit game character" vs "character"
574
+ - Adjust Depth scale: lower = more creative, higher = more faithful depth
 
575
 
576
+ **For Style:**
577
+ - Increase LORA scale (1.0-1.5) for stronger pixel art effect
578
+ - Try prompts like: "SNES style", "16-bit RPG", "Game Boy advance style"
 
579
  """)
580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  generate_btn.click(
582
  fn=process_image,
583
  inputs=[
584
+ input_image, prompt, negative_prompt, steps, guidance_scale,
585
+ controlnet_scale, lora_scale, identity_preservation, strength
586
  ],
587
  outputs=[output_image]
588
  )
589
 
590
+
591
  if __name__ == "__main__":
592
  demo.queue(max_size=20)
593
  demo.launch(