Spaces:
Sleeping
Sleeping
| import spaces | |
| import torch | |
| from diffusers import StableDiffusionXLPipeline, StableVideoDiffusionPipeline | |
| from PIL import Image | |
| from typing import List | |
| # Configuration based on typical high-performance T2I and I2V models | |
| SVD_MODEL_ID = "stabilityai/stable-video-diffusion-img2vid-xt" | |
| SDXL_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" | |
| SVD_PIPE = None | |
| SDXL_PIPE = None | |
| # Allocate maximum time for heavy model loading and initialization | |
| def load_models(): | |
| """Initializes and loads SDXL and SVD pipelines to CUDA.""" | |
| global SVD_PIPE, SDXL_PIPE | |
| if SDXL_PIPE is not None and SVD_PIPE is not None: | |
| return SDXL_PIPE, SVD_PIPE | |
| # 1. SDXL (for Text-to-Video initiation image) | |
| print("Loading SDXL pipeline...") | |
| sdxl_pipe = StableDiffusionXLPipeline.from_pretrained( | |
| SDXL_MODEL_ID, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True | |
| ).to("cuda") | |
| sdxl_pipe.enable_vae_slicing() | |
| SDXL_PIPE = sdxl_pipe | |
| print("SDXL loaded.") | |
| # 2. SVD (Stable Video Diffusion for Image-to-Video) | |
| print("Loading SVD pipeline...") | |
| svd_pipe = StableVideoDiffusionPipeline.from_pretrained( | |
| SVD_MODEL_ID, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True, | |
| variant="fp16" | |
| ).to("cuda") | |
| SVD_PIPE = svd_pipe | |
| print("SVD loaded.") | |
| return sdxl_pipe, svd_pipe | |
| # Load models upon module import (startup) | |
| try: | |
| SDXL_PIPE, SVD_PIPE = load_models() | |
| except Exception as e: | |
| print(f"Error during initial model load: {e}") | |
| SDXL_PIPE, SVD_PIPE = None, None | |
| # Standard duration for inference | |
| def generate_t2v( | |
| prompt: str, | |
| motion_bucket_id: int, | |
| frames: int, | |
| fps: int | |
| ) -> List[Image.Image]: | |
| """Generates a video based on a text prompt (T2I + I2V).""" | |
| if SDXL_PIPE is None or SVD_PIPE is None: | |
| raise ConnectionError("Models are not ready. Please wait for startup.") | |
| # 1. Generate starting image using SDXL | |
| print(f"Generating starting image for prompt: {prompt}") | |
| with torch.no_grad(): | |
| image = SDXL_PIPE(prompt, guidance_scale=6.5).images[0] | |
| # 2. Resize image for SVD (1024x576 optimal for 16:9) | |
| image = image.resize((1024, 576)) | |
| # 3. Generate Video | |
| print(f"Generating video with {frames} frames at {fps} fps...") | |
| generator = torch.Generator(device="cuda") | |
| with torch.no_grad(): | |
| video_frames = SVD_PIPE( | |
| image, | |
| decode_chunk_size=frames, | |
| motion_bucket_id=motion_bucket_id, | |
| num_frames=frames, | |
| fps=fps, | |
| noise_aug_strength=0.02, | |
| generator=generator, | |
| num_inference_steps=25, | |
| output_type="pil", | |
| ).frames[0] | |
| return video_frames | |
| # Shorter duration for I2V as it skips T2I | |
| def generate_i2v( | |
| input_image: Image.Image, | |
| motion_bucket_id: int, | |
| frames: int, | |
| fps: int | |
| ) -> List[Image.Image]: | |
| """Generates a video based on an input image (I2V).""" | |
| if SVD_PIPE is None: | |
| raise ConnectionError("SVD Model is not ready. Please wait for startup.") | |
| # Resize input image to SVD optimal resolution | |
| input_image = input_image.resize((1024, 576)) | |
| print(f"Generating video with {frames} frames at {fps} fps...") | |
| generator = torch.Generator(device="cuda") | |
| with torch.no_grad(): | |
| video_frames = SVD_PIPE( | |
| input_image, | |
| decode_chunk_size=frames, | |
| motion_bucket_id=motion_bucket_id, | |
| num_frames=frames, | |
| fps=fps, | |
| noise_aug_strength=0.05, | |
| generator=generator, | |
| num_inference_steps=25, | |
| output_type="pil", | |
| ).frames[0] | |
| return video_frames |