import spaces import torch from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel from PIL import Image import numpy as np from typing import List, Optional # Uncomment the following line if you plan to use FP8 quantization with torchao # from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig # Gradio's Progress requires a specific import path if not using the top-level 'gr' alias from gradio.components import Progress as gr_Progress from config import MODEL_ID_BASE, MODEL_ID_CANNY, MODEL_ID_DEPTH, DEVICE, DTYPE from utils import apply_canny, apply_depth print(f"Loading ControlNet models and base pipeline on {DEVICE} with {DTYPE}...") # Initialize ControlNet models controlnet_canny_model = ControlNetModel.from_pretrained(MODEL_ID_CANNY, torch_dtype=DTYPE) controlnet_depth_model = ControlNetModel.from_pretrained(MODEL_ID_DEPTH, torch_dtype=DTYPE) # Initialize the main pipeline with multiple ControlNets pipe = StableDiffusionXLControlNetPipeline.from_pretrained( MODEL_ID_BASE, controlnet=[controlnet_canny_model, controlnet_depth_model], torch_dtype=DTYPE, # vae=vae, # Uncomment if using a custom VAE, e.g., "madebyollin/sdxl-vae-fp16-fix" ) pipe.to(DEVICE) print("Pipeline and ControlNets loaded.") # ZeroGPU Ahead-of-Time (AoT) compilation for UNet and ControlNets @spaces.GPU(duration=1500) # Maximum duration allowed for startup compilation def compile_optimized_pipeline(): print("Starting AoT compilation for UNet and ControlNets...") # Optional: Apply FP8 quantization (requires H200 GPU and torchao) # try: # # Example for UNet # quantize_(pipe.unet, Float8DynamicActivationFloat8WeightConfig()) # print("FP8 Quantization applied to UNet.") # # Example for each ControlNet # for i, cn in enumerate(pipe.controlnet): # quantize_(cn, Float8DynamicActivationFloat8WeightConfig()) # print(f"FP8 Quantization applied to ControlNet {i}.") # except Exception as e: # print(f"Could not apply FP8 quantization: {e}. Falling back to default dtype.") # 1. Compile UNet with spaces.aoti_capture(pipe.unet) as unet_call: # Dummy inputs for UNet's forward pass (common SDXL latents, timestep, text embeddings, and added conditions) dummy_latents = torch.randn(1, 4, 128, 128, device=DEVICE, dtype=DTYPE) # For 1024x1024 output dummy_timestep = torch.tensor(999, device=DEVICE, dtype=torch.long) dummy_encoder_hidden_states = torch.randn(1, 77, 2048, device=DEVICE, dtype=DTYPE) # text_embeddings dummy_added_cond_kwargs = { "text_embeds": torch.randn(1, 1280, device=DEVICE, dtype=DTYPE), # pooled_prompt_embeds "time_ids": torch.randn(1, 6, device=DEVICE, dtype=DTYPE), # original_size, crops_coords, target_size } unet_call(dummy_latents, dummy_timestep, dummy_encoder_hidden_states, added_cond_kwargs=dummy_added_cond_kwargs) compiled_unet = spaces.aoti_compile(torch.export.export(pipe.unet, args=unet_call.args, kwargs=unet_call.kwargs)) print("UNet compiled.") # 2. Compile each ControlNet compiled_controlnets = [] for cn_idx, cn_model in enumerate(pipe.controlnet): with spaces.aoti_capture(cn_model) as cn_call: # Dummy inputs for ControlNet's forward pass dummy_controlnet_cond = torch.randn(1, 3, 1024, 1024, device=DEVICE, dtype=DTYPE) # Control image cn_call(dummy_latents, dummy_timestep, dummy_encoder_hidden_states, controlnet_cond=dummy_controlnet_cond) compiled_cn = spaces.aoti_compile(torch.export.export(cn_model, args=cn_call.args, kwargs=cn_call.kwargs)) compiled_controlnets.append(compiled_cn) print(f"ControlNet {cn_idx} compiled.") print("AoT compilation completed for UNet and ControlNets.") return compiled_unet, compiled_controlnets # Apply compiled models to the pipeline compiled_unet, compiled_controlnets = compile_optimized_pipeline() spaces.aoti_apply(compiled_unet, pipe.unet) for compiled_cn, original_cn in zip(compiled_controlnets, pipe.controlnet): spaces.aoti_apply(compiled_cn, original_cn) print("AoT compilation applied to pipeline modules.") @spaces.GPU(duration=120) # Standard generation duration (e.g., up to 2 minutes) def remix_images( prompt: str, canny_image: Optional[Image.Image], depth_image: Optional[Image.Image], base_image: Optional[Image.Image], guidance_scale: float, num_inference_steps: int, progress: gr_Progress, ) -> Image.Image: """ Remixes three input images with a text prompt using ControlNet (Canny, Depth) for structural control, and an optional base image for guidance and dimensions. Args: prompt (str): The text prompt for generation. canny_image (Optional[Image.Image]): Image to generate Canny edges from for ControlNet. depth_image (Optional[Image.Image]): Image to generate depth map from for ControlNet. base_image (Optional[Image.Image]): The base image. Its dimensions will be used as the output size. If provided, it also implicitly guides the generation as a starting point. guidance_scale (float): Classifier-free guidance scale. num_inference_steps (int): Number of diffusion steps. progress (gr.Progress): Gradio progress object for tracking generation. Returns: Image.Image: The remixed image. """ progress(0, desc="Preprocessing control images...") control_images_list = [] # Process Canny image if canny_image: canny_input_for_cn = apply_canny(canny_image) else: # Provide a blank black image as a placeholder for ControlNet Canny if no input canny_input_for_cn = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) control_images_list.append(canny_input_for_cn) # Process Depth image if depth_image: depth_input_for_cn = apply_depth(depth_image) # Ensure depth map is 3-channel for ControlNet input depth_input_for_cn = depth_input_for_cn.convert("RGB") else: # Provide a blank black image as a placeholder for ControlNet Depth if no input depth_input_for_cn = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) control_images_list.append(depth_input_for_cn) # Determine output image dimensions based on base_image or default to SDXL standard if base_image: width, height = base_image.size else: width, height = 1024, 1024 # Resize control images to match target output dimensions for consistency final_control_images = [] for img in control_images_list: if img.size != (width, height): final_control_images.append(img.resize((width, height), Image.BICUBIC)) else: final_control_images.append(img) progress(0.2, desc="Generating remixed image...") # Generate the image # Use a random seed for varied outputs as requested (no fixed seed) generator = torch.Generator(device=DEVICE).manual_seed(np.random.randint(0, 10**9)) output = pipe( prompt=prompt, image=final_control_images, # This list matches the order of ControlNet models width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, # Callback to update Gradio progress bar during diffusion steps callback_on_step_end=lambda step, timestep, latents: progress((0.2 + 0.8 * step / num_inference_steps), desc=f"Generating step {step}/{num_inference_steps}..."), ).images[0] progress(1.0, desc="Done!") return output