lab-vyfdxey4 / models.py
Gertie01's picture
Deploy Gradio app with multiple files
cb390da verified
import spaces
import torch
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
from PIL import Image
import numpy as np
from typing import List, Optional
# Uncomment the following line if you plan to use FP8 quantization with torchao
# from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
# Gradio's Progress requires a specific import path if not using the top-level 'gr' alias
from gradio.components import Progress as gr_Progress
from config import MODEL_ID_BASE, MODEL_ID_CANNY, MODEL_ID_DEPTH, DEVICE, DTYPE
from utils import apply_canny, apply_depth
print(f"Loading ControlNet models and base pipeline on {DEVICE} with {DTYPE}...")
# Initialize ControlNet models
controlnet_canny_model = ControlNetModel.from_pretrained(MODEL_ID_CANNY, torch_dtype=DTYPE)
controlnet_depth_model = ControlNetModel.from_pretrained(MODEL_ID_DEPTH, torch_dtype=DTYPE)
# Initialize the main pipeline with multiple ControlNets
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
MODEL_ID_BASE,
controlnet=[controlnet_canny_model, controlnet_depth_model],
torch_dtype=DTYPE,
# vae=vae, # Uncomment if using a custom VAE, e.g., "madebyollin/sdxl-vae-fp16-fix"
)
pipe.to(DEVICE)
print("Pipeline and ControlNets loaded.")
# ZeroGPU Ahead-of-Time (AoT) compilation for UNet and ControlNets
@spaces.GPU(duration=1500) # Maximum duration allowed for startup compilation
def compile_optimized_pipeline():
print("Starting AoT compilation for UNet and ControlNets...")
# Optional: Apply FP8 quantization (requires H200 GPU and torchao)
# try:
# # Example for UNet
# quantize_(pipe.unet, Float8DynamicActivationFloat8WeightConfig())
# print("FP8 Quantization applied to UNet.")
# # Example for each ControlNet
# for i, cn in enumerate(pipe.controlnet):
# quantize_(cn, Float8DynamicActivationFloat8WeightConfig())
# print(f"FP8 Quantization applied to ControlNet {i}.")
# except Exception as e:
# print(f"Could not apply FP8 quantization: {e}. Falling back to default dtype.")
# 1. Compile UNet
with spaces.aoti_capture(pipe.unet) as unet_call:
# Dummy inputs for UNet's forward pass (common SDXL latents, timestep, text embeddings, and added conditions)
dummy_latents = torch.randn(1, 4, 128, 128, device=DEVICE, dtype=DTYPE) # For 1024x1024 output
dummy_timestep = torch.tensor(999, device=DEVICE, dtype=torch.long)
dummy_encoder_hidden_states = torch.randn(1, 77, 2048, device=DEVICE, dtype=DTYPE) # text_embeddings
dummy_added_cond_kwargs = {
"text_embeds": torch.randn(1, 1280, device=DEVICE, dtype=DTYPE), # pooled_prompt_embeds
"time_ids": torch.randn(1, 6, device=DEVICE, dtype=DTYPE), # original_size, crops_coords, target_size
}
unet_call(dummy_latents, dummy_timestep, dummy_encoder_hidden_states, added_cond_kwargs=dummy_added_cond_kwargs)
compiled_unet = spaces.aoti_compile(torch.export.export(pipe.unet, args=unet_call.args, kwargs=unet_call.kwargs))
print("UNet compiled.")
# 2. Compile each ControlNet
compiled_controlnets = []
for cn_idx, cn_model in enumerate(pipe.controlnet):
with spaces.aoti_capture(cn_model) as cn_call:
# Dummy inputs for ControlNet's forward pass
dummy_controlnet_cond = torch.randn(1, 3, 1024, 1024, device=DEVICE, dtype=DTYPE) # Control image
cn_call(dummy_latents, dummy_timestep, dummy_encoder_hidden_states, controlnet_cond=dummy_controlnet_cond)
compiled_cn = spaces.aoti_compile(torch.export.export(cn_model, args=cn_call.args, kwargs=cn_call.kwargs))
compiled_controlnets.append(compiled_cn)
print(f"ControlNet {cn_idx} compiled.")
print("AoT compilation completed for UNet and ControlNets.")
return compiled_unet, compiled_controlnets
# Apply compiled models to the pipeline
compiled_unet, compiled_controlnets = compile_optimized_pipeline()
spaces.aoti_apply(compiled_unet, pipe.unet)
for compiled_cn, original_cn in zip(compiled_controlnets, pipe.controlnet):
spaces.aoti_apply(compiled_cn, original_cn)
print("AoT compilation applied to pipeline modules.")
@spaces.GPU(duration=120) # Standard generation duration (e.g., up to 2 minutes)
def remix_images(
prompt: str,
canny_image: Optional[Image.Image],
depth_image: Optional[Image.Image],
base_image: Optional[Image.Image],
guidance_scale: float,
num_inference_steps: int,
progress: gr_Progress,
) -> Image.Image:
"""
Remixes three input images with a text prompt using ControlNet (Canny, Depth) for structural control,
and an optional base image for guidance and dimensions.
Args:
prompt (str): The text prompt for generation.
canny_image (Optional[Image.Image]): Image to generate Canny edges from for ControlNet.
depth_image (Optional[Image.Image]): Image to generate depth map from for ControlNet.
base_image (Optional[Image.Image]): The base image. Its dimensions will be used as the output size.
If provided, it also implicitly guides the generation as a starting point.
guidance_scale (float): Classifier-free guidance scale.
num_inference_steps (int): Number of diffusion steps.
progress (gr.Progress): Gradio progress object for tracking generation.
Returns:
Image.Image: The remixed image.
"""
progress(0, desc="Preprocessing control images...")
control_images_list = []
# Process Canny image
if canny_image:
canny_input_for_cn = apply_canny(canny_image)
else:
# Provide a blank black image as a placeholder for ControlNet Canny if no input
canny_input_for_cn = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
control_images_list.append(canny_input_for_cn)
# Process Depth image
if depth_image:
depth_input_for_cn = apply_depth(depth_image)
# Ensure depth map is 3-channel for ControlNet input
depth_input_for_cn = depth_input_for_cn.convert("RGB")
else:
# Provide a blank black image as a placeholder for ControlNet Depth if no input
depth_input_for_cn = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
control_images_list.append(depth_input_for_cn)
# Determine output image dimensions based on base_image or default to SDXL standard
if base_image:
width, height = base_image.size
else:
width, height = 1024, 1024
# Resize control images to match target output dimensions for consistency
final_control_images = []
for img in control_images_list:
if img.size != (width, height):
final_control_images.append(img.resize((width, height), Image.BICUBIC))
else:
final_control_images.append(img)
progress(0.2, desc="Generating remixed image...")
# Generate the image
# Use a random seed for varied outputs as requested (no fixed seed)
generator = torch.Generator(device=DEVICE).manual_seed(np.random.randint(0, 10**9))
output = pipe(
prompt=prompt,
image=final_control_images, # This list matches the order of ControlNet models
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
# Callback to update Gradio progress bar during diffusion steps
callback_on_step_end=lambda step, timestep, latents: progress((0.2 + 0.8 * step / num_inference_steps), desc=f"Generating step {step}/{num_inference_steps}..."),
).images[0]
progress(1.0, desc="Done!")
return output