demo-os56ddeg / models.py
Gertie01's picture
Deploy Gradio app with multiple files
bd63620 verified
import spaces
import torch
from diffusers import StableDiffusionXLPipeline, StableVideoDiffusionPipeline
from PIL import Image
from typing import List
# Configuration based on typical high-performance T2I and I2V models
SVD_MODEL_ID = "stabilityai/stable-video-diffusion-img2vid-xt"
SDXL_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
SVD_PIPE = None
SDXL_PIPE = None
@spaces.GPU(duration=1500) # Allocate maximum time for heavy model loading and initialization
def load_models():
"""Initializes and loads SDXL and SVD pipelines to CUDA."""
global SVD_PIPE, SDXL_PIPE
if SDXL_PIPE is not None and SVD_PIPE is not None:
return SDXL_PIPE, SVD_PIPE
# 1. SDXL (for Text-to-Video initiation image)
print("Loading SDXL pipeline...")
sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
SDXL_MODEL_ID,
torch_dtype=torch.float16,
use_safetensors=True
).to("cuda")
sdxl_pipe.enable_vae_slicing()
SDXL_PIPE = sdxl_pipe
print("SDXL loaded.")
# 2. SVD (Stable Video Diffusion for Image-to-Video)
print("Loading SVD pipeline...")
svd_pipe = StableVideoDiffusionPipeline.from_pretrained(
SVD_MODEL_ID,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16"
).to("cuda")
SVD_PIPE = svd_pipe
print("SVD loaded.")
return sdxl_pipe, svd_pipe
# Load models upon module import (startup)
try:
SDXL_PIPE, SVD_PIPE = load_models()
except Exception as e:
print(f"Error during initial model load: {e}")
SDXL_PIPE, SVD_PIPE = None, None
@spaces.GPU(duration=60) # Standard duration for inference
def generate_t2v(
prompt: str,
motion_bucket_id: int,
frames: int,
fps: int
) -> List[Image.Image]:
"""Generates a video based on a text prompt (T2I + I2V)."""
if SDXL_PIPE is None or SVD_PIPE is None:
raise ConnectionError("Models are not ready. Please wait for startup.")
# 1. Generate starting image using SDXL
print(f"Generating starting image for prompt: {prompt}")
with torch.no_grad():
image = SDXL_PIPE(prompt, guidance_scale=6.5).images[0]
# 2. Resize image for SVD (1024x576 optimal for 16:9)
image = image.resize((1024, 576))
# 3. Generate Video
print(f"Generating video with {frames} frames at {fps} fps...")
generator = torch.Generator(device="cuda")
with torch.no_grad():
video_frames = SVD_PIPE(
image,
decode_chunk_size=frames,
motion_bucket_id=motion_bucket_id,
num_frames=frames,
fps=fps,
noise_aug_strength=0.02,
generator=generator,
num_inference_steps=25,
output_type="pil",
).frames[0]
return video_frames
@spaces.GPU(duration=45) # Shorter duration for I2V as it skips T2I
def generate_i2v(
input_image: Image.Image,
motion_bucket_id: int,
frames: int,
fps: int
) -> List[Image.Image]:
"""Generates a video based on an input image (I2V)."""
if SVD_PIPE is None:
raise ConnectionError("SVD Model is not ready. Please wait for startup.")
# Resize input image to SVD optimal resolution
input_image = input_image.resize((1024, 576))
print(f"Generating video with {frames} frames at {fps} fps...")
generator = torch.Generator(device="cuda")
with torch.no_grad():
video_frames = SVD_PIPE(
input_image,
decode_chunk_size=frames,
motion_bucket_id=motion_bucket_id,
num_frames=frames,
fps=fps,
noise_aug_strength=0.05,
generator=generator,
num_inference_steps=25,
output_type="pil",
).frames[0]
return video_frames