primerz commited on
Commit
c6815c0
ยท
verified ยท
1 Parent(s): 912e6dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -222
app.py CHANGED
@@ -6,12 +6,11 @@ from diffusers import (
6
  StableDiffusionXLControlNetPipeline,
7
  ControlNetModel,
8
  AutoencoderKL,
9
- DPMSolverMultistepScheduler,
10
- EulerAncestralDiscreteScheduler
11
  )
12
  from diffusers.models.attention_processor import AttnProcessor2_0
13
  from insightface.app import FaceAnalysis
14
- from PIL import Image, ImageEnhance, ImageFilter
15
  import numpy as np
16
  import cv2
17
  from transformers import pipeline as transformers_pipeline
@@ -23,8 +22,12 @@ MODEL_REPO = "primerz/pixagram"
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
  dtype = torch.float16 if device == "cuda" else torch.float32
25
 
 
 
 
26
  print(f"Using device: {device}")
27
  print(f"Loading models from: {MODEL_REPO}")
 
28
 
29
  class RetroArtConverter:
30
  def __init__(self):
@@ -32,7 +35,6 @@ class RetroArtConverter:
32
  self.dtype = dtype
33
  self.models_loaded = {
34
  'custom_checkpoint': False,
35
- 'custom_vae': False,
36
  'lora': False,
37
  'instantid': False
38
  }
@@ -50,7 +52,6 @@ class RetroArtConverter:
50
  self.face_detection_enabled = True
51
  except Exception as e:
52
  print(f"โš ๏ธ Face detection not available: {e}")
53
- print("Continuing without face detection")
54
  self.face_app = None
55
  self.face_detection_enabled = False
56
 
@@ -61,7 +62,7 @@ class RetroArtConverter:
61
  torch_dtype=self.dtype
62
  ).to(self.device)
63
 
64
- # Load InstantID ControlNet for identity preservation
65
  print("Loading InstantID ControlNet...")
66
  try:
67
  self.controlnet_instantid = ControlNetModel.from_pretrained(
@@ -74,34 +75,10 @@ class RetroArtConverter:
74
  self.models_loaded['instantid'] = True
75
  except Exception as e:
76
  print(f"โš ๏ธ InstantID ControlNet not available: {e}")
77
- print("Running without InstantID")
78
  self.controlnet_instantid = None
79
  self.instantid_enabled = False
80
 
81
- # Load custom VAE from HuggingFace Hub
82
- print("Loading custom VAE (pixelate) from HuggingFace Hub...")
83
- try:
84
- vae_path = hf_hub_download(
85
- repo_id=MODEL_REPO,
86
- filename="pixelate.safetensors",
87
- repo_type="model"
88
- )
89
- self.vae = AutoencoderKL.from_single_file(
90
- vae_path,
91
- torch_dtype=self.dtype
92
- ).to(self.device)
93
- print("โœ“ Custom VAE loaded successfully")
94
- self.models_loaded['custom_vae'] = True
95
- except Exception as e:
96
- print(f"โš ๏ธ Could not load custom VAE: {e}")
97
- print("Using high-quality SDXL VAE instead")
98
- self.vae = AutoencoderKL.from_pretrained(
99
- "madebyollin/sdxl-vae-fp16-fix",
100
- torch_dtype=self.dtype
101
- ).to(self.device)
102
- self.models_loaded['custom_vae'] = False
103
-
104
- # Load depth estimator for preprocessing
105
  print("Loading depth estimator...")
106
  self.depth_estimator = transformers_pipeline(
107
  'depth-estimation',
@@ -118,7 +95,8 @@ class RetroArtConverter:
118
  print(f"Initializing with single ControlNet: Depth only")
119
 
120
  # Load SDXL checkpoint from HuggingFace Hub
121
- print("Loading SDXL checkpoint (horizon) from HuggingFace Hub...")
 
122
  try:
123
  model_path = hf_hub_download(
124
  repo_id=MODEL_REPO,
@@ -128,11 +106,10 @@ class RetroArtConverter:
128
  self.pipe = StableDiffusionXLControlNetPipeline.from_single_file(
129
  model_path,
130
  controlnet=controlnets,
131
- vae=self.vae,
132
  torch_dtype=self.dtype,
133
  use_safetensors=True
134
  ).to(self.device)
135
- print("โœ“ Custom checkpoint loaded successfully")
136
  self.models_loaded['custom_checkpoint'] = True
137
  except Exception as e:
138
  print(f"โš ๏ธ Could not load custom checkpoint: {e}")
@@ -140,7 +117,6 @@ class RetroArtConverter:
140
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
141
  "stabilityai/stable-diffusion-xl-base-1.0",
142
  controlnet=controlnets,
143
- vae=self.vae,
144
  torch_dtype=self.dtype,
145
  use_safetensors=True
146
  ).to(self.device)
@@ -155,25 +131,23 @@ class RetroArtConverter:
155
  repo_type="model"
156
  )
157
  self.pipe.load_lora_weights(lora_path)
158
- print("โœ“ LORA loaded successfully")
 
159
  self.models_loaded['lora'] = True
160
  except Exception as e:
161
  print(f"โš ๏ธ Could not load LORA: {e}")
162
- print("Running without LORA")
163
  self.models_loaded['lora'] = False
164
 
165
- # Use EulerAncestral scheduler for better quality
166
- self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
 
167
  self.pipe.scheduler.config
168
  )
169
 
170
- # Disable VAE slicing for better quality (use only if you have VRAM issues)
171
- # self.pipe.enable_vae_slicing()
172
-
173
- # Enable attention slicing for memory efficiency
174
  self.pipe.unet.set_attn_processor(AttnProcessor2_0())
175
 
176
- # Try to enable xformers if available
177
  if self.device == "cuda":
178
  try:
179
  self.pipe.enable_xformers_memory_efficient_attention()
@@ -181,7 +155,12 @@ class RetroArtConverter:
181
  except Exception as e:
182
  print(f"โš ๏ธ xformers not available: {e}")
183
 
184
- # Track whether we're using multiple ControlNets
 
 
 
 
 
185
  self.using_multiple_controlnets = isinstance(controlnets, list)
186
  print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
187
 
@@ -191,150 +170,133 @@ class RetroArtConverter:
191
  print(f"{model}: {status}")
192
  print("===================\n")
193
 
194
- print("Model initialization complete!")
 
 
 
 
 
 
 
 
195
 
196
- def enhance_image_quality(self, image):
197
- """Enhance input image quality before processing"""
198
- # Sharpen slightly
199
- enhancer = ImageEnhance.Sharpness(image)
200
- image = enhancer.enhance(1.2)
201
-
202
- # Enhance contrast slightly
203
- enhancer = ImageEnhance.Contrast(image)
204
- image = enhancer.enhance(1.1)
205
-
206
- return image
207
-
208
- def get_depth_map(self, image, enhance=True):
209
- """Generate depth map from input image with quality improvements"""
210
- # Enhance image before depth estimation if needed
211
- if enhance:
212
- image = self.enhance_image_quality(image)
213
-
214
  depth = self.depth_estimator(image)
215
  depth_image = depth['depth']
216
 
217
  depth_array = np.array(depth_image)
218
 
219
- # Better normalization with histogram stretching
220
  depth_min, depth_max = np.percentile(depth_array, [2, 98])
221
  depth_normalized = np.clip((depth_array - depth_min) / (depth_max - depth_min + 1e-8), 0, 1) * 255
222
  depth_normalized = depth_normalized.astype(np.uint8)
223
 
224
- # Apply slight gaussian blur to reduce noise
225
  depth_normalized = cv2.GaussianBlur(depth_normalized, (3, 3), 0)
226
 
227
- # Convert to 3-channel image
228
  depth_colored = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB)
229
 
230
  return Image.fromarray(depth_colored)
231
 
232
- def extract_face_embeddings(self, image):
233
- """Extract face embeddings using InsightFace"""
234
- if not self.face_detection_enabled or self.face_app is None:
235
- return None
236
-
237
- try:
238
- img_array = np.array(image)
239
- faces = self.face_app.get(img_array)
240
-
241
- if len(faces) == 0:
242
- return None
243
-
244
- # Use the largest face
245
- face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
246
- return torch.from_numpy(face.normed_embedding).unsqueeze(0)
247
- except Exception as e:
248
- print(f"Face embedding extraction error: {e}")
249
- return None
250
-
251
- def calculate_target_size(self, original_width, original_height, max_dimension=1024):
252
- """Calculate target size maintaining aspect ratio"""
253
  aspect_ratio = original_width / original_height
254
 
255
- if original_width > original_height:
256
- new_width = min(original_width, max_dimension)
257
- new_height = int(new_width / aspect_ratio)
258
- else:
259
- new_height = min(original_height, max_dimension)
260
- new_width = int(new_height * aspect_ratio)
261
-
262
- # Round to nearest multiple of 8
263
- new_width = (new_width // 8) * 8
264
- new_height = (new_height // 8) * 8
265
-
266
- return new_width, new_height
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  def generate_retro_art(
269
  self,
270
  input_image,
271
- prompt="retro pixel art game, 16-bit style, vibrant colors",
272
- negative_prompt="blurry, low quality, modern, photorealistic, 3d render",
273
- num_inference_steps=40, # Increased for better quality
274
- guidance_scale=7.5,
275
- controlnet_conditioning_scale=0.6, # Reduced for less depth influence
276
- lora_scale=0.85,
277
  identity_preservation=0.8,
278
- image_scale=0.2,
279
- enhance_quality=True # New parameter
280
  ):
281
- """Main generation function with quality improvements"""
 
 
 
282
 
283
- # Resize image maintaining aspect ratio
284
  original_width, original_height = input_image.size
285
- target_width, target_height = self.calculate_target_size(original_width, original_height)
286
 
287
  print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}")
 
288
 
289
- # Use LANCZOS for high-quality resizing
290
  resized_image = input_image.resize((target_width, target_height), Image.LANCZOS)
291
 
292
- # Optionally enhance image quality
293
- if enhance_quality:
294
- resized_image = self.enhance_image_quality(resized_image)
295
-
296
- # Generate depth map with quality enhancements
297
  print("Generating depth map...")
298
- depth_image = self.get_depth_map(resized_image, enhance=enhance_quality)
299
  depth_image = depth_image.resize((target_width, target_height), Image.LANCZOS)
300
 
301
- # Determine if we're using multiple ControlNets
302
  using_multiple_controlnets = self.using_multiple_controlnets
303
-
304
- # Extract face embeddings if InstantID is enabled
305
  face_embeddings = None
306
  has_detected_faces = False
307
 
308
  if using_multiple_controlnets:
309
- print("Extracting face embeddings...")
310
  img_array = np.array(resized_image)
311
  faces = self.face_app.get(img_array) if self.face_app is not None else []
312
 
313
  if len(faces) > 0:
314
  has_detected_faces = True
315
- print(f"Detected {len(faces)} face(s), using for identity preservation")
316
- # Get the largest face
317
  face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
318
  face_embeddings = torch.from_numpy(face.normed_embedding).unsqueeze(0).to(self.device, dtype=self.dtype)
319
-
320
- # Enhance prompt for face preservation
321
- prompt = f"portrait, detailed face, facial features, {prompt}"
322
 
323
  # Set LORA scale
324
  if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
325
  try:
326
  self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
327
- print(f"LORA scale set to: {lora_scale}")
328
  except Exception as e:
329
- print(f"Could not set LORA adapters: {e}")
330
-
331
- # Enhanced negative prompt for better quality
332
- enhanced_negative_prompt = f"{negative_prompt}, worst quality, low quality, normal quality, lowres, watermark, signature, text, jpeg artifacts, noise, grainy"
333
 
334
- # Prepare pipeline kwargs
335
  pipe_kwargs = {
336
  "prompt": prompt,
337
- "negative_prompt": enhanced_negative_prompt,
338
  "num_inference_steps": num_inference_steps,
339
  "guidance_scale": guidance_scale,
340
  "width": target_width,
@@ -342,21 +304,24 @@ class RetroArtConverter:
342
  "generator": torch.Generator(device=self.device).manual_seed(42)
343
  }
344
 
345
- # Add control images and scales based on ControlNet configuration
 
 
 
 
346
  if using_multiple_controlnets and has_detected_faces:
347
- print("Using multiple ControlNets (Depth + InstantID)")
348
  control_images = [depth_image, resized_image]
349
  conditioning_scales = [controlnet_conditioning_scale, image_scale]
350
 
351
  pipe_kwargs["image"] = control_images
352
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
353
 
354
- # Add face embeddings for InstantID IP-Adapter
355
  if face_embeddings is not None:
356
  pipe_kwargs["cross_attention_kwargs"] = {"ip_adapter_image_embeds": [face_embeddings]}
357
 
358
  elif using_multiple_controlnets and not has_detected_faces:
359
- print("Multiple ControlNets available but no faces detected, using depth only")
360
  control_images = [depth_image, depth_image]
361
  conditioning_scales = [controlnet_conditioning_scale, 0.0]
362
 
@@ -364,22 +329,20 @@ class RetroArtConverter:
364
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
365
 
366
  else:
367
- print("Using single ControlNet (Depth only)")
368
  pipe_kwargs["image"] = depth_image
369
  pipe_kwargs["controlnet_conditioning_scale"] = controlnet_conditioning_scale
370
 
371
- # Generate image
372
- print("Generating retro art...")
373
- print(f"Steps: {num_inference_steps}, Guidance: {guidance_scale}")
374
  result = self.pipe(**pipe_kwargs)
375
 
376
  return result.images[0]
377
 
378
- # Initialize the converter
379
  print("Initializing RetroArt Converter...")
380
  converter = RetroArtConverter()
381
 
382
- # Gradio interface with ZeroGPU support
383
  @spaces.GPU
384
  def process_image(
385
  image,
@@ -390,8 +353,7 @@ def process_image(
390
  controlnet_scale,
391
  lora_scale,
392
  identity_preservation,
393
- image_scale,
394
- enhance_quality
395
  ):
396
  if image is None:
397
  return None
@@ -406,8 +368,7 @@ def process_image(
406
  controlnet_conditioning_scale=controlnet_scale,
407
  lora_scale=lora_scale,
408
  identity_preservation=identity_preservation,
409
- image_scale=image_scale,
410
- enhance_quality=enhance_quality
411
  )
412
  return result
413
  except Exception as e:
@@ -416,87 +377,89 @@ def process_image(
416
  traceback.print_exc()
417
  raise gr.Error(f"Generation failed: {str(e)}")
418
 
419
- # Create Gradio interface
420
- with gr.Blocks(title="RetroArt Converter", theme=gr.themes.Soft()) as demo:
421
  gr.Markdown("""
422
- # ๐ŸŽฎ RetroArt Converter - Quality Enhanced
423
 
424
- Convert any image into retro game art style with improved quality!
425
 
426
- **Features:**
427
- - High-quality depth estimation and preprocessing
428
- - Enhanced prompts for better results
429
- - Custom SDXL checkpoint (Horizon)
430
- - Pixelate VAE for authentic retro look
431
- - RetroArt LORA for style enhancement
432
- - Face preservation with InstantID
433
  """)
434
 
435
- # Model status display
436
  if converter.models_loaded:
437
- status_text = "**Loaded Models:**\n"
438
- status_text += f"- Custom Checkpoint: {'โœ“' if converter.models_loaded['custom_checkpoint'] else 'โœ— (using SDXL base)'}\n"
439
- status_text += f"- Custom VAE: {'โœ“' if converter.models_loaded['custom_vae'] else 'โœ— (using default VAE)'}\n"
440
- status_text += f"- LORA: {'โœ“' if converter.models_loaded['lora'] else 'โœ— (disabled)'}\n"
441
- status_text += f"- InstantID: {'โœ“' if converter.models_loaded['instantid'] else 'โœ— (disabled)'}\n"
442
  gr.Markdown(status_text)
443
 
 
 
 
 
 
 
 
 
 
444
  with gr.Row():
445
  with gr.Column():
446
  input_image = gr.Image(label="Input Image", type="pil")
447
 
448
  prompt = gr.Textbox(
449
- label="Prompt",
450
- value="masterpiece, best quality, retro pixel art game, 16-bit style, vibrant colors, highly detailed",
451
- lines=3
 
452
  )
453
 
454
  negative_prompt = gr.Textbox(
455
  label="Negative Prompt",
456
- value="blurry, low quality, modern, photorealistic, 3d render, ugly, distorted, deformed",
457
  lines=2
458
  )
459
 
460
- enhance_quality = gr.Checkbox(
461
- label="Enable Quality Enhancement",
462
- value=True,
463
- info="Sharpen and enhance input image before processing"
464
- )
465
-
466
- with gr.Accordion("Quality Settings", open=True):
467
  steps = gr.Slider(
468
- minimum=20,
469
- maximum=70,
470
- value=40,
471
- step=5,
472
- label="Inference Steps (more = better quality but slower)"
473
  )
474
 
475
  guidance_scale = gr.Slider(
476
- minimum=3,
477
- maximum=15,
478
- value=7.5,
479
- step=0.5,
480
- label="Guidance Scale (how closely to follow prompt)"
481
  )
482
 
483
  controlnet_scale = gr.Slider(
484
- minimum=0,
485
- maximum=1.5,
486
- value=0.6,
487
  step=0.05,
488
- label="ControlNet Depth Scale (lower = more creative)"
489
  )
490
 
491
  lora_scale = gr.Slider(
492
- minimum=0,
493
- maximum=2,
494
- value=0.85,
495
  step=0.05,
496
  label="RetroArt LORA Scale"
497
  )
498
 
499
- with gr.Accordion("Identity Settings (for portraits)", open=False):
500
  identity_preservation = gr.Slider(
501
  minimum=0,
502
  maximum=1.5,
@@ -519,43 +482,33 @@ with gr.Blocks(title="RetroArt Converter", theme=gr.themes.Soft()) as demo:
519
  output_image = gr.Image(label="Retro Art Output")
520
 
521
  gr.Markdown("""
522
- ### Tips for Best Quality:
523
- 1. **Use high-resolution input images** (at least 512x512)
524
- 2. **Increase inference steps** to 50-60 for maximum quality
525
- 3. **Lower ControlNet scale** (0.5-0.6) for more stylization
526
- 4. **Adjust guidance scale:** 7-9 for balanced results
527
- 5. **Enable quality enhancement** for sharper inputs
528
- 6. Try different prompts with quality keywords: "masterpiece, best quality, highly detailed"
 
 
 
 
 
 
 
 
 
529
  """)
530
 
531
- gr.Examples(
532
- examples=[
533
- [
534
- "example_portrait.jpg",
535
- "masterpiece, best quality, retro pixel art portrait, 16-bit game character, vibrant colors",
536
- "blurry, modern, low quality",
537
- 40, 7.5, 0.6, 0.85, 0.8, 0.2, True
538
- ],
539
- ],
540
- inputs=[
541
- input_image, prompt, negative_prompt, steps, guidance_scale,
542
- controlnet_scale, lora_scale, identity_preservation, image_scale, enhance_quality
543
- ],
544
- outputs=[output_image],
545
- fn=process_image,
546
- cache_examples=False
547
- )
548
-
549
  generate_btn.click(
550
  fn=process_image,
551
  inputs=[
552
- input_image, prompt, negative_prompt, steps, guidance_scale,
553
- controlnet_scale, lora_scale, identity_preservation, image_scale, enhance_quality
554
  ],
555
  outputs=[output_image]
556
  )
557
 
558
- # Launch with API enabled
559
  if __name__ == "__main__":
560
  demo.queue(max_size=20)
561
  demo.launch(
 
6
  StableDiffusionXLControlNetPipeline,
7
  ControlNetModel,
8
  AutoencoderKL,
9
+ LCMScheduler # CORRECT SCHEDULER FOR LCM
 
10
  )
11
  from diffusers.models.attention_processor import AttnProcessor2_0
12
  from insightface.app import FaceAnalysis
13
+ from PIL import Image
14
  import numpy as np
15
  import cv2
16
  from transformers import pipeline as transformers_pipeline
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  dtype = torch.float16 if device == "cuda" else torch.float32
24
 
25
+ # LORA trigger word
26
+ TRIGGER_WORD = "p1x3l4rt, pixel art"
27
+
28
  print(f"Using device: {device}")
29
  print(f"Loading models from: {MODEL_REPO}")
30
+ print(f"LORA Trigger Word: {TRIGGER_WORD}")
31
 
32
  class RetroArtConverter:
33
  def __init__(self):
 
35
  self.dtype = dtype
36
  self.models_loaded = {
37
  'custom_checkpoint': False,
 
38
  'lora': False,
39
  'instantid': False
40
  }
 
52
  self.face_detection_enabled = True
53
  except Exception as e:
54
  print(f"โš ๏ธ Face detection not available: {e}")
 
55
  self.face_app = None
56
  self.face_detection_enabled = False
57
 
 
62
  torch_dtype=self.dtype
63
  ).to(self.device)
64
 
65
+ # Load InstantID ControlNet (optional)
66
  print("Loading InstantID ControlNet...")
67
  try:
68
  self.controlnet_instantid = ControlNetModel.from_pretrained(
 
75
  self.models_loaded['instantid'] = True
76
  except Exception as e:
77
  print(f"โš ๏ธ InstantID ControlNet not available: {e}")
 
78
  self.controlnet_instantid = None
79
  self.instantid_enabled = False
80
 
81
+ # Load depth estimator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  print("Loading depth estimator...")
83
  self.depth_estimator = transformers_pipeline(
84
  'depth-estimation',
 
95
  print(f"Initializing with single ControlNet: Depth only")
96
 
97
  # Load SDXL checkpoint from HuggingFace Hub
98
+ # NOTE: VAE is bundled in the checkpoint, don't load separately!
99
+ print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
100
  try:
101
  model_path = hf_hub_download(
102
  repo_id=MODEL_REPO,
 
106
  self.pipe = StableDiffusionXLControlNetPipeline.from_single_file(
107
  model_path,
108
  controlnet=controlnets,
 
109
  torch_dtype=self.dtype,
110
  use_safetensors=True
111
  ).to(self.device)
112
+ print("โœ“ Custom checkpoint loaded successfully (VAE bundled)")
113
  self.models_loaded['custom_checkpoint'] = True
114
  except Exception as e:
115
  print(f"โš ๏ธ Could not load custom checkpoint: {e}")
 
117
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
118
  "stabilityai/stable-diffusion-xl-base-1.0",
119
  controlnet=controlnets,
 
120
  torch_dtype=self.dtype,
121
  use_safetensors=True
122
  ).to(self.device)
 
131
  repo_type="model"
132
  )
133
  self.pipe.load_lora_weights(lora_path)
134
+ print(f"โœ“ LORA loaded successfully")
135
+ print(f" Trigger word: '{TRIGGER_WORD}'")
136
  self.models_loaded['lora'] = True
137
  except Exception as e:
138
  print(f"โš ๏ธ Could not load LORA: {e}")
 
139
  self.models_loaded['lora'] = False
140
 
141
+ # CRITICAL: Use LCM Scheduler for this model!
142
+ print("Setting up LCM scheduler...")
143
+ self.pipe.scheduler = LCMScheduler.from_config(
144
  self.pipe.scheduler.config
145
  )
146
 
147
+ # Enable attention optimizations
 
 
 
148
  self.pipe.unet.set_attn_processor(AttnProcessor2_0())
149
 
150
+ # Try to enable xformers
151
  if self.device == "cuda":
152
  try:
153
  self.pipe.enable_xformers_memory_efficient_attention()
 
155
  except Exception as e:
156
  print(f"โš ๏ธ xformers not available: {e}")
157
 
158
+ # Set CLIP skip to 2
159
+ if hasattr(self.pipe, 'text_encoder'):
160
+ self.clip_skip = 2
161
+ print(f"โœ“ CLIP skip set to {self.clip_skip}")
162
+
163
+ # Track controlnet configuration
164
  self.using_multiple_controlnets = isinstance(controlnets, list)
165
  print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
166
 
 
170
  print(f"{model}: {status}")
171
  print("===================\n")
172
 
173
+ print("โœ“ Model initialization complete!")
174
+ print("\n=== LCM CONFIGURATION ===")
175
+ print("Scheduler: LCM")
176
+ print("Recommended Steps: 12")
177
+ print("Recommended CFG: 1.0-1.5")
178
+ print("Recommended Resolution: 896x1152 or 832x1216")
179
+ print("CLIP Skip: 2")
180
+ print(f"LORA Trigger: '{TRIGGER_WORD}'")
181
+ print("=========================\n")
182
 
183
+ def get_depth_map(self, image):
184
+ """Generate depth map from input image"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  depth = self.depth_estimator(image)
186
  depth_image = depth['depth']
187
 
188
  depth_array = np.array(depth_image)
189
 
190
+ # Normalize with percentile clipping
191
  depth_min, depth_max = np.percentile(depth_array, [2, 98])
192
  depth_normalized = np.clip((depth_array - depth_min) / (depth_max - depth_min + 1e-8), 0, 1) * 255
193
  depth_normalized = depth_normalized.astype(np.uint8)
194
 
195
+ # Slight blur to reduce noise
196
  depth_normalized = cv2.GaussianBlur(depth_normalized, (3, 3), 0)
197
 
198
+ # Convert to RGB
199
  depth_colored = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB)
200
 
201
  return Image.fromarray(depth_colored)
202
 
203
+ def calculate_optimal_size(self, original_width, original_height):
204
+ """Calculate optimal size from recommended resolutions"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  aspect_ratio = original_width / original_height
206
 
207
+ # Recommended resolutions for this model
208
+ recommended_sizes = [
209
+ (896, 1152), # Portrait
210
+ (1152, 896), # Landscape
211
+ (832, 1216), # Tall portrait
212
+ (1216, 832), # Wide landscape
213
+ (1024, 1024) # Square
214
+ ]
215
+
216
+ # Find closest matching aspect ratio
217
+ best_match = None
218
+ best_diff = float('inf')
219
+
220
+ for width, height in recommended_sizes:
221
+ rec_aspect = width / height
222
+ diff = abs(rec_aspect - aspect_ratio)
223
+ if diff < best_diff:
224
+ best_diff = diff
225
+ best_match = (width, height)
226
+
227
+ # Ensure dimensions are multiples of 8
228
+ width, height = best_match
229
+ width = (width // 8) * 8
230
+ height = (height // 8) * 8
231
+
232
+ return width, height
233
+
234
+ def add_trigger_word(self, prompt):
235
+ """Add trigger word to prompt if not present"""
236
+ if TRIGGER_WORD.lower() not in prompt.lower():
237
+ return f"{TRIGGER_WORD}, {prompt}"
238
+ return prompt
239
 
240
  def generate_retro_art(
241
  self,
242
  input_image,
243
+ prompt="retro game character, vibrant colors, detailed",
244
+ negative_prompt="blurry, low quality, ugly, distorted",
245
+ num_inference_steps=12, # LCM recommended: 12 steps
246
+ guidance_scale=1.0, # LCM recommended: 1.0-1.5
247
+ controlnet_conditioning_scale=0.8,
248
+ lora_scale=1.0,
249
  identity_preservation=0.8,
250
+ image_scale=0.2
 
251
  ):
252
+ """Generate retro art with correct LCM settings"""
253
+
254
+ # Add trigger word to prompt
255
+ prompt = self.add_trigger_word(prompt)
256
 
257
+ # Calculate optimal size
258
  original_width, original_height = input_image.size
259
+ target_width, target_height = self.calculate_optimal_size(original_width, original_height)
260
 
261
  print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}")
262
+ print(f"Prompt: {prompt}")
263
 
264
+ # Resize with high quality
265
  resized_image = input_image.resize((target_width, target_height), Image.LANCZOS)
266
 
267
+ # Generate depth map
 
 
 
 
268
  print("Generating depth map...")
269
+ depth_image = self.get_depth_map(resized_image)
270
  depth_image = depth_image.resize((target_width, target_height), Image.LANCZOS)
271
 
272
+ # Handle face detection for InstantID
273
  using_multiple_controlnets = self.using_multiple_controlnets
 
 
274
  face_embeddings = None
275
  has_detected_faces = False
276
 
277
  if using_multiple_controlnets:
278
+ print("Checking for faces...")
279
  img_array = np.array(resized_image)
280
  faces = self.face_app.get(img_array) if self.face_app is not None else []
281
 
282
  if len(faces) > 0:
283
  has_detected_faces = True
284
+ print(f"Detected {len(faces)} face(s)")
 
285
  face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
286
  face_embeddings = torch.from_numpy(face.normed_embedding).unsqueeze(0).to(self.device, dtype=self.dtype)
 
 
 
287
 
288
  # Set LORA scale
289
  if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
290
  try:
291
  self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
292
+ print(f"LORA scale: {lora_scale}")
293
  except Exception as e:
294
+ print(f"Could not set LORA scale: {e}")
 
 
 
295
 
296
+ # Prepare generation kwargs
297
  pipe_kwargs = {
298
  "prompt": prompt,
299
+ "negative_prompt": negative_prompt,
300
  "num_inference_steps": num_inference_steps,
301
  "guidance_scale": guidance_scale,
302
  "width": target_width,
 
304
  "generator": torch.Generator(device=self.device).manual_seed(42)
305
  }
306
 
307
+ # Add CLIP skip
308
+ if hasattr(self.pipe, 'text_encoder'):
309
+ pipe_kwargs["clip_skip"] = 2
310
+
311
+ # Configure ControlNet inputs
312
  if using_multiple_controlnets and has_detected_faces:
313
+ print("Using Depth + InstantID ControlNets")
314
  control_images = [depth_image, resized_image]
315
  conditioning_scales = [controlnet_conditioning_scale, image_scale]
316
 
317
  pipe_kwargs["image"] = control_images
318
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
319
 
 
320
  if face_embeddings is not None:
321
  pipe_kwargs["cross_attention_kwargs"] = {"ip_adapter_image_embeds": [face_embeddings]}
322
 
323
  elif using_multiple_controlnets and not has_detected_faces:
324
+ print("Multiple ControlNets available but no faces detected")
325
  control_images = [depth_image, depth_image]
326
  conditioning_scales = [controlnet_conditioning_scale, 0.0]
327
 
 
329
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
330
 
331
  else:
332
+ print("Using Depth ControlNet only")
333
  pipe_kwargs["image"] = depth_image
334
  pipe_kwargs["controlnet_conditioning_scale"] = controlnet_conditioning_scale
335
 
336
+ # Generate
337
+ print(f"Generating with LCM: Steps={num_inference_steps}, CFG={guidance_scale}")
 
338
  result = self.pipe(**pipe_kwargs)
339
 
340
  return result.images[0]
341
 
342
+ # Initialize converter
343
  print("Initializing RetroArt Converter...")
344
  converter = RetroArtConverter()
345
 
 
346
  @spaces.GPU
347
  def process_image(
348
  image,
 
353
  controlnet_scale,
354
  lora_scale,
355
  identity_preservation,
356
+ image_scale
 
357
  ):
358
  if image is None:
359
  return None
 
368
  controlnet_conditioning_scale=controlnet_scale,
369
  lora_scale=lora_scale,
370
  identity_preservation=identity_preservation,
371
+ image_scale=image_scale
 
372
  )
373
  return result
374
  except Exception as e:
 
377
  traceback.print_exc()
378
  raise gr.Error(f"Generation failed: {str(e)}")
379
 
380
+ # Gradio UI
381
+ with gr.Blocks(title="RetroArt Converter - LCM", theme=gr.themes.Soft()) as demo:
382
  gr.Markdown("""
383
+ # ๐ŸŽฎ RetroArt Converter (LCM Optimized)
384
 
385
+ Convert images into retro pixel art style using LCM (Latent Consistency Model) for fast, high-quality generation!
386
 
387
+ **โœจ Features:**
388
+ - โšก Ultra-fast generation (12 steps!)
389
+ - ๐ŸŽจ Custom pixel art LORA with trigger word: `p1x3l4rt, pixel art`
390
+ - ๐Ÿ“ Optimized resolutions: 896x1152 / 832x1216
391
+ - ๐Ÿ–ผ๏ธ Bundled VAE for authentic retro look
392
+ - ๐ŸŽฏ CLIP Skip 2 for better style
 
393
  """)
394
 
395
+ # Model status
396
  if converter.models_loaded:
397
+ status_text = "**๐Ÿ“ฆ Loaded Models:**\n"
398
+ status_text += f"- Custom Checkpoint (Horizon): {'โœ“ Loaded' if converter.models_loaded['custom_checkpoint'] else 'โœ— Using SDXL base'}\n"
399
+ status_text += f"- LORA (RetroArt): {'โœ“ Loaded' if converter.models_loaded['lora'] else 'โœ— Disabled'}\n"
400
+ status_text += f"- InstantID: {'โœ“ Loaded' if converter.models_loaded['instantid'] else 'โœ— Disabled'}\n"
 
401
  gr.Markdown(status_text)
402
 
403
+ gr.Markdown(f"""
404
+ **โš™๏ธ LCM Configuration:**
405
+ - Scheduler: LCM (Latent Consistency Model)
406
+ - Recommended Steps: **12** (fast!)
407
+ - Recommended CFG: **1.0-1.5** (lower than normal)
408
+ - CLIP Skip: **2**
409
+ - LORA Trigger: `{TRIGGER_WORD}` (auto-added)
410
+ """)
411
+
412
  with gr.Row():
413
  with gr.Column():
414
  input_image = gr.Image(label="Input Image", type="pil")
415
 
416
  prompt = gr.Textbox(
417
+ label="Prompt (trigger word auto-added)",
418
+ value="retro game character, vibrant colors, highly detailed",
419
+ lines=3,
420
+ info=f"'{TRIGGER_WORD}' will be automatically added"
421
  )
422
 
423
  negative_prompt = gr.Textbox(
424
  label="Negative Prompt",
425
+ value="blurry, low quality, ugly, distorted, deformed, bad anatomy",
426
  lines=2
427
  )
428
 
429
+ with gr.Accordion("โšก LCM Settings (Optimized)", open=True):
 
 
 
 
 
 
430
  steps = gr.Slider(
431
+ minimum=4,
432
+ maximum=20,
433
+ value=12,
434
+ step=1,
435
+ label="Inference Steps (LCM works great with just 12!)"
436
  )
437
 
438
  guidance_scale = gr.Slider(
439
+ minimum=0.5,
440
+ maximum=3.0,
441
+ value=1.0,
442
+ step=0.1,
443
+ label="Guidance Scale (CFG) - LCM uses 1.0-1.5"
444
  )
445
 
446
  controlnet_scale = gr.Slider(
447
+ minimum=0.3,
448
+ maximum=1.2,
449
+ value=0.8,
450
  step=0.05,
451
+ label="ControlNet Depth Scale"
452
  )
453
 
454
  lora_scale = gr.Slider(
455
+ minimum=0.5,
456
+ maximum=1.5,
457
+ value=1.0,
458
  step=0.05,
459
  label="RetroArt LORA Scale"
460
  )
461
 
462
+ with gr.Accordion("๐ŸŽญ Identity Settings (for portraits)", open=False):
463
  identity_preservation = gr.Slider(
464
  minimum=0,
465
  maximum=1.5,
 
482
  output_image = gr.Image(label="Retro Art Output")
483
 
484
  gr.Markdown("""
485
+ ### ๐Ÿ’ก Tips for Best Results:
486
+
487
+ **For LCM Models:**
488
+ - โœ… Use **12 steps** (already optimized!)
489
+ - โœ… Keep CFG at **1.0-1.5** (not 7.5!)
490
+ - โœ… LORA trigger word is **auto-added**
491
+ - โœ… Resolution auto-optimized to 896x1152 or 832x1216
492
+
493
+ **For Quality:**
494
+ - Use high-resolution input images
495
+ - Be specific in prompts: "16-bit game character" vs "character"
496
+ - Adjust ControlNet scale: lower = more creative, higher = more faithful
497
+
498
+ **For Style:**
499
+ - Increase LORA scale (1.0-1.5) for stronger pixel art effect
500
+ - Try prompts like: "SNES style", "16-bit RPG", "Game Boy advance style"
501
  """)
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  generate_btn.click(
504
  fn=process_image,
505
  inputs=[
506
+ input_image, prompt, negative_prompt, steps, guidance_scale,
507
+ controlnet_scale, lora_scale, identity_preservation, image_scale
508
  ],
509
  outputs=[output_image]
510
  )
511
 
 
512
  if __name__ == "__main__":
513
  demo.queue(max_size=20)
514
  demo.launch(