pixagram-dev / app.py
primerz's picture
Update app.py
ee4fca1 verified
raw
history blame
23.7 kB
import spaces # MUST be first, before any CUDA-related imports
import gradio as gr
import torch
from diffusers import (
ControlNetModel,
AutoencoderKL,
DPMSolverMultistepScheduler,
LCMScheduler
)
from diffusers.models.attention_processor import AttnProcessor2_0
from insightface.app import FaceAnalysis
from PIL import Image
import numpy as np
import cv2
from huggingface_hub import hf_hub_download
import os
# Import the custom img2img pipeline with InstantID
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
# Import ZoeDetector for better depth maps
from controlnet_aux import ZoeDetector
# Configuration
MODEL_REPO = "primerz/pixagram"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# LORA trigger word
TRIGGER_WORD = "p1x3l4rt, pixel art"
print(f"Using device: {device}")
print(f"Loading models from: {MODEL_REPO}")
print(f"LORA Trigger Word: {TRIGGER_WORD}")
class RetroArtConverter:
def __init__(self, use_lcm=False):
self.device = device
self.dtype = dtype
self.use_lcm = use_lcm
self.models_loaded = {
'custom_checkpoint': False,
'lora': False,
'instantid': False
}
# Initialize face analysis for InstantID
print("Loading face analysis model (antelopev2)...")
try:
self.face_app = FaceAnalysis(
name='antelopev2',
root='./models/insightface',
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.face_app.prepare(ctx_id=0, det_size=(640, 640))
print("✓ Face analysis model loaded successfully")
self.face_detection_enabled = True
except Exception as e:
print(f"⚠️ Face detection not available: {e}")
self.face_app = None
self.face_detection_enabled = False
# Load ControlNet for InstantID
print("Loading InstantID ControlNet...")
try:
self.controlnet_instantid = ControlNetModel.from_pretrained(
"InstantX/InstantID",
subfolder="ControlNetModel",
torch_dtype=self.dtype
).to(self.device)
print("✓ InstantID ControlNet loaded successfully")
self.instantid_enabled = True
self.models_loaded['instantid'] = True
except Exception as e:
print(f"⚠️ InstantID ControlNet not available: {e}")
self.controlnet_instantid = None
self.instantid_enabled = False
# Load ControlNet for Zoe depth
print("Loading Zoe Depth ControlNet...")
self.controlnet_depth = ControlNetModel.from_pretrained(
"diffusers/controlnet-zoe-depth-sdxl-1.0",
torch_dtype=self.dtype
).to(self.device)
# Load Zoe depth detector (better than DPT)
print("Loading Zoe depth detector...")
try:
self.zoe_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators")
self.zoe_detector.to(self.device)
print("✓ Zoe detector loaded successfully")
except Exception as e:
print(f"⚠️ Could not load Zoe detector: {e}")
self.zoe_detector = None
# Determine which controlnets to use
if self.instantid_enabled and self.controlnet_instantid is not None:
controlnets = [self.controlnet_instantid, self.controlnet_depth]
print(f"Initializing with multiple ControlNets: InstantID + Zoe Depth")
else:
controlnets = self.controlnet_depth
print(f"Initializing with single ControlNet: Zoe Depth only")
# Load VAE
print("Loading VAE...")
self.vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=self.dtype
).to(self.device)
# Load SDXL checkpoint from HuggingFace Hub
print("Loading SDXL checkpoint (horizon) from HuggingFace Hub...")
try:
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename="horizon.safetensors",
repo_type="model"
)
# Use the custom img2img pipeline for better results
self.pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
model_path,
controlnet=controlnets,
vae=self.vae,
torch_dtype=self.dtype,
use_safetensors=True
).to(self.device)
print("✓ Custom checkpoint loaded successfully")
self.models_loaded['custom_checkpoint'] = True
except Exception as e:
print(f"⚠️ Could not load custom checkpoint: {e}")
print("Using default SDXL base model")
self.pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnets,
vae=self.vae,
torch_dtype=self.dtype,
use_safetensors=True
).to(self.device)
self.models_loaded['custom_checkpoint'] = False
# Load InstantID IP-Adapter
if self.instantid_enabled:
print("Loading InstantID IP-Adapter...")
try:
ip_adapter_path = hf_hub_download(
repo_id="InstantX/InstantID",
filename="ip-adapter.bin"
)
self.pipe.load_ip_adapter_instantid(ip_adapter_path)
self.pipe.set_ip_adapter_scale(0.8)
print("✓ InstantID IP-Adapter loaded successfully")
except Exception as e:
print(f"⚠️ Could not load IP-Adapter: {e}")
# Load LORA from HuggingFace Hub
print("Loading LORA (retroart) from HuggingFace Hub...")
try:
lora_path = hf_hub_download(
repo_id=MODEL_REPO,
filename="retroart.safetensors",
repo_type="model"
)
self.pipe.load_lora_weights(lora_path)
print(f"✓ LORA loaded successfully")
print(f" Trigger word: '{TRIGGER_WORD}'")
self.models_loaded['lora'] = True
except Exception as e:
print(f"⚠️ Could not load LORA: {e}")
self.models_loaded['lora'] = False
# Choose scheduler based on mode
if use_lcm:
print("Setting up LCM scheduler for fast generation...")
self.pipe.scheduler = LCMScheduler.from_config(
self.pipe.scheduler.config
)
else:
print("Setting up DPMSolverMultistep scheduler with Karras sigmas for quality...")
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config,
use_karras_sigmas=True
)
# Enable attention optimizations
self.pipe.unet.set_attn_processor(AttnProcessor2_0())
# Try to enable xformers
if self.device == "cuda":
try:
self.pipe.enable_xformers_memory_efficient_attention()
print("✓ xformers enabled")
except Exception as e:
print(f"⚠️ xformers not available: {e}")
# Track controlnet configuration
self.using_multiple_controlnets = isinstance(controlnets, list)
print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
print("\n=== MODEL STATUS ===")
for model, loaded in self.models_loaded.items():
status = "✓ LOADED" if loaded else "✗ FALLBACK"
print(f"{model}: {status}")
print("===================\n")
print("✓ Model initialization complete!")
if use_lcm:
print("\n=== LCM CONFIGURATION ===")
print("Scheduler: LCM")
print("Recommended Steps: 8-12")
print("Recommended CFG: 1.0-1.5")
print("Recommended Strength: 0.6-0.8")
else:
print("\n=== QUALITY CONFIGURATION ===")
print("Scheduler: DPMSolverMultistep + Karras")
print("Recommended Steps: 25-40")
print("Recommended CFG: 5.0-7.5")
print("Recommended Strength: 0.4-0.7")
print(f"LORA Trigger: '{TRIGGER_WORD}'")
print("=========================\n")
def get_depth_map(self, image):
"""Generate depth map from input image using Zoe"""
if self.zoe_detector is not None:
# Use Zoe detector for better depth maps
depth_image = self.zoe_detector(image)
return depth_image
else:
# Fallback to basic conversion
img_array = np.array(image.convert('L'))
depth_colored = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
return Image.fromarray(depth_colored)
def calculate_optimal_size(self, original_width, original_height):
"""Calculate optimal size from recommended resolutions"""
aspect_ratio = original_width / original_height
# Recommended resolutions for SDXL
recommended_sizes = [
(896, 1152), # Portrait
(1152, 896), # Landscape
(832, 1216), # Tall portrait
(1216, 832), # Wide landscape
(1024, 1024) # Square
]
# Find closest matching aspect ratio
best_match = None
best_diff = float('inf')
for width, height in recommended_sizes:
rec_aspect = width / height
diff = abs(rec_aspect - aspect_ratio)
if diff < best_diff:
best_diff = diff
best_match = (width, height)
# Ensure dimensions are multiples of 8
width, height = best_match
width = (width // 8) * 8
height = (height // 8) * 8
return width, height
def add_trigger_word(self, prompt):
"""Add trigger word to prompt if not present"""
if TRIGGER_WORD.lower() not in prompt.lower():
return f"{TRIGGER_WORD}, {prompt}"
return prompt
def generate_retro_art(
self,
input_image,
prompt="retro game character, vibrant colors, detailed",
negative_prompt="blurry, low quality, ugly, distorted",
num_inference_steps=25,
guidance_scale=5.0,
strength=0.6, # img2img strength
controlnet_conditioning_scale=0.8,
lora_scale=1.0,
face_strength=0.85, # InstantID face strength
depth_control_scale=0.8 # Zoe depth strength
):
"""Generate retro art using img2img pipeline with face keypoints"""
# Add trigger word to prompt
prompt = self.add_trigger_word(prompt)
# Calculate optimal size
original_width, original_height = input_image.size
target_width, target_height = self.calculate_optimal_size(original_width, original_height)
print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}")
print(f"Prompt: {prompt}")
# Resize with high quality
resized_image = input_image.resize((target_width, target_height), Image.LANCZOS)
# Generate depth map using Zoe
print("Generating Zoe depth map...")
depth_image = self.get_depth_map(resized_image)
if depth_image.size != (target_width, target_height):
depth_image = depth_image.resize((target_width, target_height), Image.LANCZOS)
# Handle face detection for InstantID
using_multiple_controlnets = self.using_multiple_controlnets
face_kps = None
face_embeddings = None
has_detected_faces = False
if using_multiple_controlnets and self.face_app is not None:
print("Detecting faces and extracting keypoints...")
img_array = np.array(resized_image)
faces = self.face_app.get(img_array)
if len(faces) > 0:
has_detected_faces = True
print(f"Detected {len(faces)} face(s)")
# Get the largest face
face = sorted(faces,
key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
# Extract face embeddings
face_embeddings = torch.from_numpy(face.normed_embedding).unsqueeze(0).to(
self.device, dtype=self.dtype
)
# Draw keypoints (this shows age, gender, expression)
face_kps = draw_kps(resized_image, face.kps)
print(f"Face keypoints drawn (age/gender/expression preserved)")
else:
print("No faces detected in image")
# Set LORA scale
if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
try:
self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
print(f"LORA scale: {lora_scale}")
except Exception as e:
print(f"Could not set LORA scale: {e}")
# Prepare generation kwargs
pipe_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": resized_image, # Original image for img2img
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"strength": strength, # img2img denoising strength
"generator": torch.Generator(device=self.device).manual_seed(42)
}
# Configure ControlNet inputs
if using_multiple_controlnets and has_detected_faces and face_kps is not None:
print("Using InstantID + Zoe Depth ControlNets with face keypoints")
control_images = [face_kps, depth_image]
conditioning_scales = [face_strength, depth_control_scale]
pipe_kwargs["control_image"] = control_images
pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
# Add face embeddings through IP-Adapter
if face_embeddings is not None and hasattr(self.pipe, 'set_ip_adapter_scale'):
pipe_kwargs["ip_adapter_image_embeds"] = [face_embeddings]
elif using_multiple_controlnets:
print("Multiple ControlNets available but no faces detected - using depth only")
# Use depth for both to maintain structure
control_images = [depth_image, depth_image]
conditioning_scales = [0.0, depth_control_scale] # Disable InstantID
pipe_kwargs["control_image"] = control_images
pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
else:
print("Using Zoe Depth ControlNet only")
pipe_kwargs["control_image"] = depth_image
pipe_kwargs["controlnet_conditioning_scale"] = depth_control_scale
# Generate
mode = "LCM" if self.use_lcm else "Quality"
print(f"Generating with {mode} mode: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
result = self.pipe(**pipe_kwargs)
return result.images[0]
# Initialize converter
print("Initializing RetroArt Converter...")
print("Choose mode: LCM (fast) or Quality (better)")
converter_lcm = RetroArtConverter(use_lcm=True)
converter_quality = RetroArtConverter(use_lcm=False)
@spaces.GPU
def process_image(
image,
prompt,
negative_prompt,
steps,
guidance_scale,
strength,
controlnet_scale,
lora_scale,
face_strength,
depth_control_scale,
use_lcm_mode
):
if image is None:
return None
try:
# Choose the right converter based on mode
converter = converter_lcm if use_lcm_mode else converter_quality
result = converter.generate_retro_art(
input_image=image,
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=int(steps),
guidance_scale=guidance_scale,
strength=strength,
controlnet_conditioning_scale=controlnet_scale,
lora_scale=lora_scale,
face_strength=face_strength,
depth_control_scale=depth_control_scale
)
return result
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
raise gr.Error(f"Generation failed: {str(e)}")
# Gradio UI
with gr.Blocks(title="RetroArt Converter - Improved", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎮 RetroArt Converter (Improved with True Img2Img)
Convert images into retro pixel art style with **proper face detection** and **gender/age preservation**!
**✨ Key Improvements:**
- 🎯 **True img2img pipeline** for better structure preservation
- 👤 **draw_kps**: Detects and preserves age, gender, expression
- 🗺️ **Zoe Depth**: Superior depth estimation
- ⚡ **Dual Mode**: Fast LCM or Quality DPM++
- 🎨 Custom pixel art LORA with trigger: `p1x3l4rt, pixel art`
""")
# Model status
status_text = "**📦 Loaded Models (LCM Mode):**\n"
status_text += f"- Custom Checkpoint: {'✓ Loaded' if converter_lcm.models_loaded['custom_checkpoint'] else '✗ Using SDXL base'}\n"
status_text += f"- LORA (RetroArt): {'✓ Loaded' if converter_lcm.models_loaded['lora'] else '✗ Disabled'}\n"
status_text += f"- InstantID: {'✓ Loaded' if converter_lcm.models_loaded['instantid'] else '✗ Disabled'}\n"
gr.Markdown(status_text)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
prompt = gr.Textbox(
label="Prompt (trigger word auto-added)",
value="retro game character, vibrant colors, highly detailed",
lines=3,
info=f"'{TRIGGER_WORD}' will be automatically added"
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="blurry, low quality, ugly, distorted, deformed, bad anatomy",
lines=2
)
use_lcm_mode = gr.Checkbox(
label="Use LCM Mode (Fast)",
value=True,
info="Uncheck for Quality mode (slower but better)"
)
with gr.Accordion("⚙️ Generation Settings", open=True):
steps = gr.Slider(
minimum=4,
maximum=50,
value=12,
step=1,
label="Inference Steps (12 for LCM, 25-40 for Quality)"
)
guidance_scale = gr.Slider(
minimum=0.5,
maximum=15.0,
value=1.0,
step=0.1,
label="Guidance Scale (1.0-1.5 for LCM, 5-7.5 for Quality)"
)
strength = gr.Slider(
minimum=0.3,
maximum=1.0,
value=0.7,
step=0.05,
label="Img2Img Strength (how much to change)"
)
with gr.Accordion("🎨 Style Settings", open=True):
lora_scale = gr.Slider(
minimum=0.5,
maximum=1.5,
value=1.0,
step=0.05,
label="RetroArt LORA Scale"
)
controlnet_scale = gr.Slider(
minimum=0.3,
maximum=1.2,
value=0.8,
step=0.05,
label="Overall ControlNet Scale"
)
with gr.Accordion("👤 Face & Depth Settings", open=False):
face_strength = gr.Slider(
minimum=0,
maximum=2.0,
value=0.85,
step=0.05,
label="Face Preservation (InstantID)",
info="Higher = better face likeness"
)
depth_control_scale = gr.Slider(
minimum=0,
maximum=1.0,
value=0.8,
step=0.05,
label="Zoe Depth Control Scale",
info="Higher = more structure preservation"
)
generate_btn = gr.Button("🎨 Generate Retro Art", variant="primary", size="lg")
with gr.Column():
output_image = gr.Image(label="Retro Art Output")
gr.Markdown("""
### 💡 Tips for Best Results:
**Mode Selection:**
- ✅ **LCM Mode**: 12 steps, CFG 1.0-1.5, Strength 0.6-0.8 (⚡ fast!)
- ✅ **Quality Mode**: 25-40 steps, CFG 5-7.5, Strength 0.4-0.7 (🎨 better!)
**Face Preservation:**
- System automatically detects faces and draws keypoints
- Preserves age, gender, and expression characteristics
- Adjust "Face Preservation" slider for control
**For Best Quality:**
- Use high-resolution input images (min 512px)
- For portraits: enable Quality mode + high face strength
- For scenes: lower img2img strength for more creativity
- Adjust depth control for structure vs creativity balance
**Style Control:**
- LORA trigger word auto-added for pixel art style
- Increase LORA scale (1.2-1.5) for stronger retro effect
- Try: "SNES style", "16-bit RPG", "Game Boy advance style"
""")
# Update defaults when switching modes
def update_mode_defaults(use_lcm):
if use_lcm:
return (
gr.update(value=12), # steps
gr.update(value=1.0), # guidance_scale
gr.update(value=0.7) # strength
)
else:
return (
gr.update(value=30), # steps
gr.update(value=6.0), # guidance_scale
gr.update(value=0.6) # strength
)
use_lcm_mode.change(
fn=update_mode_defaults,
inputs=[use_lcm_mode],
outputs=[steps, guidance_scale, strength]
)
generate_btn.click(
fn=process_image,
inputs=[
input_image, prompt, negative_prompt, steps, guidance_scale, strength,
controlnet_scale, lora_scale, face_strength, depth_control_scale, use_lcm_mode
],
outputs=[output_image]
)
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_api=True
)