Spaces:
Runtime error
Runtime error
| import spaces # MUST be first, before any CUDA-related imports | |
| import gradio as gr | |
| import torch | |
| from diffusers import ( | |
| ControlNetModel, | |
| AutoencoderKL, | |
| DPMSolverMultistepScheduler, | |
| LCMScheduler | |
| ) | |
| from diffusers.models.attention_processor import AttnProcessor2_0 | |
| from insightface.app import FaceAnalysis | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # Import the custom img2img pipeline with InstantID | |
| from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps | |
| # Import ZoeDetector for better depth maps | |
| from controlnet_aux import ZoeDetector | |
| # Configuration | |
| MODEL_REPO = "primerz/pixagram" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| # LORA trigger word | |
| TRIGGER_WORD = "p1x3l4rt, pixel art" | |
| print(f"Using device: {device}") | |
| print(f"Loading models from: {MODEL_REPO}") | |
| print(f"LORA Trigger Word: {TRIGGER_WORD}") | |
| class RetroArtConverter: | |
| def __init__(self, use_lcm=False): | |
| self.device = device | |
| self.dtype = dtype | |
| self.use_lcm = use_lcm | |
| self.models_loaded = { | |
| 'custom_checkpoint': False, | |
| 'lora': False, | |
| 'instantid': False | |
| } | |
| # Initialize face analysis for InstantID | |
| print("Loading face analysis model (antelopev2)...") | |
| try: | |
| self.face_app = FaceAnalysis( | |
| name='antelopev2', | |
| root='./models/insightface', | |
| providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| ) | |
| self.face_app.prepare(ctx_id=0, det_size=(640, 640)) | |
| print("✓ Face analysis model loaded successfully") | |
| self.face_detection_enabled = True | |
| except Exception as e: | |
| print(f"⚠️ Face detection not available: {e}") | |
| self.face_app = None | |
| self.face_detection_enabled = False | |
| # Load ControlNet for InstantID | |
| print("Loading InstantID ControlNet...") | |
| try: | |
| self.controlnet_instantid = ControlNetModel.from_pretrained( | |
| "InstantX/InstantID", | |
| subfolder="ControlNetModel", | |
| torch_dtype=self.dtype | |
| ).to(self.device) | |
| print("✓ InstantID ControlNet loaded successfully") | |
| self.instantid_enabled = True | |
| self.models_loaded['instantid'] = True | |
| except Exception as e: | |
| print(f"⚠️ InstantID ControlNet not available: {e}") | |
| self.controlnet_instantid = None | |
| self.instantid_enabled = False | |
| # Load ControlNet for Zoe depth | |
| print("Loading Zoe Depth ControlNet...") | |
| self.controlnet_depth = ControlNetModel.from_pretrained( | |
| "diffusers/controlnet-zoe-depth-sdxl-1.0", | |
| torch_dtype=self.dtype | |
| ).to(self.device) | |
| # Load Zoe depth detector (better than DPT) | |
| print("Loading Zoe depth detector...") | |
| try: | |
| self.zoe_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators") | |
| self.zoe_detector.to(self.device) | |
| print("✓ Zoe detector loaded successfully") | |
| except Exception as e: | |
| print(f"⚠️ Could not load Zoe detector: {e}") | |
| self.zoe_detector = None | |
| # Determine which controlnets to use | |
| if self.instantid_enabled and self.controlnet_instantid is not None: | |
| controlnets = [self.controlnet_instantid, self.controlnet_depth] | |
| print(f"Initializing with multiple ControlNets: InstantID + Zoe Depth") | |
| else: | |
| controlnets = self.controlnet_depth | |
| print(f"Initializing with single ControlNet: Zoe Depth only") | |
| # Load VAE | |
| print("Loading VAE...") | |
| self.vae = AutoencoderKL.from_pretrained( | |
| "madebyollin/sdxl-vae-fp16-fix", | |
| torch_dtype=self.dtype | |
| ).to(self.device) | |
| # Load SDXL checkpoint from HuggingFace Hub | |
| print("Loading SDXL checkpoint (horizon) from HuggingFace Hub...") | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="horizon.safetensors", | |
| repo_type="model" | |
| ) | |
| # Use the custom img2img pipeline for better results | |
| self.pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file( | |
| model_path, | |
| controlnet=controlnets, | |
| vae=self.vae, | |
| torch_dtype=self.dtype, | |
| use_safetensors=True | |
| ).to(self.device) | |
| print("✓ Custom checkpoint loaded successfully") | |
| self.models_loaded['custom_checkpoint'] = True | |
| except Exception as e: | |
| print(f"⚠️ Could not load custom checkpoint: {e}") | |
| print("Using default SDXL base model") | |
| self.pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| controlnet=controlnets, | |
| vae=self.vae, | |
| torch_dtype=self.dtype, | |
| use_safetensors=True | |
| ).to(self.device) | |
| self.models_loaded['custom_checkpoint'] = False | |
| # Load InstantID IP-Adapter | |
| if self.instantid_enabled: | |
| print("Loading InstantID IP-Adapter...") | |
| try: | |
| ip_adapter_path = hf_hub_download( | |
| repo_id="InstantX/InstantID", | |
| filename="ip-adapter.bin" | |
| ) | |
| self.pipe.load_ip_adapter_instantid(ip_adapter_path) | |
| self.pipe.set_ip_adapter_scale(0.8) | |
| print("✓ InstantID IP-Adapter loaded successfully") | |
| except Exception as e: | |
| print(f"⚠️ Could not load IP-Adapter: {e}") | |
| # Load LORA from HuggingFace Hub | |
| print("Loading LORA (retroart) from HuggingFace Hub...") | |
| try: | |
| lora_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="retroart.safetensors", | |
| repo_type="model" | |
| ) | |
| self.pipe.load_lora_weights(lora_path) | |
| print(f"✓ LORA loaded successfully") | |
| print(f" Trigger word: '{TRIGGER_WORD}'") | |
| self.models_loaded['lora'] = True | |
| except Exception as e: | |
| print(f"⚠️ Could not load LORA: {e}") | |
| self.models_loaded['lora'] = False | |
| # Choose scheduler based on mode | |
| if use_lcm: | |
| print("Setting up LCM scheduler for fast generation...") | |
| self.pipe.scheduler = LCMScheduler.from_config( | |
| self.pipe.scheduler.config | |
| ) | |
| else: | |
| print("Setting up DPMSolverMultistep scheduler with Karras sigmas for quality...") | |
| self.pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
| self.pipe.scheduler.config, | |
| use_karras_sigmas=True | |
| ) | |
| # Enable attention optimizations | |
| self.pipe.unet.set_attn_processor(AttnProcessor2_0()) | |
| # Try to enable xformers | |
| if self.device == "cuda": | |
| try: | |
| self.pipe.enable_xformers_memory_efficient_attention() | |
| print("✓ xformers enabled") | |
| except Exception as e: | |
| print(f"⚠️ xformers not available: {e}") | |
| # Track controlnet configuration | |
| self.using_multiple_controlnets = isinstance(controlnets, list) | |
| print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)") | |
| print("\n=== MODEL STATUS ===") | |
| for model, loaded in self.models_loaded.items(): | |
| status = "✓ LOADED" if loaded else "✗ FALLBACK" | |
| print(f"{model}: {status}") | |
| print("===================\n") | |
| print("✓ Model initialization complete!") | |
| if use_lcm: | |
| print("\n=== LCM CONFIGURATION ===") | |
| print("Scheduler: LCM") | |
| print("Recommended Steps: 8-12") | |
| print("Recommended CFG: 1.0-1.5") | |
| print("Recommended Strength: 0.6-0.8") | |
| else: | |
| print("\n=== QUALITY CONFIGURATION ===") | |
| print("Scheduler: DPMSolverMultistep + Karras") | |
| print("Recommended Steps: 25-40") | |
| print("Recommended CFG: 5.0-7.5") | |
| print("Recommended Strength: 0.4-0.7") | |
| print(f"LORA Trigger: '{TRIGGER_WORD}'") | |
| print("=========================\n") | |
| def get_depth_map(self, image): | |
| """Generate depth map from input image using Zoe""" | |
| if self.zoe_detector is not None: | |
| # Use Zoe detector for better depth maps | |
| depth_image = self.zoe_detector(image) | |
| return depth_image | |
| else: | |
| # Fallback to basic conversion | |
| img_array = np.array(image.convert('L')) | |
| depth_colored = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB) | |
| return Image.fromarray(depth_colored) | |
| def calculate_optimal_size(self, original_width, original_height): | |
| """Calculate optimal size from recommended resolutions""" | |
| aspect_ratio = original_width / original_height | |
| # Recommended resolutions for SDXL | |
| recommended_sizes = [ | |
| (896, 1152), # Portrait | |
| (1152, 896), # Landscape | |
| (832, 1216), # Tall portrait | |
| (1216, 832), # Wide landscape | |
| (1024, 1024) # Square | |
| ] | |
| # Find closest matching aspect ratio | |
| best_match = None | |
| best_diff = float('inf') | |
| for width, height in recommended_sizes: | |
| rec_aspect = width / height | |
| diff = abs(rec_aspect - aspect_ratio) | |
| if diff < best_diff: | |
| best_diff = diff | |
| best_match = (width, height) | |
| # Ensure dimensions are multiples of 8 | |
| width, height = best_match | |
| width = (width // 8) * 8 | |
| height = (height // 8) * 8 | |
| return width, height | |
| def add_trigger_word(self, prompt): | |
| """Add trigger word to prompt if not present""" | |
| if TRIGGER_WORD.lower() not in prompt.lower(): | |
| return f"{TRIGGER_WORD}, {prompt}" | |
| return prompt | |
| def generate_retro_art( | |
| self, | |
| input_image, | |
| prompt="retro game character, vibrant colors, detailed", | |
| negative_prompt="blurry, low quality, ugly, distorted", | |
| num_inference_steps=25, | |
| guidance_scale=5.0, | |
| strength=0.6, # img2img strength | |
| controlnet_conditioning_scale=0.8, | |
| lora_scale=1.0, | |
| face_strength=0.85, # InstantID face strength | |
| depth_control_scale=0.8 # Zoe depth strength | |
| ): | |
| """Generate retro art using img2img pipeline with face keypoints""" | |
| # Add trigger word to prompt | |
| prompt = self.add_trigger_word(prompt) | |
| # Calculate optimal size | |
| original_width, original_height = input_image.size | |
| target_width, target_height = self.calculate_optimal_size(original_width, original_height) | |
| print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}") | |
| print(f"Prompt: {prompt}") | |
| # Resize with high quality | |
| resized_image = input_image.resize((target_width, target_height), Image.LANCZOS) | |
| # Generate depth map using Zoe | |
| print("Generating Zoe depth map...") | |
| depth_image = self.get_depth_map(resized_image) | |
| if depth_image.size != (target_width, target_height): | |
| depth_image = depth_image.resize((target_width, target_height), Image.LANCZOS) | |
| # Handle face detection for InstantID | |
| using_multiple_controlnets = self.using_multiple_controlnets | |
| face_kps = None | |
| face_embeddings = None | |
| has_detected_faces = False | |
| if using_multiple_controlnets and self.face_app is not None: | |
| print("Detecting faces and extracting keypoints...") | |
| img_array = np.array(resized_image) | |
| faces = self.face_app.get(img_array) | |
| if len(faces) > 0: | |
| has_detected_faces = True | |
| print(f"Detected {len(faces)} face(s)") | |
| # Get the largest face | |
| face = sorted(faces, | |
| key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1] | |
| # Extract face embeddings | |
| face_embeddings = torch.from_numpy(face.normed_embedding).unsqueeze(0).to( | |
| self.device, dtype=self.dtype | |
| ) | |
| # Draw keypoints (this shows age, gender, expression) | |
| face_kps = draw_kps(resized_image, face.kps) | |
| print(f"Face keypoints drawn (age/gender/expression preserved)") | |
| else: | |
| print("No faces detected in image") | |
| # Set LORA scale | |
| if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']: | |
| try: | |
| self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale]) | |
| print(f"LORA scale: {lora_scale}") | |
| except Exception as e: | |
| print(f"Could not set LORA scale: {e}") | |
| # Prepare generation kwargs | |
| pipe_kwargs = { | |
| "prompt": prompt, | |
| "negative_prompt": negative_prompt, | |
| "image": resized_image, # Original image for img2img | |
| "num_inference_steps": num_inference_steps, | |
| "guidance_scale": guidance_scale, | |
| "strength": strength, # img2img denoising strength | |
| "generator": torch.Generator(device=self.device).manual_seed(42) | |
| } | |
| # Configure ControlNet inputs | |
| if using_multiple_controlnets and has_detected_faces and face_kps is not None: | |
| print("Using InstantID + Zoe Depth ControlNets with face keypoints") | |
| control_images = [face_kps, depth_image] | |
| conditioning_scales = [face_strength, depth_control_scale] | |
| pipe_kwargs["control_image"] = control_images | |
| pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales | |
| # Add face embeddings through IP-Adapter | |
| if face_embeddings is not None and hasattr(self.pipe, 'set_ip_adapter_scale'): | |
| pipe_kwargs["ip_adapter_image_embeds"] = [face_embeddings] | |
| elif using_multiple_controlnets: | |
| print("Multiple ControlNets available but no faces detected - using depth only") | |
| # Use depth for both to maintain structure | |
| control_images = [depth_image, depth_image] | |
| conditioning_scales = [0.0, depth_control_scale] # Disable InstantID | |
| pipe_kwargs["control_image"] = control_images | |
| pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales | |
| else: | |
| print("Using Zoe Depth ControlNet only") | |
| pipe_kwargs["control_image"] = depth_image | |
| pipe_kwargs["controlnet_conditioning_scale"] = depth_control_scale | |
| # Generate | |
| mode = "LCM" if self.use_lcm else "Quality" | |
| print(f"Generating with {mode} mode: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}") | |
| result = self.pipe(**pipe_kwargs) | |
| return result.images[0] | |
| # Initialize converter | |
| print("Initializing RetroArt Converter...") | |
| print("Choose mode: LCM (fast) or Quality (better)") | |
| converter_lcm = RetroArtConverter(use_lcm=True) | |
| converter_quality = RetroArtConverter(use_lcm=False) | |
| def process_image( | |
| image, | |
| prompt, | |
| negative_prompt, | |
| steps, | |
| guidance_scale, | |
| strength, | |
| controlnet_scale, | |
| lora_scale, | |
| face_strength, | |
| depth_control_scale, | |
| use_lcm_mode | |
| ): | |
| if image is None: | |
| return None | |
| try: | |
| # Choose the right converter based on mode | |
| converter = converter_lcm if use_lcm_mode else converter_quality | |
| result = converter.generate_retro_art( | |
| input_image=image, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=int(steps), | |
| guidance_scale=guidance_scale, | |
| strength=strength, | |
| controlnet_conditioning_scale=controlnet_scale, | |
| lora_scale=lora_scale, | |
| face_strength=face_strength, | |
| depth_control_scale=depth_control_scale | |
| ) | |
| return result | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise gr.Error(f"Generation failed: {str(e)}") | |
| # Gradio UI | |
| with gr.Blocks(title="RetroArt Converter - Improved", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🎮 RetroArt Converter (Improved with True Img2Img) | |
| Convert images into retro pixel art style with **proper face detection** and **gender/age preservation**! | |
| **✨ Key Improvements:** | |
| - 🎯 **True img2img pipeline** for better structure preservation | |
| - 👤 **draw_kps**: Detects and preserves age, gender, expression | |
| - 🗺️ **Zoe Depth**: Superior depth estimation | |
| - ⚡ **Dual Mode**: Fast LCM or Quality DPM++ | |
| - 🎨 Custom pixel art LORA with trigger: `p1x3l4rt, pixel art` | |
| """) | |
| # Model status | |
| status_text = "**📦 Loaded Models (LCM Mode):**\n" | |
| status_text += f"- Custom Checkpoint: {'✓ Loaded' if converter_lcm.models_loaded['custom_checkpoint'] else '✗ Using SDXL base'}\n" | |
| status_text += f"- LORA (RetroArt): {'✓ Loaded' if converter_lcm.models_loaded['lora'] else '✗ Disabled'}\n" | |
| status_text += f"- InstantID: {'✓ Loaded' if converter_lcm.models_loaded['instantid'] else '✗ Disabled'}\n" | |
| gr.Markdown(status_text) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="pil") | |
| prompt = gr.Textbox( | |
| label="Prompt (trigger word auto-added)", | |
| value="retro game character, vibrant colors, highly detailed", | |
| lines=3, | |
| info=f"'{TRIGGER_WORD}' will be automatically added" | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="blurry, low quality, ugly, distorted, deformed, bad anatomy", | |
| lines=2 | |
| ) | |
| use_lcm_mode = gr.Checkbox( | |
| label="Use LCM Mode (Fast)", | |
| value=True, | |
| info="Uncheck for Quality mode (slower but better)" | |
| ) | |
| with gr.Accordion("⚙️ Generation Settings", open=True): | |
| steps = gr.Slider( | |
| minimum=4, | |
| maximum=50, | |
| value=12, | |
| step=1, | |
| label="Inference Steps (12 for LCM, 25-40 for Quality)" | |
| ) | |
| guidance_scale = gr.Slider( | |
| minimum=0.5, | |
| maximum=15.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Guidance Scale (1.0-1.5 for LCM, 5-7.5 for Quality)" | |
| ) | |
| strength = gr.Slider( | |
| minimum=0.3, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.05, | |
| label="Img2Img Strength (how much to change)" | |
| ) | |
| with gr.Accordion("🎨 Style Settings", open=True): | |
| lora_scale = gr.Slider( | |
| minimum=0.5, | |
| maximum=1.5, | |
| value=1.0, | |
| step=0.05, | |
| label="RetroArt LORA Scale" | |
| ) | |
| controlnet_scale = gr.Slider( | |
| minimum=0.3, | |
| maximum=1.2, | |
| value=0.8, | |
| step=0.05, | |
| label="Overall ControlNet Scale" | |
| ) | |
| with gr.Accordion("👤 Face & Depth Settings", open=False): | |
| face_strength = gr.Slider( | |
| minimum=0, | |
| maximum=2.0, | |
| value=0.85, | |
| step=0.05, | |
| label="Face Preservation (InstantID)", | |
| info="Higher = better face likeness" | |
| ) | |
| depth_control_scale = gr.Slider( | |
| minimum=0, | |
| maximum=1.0, | |
| value=0.8, | |
| step=0.05, | |
| label="Zoe Depth Control Scale", | |
| info="Higher = more structure preservation" | |
| ) | |
| generate_btn = gr.Button("🎨 Generate Retro Art", variant="primary", size="lg") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Retro Art Output") | |
| gr.Markdown(""" | |
| ### 💡 Tips for Best Results: | |
| **Mode Selection:** | |
| - ✅ **LCM Mode**: 12 steps, CFG 1.0-1.5, Strength 0.6-0.8 (⚡ fast!) | |
| - ✅ **Quality Mode**: 25-40 steps, CFG 5-7.5, Strength 0.4-0.7 (🎨 better!) | |
| **Face Preservation:** | |
| - System automatically detects faces and draws keypoints | |
| - Preserves age, gender, and expression characteristics | |
| - Adjust "Face Preservation" slider for control | |
| **For Best Quality:** | |
| - Use high-resolution input images (min 512px) | |
| - For portraits: enable Quality mode + high face strength | |
| - For scenes: lower img2img strength for more creativity | |
| - Adjust depth control for structure vs creativity balance | |
| **Style Control:** | |
| - LORA trigger word auto-added for pixel art style | |
| - Increase LORA scale (1.2-1.5) for stronger retro effect | |
| - Try: "SNES style", "16-bit RPG", "Game Boy advance style" | |
| """) | |
| # Update defaults when switching modes | |
| def update_mode_defaults(use_lcm): | |
| if use_lcm: | |
| return ( | |
| gr.update(value=12), # steps | |
| gr.update(value=1.0), # guidance_scale | |
| gr.update(value=0.7) # strength | |
| ) | |
| else: | |
| return ( | |
| gr.update(value=30), # steps | |
| gr.update(value=6.0), # guidance_scale | |
| gr.update(value=0.6) # strength | |
| ) | |
| use_lcm_mode.change( | |
| fn=update_mode_defaults, | |
| inputs=[use_lcm_mode], | |
| outputs=[steps, guidance_scale, strength] | |
| ) | |
| generate_btn.click( | |
| fn=process_image, | |
| inputs=[ | |
| input_image, prompt, negative_prompt, steps, guidance_scale, strength, | |
| controlnet_scale, lora_scale, face_strength, depth_control_scale, use_lcm_mode | |
| ], | |
| outputs=[output_image] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_api=True | |
| ) |