Spaces:
Paused
Paused
| # FILE: api/ltx/ltx_aduc_pipeline.py | |
| # DESCRIPTION: Final high-level orchestrator with robust, intelligent memory and file cleanup. | |
| import warnings | |
| import gc | |
| import json | |
| import os | |
| import shutil | |
| import sys | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import random | |
| import torch | |
| import yaml | |
| import numpy as np | |
| from PIL import Image | |
| from api.ltx.ltx_utils import seed_everything | |
| from utils.debug_utils import log_function_io | |
| from managers.gpu_manager import gpu_manager | |
| from api.ltx.ltx_aduc_manager import ltx_aduc_manager, LatentConditioningItem | |
| from api.ltx.vae_aduc_pipeline import vae_aduc_pipeline | |
| from tools.video_encode_tool import video_encode_tool_singleton | |
| # (O resto das importações e configurações iniciais permanecem as mesmas) | |
| import logging | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", message=".*") | |
| from huggingface_hub import logging as ll | |
| ll.set_verbosity_error() | |
| ll.set_verbosity_warning() | |
| ll.set_verbosity_info() | |
| ll.set_verbosity_debug() | |
| logger = logging.getLogger("AducDebug") | |
| logging.basicConfig(level=logging.DEBUG) | |
| logger.setLevel(logging.DEBUG) | |
| DEPS_DIR = Path("/data") | |
| LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video" | |
| RESULTS_DIR = Path("/app/output") | |
| DEFAULT_FPS = 24.0 | |
| FRAMES_ALIGNMENT = 8 | |
| repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if repo_path not in sys.path: | |
| sys.path.insert(0, repo_path) | |
| from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy | |
| class LtxAducPipeline: | |
| """ | |
| Orchestrates the high-level logic of video generation with robust cleanup. | |
| """ | |
| def __init__(self): | |
| t0 = time.time() | |
| logging.info("Initializing VideoService Orchestrator...") | |
| if ltx_aduc_manager is None or vae_aduc_pipeline is None: | |
| raise RuntimeError("A required manager (LTX or VAE) failed to initialize. Aborting.") | |
| self.pipeline = ltx_aduc_manager.get_pipeline() | |
| self.main_device = self.pipeline.device | |
| self.vae_device = self.pipeline.vae.device | |
| self.config = ltx_aduc_manager.config | |
| # --- NOVO: Inicializa a lista para rastrear arquivos temporários --- | |
| self._temp_files = [] | |
| self._apply_precision_policy() | |
| logging.info(f"VideoService ready. Using Main: {self.main_device}, VAE: {self.vae_device}. Startup time: {time.time() - t0:.2f}s") | |
| def _cleanup(self): | |
| """ | |
| [LIMPEZA INTELIGENTE] Limpa a memória da GPU e remove arquivos temporários. | |
| Esta função é chamada no bloco 'finally' para garantir sua execução. | |
| """ | |
| logging.info("--- Iniciando Limpeza Inteligente (Cleanup) ---") | |
| # 1. Limpar arquivos temporários | |
| logging.info(f"Removendo {len(self._temp_files)} arquivo(s) temporário(s)...") | |
| for f_path in self._temp_files: | |
| try: | |
| if os.path.exists(f_path): | |
| os.remove(f_path) | |
| logging.info(f" - Removido: {f_path}") | |
| except OSError as e: | |
| logging.error(f" - Erro ao remover {f_path}: {e}") | |
| self._temp_files.clear() # Limpa a lista para a próxima execução | |
| # 2. Limpar memória | |
| logging.info("Limpando memória (GC e Cache da GPU)...") | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| with torch.cuda.device(self.main_device): | |
| torch.cuda.empty_cache() | |
| with torch.cuda.device(self.vae_device): | |
| torch.cuda.empty_cache() | |
| try: | |
| torch.cuda.ipc_collect() | |
| logging.info("Cache da GPU e memória IPC limpos.") | |
| except Exception as e: | |
| logging.warning(f"Falha ao limpar memória IPC da GPU: {e}") | |
| logging.info("--- Limpeza Inteligente Concluída ---") | |
| def generate_low_resolution( | |
| self, | |
| prompt_list: List[str], | |
| initial_media_items: Optional[List[Tuple[Union[str, Image.Image, torch.Tensor], int, float]]] = None, | |
| **kwargs | |
| ) -> Tuple[Optional[str], Optional[str], Optional[int]]: | |
| # O bloco try...finally garante que _cleanup() seja sempre chamado. | |
| try: | |
| logging.info("Starting unified low-resolution generation...") | |
| used_seed = self._get_random_seed() | |
| seed_everything(used_seed) | |
| logging.info(f"Using randomly generated seed: {used_seed}") | |
| if not prompt_list: raise ValueError("Prompt list cannot be empty.") | |
| is_narrative = len(prompt_list) > 1 | |
| num_chunks = len(prompt_list) | |
| #total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0)) | |
| total_frames = max(9, int(round((round(kwargs.get("duration", 1.0) * DEFAULT_FPS) - 1) / 8.0) * 8 + 1)) | |
| frames_per_chunk = max(FRAMES_ALIGNMENT, (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT) | |
| overlap_frames = 4 if is_narrative else 0 | |
| initial_conditions = [] | |
| if initial_media_items: | |
| initial_conditions = vae_aduc_pipeline.generate_conditioning_items( | |
| media_items=[item[0] for item in initial_media_items], | |
| target_frames=[item[1] for item in initial_media_items], | |
| strengths=[item[2] for item in initial_media_items], | |
| target_resolution=(kwargs['height'], kwargs['width']) | |
| ) | |
| height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width'])) | |
| downscale_factor = self.config.get("downscale_factor", 0.6666666) | |
| vae_scale_factor = self.pipeline.vae_scale_factor | |
| downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor) | |
| downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor) | |
| stg_mode_str = self.config.get("stg_mode", "attention_values") | |
| stg_strategy = None | |
| if stg_mode_str.lower() in ["stg_av", "attention_values"]: stg_strategy = SkipLayerStrategy.AttentionValues | |
| elif stg_mode_str.lower() in ["stg_as", "attention_skip"]: stg_strategy = SkipLayerStrategy.AttentionSkip | |
| elif stg_mode_str.lower() in ["stg_r", "residual"]: stg_strategy = SkipLayerStrategy.Residual | |
| elif stg_mode_str.lower() in ["stg_t", "transformer_block"]: stg_strategy = SkipLayerStrategy.TransformerBlock | |
| height_padded = ((kwargs['height'] - 1) // 8 + 1) * 8 | |
| width_padded = ((kwargs['width'] - 1) // 8 + 1) * 8 | |
| downscale_factor = self.config.get("downscale_factor", 0.6666666) | |
| vae_scale_factor = self.pipeline.vae_scale_factor | |
| x_width = int(width_padded * downscale_factor) | |
| downscaled_width = x_width - (x_width % vae_scale_factor) | |
| x_height = int(height_padded * downscale_factor) | |
| downscaled_height = x_height - (x_height % vae_scale_factor) | |
| call_kwargs = { | |
| "height": downscaled_height, | |
| "width": downscaled_width, | |
| "skip_initial_inference_steps": 0, "skip_final_inference_steps": 0, "num_inference_steps": 20, | |
| "negative_prompt": kwargs['negative_prompt'], | |
| "guidance_scale": 4, "stg_scale": self.config.get("stg_scale", 4), | |
| "rescaling_scale": self.config.get("rescaling_scale", 0.7), "skip_layer_strategy": stg_strategy, | |
| "skip_block_list": self.config.get("skip_block_list", None), "frame_rate": int(DEFAULT_FPS), | |
| "generator": torch.Generator(device=self.main_device).manual_seed(self._get_random_seed()), | |
| "output_type": "latent", "media_items": None, "decode_timestep": self.config.get("decode_timestep", None), | |
| "decode_noise_scale": self.config.get("decode_noise_scale", None), "stochastic_sampling": self.config.get("stochastic_sampling", None), | |
| "image_cond_noise_scale": 0.15, "is_video": True, "vae_per_channel_normalize": True, | |
| "mixed_precision": (self.config["precision"] == "mixed_precision"), "offload_to_cpu": False, | |
| "enhance_prompt": False, | |
| } | |
| ltx_configs_override = kwargs.get("ltx_configs_override", {}) | |
| if ltx_configs_override: call_kwargs.update(ltx_configs_override) | |
| if initial_conditions: call_kwargs["conditioning_items"] = initial_conditions | |
| # --- ETAPA 1: GERAÇÃO DE CHUNKS E SALVAMENTO --- | |
| for i, chunk_prompt in enumerate(prompt_list): | |
| logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'") | |
| current_frames_base = frames_per_chunk if i < num_chunks - 1 else total_frames - ((num_chunks - 1) * frames_per_chunk) | |
| current_frames = current_frames_base + (overlap_frames if i > 0 else 0) | |
| current_frames = self._align(current_frames, alignment_rule='n*8+1') | |
| call_kwargs["prompt"] = chunk_prompt | |
| call_kwargs["num_frames"] = current_frames | |
| with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type): | |
| chunk_latents = self.pipeline(**call_kwargs).images | |
| if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for scene {i+1}.") | |
| if is_narrative and i < num_chunks - 1: | |
| overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone() | |
| call_kwargs["conditioning_items"] = [LatentConditioningItem(overlap_latents, 0, 1.0)] | |
| else: | |
| call_kwargs.pop("conditioning_items", None) | |
| if i > 0: chunk_latents = chunk_latents[:, :, overlap_frames:, :, :] | |
| chunk_path = RESULTS_DIR / f"temp_chunk_{i}_{used_seed}.pt" | |
| # --- NOVO: Rastreia o arquivo para limpeza --- | |
| self._temp_files.append(chunk_path) | |
| torch.save(chunk_latents.cpu(), chunk_path) | |
| del chunk_latents | |
| # --- ETAPA 2: CONCATENAÇÃO DOS LATENTES (CPU) --- | |
| logging.info(f"Concatenating {len(self._temp_files)} latent chunks on CPU...") | |
| all_tensors_cpu = [torch.load(p) for p in self._temp_files] | |
| final_latents_cpu = torch.cat(all_tensors_cpu, dim=2) | |
| logging.info(f"Concatenating SuperLat {final_latents_cpu.shape}") | |
| # --- ETAPA 3 e 4: FINALIZAÇÃO --- | |
| base_filename = "narrative_video" if is_narrative else "single_video" | |
| video_path, latents_path = self._finalize_generation(final_latents_cpu, base_filename, used_seed) | |
| return video_path, latents_path, used_seed | |
| finally: | |
| # --- NOVO: A chamada de limpeza inteligente sempre ocorre --- | |
| self._cleanup() | |
| # (O resto das funções de _finalize_generation, _save_and_log_video, etc., permanecem as mesmas) | |
| def _finalize_generation(self, final_latents_cpu: torch.Tensor, base_filename: str, seed: int) -> Tuple[str, str]: | |
| final_latents_path = RESULTS_DIR / f"latents_{base_filename}_{seed}.pt" | |
| torch.save(final_latents_cpu, final_latents_path) | |
| logging.info(f"Final latents saved to: {final_latents_path}") | |
| logging.info("Delegating to VaeServer for decoding latents to pixels...") | |
| pixel_tensor_cpu = vae_aduc_pipeline.decode_to_pixels( | |
| final_latents_cpu, decode_timestep=float(self.config.get("decode_timestep", 0.05)) | |
| ) | |
| logging.info("Delegating to VideoEncodeTool to save pixel tensor as MP4...") | |
| video_path = self._save_and_log_video(pixel_tensor_cpu, f"{base_filename}_{seed}") | |
| return str(video_path), str(final_latents_path) | |
| def _save_and_log_video(self, pixel_tensor_cpu: torch.Tensor, base_filename: str) -> Path: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_path = os.path.join(temp_dir, f"{base_filename}.mp4") | |
| video_encode_tool_singleton.save_video_from_tensor(pixel_tensor_cpu, temp_path, fps=DEFAULT_FPS) | |
| final_path = RESULTS_DIR / f"{base_filename}.mp4" | |
| shutil.move(temp_path, final_path) | |
| logging.info(f"Video saved successfully to: {final_path}") | |
| return final_path | |
| def _apply_precision_policy(self): | |
| precision = str(self.config.get("precision", "bfloat16")).lower() | |
| if precision in ["float8_e4m3fn", "bfloat16"]: self.runtime_autocast_dtype = torch.bfloat16 | |
| elif precision == "mixed_precision": self.runtime_autocast_dtype = torch.float16 | |
| else: self.runtime_autocast_dtype = torch.float32 | |
| logging.info(f"Runtime precision policy set for autocast: {self.runtime_autocast_dtype}") | |
| def _align(self, dim: int, alignment: int = FRAMES_ALIGNMENT, alignment_rule: str = 'default') -> int: | |
| if alignment_rule == 'n*8+1': | |
| return ((dim - 1) // alignment) * alignment + 1 | |
| return ((dim - 1) // alignment + 1) * alignment | |
| def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int: | |
| num_frames = int(round(duration_s * DEFAULT_FPS)) | |
| aligned_frames = self._align(num_frames, alignment=FRAMES_ALIGNMENT) | |
| return max(aligned_frames, min_frames) | |
| def _get_random_seed(self) -> int: | |
| return random.randint(0, 2**32 - 1) | |
| ltx_aduc_pipeline = LtxAducPipeline() | |
| logging.info("Global VideoService orchestrator instance created successfully.") |