pixagram-dev / app.py
primerz's picture
Update app.py
d2e5a40 verified
raw
history blame
24.6 kB
import spaces # MUST be first, before any CUDA-related imports
import gradio as gr
import torch
from diffusers import (
StableDiffusionXLControlNetImg2ImgPipeline, # Changed to img2img
ControlNetModel,
AutoencoderKL,
LCMScheduler,
DPMSolverMultistepScheduler
)
from diffusers.models.attention_processor import AttnProcessor2_0
from insightface.app import FaceAnalysis
from PIL import Image
import numpy as np
import cv2
import math
from controlnet_aux import ZoeDetector # Better depth detection
from huggingface_hub import hf_hub_download
import os
# Configuration
MODEL_REPO = "primerz/pixagram"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# LORA trigger word
TRIGGER_WORD = "p1x3l4rt, pixel art"
# Use LCM or DPM++ scheduler
USE_LCM = True # Set to False to use DPM++ 2M Karras
print(f"Using device: {device}")
print(f"Loading models from: {MODEL_REPO}")
print(f"LORA Trigger Word: {TRIGGER_WORD}")
print(f"Scheduler: {'LCM' if USE_LCM else 'DPM++ 2M Karras'}")
def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
"""Draw facial keypoints on image for InstantID ControlNet"""
stickwidth = 4
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
kps = np.array(kps)
w, h = image_pil.size
out_img = np.zeros([h, w, 3])
for i in range(len(limbSeq)):
index = limbSeq[i]
color = color_list[index[0]]
x = kps[index][:, 0]
y = kps[index][:, 1]
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
polygon = cv2.ellipse2Poly(
(int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
)
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
out_img = (out_img * 0.6).astype(np.uint8)
for idx_kp, kp in enumerate(kps):
color = color_list[idx_kp]
x, y = kp
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
out_img_pil = Image.fromarray(out_img.astype(np.uint8))
return out_img_pil
class RetroArtConverter:
def __init__(self):
self.device = device
self.dtype = dtype
self.use_lcm = USE_LCM
self.models_loaded = {
'custom_checkpoint': False,
'lora': False,
'instantid': False,
'zoe_depth': False
}
# Initialize face analysis for InstantID
print("Loading face analysis model...")
try:
self.face_app = FaceAnalysis(
name='antelopev2',
root='./models/insightface',
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.face_app.prepare(ctx_id=0, det_size=(640, 640))
print("โœ“ Face analysis model loaded successfully")
self.face_detection_enabled = True
except Exception as e:
print(f"โš ๏ธ Face detection not available: {e}")
self.face_app = None
self.face_detection_enabled = False
# Load Zoe Depth detector (better than DPT)
print("Loading Zoe Depth detector...")
try:
self.zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
self.zoe_depth.to(self.device)
print("โœ“ Zoe Depth loaded successfully")
self.models_loaded['zoe_depth'] = True
except Exception as e:
print(f"โš ๏ธ Zoe Depth not available: {e}")
self.zoe_depth = None
# Load ControlNet for depth
print("Loading ControlNet Zoe Depth model...")
self.controlnet_depth = ControlNetModel.from_pretrained(
"diffusers/controlnet-zoe-depth-sdxl-1.0",
torch_dtype=self.dtype
).to(self.device)
# Load InstantID ControlNet
print("Loading InstantID ControlNet...")
try:
self.controlnet_instantid = ControlNetModel.from_pretrained(
"InstantX/InstantID",
subfolder="ControlNetModel",
torch_dtype=self.dtype
).to(self.device)
print("โœ“ InstantID ControlNet loaded successfully")
self.instantid_enabled = True
self.models_loaded['instantid'] = True
except Exception as e:
print(f"โš ๏ธ InstantID ControlNet not available: {e}")
self.controlnet_instantid = None
self.instantid_enabled = False
# Determine which controlnets to use
if self.instantid_enabled and self.controlnet_instantid is not None:
controlnets = [self.controlnet_instantid, self.controlnet_depth]
print(f"Initializing with multiple ControlNets: InstantID + Depth")
else:
controlnets = self.controlnet_depth
print(f"Initializing with single ControlNet: Depth only")
# Load SDXL checkpoint from HuggingFace Hub
print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
try:
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename="horizon.safetensors",
repo_type="model"
)
# Use Img2Img pipeline
self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
model_path,
controlnet=controlnets,
torch_dtype=self.dtype,
use_safetensors=True
).to(self.device)
print("โœ“ Custom checkpoint loaded successfully (VAE bundled)")
self.models_loaded['custom_checkpoint'] = True
except Exception as e:
print(f"โš ๏ธ Could not load custom checkpoint: {e}")
print("Using default SDXL base model")
self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnets,
torch_dtype=self.dtype,
use_safetensors=True
).to(self.device)
self.models_loaded['custom_checkpoint'] = False
# Load LORA from HuggingFace Hub
print("Loading LORA (retroart) from HuggingFace Hub...")
try:
lora_path = hf_hub_download(
repo_id=MODEL_REPO,
filename="retroart.safetensors",
repo_type="model"
)
self.pipe.load_lora_weights(lora_path)
print(f"โœ“ LORA loaded successfully")
print(f" Trigger word: '{TRIGGER_WORD}'")
self.models_loaded['lora'] = True
except Exception as e:
print(f"โš ๏ธ Could not load LORA: {e}")
self.models_loaded['lora'] = False
# Setup scheduler based on USE_LCM flag
if self.use_lcm:
print("Setting up LCM scheduler...")
self.pipe.scheduler = LCMScheduler.from_config(
self.pipe.scheduler.config
)
else:
print("Setting up DPM++ 2M Karras scheduler...")
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config,
use_karras_sigmas=True
)
# Enable attention optimizations
self.pipe.unet.set_attn_processor(AttnProcessor2_0())
# Try to enable xformers
if self.device == "cuda":
try:
self.pipe.enable_xformers_memory_efficient_attention()
print("โœ“ xformers enabled")
except Exception as e:
print(f"โš ๏ธ xformers not available: {e}")
# Set CLIP skip to 2
if hasattr(self.pipe, 'text_encoder'):
self.clip_skip = 2
print(f"โœ“ CLIP skip set to {self.clip_skip}")
# Track controlnet configuration
self.using_multiple_controlnets = isinstance(controlnets, list)
print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
print("\n=== MODEL STATUS ===")
for model, loaded in self.models_loaded.items():
status = "โœ“ LOADED" if loaded else "โœ— FALLBACK"
print(f"{model}: {status}")
print("===================\n")
print("โœ“ Model initialization complete!")
print("\n=== CONFIGURATION ===")
print(f"Scheduler: {'LCM' if self.use_lcm else 'DPM++ 2M Karras'}")
if self.use_lcm:
print("Recommended Steps: 12")
print("Recommended CFG: 1.0-1.5")
else:
print("Recommended Steps: 30-50")
print("Recommended CFG: 7.0-8.0")
print("Recommended Resolution: 896x1152 or 832x1216")
print("CLIP Skip: 2")
print(f"LORA Trigger: '{TRIGGER_WORD}'")
print("=====================\n")
def get_depth_map(self, image):
"""Generate depth map using Zoe Depth"""
if self.zoe_depth is not None:
try:
# Ensure clean PIL Image to avoid numpy type issues in ZoeDepth
# Convert to RGB explicitly to ensure proper format
if image.mode != 'RGB':
image = image.convert('RGB')
# Get dimensions and ensure they're Python ints
width, height = image.size
width, height = int(width), int(height)
# Create a fresh image to avoid any numpy type contamination
# This fixes the nn.functional.interpolate numpy.int64 error
image_array = np.array(image)
clean_image = Image.fromarray(image_array.astype(np.uint8))
# Use Zoe detector
depth_image = self.zoe_depth(clean_image)
return depth_image
except Exception as e:
print(f"Warning: ZoeDetector failed ({e}), falling back to grayscale depth")
# Fallback if ZoeDetector fails
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
return Image.fromarray(depth_colored)
else:
# Fallback to simple grayscale
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
return Image.fromarray(depth_colored)
def calculate_optimal_size(self, original_width, original_height):
"""Calculate optimal size from recommended resolutions"""
aspect_ratio = original_width / original_height
# Recommended resolutions for this model
recommended_sizes = [
(896, 1152), # Portrait
(1152, 896), # Landscape
(832, 1216), # Tall portrait
(1216, 832), # Wide landscape
(1024, 1024) # Square
]
# Find closest matching aspect ratio
best_match = None
best_diff = float('inf')
for width, height in recommended_sizes:
rec_aspect = width / height
diff = abs(rec_aspect - aspect_ratio)
if diff < best_diff:
best_diff = diff
best_match = (width, height)
# Ensure dimensions are multiples of 8 and explicitly convert to Python int
width, height = best_match
width = int((width // 8) * 8)
height = int((height // 8) * 8)
return width, height
def add_trigger_word(self, prompt):
"""Add trigger word to prompt if not present"""
if TRIGGER_WORD.lower() not in prompt.lower():
return f"{TRIGGER_WORD}, {prompt}"
return prompt
def generate_retro_art(
self,
input_image,
prompt="retro game character, vibrant colors, detailed",
negative_prompt="blurry, low quality, ugly, distorted",
num_inference_steps=12,
guidance_scale=1.0,
controlnet_conditioning_scale=0.8,
lora_scale=1.0,
identity_preservation=0.8,
strength=0.75 # img2img strength
):
"""Generate retro art with img2img pipeline"""
# Add trigger word to prompt
prompt = self.add_trigger_word(prompt)
# Calculate optimal size
original_width, original_height = input_image.size
target_width, target_height = self.calculate_optimal_size(original_width, original_height)
print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}")
print(f"Prompt: {prompt}")
print(f"Img2Img Strength: {strength}")
# Resize with high quality - ensure dimensions are Python ints
resized_image = input_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
# Generate depth map using Zoe
print("Generating Zoe depth map...")
depth_image = self.get_depth_map(resized_image)
if depth_image.size != (target_width, target_height):
depth_image = depth_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
# Handle face detection for InstantID
using_multiple_controlnets = self.using_multiple_controlnets
face_kps_image = None
face_embeddings = None
has_detected_faces = False
if using_multiple_controlnets and self.face_app is not None:
print("Detecting faces and extracting keypoints...")
img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
faces = self.face_app.get(img_array)
if len(faces) > 0:
has_detected_faces = True
print(f"Detected {len(faces)} face(s)")
# Get largest face
face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
# Extract face embeddings
face_embeddings = face.normed_embedding
# Draw keypoints
face_kps = face.kps
face_kps_image = draw_kps(resized_image, face_kps)
print(f"Face info: bbox={face.bbox}, age={face.age if hasattr(face, 'age') else 'N/A'}, gender={'M' if face.gender == 1 else 'F' if hasattr(face, 'gender') else 'N/A'}")
# Set LORA scale
if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
try:
self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
print(f"LORA scale: {lora_scale}")
except Exception as e:
print(f"Could not set LORA scale: {e}")
# Prepare generation kwargs
pipe_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": resized_image, # img2img source
"strength": strength, # how much to transform
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"generator": torch.Generator(device=self.device).manual_seed(42)
}
# Add CLIP skip
if hasattr(self.pipe, 'text_encoder'):
pipe_kwargs["clip_skip"] = 2
# Configure ControlNet inputs
if using_multiple_controlnets and has_detected_faces and face_kps_image is not None:
print("Using InstantID (keypoints) + Depth ControlNets")
# Order: [InstantID, Depth]
control_images = [face_kps_image, depth_image]
conditioning_scales = [identity_preservation, controlnet_conditioning_scale]
pipe_kwargs["control_image"] = control_images
pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
elif using_multiple_controlnets and not has_detected_faces:
print("Multiple ControlNets available but no faces detected, using depth only")
# Use depth for both to avoid errors
control_images = [depth_image, depth_image]
conditioning_scales = [0.0, controlnet_conditioning_scale]
pipe_kwargs["control_image"] = control_images
pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
else:
print("Using Depth ControlNet only")
pipe_kwargs["control_image"] = depth_image
pipe_kwargs["controlnet_conditioning_scale"] = controlnet_conditioning_scale
# Generate
scheduler_name = "LCM" if self.use_lcm else "DPM++"
print(f"Generating with {scheduler_name}: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
result = self.pipe(**pipe_kwargs)
return result.images[0]
# Initialize converter
print("Initializing RetroArt Converter...")
converter = RetroArtConverter()
@spaces.GPU
def process_image(
image,
prompt,
negative_prompt,
steps,
guidance_scale,
controlnet_scale,
lora_scale,
identity_preservation,
strength
):
if image is None:
return None
try:
result = converter.generate_retro_art(
input_image=image,
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=int(steps),
guidance_scale=guidance_scale,
controlnet_conditioning_scale=controlnet_scale,
lora_scale=lora_scale,
identity_preservation=identity_preservation,
strength=strength
)
return result
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
raise gr.Error(f"Generation failed: {str(e)}")
# Gradio UI
with gr.Blocks(title="RetroArt Converter - Img2Img", theme=gr.themes.Soft()) as demo:
gr.Markdown(f"""
# ๐ŸŽฎ RetroArt Converter (Img2Img + InstantID)
Convert images into retro pixel art style using img2img with face preservation!
**โœจ Features:**
- ๐Ÿ–ผ๏ธ **True Img2Img**: Transforms your image while preserving structure
- ๐Ÿ‘ค **InstantID**: Facial keypoint detection with age/gender detection
- ๐ŸŽจ Custom pixel art LORA with trigger word: `{TRIGGER_WORD}`
- ๐Ÿ”๏ธ **Zoe Depth**: Better depth map quality
- โšก **{'LCM' if USE_LCM else 'DPM++ 2M Karras'}** scheduler
- ๐Ÿ“ Optimized resolutions: 896x1152 / 832x1216
- ๐ŸŽฏ CLIP Skip 2 for better style
""")
# Model status
if converter.models_loaded:
status_text = "**๐Ÿ“ฆ Loaded Models:**\n"
status_text += f"- Custom Checkpoint (Horizon): {'โœ“ Loaded' if converter.models_loaded['custom_checkpoint'] else 'โœ— Using SDXL base'}\n"
status_text += f"- LORA (RetroArt): {'โœ“ Loaded' if converter.models_loaded['lora'] else 'โœ— Disabled'}\n"
status_text += f"- InstantID: {'โœ“ Loaded' if converter.models_loaded['instantid'] else 'โœ— Disabled'}\n"
status_text += f"- Zoe Depth: {'โœ“ Loaded' if converter.models_loaded['zoe_depth'] else 'โœ— Fallback'}\n"
gr.Markdown(status_text)
scheduler_info = f"""
**โš™๏ธ Configuration:**
- Pipeline: **Img2Img** (better structure preservation)
- Scheduler: **{'LCM' if USE_LCM else 'DPM++ 2M Karras'}**
- Recommended Steps: **{12 if USE_LCM else '30-50'}**
- Recommended CFG: **{1.0 if USE_LCM else '7.0-8.0'}**
- CLIP Skip: **2**
- LORA Trigger: `{TRIGGER_WORD}` (auto-added)
- Face Detection: **Age & Gender detection enabled**
"""
gr.Markdown(scheduler_info)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
prompt = gr.Textbox(
label="Prompt (trigger word auto-added)",
value=" ",
lines=3,
info=f"'{TRIGGER_WORD}' will be automatically added"
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=" ",
lines=2
)
with gr.Accordion(f"โšก {'LCM' if USE_LCM else 'DPM++'} Settings", open=True):
steps = gr.Slider(
minimum=4,
maximum=50,
value=12 if USE_LCM else 30,
step=1,
label=f"Inference Steps ({'LCM works with 12' if USE_LCM else 'DPM++ uses 30-50'})"
)
guidance_scale = gr.Slider(
minimum=0.5,
maximum=15.0,
value=1.5 if USE_LCM else 7.5,
step=0.1,
label=f"Guidance Scale (CFG) - {'LCM uses 1.0-2.0' if USE_LCM else 'DPM++ uses 7-8'}"
)
strength = gr.Slider(
minimum=0.3,
maximum=0.95,
value=0.50,
step=0.05,
label="Img2Img Strength (how much to transform)"
)
controlnet_scale = gr.Slider(
minimum=0.3,
maximum=1.2,
value=0.75,
step=0.05,
label="Zoe Depth ControlNet Scale"
)
lora_scale = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.25,
step=0.05,
label="RetroArt LORA Scale"
)
with gr.Accordion("๐Ÿ‘ค InstantID Settings (for portraits)", open=False):
identity_preservation = gr.Slider(
minimum=0,
maximum=1.5,
value=1.0,
step=0.1,
label="Identity/Keypoint Preservation"
)
generate_btn = gr.Button("๐ŸŽจ Generate Retro Art", variant="primary", size="lg")
with gr.Column():
output_image = gr.Image(label="Retro Art Output")
gr.Markdown(f"""
### ๐Ÿ’ก Tips for Best Results:
**For Img2Img:**
- โœ… **Strength 0.7-0.8**: Good balance of transformation and structure
- โœ… **Strength 0.5-0.6**: More faithful to original
- โœ… **Strength 0.8-0.9**: More creative/stylized
**For {'LCM' if USE_LCM else 'DPM++'}:**
- {'โœ… Use **12 steps** (optimized for speed)' if USE_LCM else 'โœ… Use **30-50 steps** (better quality)'}
- {'โœ… Keep CFG at **1.0-2.0**' if USE_LCM else 'โœ… Keep CFG at **7.0-8.0**'}
- โœ… LORA trigger word is **auto-added**
- โœ… Resolution auto-optimized to 896x1152 or 832x1216
**For Portraits:**
- The system detects **age and gender** automatically
- Facial **keypoints** are used for better face preservation
- Adjust Identity Preservation: lower = more stylized, higher = more realistic face
**For Quality:**
- Use high-resolution input images
- Be specific in prompts: "16-bit game character" vs "character"
- Adjust Depth scale: lower = more creative, higher = more faithful depth
**For Style:**
- Increase LORA scale (1.0-1.5) for stronger pixel art effect
- Try prompts like: "SNES style", "16-bit RPG", "Game Boy advance style"
""")
generate_btn.click(
fn=process_image,
inputs=[
input_image, prompt, negative_prompt, steps, guidance_scale,
controlnet_scale, lora_scale, identity_preservation, strength
],
outputs=[output_image]
)
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_api=True
)