# FILE: api/ltx_server_refactored_complete.py # DESCRIPTION: Final backend service for LTX-Video generation. # Features dedicated VAE device logic, robust initialization, and narrative chunking. import gc import io import json import logging import os import random import shutil import subprocess import sys import tempfile import time import traceback import warnings from pathlib import Path from typing import Dict, List, Optional, Tuple import torch import yaml import numpy as np from einops import rearrange from huggingface_hub import hf_hub_download # ============================================================================== # --- INITIAL SETUP & CONFIGURATION --- # ============================================================================== warnings.filterwarnings("ignore") logging.getLogger("huggingface_hub").setLevel(logging.ERROR) logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s') # --- CONSTANTS --- DEPS_DIR = Path("/data") LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video" BASE_CONFIG_PATH = LTX_VIDEO_REPO_DIR / "configs" DEFAULT_CONFIG_FILE = BASE_CONFIG_PATH / "ltxv-13b-0.9.8-distilled-fp8.yaml" LTX_REPO_ID = "Lightricks/LTX-Video" RESULTS_DIR = Path("/app/output") DEFAULT_FPS = 24.0 FRAMES_ALIGNMENT = 8 # --- CRITICAL: DEPENDENCY PATH INJECTION --- def add_deps_to_path(): """Adds the LTX repository directory to the Python system path for imports.""" repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) if repo_path not in sys.path: sys.path.insert(0, repo_path) logging.info(f"LTX-Video repository added to sys.path: {repo_path}") add_deps_to_path() # --- PROJECT IMPORTS --- try: from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, create_latent_upsampler # E outros... from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder from ltx_video.models.transformers.transformer3d import Transformer3DModel from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier from ltx_video.schedulers.rf import RectifiedFlowScheduler from transformers import T5EncoderModel, T5Tokenizer from safetensors import safe_open from api.gpu_manager import gpu_manager from ltx_video.models.autoencoders.vae_encode import (normalize_latents, un_normalize_latents) from ltx_video.pipelines.pipeline_ltx_video import (ConditioningItem, LTXMultiScalePipeline, adain_filter_latent, create_latent_upsampler) from ltx_video.utils.inference_utils import load_image_to_tensor_with_resize_and_crop from managers.vae_manager import vae_manager_singleton from tools.video_encode_tool import video_encode_tool_singleton except ImportError as e: logging.critical(f"A crucial LTX import failed. Check LTX-Video repo integrity. Error: {e}") sys.exit(1) # ============================================================================== # --- UTILITY & HELPER FUNCTIONS --- # ============================================================================== def seed_everything(seed: int): """Sets the seed for reproducibility.""" random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def calculate_padding(orig_h: int, orig_w: int, target_h: int, target_w: int) -> Tuple[int, int, int, int]: """Calculates symmetric padding values.""" pad_h = target_h - orig_h pad_w = target_w - orig_w pad_top = pad_h // 2 pad_bottom = pad_h - pad_top pad_left = pad_w // 2 pad_right = pad_w - pad_left return (pad_left, pad_right, pad_top, pad_bottom) def log_tensor_info(tensor: torch.Tensor, name: str = "Tensor"): """Logs detailed debug information about a PyTorch tensor.""" if not isinstance(tensor, torch.Tensor): logging.debug(f"'{name}' is not a tensor.") return info_str = ( f"--- Tensor: {name} ---\n" f" - Shape: {tuple(tensor.shape)}\n" f" - Dtype: {tensor.dtype}\n" f" - Device: {tensor.device}\n" ) if tensor.numel() > 0: try: info_str += ( f" - Min: {tensor.min().item():.4f} | " f"Max: {tensor.max().item():.4f} | " f"Mean: {tensor.mean().item():.4f}\n" ) except Exception: pass # Fails on some dtypes logging.debug(info_str + "----------------------") # ============================================================================== # --- VIDEO SERVICE CLASS --- # ============================================================================== class VideoService: """Backend service for orchestrating video generation using the LTX-Video pipeline.""" def __init__(self): """Initializes the service with dedicated GPU logic for main pipeline and VAE.""" t0 = time.perf_counter() logging.info("Initializing VideoService...") RESULTS_DIR.mkdir(parents=True, exist_ok=True) target_main_device_str = str(gpu_manager.get_ltx_device()) target_vae_device_str = str(gpu_manager.get_ltx_vae_device()) logging.info(f"LTX allocated to devices: Main='{target_main_device_str}', VAE='{target_vae_device_str}'") self.config = self._load_config() self.pipeline, self.latent_upsampler = self._load_models() self.main_device = torch.device("cpu") self.vae_device = torch.device("cpu") self.move_to_device(main_device_str=target_main_device_str, vae_device_str=target_vae_device_str) self._apply_precision_policy() vae_manager_singleton.attach_pipeline( self.pipeline, device=self.vae_device, autocast_dtype=self.runtime_autocast_dtype ) self._tmp_dirs = set() logging.info(f"VideoService ready. Startup time: {time.perf_counter()-t0:.2f}s") # ========================================================================== # --- LIFECYCLE & MODEL MANAGEMENT --- # ========================================================================== def _load_config(self) -> Dict: """Loads the YAML configuration file.""" config_path = DEFAULT_CONFIG_FILE logging.info(f"Loading config from: {config_path}") with open(config_path, "r") as file: return yaml.safe_load(file) def _load_models(self) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]: """ Carrega todos os sub-modelos do pipeline na CPU. Esta função substitui a necessidade de chamar a `create_ltx_video_pipeline` externa, dando-nos controle total sobre o processo. """ t0 = time.perf_counter() logging.info("Carregando sub-modelos do LTX para a CPU...") ckpt_path = Path(self.config["checkpoint_path"]) if not ckpt_path.is_file(): raise FileNotFoundError(f"Arquivo de checkpoint principal não encontrado em: {ckpt_path}") # 1. Carrega Metadados do Checkpoint with safe_open(ckpt_path, framework="pt") as f: metadata = f.metadata() or {} config_str = metadata.get("config", "{}") configs = json.loads(config_str) allowed_inference_steps = configs.get("allowed_inference_steps") # 2. Carrega os Componentes Individuais (todos na CPU) # O `.from_pretrained(ckpt_path)` é inteligente e carrega os pesos corretos do arquivo .safetensors. logging.info("Carregando VAE...") vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu") logging.info("Carregando Transformer...") transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu") logging.info("Carregando Scheduler...") scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) logging.info("Carregando Text Encoder e Tokenizer...") text_encoder_path = self.config["text_encoder_model_name_or_path"] text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu") tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer") patchifier = SymmetricPatchifier(patch_size=1) # 3. Define a precisão dos modelos (ainda na CPU, será aplicado na GPU depois) precision = self.config.get("precision", "bfloat16") if precision == "bfloat16": vae.to(torch.bfloat16) transformer.to(torch.bfloat16) text_encoder.to(torch.bfloat16) # 4. Monta o objeto do Pipeline com os componentes carregados logging.info("Montando o objeto LTXVideoPipeline...") submodel_dict = { "transformer": transformer, "patchifier": patchifier, "text_encoder": text_encoder, "tokenizer": tokenizer, "scheduler": scheduler, "vae": vae, "allowed_inference_steps": allowed_inference_steps, # Os prompt enhancers são opcionais e não são carregados por padrão para economizar memória "prompt_enhancer_image_caption_model": None, "prompt_enhancer_image_caption_processor": None, "prompt_enhancer_llm_model": None, "prompt_enhancer_llm_tokenizer": None, } pipeline = LTXVideoPipeline(**submodel_dict) # 5. Carrega o Latent Upsampler (também na CPU) latent_upsampler = None if self.config.get("spatial_upscaler_model_path"): logging.info("Carregando Latent Upsampler...") spatial_path = self.config["spatial_upscaler_model_path"] latent_upsampler = create_latent_upsampler(spatial_path, device="cpu") if precision == "bfloat16": latent_upsampler.to(torch.bfloat16) logging.info(f"Modelos LTX carregados na CPU em {time.perf_counter()-t0:.2f}s") return pipeline, latent_upsampler def move_to_device(self, main_device_str: str, vae_device_str: str): """Moves pipeline components to their target devices.""" target_main_device = torch.device(main_device_str) target_vae_device = torch.device(vae_device_str) logging.info(f"Moving LTX models -> Main Pipeline: {target_main_device}, VAE: {target_vae_device}") self.main_device = target_main_device self.pipeline.to(self.main_device) self.vae_device = target_vae_device self.pipeline.vae.to(self.vae_device) if self.latent_upsampler: self.latent_upsampler.to(self.main_device) logging.info("LTX models successfully moved to target devices.") def move_to_cpu(self): """Moves all LTX components to CPU to free VRAM.""" self.move_to_device(main_device_str="cpu", vae_device_str="cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() def finalize(self): """Cleans up GPU memory after a generation task.""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() try: torch.cuda.ipc_collect(); except Exception: pass # ========================================================================== # --- PUBLIC ORCHESTRATORS --- # ========================================================================== def generate_narrative_low(self, prompt: str, **kwargs) -> Tuple[Optional[str], Optional[str], Optional[int]]: """[ORCHESTRATOR] Generates a video from a multi-line prompt (sequence of scenes).""" logging.info("Starting narrative low-res generation...") used_seed = self._resolve_seed(kwargs.get("seed")) seed_everything(used_seed) prompt_list = [p.strip() for p in prompt.splitlines() if p.strip()] if not prompt_list: raise ValueError("Prompt is empty or contains no valid lines.") num_chunks = len(prompt_list) total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0)) frames_per_chunk = (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT overlap_frames = self.config.get("overlap_frames", 8) all_latents_paths = [] overlap_condition_item = None try: for i, chunk_prompt in enumerate(prompt_list): logging.info(f"Generating narrative chunk {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'") current_frames = frames_per_chunk if i > 0: current_frames += overlap_frames current_conditions = kwargs.get("initial_conditions", []) if i == 0 else [] if overlap_condition_item: current_conditions.append(overlap_condition_item) chunk_latents = self._generate_single_chunk_low( prompt=chunk_prompt, num_frames=current_frames, seed=used_seed + i, conditioning_items=current_conditions, **kwargs ) if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for chunk {i+1}.") if i < num_chunks - 1: overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone() overlap_condition_item = ConditioningItem(media_item=overlap_latents, media_frame_number=0, conditioning_strength=1.0) if i > 0: chunk_latents = chunk_latents[:, :, overlap_frames:, :, :] chunk_path = RESULTS_DIR / f"temp_chunk_{i}_{used_seed}.pt" torch.save(chunk_latents.cpu(), chunk_path) all_latents_paths.append(chunk_path) return self._finalize_generation(all_latents_paths, "narrative_video", used_seed) except Exception as e: logging.error(f"Error during narrative generation: {e}", exc_info=True) return None, None, None finally: for path in all_latents_paths: if path.exists(): path.unlink() self.finalize() def generate_single_low(self, **kwargs) -> Tuple[Optional[str], Optional[str], Optional[int]]: """[ORCHESTRATOR] Generates a video from a single prompt in one go.""" logging.info("Starting single-prompt low-res generation...") used_seed = self._resolve_seed(kwargs.get("seed")) seed_everything(used_seed) try: total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0), min_frames=9) final_latents = self._generate_single_chunk_low( num_frames=total_frames, seed=used_seed, conditioning_items=kwargs.get("initial_conditions", []), **kwargs ) if final_latents is None: raise RuntimeError("Failed to generate latents.") latents_path = RESULTS_DIR / f"temp_single_{used_seed}.pt" torch.save(final_latents.cpu(), latents_path) return self._finalize_generation([latents_path], "single_video", used_seed) except Exception as e: logging.error(f"Error during single generation: {e}", exc_info=True) return None, None, None finally: self.finalize() # ========================================================================== # --- INTERNAL WORKER & HELPER METHODS --- # ========================================================================== def _generate_single_chunk_low( self, prompt: str, negative_prompt: str, height: int, width: int, num_frames: int, seed: int, conditioning_items: List[ConditioningItem], ltx_configs_override: Optional[Dict], **kwargs ) -> Optional[torch.Tensor]: """[WORKER] Generates a single chunk of latents. This is the core generation unit.""" height_padded, width_padded = (self._align(d) for d in (height, width)) downscale_factor = self.config.get("downscale_factor", 0.6666666) vae_scale_factor = self.pipeline.vae_scale_factor downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor) downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor) first_pass_config = self.config.get("first_pass", {}).copy() if ltx_configs_override: first_pass_config.update(self._prepare_guidance_overrides(ltx_configs_override)) pipeline_kwargs = { "prompt": prompt, "negative_prompt": negative_prompt, "height": downscaled_height, "width": downscaled_width, "num_frames": num_frames, "frame_rate": DEFAULT_FPS, "generator": torch.Generator(device=self.main_device).manual_seed(seed), "output_type": "latent", "conditioning_items": conditioning_items, **first_pass_config } with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type): latents_raw = self.pipeline(**pipeline_kwargs).images log_tensor_info(latents_raw, f"Raw Latents for '{prompt[:40]}...'") return latents_raw def _finalize_generation(self, latents_paths: List[Path], base_filename: str, seed: int) -> Tuple[str, str, int]: """Loads latents, concatenates, decodes to video, and saves both.""" logging.info("Finalizing generation: decoding latents to video.") all_tensors_cpu = [torch.load(p) for p in latents_paths] final_latents = torch.cat(all_tensors_cpu, dim=2) final_latents_path = RESULTS_DIR / f"latents_{base_filename}_{seed}.pt" torch.save(final_latents, final_latents_path) logging.info(f"Final latents saved to: {final_latents_path}") # The decode method in vae_manager now handles moving the tensor to the correct VAE device. pixel_tensor = vae_manager_singleton.decode( final_latents, decode_timestep=float(self.config.get("decode_timestep", 0.05)) ) video_path = self._save_and_log_video(pixel_tensor, f"{base_filename}_{seed}") return str(video_path), str(final_latents_path), seed def prepare_condition_items(self, items_list: List, height: int, width: int, num_frames: int) -> List[ConditioningItem]: """Prepares a list of ConditioningItem objects from file paths or tensors.""" if not items_list: return [] height_padded, width_padded = self._align(height), self._align(width) padding_values = calculate_padding(height, width, height_padded, width_padded) conditioning_items = [] for media, frame, weight in items_list: tensor = self._prepare_conditioning_tensor(media, height, width, padding_values) safe_frame = max(0, min(int(frame), num_frames - 1)) conditioning_items.append(ConditioningItem(tensor, safe_frame, float(weight))) return conditioning_items def _prepare_conditioning_tensor(self, media_path: str, height: int, width: int, padding: Tuple) -> torch.Tensor: """Loads and processes an image to be a conditioning tensor.""" tensor = load_image_to_tensor_with_resize_and_crop(media_path, height, width) tensor = torch.nn.functional.pad(tensor, padding) # Conditioning tensors are needed on the main device for the transformer pass return tensor.to(self.main_device, dtype=self.runtime_autocast_dtype) def _prepare_guidance_overrides(self, ltx_configs: Dict) -> Dict: """Parses UI presets for guidance into pipeline-compatible arguments.""" overrides = {} preset = ltx_configs.get("guidance_preset", "Padrão (Recomendado)") if preset == "Agressivo": overrides["guidance_scale"] = [1, 2, 8, 12, 8, 2, 1] overrides["stg_scale"] = [0, 0, 5, 6, 5, 3, 2] elif preset == "Suave": overrides["guidance_scale"] = [1, 1, 4, 5, 4, 1, 1] overrides["stg_scale"] = [0, 0, 2, 2, 2, 1, 0] elif preset == "Customizado": try: overrides["guidance_scale"] = json.loads(ltx_configs["guidance_scale_list"]) overrides["stg_scale"] = json.loads(ltx_configs["stg_scale_list"]) except (json.JSONDecodeError, KeyError) as e: logging.warning(f"Failed to parse custom guidance values: {e}. Falling back to defaults.") if overrides: logging.info(f"Applying '{preset}' guidance preset overrides.") return overrides def _save_and_log_video(self, pixel_tensor: torch.Tensor, base_filename: str) -> Path: """Saves a pixel tensor (on CPU) to an MP4 file.""" with tempfile.TemporaryDirectory() as temp_dir: temp_path = os.path.join(temp_dir, f"{base_filename}.mp4") video_encode_tool_singleton.save_video_from_tensor( pixel_tensor, temp_path, fps=DEFAULT_FPS ) final_path = RESULTS_DIR / f"{base_filename}.mp4" shutil.move(temp_path, final_path) logging.info(f"Video saved successfully to: {final_path}") return final_path def _apply_precision_policy(self): """Sets the autocast dtype based on the configuration file.""" precision = str(self.config.get("precision", "bfloat16")).lower() if precision in ["float8_e4m3fn", "bfloat16"]: self.runtime_autocast_dtype = torch.bfloat16 elif precision == "mixed_precision": self.runtime_autocast_dtype = torch.float16 else: self.runtime_autocast_dtype = torch.float32 logging.info(f"Runtime precision policy set for autocast: {self.runtime_autocast_dtype}") def _align(self, dim: int, alignment: int = FRAMES_ALIGNMENT) -> int: """Aligns a dimension to the nearest multiple of `alignment`.""" return ((dim - 1) // alignment + 1) * alignment def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int: """Calculates total frames based on duration, ensuring alignment.""" num_frames = int(round(duration_s * DEFAULT_FPS)) aligned_frames = self._align(num_frames) return max(aligned_frames + 1, min_frames) def _resolve_seed(self, seed: Optional[int]) -> int: """Returns the given seed or generates a new random one.""" return random.randint(0, 2**32 - 1) if seed is None else int(seed) # ============================================================================== # --- SINGLETON INSTANTIATION --- # ============================================================================== try: video_generation_service = VideoService() logging.info("Global VideoService instance created successfully.") except Exception as e: logging.critical(f"Failed to initialize VideoService: {e}", exc_info=True) sys.exit(1)