|
|
""" |
|
|
LLM-powered script generation for EceMotion Pictures. |
|
|
Generates intelligent, structure-aware commercial scripts with timing markers. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import random |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from config import ( |
|
|
MODEL_LLM, MODEL_CONFIGS, VOICE_STYLES, STRUCTURE_TEMPLATES, TAGLINES, |
|
|
get_safe_model_name |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class ScriptSegment: |
|
|
"""Represents a segment of the commercial script with timing information.""" |
|
|
text: str |
|
|
duration_estimate: float |
|
|
segment_type: str |
|
|
timing_marker: Optional[str] = None |
|
|
|
|
|
@dataclass |
|
|
class GeneratedScript: |
|
|
"""Complete generated script with all segments and metadata.""" |
|
|
segments: List[ScriptSegment] |
|
|
total_duration: float |
|
|
tagline: str |
|
|
voice_style: str |
|
|
word_count: int |
|
|
raw_script: str |
|
|
|
|
|
class LLMScriptGenerator: |
|
|
"""Generates commercial scripts using large language models with fallbacks.""" |
|
|
|
|
|
def __init__(self, model_name: str = MODEL_LLM): |
|
|
self.model_name = get_safe_model_name(model_name, "llm") |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.model_config = MODEL_CONFIGS.get(self.model_name, {}) |
|
|
self.llm_available = False |
|
|
|
|
|
|
|
|
self._try_init_llm() |
|
|
|
|
|
def _try_init_llm(self): |
|
|
"""Try to initialize the LLM model.""" |
|
|
try: |
|
|
if "dialo" in self.model_name.lower(): |
|
|
self._init_dialogpt() |
|
|
elif "qwen" in self.model_name.lower(): |
|
|
self._init_qwen() |
|
|
else: |
|
|
logger.warning(f"Unknown LLM model: {self.model_name}, using fallback") |
|
|
self.llm_available = False |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to initialize LLM {self.model_name}: {e}") |
|
|
self.llm_available = False |
|
|
|
|
|
def _init_dialogpt(self): |
|
|
"""Initialize DialoGPT model.""" |
|
|
try: |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_name, |
|
|
torch_dtype="auto", |
|
|
device_map="auto" if self._has_gpu() else "cpu" |
|
|
) |
|
|
self.llm_available = True |
|
|
logger.info(f"DialoGPT model {self.model_name} loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load DialoGPT: {e}") |
|
|
self.llm_available = False |
|
|
|
|
|
def _init_qwen(self): |
|
|
"""Initialize Qwen model.""" |
|
|
try: |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.model_name, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_name, |
|
|
torch_dtype="auto", |
|
|
device_map="auto" if self._has_gpu() else "cpu", |
|
|
trust_remote_code=True |
|
|
) |
|
|
self.llm_available = True |
|
|
logger.info(f"Qwen model {self.model_name} loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load Qwen: {e}") |
|
|
self.llm_available = False |
|
|
|
|
|
def _has_gpu(self) -> bool: |
|
|
"""Check if GPU is available.""" |
|
|
try: |
|
|
import torch |
|
|
return torch.cuda.is_available() |
|
|
except ImportError: |
|
|
return False |
|
|
|
|
|
def _create_system_prompt(self) -> str: |
|
|
"""Create system prompt for retro commercial script generation.""" |
|
|
return """You are a professional copywriter specializing in 1980s-style TV commercials. |
|
|
Your task is to create engaging, persuasive commercial scripts that capture the authentic retro aesthetic. |
|
|
|
|
|
Key requirements: |
|
|
- Use 1980s commercial language and style |
|
|
- Include clear hooks, benefits, and calls-to-action |
|
|
- Keep scripts concise and punchy |
|
|
- Use active voice and emotional appeals |
|
|
- End with a memorable tagline |
|
|
|
|
|
Format your response as: |
|
|
HOOK: [Opening attention-grabber] |
|
|
FLOW: [Main content following the structure] |
|
|
BENEFIT: [Key value proposition] |
|
|
CTA: [Call to action with tagline] |
|
|
|
|
|
Keep each segment under 2-3 sentences. Use enthusiastic, confident language typical of 1980s advertising.""" |
|
|
|
|
|
def _create_user_prompt(self, brand: str, structure: str, script_prompt: str, |
|
|
duration: int, voice_style: str) -> str: |
|
|
"""Create user prompt with specific requirements.""" |
|
|
return f"""Create a {duration}-second retro commercial script for {brand}. |
|
|
|
|
|
Structure: {structure} |
|
|
Script idea: {script_prompt} |
|
|
Voice style: {voice_style} |
|
|
|
|
|
Make it authentic to 1980s TV commercials with the energy and style of that era.""" |
|
|
|
|
|
def _parse_script_response(self, response: str) -> List[ScriptSegment]: |
|
|
"""Parse LLM response into structured script segments.""" |
|
|
segments = [] |
|
|
|
|
|
|
|
|
import re |
|
|
parts = re.split(r'(HOOK:|FLOW:|BENEFIT:|CTA:)', response) |
|
|
|
|
|
for i in range(1, len(parts), 2): |
|
|
if i + 1 < len(parts): |
|
|
segment_type = parts[i].rstrip(':').lower() |
|
|
text = parts[i + 1].strip() |
|
|
|
|
|
if text: |
|
|
|
|
|
word_count = len(text.split()) |
|
|
duration = (word_count / 150) * 60 |
|
|
|
|
|
segments.append(ScriptSegment( |
|
|
text=text, |
|
|
duration_estimate=duration, |
|
|
segment_type=segment_type, |
|
|
timing_marker=f"[{segment_type.upper()}]" |
|
|
)) |
|
|
|
|
|
return segments |
|
|
|
|
|
def _extract_tagline(self, response: str) -> str: |
|
|
"""Extract tagline from the script response.""" |
|
|
|
|
|
import re |
|
|
cta_match = re.search(r'CTA:.*?([A-Z][^.!?]*[.!?])', response, re.DOTALL) |
|
|
if cta_match: |
|
|
cta_text = cta_match.group(1) |
|
|
|
|
|
sentences = re.split(r'[.!?]+', cta_text) |
|
|
if sentences: |
|
|
tagline = sentences[-1].strip() |
|
|
if len(tagline) > 5: |
|
|
return tagline |
|
|
|
|
|
|
|
|
return random.choice(TAGLINES) |
|
|
|
|
|
def generate_script_with_llm(self, brand: str, structure: str, script_prompt: str, |
|
|
duration: int, voice_style: str, seed: int = 42) -> GeneratedScript: |
|
|
"""Generate script using LLM.""" |
|
|
if not self.llm_available: |
|
|
raise RuntimeError("LLM not available") |
|
|
|
|
|
|
|
|
random.seed(seed) |
|
|
|
|
|
|
|
|
system_prompt = self._create_system_prompt() |
|
|
user_prompt = self._create_user_prompt(brand, structure, script_prompt, duration, voice_style) |
|
|
|
|
|
|
|
|
if "dialo" in self.model_name.lower(): |
|
|
|
|
|
text = f"{user_prompt}\n\nResponse:" |
|
|
else: |
|
|
|
|
|
text = f"System: {system_prompt}\n\nUser: {user_prompt}\n\nAssistant:" |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
|
|
|
|
|
|
|
device = next(self.model.parameters()).device |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=self.model_config.get("max_tokens", 256), |
|
|
temperature=self.model_config.get("temperature", 0.7), |
|
|
top_p=self.model_config.get("top_p", 0.9), |
|
|
do_sample=True, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
num_return_sequences=1 |
|
|
) |
|
|
|
|
|
|
|
|
response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) |
|
|
|
|
|
logger.info(f"Generated script response: {response[:200]}...") |
|
|
|
|
|
|
|
|
segments = self._parse_script_response(response) |
|
|
tagline = self._extract_tagline(response) |
|
|
|
|
|
|
|
|
total_duration = sum(segment.duration_estimate for segment in segments) |
|
|
|
|
|
|
|
|
word_count = sum(len(segment.text.split()) for segment in segments) |
|
|
|
|
|
return GeneratedScript( |
|
|
segments=segments, |
|
|
total_duration=total_duration, |
|
|
tagline=tagline, |
|
|
voice_style=voice_style, |
|
|
word_count=word_count, |
|
|
raw_script=response |
|
|
) |
|
|
|
|
|
def generate_script_with_template(self, brand: str, structure: str, script_prompt: str, |
|
|
duration: int, voice_style: str, seed: int = 42) -> GeneratedScript: |
|
|
"""Generate script using template-based approach (fallback).""" |
|
|
random.seed(seed) |
|
|
|
|
|
|
|
|
structure_template = structure.strip() or random.choice(STRUCTURE_TEMPLATES) |
|
|
|
|
|
|
|
|
segments = [] |
|
|
|
|
|
|
|
|
hook_text = script_prompt or f"Introducing {brand} - the future is here!" |
|
|
segments.append(ScriptSegment( |
|
|
text=hook_text, |
|
|
duration_estimate=2.0, |
|
|
segment_type="hook", |
|
|
timing_marker="[HOOK]" |
|
|
)) |
|
|
|
|
|
|
|
|
flow_text = f"With {structure_template.lower()}, {brand} delivers results like never before." |
|
|
segments.append(ScriptSegment( |
|
|
text=flow_text, |
|
|
duration_estimate=3.0, |
|
|
segment_type="flow", |
|
|
timing_marker="[FLOW]" |
|
|
)) |
|
|
|
|
|
|
|
|
benefit_text = "Faster, simpler, cooler - just like your favorite retro tech." |
|
|
segments.append(ScriptSegment( |
|
|
text=benefit_text, |
|
|
duration_estimate=2.5, |
|
|
segment_type="benefit", |
|
|
timing_marker="[BENEFIT]" |
|
|
)) |
|
|
|
|
|
|
|
|
tagline = random.choice(TAGLINES) |
|
|
cta_text = f"Try {brand} today. {tagline}" |
|
|
segments.append(ScriptSegment( |
|
|
text=cta_text, |
|
|
duration_estimate=2.5, |
|
|
segment_type="cta", |
|
|
timing_marker="[CTA]" |
|
|
)) |
|
|
|
|
|
|
|
|
total_duration = sum(segment.duration_estimate for segment in segments) |
|
|
word_count = sum(len(segment.text.split()) for segment in segments) |
|
|
|
|
|
return GeneratedScript( |
|
|
segments=segments, |
|
|
total_duration=total_duration, |
|
|
tagline=tagline, |
|
|
voice_style=voice_style, |
|
|
word_count=word_count, |
|
|
raw_script=f"Template-based script for {brand}" |
|
|
) |
|
|
|
|
|
def generate_script(self, brand: str, structure: str, script_prompt: str, |
|
|
duration: int, voice_style: str, seed: int = 42) -> GeneratedScript: |
|
|
""" |
|
|
Generate a complete commercial script. |
|
|
""" |
|
|
try: |
|
|
if self.llm_available: |
|
|
return self.generate_script_with_llm(brand, structure, script_prompt, duration, voice_style, seed) |
|
|
else: |
|
|
logger.info("Using template-based script generation (LLM not available)") |
|
|
return self.generate_script_with_template(brand, structure, script_prompt, duration, voice_style, seed) |
|
|
except Exception as e: |
|
|
logger.error(f"Script generation failed: {e}") |
|
|
logger.info("Falling back to template-based generation") |
|
|
return self.generate_script_with_template(brand, structure, script_prompt, duration, voice_style, seed) |
|
|
|
|
|
def suggest_scripts(self, structure: str, n: int = 6, seed: int = 0) -> List[str]: |
|
|
""" |
|
|
Generate multiple script suggestions based on structure. |
|
|
""" |
|
|
try: |
|
|
suggestions = [] |
|
|
for i in range(n): |
|
|
script = self.generate_script( |
|
|
brand="YourBrand", |
|
|
structure=structure, |
|
|
script_prompt="Create an engaging hook", |
|
|
duration=10, |
|
|
voice_style="Announcer '80s", |
|
|
seed=seed + i |
|
|
) |
|
|
|
|
|
|
|
|
if script.segments: |
|
|
hook = script.segments[0].text |
|
|
suggestions.append(hook) |
|
|
else: |
|
|
suggestions.append("Back to '87 - the future is now!") |
|
|
|
|
|
return suggestions |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Script suggestion failed: {e}") |
|
|
|
|
|
return self._fallback_suggestions(structure, n, seed) |
|
|
|
|
|
def _fallback_suggestions(self, structure: str, n: int, seed: int) -> List[str]: |
|
|
"""Fallback to original random script generation.""" |
|
|
random.seed(seed) |
|
|
|
|
|
base = (structure or "").lower().strip() |
|
|
ideas = [] |
|
|
|
|
|
for _ in range(n): |
|
|
style = random.choice(["infomercial", "mall ad", "late-night", "newsflash", "arcade bumper"]) |
|
|
shot = random.choice(["neon grid", "CRT scanlines", "vaporwave sunset", "shopping mall", "boombox close-up"]) |
|
|
hook = random.choice([ |
|
|
"Remember this sound?", "Back to '87.", "Deal of the decade.", |
|
|
"We paused time.", "Be kind, rewind your brand." |
|
|
]) |
|
|
idea = f"{hook} {style} with {shot}." |
|
|
|
|
|
|
|
|
for kw in ["montage", "testimonial", "news", "unboxing", "before", "after", "countdown", "logo", "cta"]: |
|
|
if kw in base and kw not in idea: |
|
|
idea += f" Includes {kw}." |
|
|
|
|
|
ideas.append(idea) |
|
|
|
|
|
return ideas |
|
|
|
|
|
def create_script_generator() -> LLMScriptGenerator: |
|
|
"""Factory function to create a script generator.""" |
|
|
return LLMScriptGenerator() |
|
|
|
|
|
|