smart-swift-tool / models.py
Gertie01's picture
Deploy Gradio app with multiple files
48ae3bc verified
raw
history blame
5.1 kB
# models.py
import torch
import numpy as np
from diffusers import DiffusionPipeline
from typing import Tuple, Union
import spaces
from PIL import Image
import imageio
import os
from scipy.io import wavfile
from config import MODEL_ID_T2V, MAX_DURATION_SECONDS
# --- Model Loading (ZeroGPU Setup) ---
pipe_t2v = None
MODEL_LOADED = False
try:
# Use bfloat16 if available (recommended for modern GPUs)
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8 else torch.float16
pipe_t2v = DiffusionPipeline.from_pretrained(
MODEL_ID_T2V,
torch_dtype=dtype,
variant="fp16"
)
# Move to CUDA and enable CPU offload for large models
pipe_t2v.to("cuda")
pipe_t2v.enable_model_cpu_offload()
MODEL_LOADED = True
print(f"✅ Loaded model {MODEL_ID_T2V} to CUDA.")
except Exception as e:
print(f"❌ Failed to load ZeroScope model for GPU: {e}")
MODEL_LOADED = False
# Fallback generator function
def fallback_video_generator(prompt: str, duration: int) -> str:
print(f"⚠️ Using CPU Fallback Generator for '{prompt}'.")
# Simulate generation time
# This ensures the user waits, mirroring the real process time
import time; time.sleep(duration * 1.5)
num_frames = duration * 10 # 10 FPS
frames = []
# Simple gradient animation
width, height = 576, 320
for i in range(num_frames):
# Create a simple color based on frame index
r = (128 + 100 * np.sin(i * 0.1)).astype(np.uint8)
g = (128 + 100 * np.sin(i * 0.15)).astype(np.uint8)
b = (128 + 100 * np.sin(i * 0.2)).astype(np.uint8)
frame = np.zeros((height, width, 3), dtype=np.uint8)
frame[:, :] = [r, g, b]
frames.append(frame)
output_path = "output_fallback.mp4"
imageio.mimsave(output_path, frames, fps=10)
return output_path
def synthesize_audio(prompt: str) -> Tuple[int, np.ndarray]:
"""Synthesizes placeholder audio based on the prompt complexity."""
try:
base_freq = 200 + len(prompt.split()) * 15 # Frequency scales with word count
duration = 4.0 # seconds (fixed length for simplicity)
sample_rate = 22050
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
# Complex waveform: multiple sine waves + envelope
waveform = 0.6 * np.sin(2 * np.pi * base_freq * t)
waveform += 0.3 * np.sin(2 * np.pi * (base_freq * 1.5) * t)
# Apply gentle attack/decay envelope
envelope = np.ones_like(t)
attack_len = int(sample_rate * 0.5)
decay_len = int(sample_rate * (duration - 0.5))
envelope[:attack_len] = np.linspace(0, 1, attack_len)
envelope[decay_len:] = np.linspace(1, 0, len(t) - decay_len)
waveform *= envelope
# Scale to 16-bit PCM
audio_data = (waveform * 32767).astype(np.int16)
return sample_rate, audio_data
except Exception as e:
print(f"Audio synthesis error: {e}")
return 22050, np.zeros(22050 * 4, dtype=np.int16)
@spaces.GPU(duration=300) # Generous duration for video generation
def generate_video(
prompt: str,
input_image: Union[Image.Image, None],
duration: int,
is_image_to_video: bool
) -> Tuple[str, Tuple[int, np.ndarray]]:
"""
Generates a video (and synthesized audio) based on the input parameters.
"""
# 1. Video generation logic
if not MODEL_LOADED or pipe_t2v is None:
video_path = fallback_video_generator(prompt, duration)
else:
actual_duration = min(duration, MAX_DURATION_SECONDS)
# Using a fixed frame rate common for ZeroScope
fps = 10
num_frames = actual_duration * fps
print(f"Using ZeroScope T2V. Duration: {actual_duration}s, Frames: {num_frames}")
if is_image_to_video and input_image:
# For I2V using T2V, we must guide the model using the prompt
# and rely on future model iterations (or Lora/ControlNet) for true image conditioning.
prompt = f"video starting from a visual of the following: {prompt}"
# In a real I2V setup, input_image would condition the VAE/UNet.
try:
# Generate frames
video_frames = pipe_t2v(
prompt,
num_frames=num_frames,
height=320,
width=576
).frames
output_path = "output_video.mp4"
# Use 'H.264' codec for better compatibility in web browsers
imageio.mimsave(output_path, [np.array(f) for f in video_frames], fps=fps, quality=8, codec='libx264', pixelformat='yuv420p')
except Exception as e:
print(f"Critical Error during ZeroScope generation: {e}")
video_path = fallback_video_generator(prompt, duration)
# 2. Synthesize audio
audio_output = synthesize_audio(prompt)
return video_path, audio_output