Spaces:
Runtime error
Runtime error
| import spaces | |
| import torch | |
| from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel | |
| from PIL import Image | |
| import numpy as np | |
| from typing import List, Optional | |
| # Uncomment the following line if you plan to use FP8 quantization with torchao | |
| # from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig | |
| # Gradio's Progress requires a specific import path if not using the top-level 'gr' alias | |
| from gradio.components import Progress as gr_Progress | |
| from config import MODEL_ID_BASE, MODEL_ID_CANNY, MODEL_ID_DEPTH, DEVICE, DTYPE | |
| from utils import apply_canny, apply_depth | |
| print(f"Loading ControlNet models and base pipeline on {DEVICE} with {DTYPE}...") | |
| # Initialize ControlNet models | |
| controlnet_canny_model = ControlNetModel.from_pretrained(MODEL_ID_CANNY, torch_dtype=DTYPE) | |
| controlnet_depth_model = ControlNetModel.from_pretrained(MODEL_ID_DEPTH, torch_dtype=DTYPE) | |
| # Initialize the main pipeline with multiple ControlNets | |
| pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | |
| MODEL_ID_BASE, | |
| controlnet=[controlnet_canny_model, controlnet_depth_model], | |
| torch_dtype=DTYPE, | |
| # vae=vae, # Uncomment if using a custom VAE, e.g., "madebyollin/sdxl-vae-fp16-fix" | |
| ) | |
| pipe.to(DEVICE) | |
| print("Pipeline and ControlNets loaded.") | |
| # ZeroGPU Ahead-of-Time (AoT) compilation for UNet and ControlNets | |
| # Maximum duration allowed for startup compilation | |
| def compile_optimized_pipeline(): | |
| print("Starting AoT compilation for UNet and ControlNets...") | |
| # Optional: Apply FP8 quantization (requires H200 GPU and torchao) | |
| # try: | |
| # # Example for UNet | |
| # quantize_(pipe.unet, Float8DynamicActivationFloat8WeightConfig()) | |
| # print("FP8 Quantization applied to UNet.") | |
| # # Example for each ControlNet | |
| # for i, cn in enumerate(pipe.controlnet): | |
| # quantize_(cn, Float8DynamicActivationFloat8WeightConfig()) | |
| # print(f"FP8 Quantization applied to ControlNet {i}.") | |
| # except Exception as e: | |
| # print(f"Could not apply FP8 quantization: {e}. Falling back to default dtype.") | |
| # 1. Compile UNet | |
| with spaces.aoti_capture(pipe.unet) as unet_call: | |
| # Dummy inputs for UNet's forward pass (common SDXL latents, timestep, text embeddings, and added conditions) | |
| dummy_latents = torch.randn(1, 4, 128, 128, device=DEVICE, dtype=DTYPE) # For 1024x1024 output | |
| dummy_timestep = torch.tensor(999, device=DEVICE, dtype=torch.long) | |
| dummy_encoder_hidden_states = torch.randn(1, 77, 2048, device=DEVICE, dtype=DTYPE) # text_embeddings | |
| dummy_added_cond_kwargs = { | |
| "text_embeds": torch.randn(1, 1280, device=DEVICE, dtype=DTYPE), # pooled_prompt_embeds | |
| "time_ids": torch.randn(1, 6, device=DEVICE, dtype=DTYPE), # original_size, crops_coords, target_size | |
| } | |
| unet_call(dummy_latents, dummy_timestep, dummy_encoder_hidden_states, added_cond_kwargs=dummy_added_cond_kwargs) | |
| compiled_unet = spaces.aoti_compile(torch.export.export(pipe.unet, args=unet_call.args, kwargs=unet_call.kwargs)) | |
| print("UNet compiled.") | |
| # 2. Compile each ControlNet | |
| compiled_controlnets = [] | |
| for cn_idx, cn_model in enumerate(pipe.controlnet): | |
| with spaces.aoti_capture(cn_model) as cn_call: | |
| # Dummy inputs for ControlNet's forward pass | |
| dummy_controlnet_cond = torch.randn(1, 3, 1024, 1024, device=DEVICE, dtype=DTYPE) # Control image | |
| cn_call(dummy_latents, dummy_timestep, dummy_encoder_hidden_states, controlnet_cond=dummy_controlnet_cond) | |
| compiled_cn = spaces.aoti_compile(torch.export.export(cn_model, args=cn_call.args, kwargs=cn_call.kwargs)) | |
| compiled_controlnets.append(compiled_cn) | |
| print(f"ControlNet {cn_idx} compiled.") | |
| print("AoT compilation completed for UNet and ControlNets.") | |
| return compiled_unet, compiled_controlnets | |
| # Apply compiled models to the pipeline | |
| compiled_unet, compiled_controlnets = compile_optimized_pipeline() | |
| spaces.aoti_apply(compiled_unet, pipe.unet) | |
| for compiled_cn, original_cn in zip(compiled_controlnets, pipe.controlnet): | |
| spaces.aoti_apply(compiled_cn, original_cn) | |
| print("AoT compilation applied to pipeline modules.") | |
| # Standard generation duration (e.g., up to 2 minutes) | |
| def remix_images( | |
| prompt: str, | |
| canny_image: Optional[Image.Image], | |
| depth_image: Optional[Image.Image], | |
| base_image: Optional[Image.Image], | |
| guidance_scale: float, | |
| num_inference_steps: int, | |
| progress: gr_Progress, | |
| ) -> Image.Image: | |
| """ | |
| Remixes three input images with a text prompt using ControlNet (Canny, Depth) for structural control, | |
| and an optional base image for guidance and dimensions. | |
| Args: | |
| prompt (str): The text prompt for generation. | |
| canny_image (Optional[Image.Image]): Image to generate Canny edges from for ControlNet. | |
| depth_image (Optional[Image.Image]): Image to generate depth map from for ControlNet. | |
| base_image (Optional[Image.Image]): The base image. Its dimensions will be used as the output size. | |
| If provided, it also implicitly guides the generation as a starting point. | |
| guidance_scale (float): Classifier-free guidance scale. | |
| num_inference_steps (int): Number of diffusion steps. | |
| progress (gr.Progress): Gradio progress object for tracking generation. | |
| Returns: | |
| Image.Image: The remixed image. | |
| """ | |
| progress(0, desc="Preprocessing control images...") | |
| control_images_list = [] | |
| # Process Canny image | |
| if canny_image: | |
| canny_input_for_cn = apply_canny(canny_image) | |
| else: | |
| # Provide a blank black image as a placeholder for ControlNet Canny if no input | |
| canny_input_for_cn = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) | |
| control_images_list.append(canny_input_for_cn) | |
| # Process Depth image | |
| if depth_image: | |
| depth_input_for_cn = apply_depth(depth_image) | |
| # Ensure depth map is 3-channel for ControlNet input | |
| depth_input_for_cn = depth_input_for_cn.convert("RGB") | |
| else: | |
| # Provide a blank black image as a placeholder for ControlNet Depth if no input | |
| depth_input_for_cn = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) | |
| control_images_list.append(depth_input_for_cn) | |
| # Determine output image dimensions based on base_image or default to SDXL standard | |
| if base_image: | |
| width, height = base_image.size | |
| else: | |
| width, height = 1024, 1024 | |
| # Resize control images to match target output dimensions for consistency | |
| final_control_images = [] | |
| for img in control_images_list: | |
| if img.size != (width, height): | |
| final_control_images.append(img.resize((width, height), Image.BICUBIC)) | |
| else: | |
| final_control_images.append(img) | |
| progress(0.2, desc="Generating remixed image...") | |
| # Generate the image | |
| # Use a random seed for varied outputs as requested (no fixed seed) | |
| generator = torch.Generator(device=DEVICE).manual_seed(np.random.randint(0, 10**9)) | |
| output = pipe( | |
| prompt=prompt, | |
| image=final_control_images, # This list matches the order of ControlNet models | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| # Callback to update Gradio progress bar during diffusion steps | |
| callback_on_step_end=lambda step, timestep, latents: progress((0.2 + 0.8 * step / num_inference_steps), desc=f"Generating step {step}/{num_inference_steps}..."), | |
| ).images[0] | |
| progress(1.0, desc="Done!") | |
| return output | |