EceMotion_Pictures / config.py
drvsbrkcn's picture
Upload 10 files
b12e499 verified
"""
Configuration management for EceMotion Pictures.
Centralized settings for models, parameters, and deployment.
"""
import os
from typing import Dict, Any, Optional
# Model Configuration - with fallbacks for HuggingFace Spaces
MODEL_VIDEO = os.getenv("MODEL_VIDEO", "damo-vilab/text-to-video-ms-1.7b") # Start with lighter model
MODEL_AUDIO = os.getenv("MODEL_AUDIO", "parler-tts/parler-tts-mini-v1") # Start with working model
MODEL_LLM = os.getenv("MODEL_LLM", "microsoft/DialoGPT-medium") # Start with lighter LLM
# Video Configuration
MAX_DURATION = int(os.getenv("MAX_DURATION", "15"))
MIN_DURATION = int(os.getenv("MIN_DURATION", "5"))
DEFAULT_FPS = int(os.getenv("DEFAULT_FPS", "8"))
DEFAULT_FRAMES = int(os.getenv("DEFAULT_FRAMES", "64")) # 8 seconds at 8fps
# Audio Configuration
AUDIO_SAMPLE_RATE = int(os.getenv("AUDIO_SAMPLE_RATE", "22050")) # Standard rate
AUDIO_BITRATE = os.getenv("AUDIO_BITRATE", "128k") # Lower bitrate for stability
MUSIC_GAIN = float(os.getenv("MUSIC_GAIN", "0.3"))
# GPU Configuration
GPU_MEMORY_THRESHOLD = float(os.getenv("GPU_MEMORY_THRESHOLD", "0.8"))
USE_QUANTIZATION = os.getenv("USE_QUANTIZATION", "true").lower() == "true"
QUANTIZATION_BITS = int(os.getenv("QUANTIZATION_BITS", "8"))
# Sync Configuration
SYNC_TOLERANCE_MS = int(os.getenv("SYNC_TOLERANCE_MS", "200")) # More lenient for stability
FORCE_SYNC = os.getenv("FORCE_SYNC", "false").lower() == "true" # Disabled by default
# Retro Filter Configuration
VHS_INTENSITY = float(os.getenv("VHS_INTENSITY", "0.5"))
SCANLINE_OPACITY = float(os.getenv("SCANLINE_OPACITY", "0.2"))
CHROMATIC_ABERRATION = float(os.getenv("CHROMATIC_ABERRATION", "0.05"))
FILM_GRAIN = float(os.getenv("FILM_GRAIN", "0.1"))
# UI Configuration
UI_THEME = os.getenv("UI_THEME", "default")
SHOW_PROGRESS = os.getenv("SHOW_PROGRESS", "true").lower() == "true"
ENABLE_EXAMPLES = os.getenv("ENABLE_EXAMPLES", "true").lower() == "true"
# Logging Configuration
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
LOG_FORMAT = os.getenv("LOG_FORMAT", "text") # text format for HuggingFace Spaces
# Model-specific configurations with conservative settings
MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
"damo-vilab/text-to-video-ms-1.7b": {
"max_frames": 64,
"min_frames": 8,
"default_frames": 32,
"memory_usage_gb": 6,
"supports_quantization": False,
"stable": True,
},
"THUDM/CogVideoX-5b": {
"max_frames": 48, # Reduced for stability
"min_frames": 16,
"default_frames": 32,
"memory_usage_gb": 16, # Conservative estimate
"supports_quantization": True,
"stable": False, # Mark as experimental
},
"parler-tts/parler-tts-mini-v1": {
"max_text_length": 500,
"min_text_length": 10,
"default_voice": "Announcer '80s",
"memory_usage_gb": 2,
"stable": True,
},
"SWivid/F5-TTS": {
"max_text_length": 300,
"min_text_length": 10,
"default_voice": "announcer",
"memory_usage_gb": 4,
"stable": False, # Mark as experimental
},
"microsoft/DialoGPT-medium": {
"max_tokens": 512,
"temperature": 0.7,
"top_p": 0.9,
"memory_usage_gb": 2,
"stable": True,
},
"Qwen/Qwen2.5-7B-Instruct": {
"max_tokens": 1024,
"temperature": 0.7,
"top_p": 0.9,
"memory_usage_gb": 8,
"stable": False, # Mark as experimental
},
}
# Voice styles for TTS
VOICE_STYLES = {
"Announcer '80s": "A confident, upbeat 1980s TV announcer with warm AM-radio tone.",
"Mall PA": "Casual, slightly echoey mall public-address vibe.",
"Late Night": "Low energy, sly late-night infomercial style.",
"News Anchor": "Professional, authoritative news anchor delivery.",
"Infomercial": "Enthusiastic, persuasive infomercial host style.",
"Radio DJ": "Smooth, charismatic radio disc jockey voice.",
}
# Structure templates for script generation
STRUCTURE_TEMPLATES = [
"Montage → Close-up → Logo stinger",
"Before/After → Feature highlight → CTA",
"Testimonial → B-roll → Price tag reveal",
"Unboxing → Demo → Deal countdown",
"Retro news bulletin → Product shot → Tagline",
"Opening hook → Problem/Solution → Call to action",
"Brand story → Product showcase → Final tagline",
]
# Taglines for commercial endings
TAGLINES = [
"So retro, it's the future.",
"Pixels you can trust.",
"VHS vibes. Modern results.",
"Old-school cool. New-school sales.",
"Where nostalgia meets innovation.",
"Rewind to the future.",
"Classic style. Modern performance.",
"The past perfected.",
"EceMotion Pictures - Bringing the '80s back to life.",
"Your story, our vision, timeless memories.",
]
def get_model_config(model_name: str) -> Dict[str, Any]:
"""Get configuration for a specific model."""
return MODEL_CONFIGS.get(model_name, {
"max_frames": 32,
"min_frames": 8,
"default_frames": 16,
"memory_usage_gb": 4,
"supports_quantization": False,
"stable": True,
})
def get_device() -> str:
"""Determine the best available device."""
try:
import torch
if torch.cuda.is_available() and os.getenv("CUDA_VISIBLE_DEVICES", None) not in (None, ""):
return "cuda"
except ImportError:
pass
return "cpu"
def validate_config() -> bool:
"""Validate configuration settings."""
try:
assert MIN_DURATION < MAX_DURATION, "MIN_DURATION must be less than MAX_DURATION"
assert DEFAULT_FPS > 0, "DEFAULT_FPS must be positive"
assert AUDIO_SAMPLE_RATE > 0, "AUDIO_SAMPLE_RATE must be positive"
assert 0 <= VHS_INTENSITY <= 1, "VHS_INTENSITY must be between 0 and 1"
assert 0 <= SCANLINE_OPACITY <= 1, "SCANLINE_OPACITY must be between 0 and 1"
return True
except AssertionError as e:
print(f"Configuration validation failed: {e}")
return False
def get_safe_model_name(model_name: str, model_type: str) -> str:
"""Get a safe model name with fallback to stable models."""
config = get_model_config(model_name)
# If model is not stable, fallback to stable alternatives
if not config.get("stable", False):
if model_type == "video":
return "damo-vilab/text-to-video-ms-1.7b"
elif model_type == "audio":
return "parler-tts/parler-tts-mini-v1"
elif model_type == "llm":
return "microsoft/DialoGPT-medium"
return model_name
def log_config():
"""Log current configuration for debugging."""
print(f"EceMotion Pictures Configuration:")
print(f" Video Model: {MODEL_VIDEO}")
print(f" Audio Model: {MODEL_AUDIO}")
print(f" LLM Model: {MODEL_LLM}")
print(f" Device: {get_device()}")
print(f" Duration Range: {MIN_DURATION}-{MAX_DURATION}s")
print(f" FPS: {DEFAULT_FPS}")
print(f" Sync Tolerance: {SYNC_TOLERANCE_MS}ms")