Spaces:
Paused
Paused
| # FILE: api/ltx_server_refactored_complete.py | |
| # DESCRIPTION: Final orchestrator for LTX-Video generation. | |
| # This version includes the fix for the narrative generation overlap bug and | |
| # consolidates all previous refactoring and debugging improvements. | |
| import gc | |
| import json | |
| import logging | |
| import os | |
| import shutil | |
| import sys | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import random | |
| import torch | |
| import yaml | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| # ============================================================================== | |
| # --- SETUP E IMPORTAÇÕES DO PROJETO --- | |
| # ============================================================================== | |
| # Configuração de logging e supressão de warnings | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("huggingface_hub").setLevel(logging.ERROR) | |
| log_level = os.environ.get("ADUC_LOG_LEVEL", "INFO").upper() | |
| logging.basicConfig(level=log_level, format='[%(levelname)s] [%(name)s] %(message)s') | |
| # --- Constantes de Configuração --- | |
| DEPS_DIR = Path("/data") | |
| LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video" | |
| RESULTS_DIR = Path("/app/output") | |
| DEFAULT_FPS = 24.0 | |
| FRAMES_ALIGNMENT = 8 | |
| LTX_REPO_ID = "Lightricks/LTX-Video" | |
| # Garante que a biblioteca LTX-Video seja importável | |
| def add_deps_to_path(): | |
| repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if repo_path not in sys.path: | |
| sys.path.insert(0, repo_path) | |
| logging.info(f"[ltx_server] LTX-Video repository added to sys.path: {repo_path}") | |
| add_deps_to_path() | |
| # --- Módulos da nossa Arquitetura --- | |
| try: | |
| from api.gpu_manager import gpu_manager | |
| from managers.vae_manager import vae_manager_singleton | |
| from tools.video_encode_tool import video_encode_tool_singleton | |
| from api.ltx.ltx_utils import ( | |
| build_ltx_pipeline_on_cpu, | |
| seed_everything, | |
| load_image_to_tensor_with_resize_and_crop, | |
| ConditioningItem, | |
| ) | |
| from api.utils.debug_utils import log_function_io | |
| except ImportError as e: | |
| logging.critical(f"A crucial import from the local API/architecture failed. Error: {e}", exc_info=True) | |
| sys.exit(1) | |
| # ============================================================================== | |
| # --- FUNÇÕES AUXILIARES DO ORQUESTRADOR --- | |
| # ============================================================================== | |
| def calculate_padding(orig_h: int, orig_w: int, target_h: int, target_w: int) -> Tuple[int, int, int, int]: | |
| """Calculates symmetric padding required to meet target dimensions.""" | |
| pad_h = target_h - orig_h | |
| pad_w = target_w - orig_w | |
| pad_top = pad_h // 2 | |
| pad_bottom = pad_h - pad_top | |
| pad_left = pad_w // 2 | |
| pad_right = pad_w - pad_left | |
| return (pad_left, pad_right, pad_top, pad_bottom) | |
| # ============================================================================== | |
| # --- CLASSE DE SERVIÇO (O ORQUESTRADOR) --- | |
| # ============================================================================== | |
| class VideoService: | |
| """ | |
| Orchestrates the high-level logic of video generation, delegating low-level | |
| tasks to specialized managers and utility modules. | |
| """ | |
| def __init__(self): | |
| t0 = time.perf_counter() | |
| logging.info("Initializing VideoService Orchestrator...") | |
| RESULTS_DIR.mkdir(parents=True, exist_ok=True) | |
| target_main_device_str = str(gpu_manager.get_ltx_device()) | |
| target_vae_device_str = str(gpu_manager.get_ltx_vae_device()) | |
| logging.info(f"LTX allocated to devices: Main='{target_main_device_str}', VAE='{target_vae_device_str}'") | |
| self.config = self._load_config() | |
| self._resolve_model_paths_from_cache() | |
| self.pipeline, self.latent_upsampler = build_ltx_pipeline_on_cpu(self.config) | |
| self.main_device = torch.device("cpu") | |
| self.vae_device = torch.device("cpu") | |
| self.move_to_device(main_device_str=target_main_device_str, vae_device_str=target_vae_device_str) | |
| self._apply_precision_policy() | |
| vae_manager_singleton.attach_pipeline(self.pipeline, device=self.vae_device, autocast_dtype=self.runtime_autocast_dtype) | |
| logging.info(f"VideoService ready. Startup time: {time.perf_counter()-t0:.2f}s") | |
| def _load_config(self) -> Dict: | |
| """Loads the YAML configuration file.""" | |
| config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml" | |
| logging.info(f"Loading config from: {config_path}") | |
| with open(config_path, "r") as file: | |
| return yaml.safe_load(file) | |
| def _resolve_model_paths_from_cache(self): | |
| """Finds the absolute paths to model files in the cache and updates the in-memory config.""" | |
| logging.info("Resolving model paths from Hugging Face cache...") | |
| cache_dir = os.environ.get("HF_HOME") | |
| try: | |
| main_ckpt_path = hf_hub_download(repo_id=LTX_REPO_ID, filename=self.config["checkpoint_path"], cache_dir=cache_dir) | |
| self.config["checkpoint_path"] = main_ckpt_path | |
| logging.info(f" -> Main checkpoint resolved to: {main_ckpt_path}") | |
| if self.config.get("spatial_upscaler_model_path"): | |
| upscaler_path = hf_hub_download(repo_id=LTX_REPO_ID, filename=self.config["spatial_upscaler_model_path"], cache_dir=cache_dir) | |
| self.config["spatial_upscaler_model_path"] = upscaler_path | |
| logging.info(f" -> Spatial upscaler resolved to: {upscaler_path}") | |
| except Exception as e: | |
| logging.critical(f"Failed to resolve model paths. Ensure setup.py ran correctly. Error: {e}", exc_info=True) | |
| sys.exit(1) | |
| def move_to_device(self, main_device_str: str, vae_device_str: str): | |
| """Moves pipeline components to their designated target devices.""" | |
| target_main_device = torch.device(main_device_str) | |
| target_vae_device = torch.device(vae_device_str) | |
| logging.info(f"Moving LTX models -> Main Pipeline: {target_main_device}, VAE: {target_vae_device}") | |
| self.main_device = target_main_device | |
| self.pipeline.to(self.main_device) | |
| self.vae_device = target_vae_device | |
| self.pipeline.vae.to(self.vae_device) | |
| if self.latent_upsampler: self.latent_upsampler.to(self.main_device) | |
| logging.info("LTX models successfully moved to target devices.") | |
| def move_to_cpu(self): | |
| """Moves all LTX components to CPU to free VRAM for other services.""" | |
| self.move_to_device(main_device_str="cpu", vae_device_str="cpu") | |
| if torch.cuda.is_available(): torch.cuda.empty_cache() | |
| def finalize(self): | |
| """Cleans up GPU memory after a generation task.""" | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| try: torch.cuda.ipc_collect(); | |
| except Exception: pass | |
| # ========================================================================== | |
| # --- LÓGICA DE NEGÓCIO: ORQUESTRADOR PÚBLICO UNIFICADO --- | |
| # ========================================================================== | |
| def generate_low_resolution(self, prompt: str, **kwargs) -> Tuple[Optional[str], Optional[str], Optional[int]]: | |
| """ | |
| [UNIFIED ORCHESTRATOR] Generates a low-resolution video from a prompt. | |
| Handles both single-line and multi-line prompts transparently. | |
| """ | |
| logging.info("Starting unified low-resolution generation (random seed)...") | |
| used_seed = self._get_random_seed() | |
| seed_everything(used_seed) | |
| logging.info(f"Using randomly generated seed: {used_seed}") | |
| prompt_list = [p.strip() for p in prompt.splitlines() if p.strip()] | |
| if not prompt_list: raise ValueError("Prompt is empty or contains no valid lines.") | |
| is_narrative = len(prompt_list) > 1 | |
| logging.info(f"Generation mode detected: {'Narrative' if is_narrative else 'Simple'} ({len(prompt_list)} scene(s)).") | |
| num_chunks = len(prompt_list) | |
| total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0)) | |
| frames_per_chunk = max(FRAMES_ALIGNMENT, (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT) | |
| # Overlap must be N*8+1 frames. 9 is the smallest practical value. | |
| overlap_frames = 9 if is_narrative else 0 | |
| if is_narrative: | |
| logging.info(f"Narrative mode: Using overlap of {overlap_frames} frames between chunks.") | |
| temp_latent_paths = [] | |
| overlap_condition_item = None | |
| try: | |
| for i, chunk_prompt in enumerate(prompt_list): | |
| logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'") | |
| if i < num_chunks - 1: | |
| current_frames_base = frames_per_chunk | |
| else: # Last chunk takes all remaining frames | |
| processed_frames_base = (num_chunks - 1) * frames_per_chunk | |
| current_frames_base = total_frames - processed_frames_base | |
| current_frames = current_frames_base + (overlap_frames if i > 0 else 0) | |
| # Ensure final frame count for generation is N*8+1 | |
| current_frames = self._align(current_frames, alignment_rule='n*8+1') | |
| current_conditions = kwargs.get("initial_conditions", []) if i == 0 else [] | |
| if overlap_condition_item: | |
| current_conditions.append(overlap_condition_item) | |
| chunk_latents = self._generate_single_chunk_low( | |
| prompt=chunk_prompt, num_frames=current_frames, seed=used_seed + i, | |
| conditioning_items=current_conditions, **kwargs | |
| ) | |
| if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for scene {i+1}.") | |
| if is_narrative and i < num_chunks - 1: | |
| # 1. Criar tensor overlap latente | |
| overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone() | |
| logging.info(f"Criado overlap latente com shape: {list(overlap_latents.shape)}") | |
| # 2. DECODIFICA o latente de volta para um tensor de PIXEL | |
| logging.info("Decodificando latente de overlap para tensor de pixel...") | |
| overlap_pixel_tensor = vae_manager_singleton.decode( | |
| overlap_latents, | |
| decode_timestep=float(self.config.get("decode_timestep", 0.05)) | |
| ) | |
| # O resultado de decode() está na CPU, no formato (B, C, F, H, W) e [0, 1] | |
| # Precisamos normalizá-lo para [-1, 1] que é o que o pipeline espera. | |
| overlap_pixel_tensor_normalized = (overlap_pixel_tensor * 2.0) - 1.0 | |
| logging.info(f"Tensor de pixel de overlap criado com shape: {list(overlap_pixel_tensor_normalized.shape)}") | |
| # 3. Cria o ConditioningItem com o TENSOR DE PIXEL, não com o latente. | |
| overlap_condition_item = ConditioningItem( | |
| media_item=overlap_pixel_tensor_normalized, | |
| media_frame_number=0,conditioning_strength=1.0 | |
| ) | |
| if i > 0: | |
| chunk_latents = chunk_latents[:, :, overlap_frames:, :, :] | |
| chunk_path = RESULTS_DIR / f"temp_chunk_{i}_{used_seed}.pt" | |
| torch.save(chunk_latents.cpu(), chunk_path) | |
| temp_latent_paths.append(chunk_path) | |
| base_filename = "narrative_video" if is_narrative else "single_video" | |
| return self._finalize_generation(temp_latent_paths, base_filename, used_seed) | |
| except Exception as e: | |
| logging.error(f"Error during unified generation: {e}", exc_info=True) | |
| return None, None, None | |
| finally: | |
| for path in temp_latent_paths: | |
| if path.exists(): path.unlink() | |
| self.finalize() | |
| # ========================================================================== | |
| # --- UNIDADES DE TRABALHO E HELPERS INTERNOS --- | |
| # ========================================================================== | |
| # --- NOVA FUNÇÃO DE LOG DEDICADA --- | |
| def _log_conditioning_items(self, items: List[ConditioningItem]): | |
| """ | |
| Logs detailed information about a list of ConditioningItem objects. | |
| This is a dedicated debug helper function. | |
| """ | |
| # Só imprime o log se o nível de logging for DEBUG | |
| if logging.getLogger().isEnabledFor(logging.INFO): | |
| log_str = ["\n" + "="*25 + " INFO: Conditioning Items " + "="*25] | |
| if not items: | |
| log_str.append(" -> Lista de conditioning_items está vazia.") | |
| else: | |
| for i, item in enumerate(items): | |
| if hasattr(item, 'media_item') and isinstance(item.media_item, torch.Tensor): | |
| t = item.media_item | |
| log_str.append( | |
| f" -> Item [{i}]: " | |
| f"Tensor(shape={list(t.shape)}, " | |
| f"device='{t.device}', " | |
| f"dtype={t.dtype}), " | |
| f"Target Frame = {item.media_frame_number}, " | |
| f"Strength = {item.conditioning_strength:.2f}" | |
| ) | |
| else: | |
| log_str.append(f" -> Item [{i}]: Não contém um tensor válido.") | |
| log_str.append("="*75 + "\n") | |
| # Usa o logger de debug para imprimir a mensagem completa | |
| logging.info("\n".join(log_str)) | |
| def _generate_single_chunk_low(self, **kwargs) -> Optional[torch.Tensor]: | |
| """[WORKER] Calls the pipeline to generate a single chunk of latents.""" | |
| height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width'])) | |
| downscale_factor = self.config.get("downscale_factor", 0.6666666) | |
| vae_scale_factor = self.pipeline.vae_scale_factor | |
| downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor) | |
| downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor) | |
| # 1. Começa com a configuração padrão | |
| first_pass_config = self.config.get("first_pass", {}).copy() | |
| # 2. Aplica os overrides da UI, se existirem | |
| if kwargs.get("ltx_configs_override"): | |
| self._apply_ui_overrides(first_pass_config, kwargs.get("ltx_configs_override")) | |
| # 3. Monta o dicionário de argumentos SEM conditioning_items primeiro | |
| pipeline_kwargs = { | |
| "prompt": kwargs['prompt'], | |
| "negative_prompt": kwargs['negative_prompt'], | |
| "height": downscaled_height, | |
| "width": downscaled_width, | |
| "num_frames": kwargs['num_frames'], | |
| "frame_rate": int(DEFAULT_FPS), | |
| "generator": torch.Generator(device=self.main_device).manual_seed(kwargs['seed']), | |
| "output_type": "latent", | |
| #"conditioning_items": conditioning_items if conditioning_items else None, | |
| "media_items": None, | |
| "decode_timestep": self.config["decode_timestep"], | |
| "decode_noise_scale": self.config["decode_noise_scale"], | |
| "stochastic_sampling": self.config["stochastic_sampling"], | |
| "image_cond_noise_scale": 0.01, | |
| "is_video": True, | |
| "vae_per_channel_normalize": True, | |
| "mixed_precision": (self.config["precision"] == "mixed_precision"), | |
| "offload_to_cpu": False, | |
| "enhance_prompt": False, | |
| #"skip_layer_strategy": SkipLayerStrategy.AttentionValues, | |
| **first_pass_config | |
| } | |
| # --- Bloco de Logging para Depuração --- | |
| # 4. Loga os argumentos do pipeline (sem os tensores de condição) | |
| logging.info(f"\n[Info] Pipeline Arguments (BASE):\n {json.dumps(pipeline_kwargs, indent=2, default=str)}\n") | |
| # Loga os conditioning_items separadamente com a nossa função helper | |
| conditioning_items_list = kwargs.get('conditioning_items') | |
| self._log_conditioning_items(conditioning_items_list) | |
| # --- Fim do Bloco de Logging --- | |
| # 5. Adiciona os conditioning_items ao dicionário | |
| pipeline_kwargs['conditioning_items'] = conditioning_items_list | |
| # 6. Executa o pipeline com o dicionário completo | |
| with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type): | |
| latents_raw = self.pipeline(**pipeline_kwargs).images | |
| return latents_raw.to(self.main_device) | |
| def _finalize_generation(self, temp_latent_paths: List[Path], base_filename: str, seed: int) -> Tuple[str, str, int]: | |
| """Consolidates latents, decodes them to video, and saves final artifacts.""" | |
| logging.info("Finalizing generation: decoding latents to video.") | |
| all_tensors_cpu = [torch.load(p) for p in temp_latent_paths] | |
| final_latents = torch.cat(all_tensors_cpu, dim=2) | |
| final_latents_path = RESULTS_DIR / f"latents_{base_filename}_{seed}.pt" | |
| torch.save(final_latents, final_latents_path) | |
| logging.info(f"Final latents saved to: {final_latents_path}") | |
| pixel_tensor = vae_manager_singleton.decode( | |
| final_latents, decode_timestep=float(self.config.get("decode_timestep", 0.05)) | |
| ) | |
| video_path = self._save_and_log_video(pixel_tensor, f"{base_filename}_{seed}") | |
| return str(video_path), str(final_latents_path), seed | |
| def prepare_condition_items(self, items_list: List, height: int, width: int, num_frames: int) -> List[ConditioningItem]: | |
| """ | |
| [CORRIGIDO] Prepara ConditioningItems, garantindo que o tensor final | |
| resida no dispositivo principal do pipeline (main_device). | |
| """ | |
| if not items_list: return [] | |
| height_padded, width_padded = self._align(height), self._align(width) | |
| padding_values = calculate_padding(height, width, height_padded, width_padded) | |
| conditioning_items = [] | |
| for media_item, frame, weight in items_list: | |
| final_tensor = None | |
| if isinstance(media_item, str): | |
| # 1. Carrega a imagem. A função pode usar o VAE, então ela pode | |
| # retornar um tensor em qualquer dispositivo. | |
| tensor = load_image_to_tensor_with_resize_and_crop(media_item, height, width) | |
| # 2. Aplica padding. | |
| tensor = torch.nn.functional.pad(tensor, padding_values) | |
| # 3. GARANTE que o tensor final esteja no dispositivo principal. | |
| final_tensor = tensor.to(self.main_device, dtype=self.runtime_autocast_dtype) | |
| elif isinstance(media_item, torch.Tensor): | |
| # Se já for um tensor (ex: overlap), apenas garante que ele está no dispositivo principal. | |
| final_tensor = media_item.to(self.main_device, dtype=self.runtime_autocast_dtype) | |
| else: | |
| logging.warning(f"Unknown conditioning media type: {type(media_item)}. Skipping.") | |
| continue | |
| safe_frame = max(0, min(int(frame), num_frames - 1)) | |
| conditioning_items.append(ConditioningItem(final_tensor, safe_frame, float(weight))) | |
| self._log_conditioning_items(conditioning_items) | |
| return conditioning_items | |
| def _apply_ui_overrides(self, config_dict: Dict, overrides: Dict): | |
| """Applies advanced settings from the UI to a config dictionary.""" | |
| # Override step counts | |
| for key in ["num_inference_steps", "skip_initial_inference_steps", "skip_final_inference_steps"]: | |
| ui_value = overrides.get(key) | |
| if ui_value and ui_value > 0: | |
| config_dict[key] = ui_value | |
| logging.info(f"Override: '{key}' set to {ui_value} by UI.") | |
| def _save_and_log_video(self, pixel_tensor: torch.Tensor, base_filename: str) -> Path: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_path = os.path.join(temp_dir, f"{base_filename}.mp4") | |
| video_encode_tool_singleton.save_video_from_tensor(pixel_tensor, temp_path, fps=DEFAULT_FPS) | |
| final_path = RESULTS_DIR / f"{base_filename}.mp4" | |
| shutil.move(temp_path, final_path) | |
| logging.info(f"Video saved successfully to: {final_path}") | |
| return final_path | |
| def _apply_precision_policy(self): | |
| precision = str(self.config.get("precision", "bfloat16")).lower() | |
| if precision in ["float8_e4m3fn", "bfloat16"]: self.runtime_autocast_dtype = torch.bfloat16 | |
| elif precision == "mixed_precision": self.runtime_autocast_dtype = torch.float16 | |
| else: self.runtime_autocast_dtype = torch.float32 | |
| logging.info(f"Runtime precision policy set for autocast: {self.runtime_autocast_dtype}") | |
| def _align(self, dim: int, alignment: int = FRAMES_ALIGNMENT, alignment_rule: str = 'default') -> int: | |
| """Aligns a dimension to the nearest multiple of `alignment`.""" | |
| if alignment_rule == 'n*8+1': | |
| return ((dim - 1) // alignment) * alignment + 1 | |
| return ((dim - 1) // alignment + 1) * alignment | |
| def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int: | |
| num_frames = int(round(duration_s * DEFAULT_FPS)) | |
| # Para a duração total, sempre arredondamos para cima para o múltiplo de 8 mais próximo | |
| aligned_frames = self._align(num_frames, alignment=FRAMES_ALIGNMENT) | |
| return max(aligned_frames, min_frames) | |
| def _get_random_seed(self) -> int: | |
| """Always generates and returns a new random seed.""" | |
| return random.randint(0, 2**32 - 1) | |
| # ============================================================================== | |
| # --- INSTANCIAÇÃO SINGLETON --- | |
| # ============================================================================== | |
| try: | |
| video_generation_service = VideoService() | |
| logging.info("Global VideoService orchestrator instance created successfully.") | |
| except Exception as e: | |
| logging.critical(f"Failed to initialize VideoService: {e}", exc_info=True) | |
| sys.exit(1) |