# FILE: api/ltx/ltx_aduc_pipeline.py # DESCRIPTION: A unified high-level client for submitting ALL LTX-related jobs (generation and VAE) # to the LTXAducManager pool. import logging import time import torch import random from typing import List, Optional, Tuple, Dict from PIL import Image from dataclasses import dataclass from pathlib import Path import sys from api.ltx.ltx_utils import load_image_to_tensor_with_resize_and_crop # Importa o helper de ltx_utils # O cliente importa o MANAGER para submeter todos os trabalhos. from api.ltx.ltx_aduc_manager import ltx_aduc_manager # Adiciona o path do LTX-Video para importações de baixo nível e tipos. LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") def add_deps_to_path(): repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) if repo_path not in sys.path: sys.path.insert(0, repo_path) add_deps_to_path() from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode # ============================================================================== # --- DEFINIÇÕES DE ESTRUTURA --- # ============================================================================== @dataclass class LatentConditioningItem: """Estrutura de dados para passar latentes condicionados ao job de geração.""" latent_tensor: torch.Tensor media_frame_number: int conditioning_strength: float # ============================================================================== # --- FUNÇÕES DE TRABALHO (Jobs a serem executados no Pool LTX) --- # ============================================================================== def _job_encode_media(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, pixel_tensor: torch.Tensor) -> torch.Tensor: """Job que usa o VAE do pipeline para codificar um tensor de pixel.""" vae = pipeline.vae pixel_tensor_gpu = pixel_tensor.to(vae.device, dtype=vae.dtype) latents = vae_encode(pixel_tensor_gpu, vae, vae_per_channel_normalize=True) return latents.cpu() def _job_decode_latent(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, latent_tensor: torch.Tensor) -> torch.Tensor: """Job que usa o VAE do pipeline para decodificar um tensor latente.""" vae = pipeline.vae latent_tensor_gpu = latent_tensor.to(vae.device, dtype=vae.dtype) pixels = vae_decode(latent_tensor_gpu, vae, is_video=True, vae_per_channel_normalize=True) return pixels.cpu() def _job_generate_latent_chunk(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, **kwargs) -> torch.Tensor: """Job que usa o pipeline principal para gerar um chunk de vídeo latente.""" generator = torch.Generator(device=pipeline.device).manual_seed(kwargs['seed']) pipeline_kwargs = {"generator": generator, "output_type": "latent", **kwargs} with torch.autocast(device_type=pipeline.device.type, dtype=autocast_dtype): latents_raw = pipeline(**pipeline_kwargs).images return latents_raw.cpu() # ============================================================================== # --- A CLASSE CLIENTE UNIFICADA --- # ============================================================================== class LtxAducPipeline: """ Cliente unificado para orquestrar todas as tarefas LTX, incluindo geração e VAE. """ def __init__(self): logging.info("✅ Unified LTX/VAE ADUC Pipeline (Client) initialized.") self.FRAMES_ALIGNMENT = 8 def _get_random_seed(self) -> int: return random.randint(0, 2**32 - 1) def _align(self, dim: int, alignment: int = 8) -> int: return ((dim + alignment - 1) // alignment) * alignment # --- Métodos de API para o Orquestrador --- def encode_to_conditioning_items(self, media_list: List, params: List, resolution: Tuple[int, int]) -> List[LatentConditioningItem]: """Converte uma lista de imagens em uma lista de LatentConditioningItem.""" pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, resolution[0], resolution[1]) for m in media_list] items = [] for i, pt in enumerate(pixel_tensors): latent_tensor = ltx_aduc_manager.submit_job(_job_encode_media, pixel_tensor=pt) frame_number, strength = params[i] items.append(LatentConditioningItem( latent_tensor=latent_tensor, media_frame_number=frame_number, conditioning_strength=strength )) return items def decode_to_pixels(self, latent_tensor: torch.Tensor) -> torch.Tensor: """Decodifica um tensor latente em um tensor de pixels.""" return ltx_aduc_manager.submit_job(_job_decode_latent, latent_tensor=latent_tensor) def generate_latents( self, prompt_list: List[str], duration_in_seconds: float, common_ltx_args: Dict, initial_conditioning_items: Optional[List[LatentConditioningItem]] = None ) -> Tuple[Optional[torch.Tensor], Optional[int]]: """Gera um vídeo latente completo a partir de uma lista de prompts.""" t0 = time.time() logging.info(f"LTX Client received a generation job for {len(prompt_list)} scenes.") used_seed = self._get_random_seed() num_chunks = len(prompt_list) total_frames = self._align(int(duration_in_seconds * 24)) frames_per_chunk_base = total_frames // num_chunks if num_chunks > 0 else total_frames overlap_frames = self._align(9) if num_chunks > 1 else 0 final_latents_list = [] overlap_condition_item = None for i, chunk_prompt in enumerate(prompt_list): current_conditions = [] if i == 0 and initial_conditioning_items: current_conditions.extend(initial_conditioning_items) if overlap_condition_item: current_conditions.append(overlap_condition_item) num_frames_for_chunk = frames_per_chunk_base if i == num_chunks - 1: processed_frames = sum(f.shape[2] for f in final_latents_list) num_frames_for_chunk = total_frames - processed_frames num_frames_for_chunk = self._align(num_frames_for_chunk) if num_frames_for_chunk <= 0: continue job_specific_args = { "prompt": chunk_prompt, "num_frames": num_frames_for_chunk, "seed": used_seed + i, "conditioning_items": current_conditions } final_job_args = {**common_ltx_args, **job_specific_args} chunk_latents = ltx_aduc_manager.submit_job(_job_generate_latent_chunk, **final_job_args) if chunk_latents is None: logging.error(f"Failed to generate latents for scene {i+1}. Aborting.") return None, used_seed if i < num_chunks - 1: overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone() overlap_condition_item = LatentConditioningItem( latent_tensor=overlap_latents, media_frame_number=0, conditioning_strength=1.0) final_latents_list.append(chunk_latents[:, :, :-overlap_frames, :, :]) else: final_latents_list.append(chunk_latents) if not final_latents_list: logging.warning("No latent chunks were generated.") return None, used_seed final_latents = torch.cat(final_latents_list, dim=2) logging.info(f"LTX Client job finished in {time.time() - t0:.2f}s. Final latent shape: {final_latents.shape}") return final_latents, used_seed # --- INSTÂNCIA SINGLETON DO CLIENTE --- ltx_aduc_pipeline = LtxAducPipeline()