# FILE: api/ltx/vae_aduc_pipeline.py # DESCRIPTION: A dedicated, "hot" VAE service specialist. # It loads the VAE model onto a dedicated GPU (managed by GPUManager) # and keeps it in memory to handle all encoding and decoding requests # with minimal latency, using the instance pre-loaded by LTXAducManager. import os import sys import time import logging from pathlib import Path from typing import List, Union, Tuple import torch import numpy as np from PIL import Image # Importa o gerenciador de GPUs e o gerenciador principal do LTX from managers.gpu_manager import gpu_manager from api.ltx.ltx_aduc_manager import LatentConditioningItem, ltx_aduc_manager # --- Importações da Arquitetura e do LTX --- try: # Adiciona o path para as bibliotecas do LTX LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path: sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve())) from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode except ImportError as e: raise ImportError(f"A crucial import failed for VaeLtxAducPipeline. Check dependencies. Error: {e}") class VaeLtxAducPipeline: _instance = None def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return logging.info("⚙️ Initializing VaeLtxAducPipeline Singleton...") t0 = time.time() # 1. Obter o dispositivo VAE dedicado do gerenciador central self.device = gpu_manager.get_ltx_vae_device() # 2. Obter a referência ao modelo VAE já carregado e posicionado pelo LTXAducManager try: # Esta é a etapa crucial: reutilizamos o pipeline já existente. self.vae = ltx_aduc_manager.get_pipeline().vae except Exception as e: logging.critical(f"Failed to get VAE from LTXAducManager. Is it initialized first? Error: {e}", exc_info=True) raise # 3. Confirmação: Garante que o VAE está no dispositivo correto. # O LTXAducManager já deve ter feito isso, mas esta é uma verificação de segurança. if self.vae.device != self.device: logging.warning(f"VAE device mismatch! Expected {self.device} but found {self.vae.device}. Forcing move.") self.vae.to(self.device) self.vae.eval() self.dtype = self.vae.dtype self._initialized = True logging.info(f"✅ VaeLtxAducPipeline ready. VAE model is 'hot' on {self.device} with dtype {self.dtype}. Startup time: {time.time() - t0:.2f}s") def _cleanup_gpu(self): """Limpa a VRAM da GPU do VAE.""" if torch.cuda.is_available(): with torch.cuda.device(self.device): torch.cuda.empty_cache() def _preprocess_input(self, item: Union[Image.Image, torch.Tensor], target_resolution: Tuple[int, int]) -> torch.Tensor: """Prepara uma imagem PIL ou um tensor para o formato de pixel que o VAE espera.""" if isinstance(item, Image.Image): from PIL import ImageOps img = item.convert("RGB") # Redimensiona mantendo a proporção e cortando o excesso processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS) image_np = np.array(processed_img).astype(np.float32) / 255.0 tensor = torch.from_numpy(image_np).permute(2, 0, 1) # HWC -> CHW elif isinstance(item, torch.Tensor): # Se já for um tensor, apenas garante que está no formato CHW if item.ndim == 4 and item.shape[0] == 1: # Remove dimensão de batch se houver tensor = item.squeeze(0) elif item.ndim == 3: tensor = item else: raise ValueError(f"Input tensor must have 3 or 4 dimensions (CHW or BCHW), but got {item.ndim}") else: raise TypeError(f"Input must be a PIL Image or a torch.Tensor, but got {type(item)}") # Converte para 5D (B, C, F, H, W) e normaliza para [-1, 1] tensor_5d = tensor.unsqueeze(0).unsqueeze(2) # Adiciona B=1 e F=1 return (tensor_5d * 2.0) - 1.0 @torch.no_grad() def generate_conditioning_items( self, media_items: List[Union[Image.Image, torch.Tensor]], target_frames: List[int], strengths: List[float], target_resolution: Tuple[int, int] ) -> List[LatentConditioningItem]: """ [FUNÇÃO PRINCIPAL] Converte uma lista de imagens (PIL ou tensores de pixel) em uma lista de LatentConditioningItem, pronta para ser usada pelo pipeline LTX corrigido. """ t0 = time.time() logging.info(f"Generating {len(media_items)} latent conditioning items on device {self.device}...") if not (len(media_items) == len(target_frames) == len(strengths)): raise ValueError("As listas de media_items, target_frames e strengths devem ter o mesmo tamanho.") conditioning_items = [] try: for item, frame, strength in zip(media_items, target_frames, strengths): # 1. Prepara a imagem/tensor para o formato de pixel correto pixel_tensor = self._preprocess_input(item, target_resolution) # 2. Move o tensor de pixel para a GPU do VAE e encoda para latente pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.dtype) latents = vae_encode(pixel_tensor_gpu, self.vae, vae_per_channel_normalize=True) # 3. Cria o LatentConditioningItem com o latente (movido para CPU para evitar manter na VRAM) conditioning_items.append(LatentConditioningItem(latents.cpu(), frame, strength)) logging.info(f"Generated {len(conditioning_items)} items in {time.time() - t0:.2f}s.") return conditioning_items finally: self._cleanup_gpu() @torch.no_grad() def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: """Decodifica um tensor latente para um tensor de pixels, retornando na CPU.""" t0 = time.time() try: latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype) num_items_in_batch = latent_tensor_gpu.shape[0] timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=self.dtype) pixels = vae_decode( latent_tensor_gpu, self.vae, is_video=True, timestep=timestep_tensor, vae_per_channel_normalize=True ) logging.info(f"Decoded latents with shape {latent_tensor.shape} in {time.time() - t0:.2f}s.") return pixels.cpu() # Retorna na CPU para liberar VRAM da GPU do VAE finally: self._cleanup_gpu() # --- Instância Singleton --- # A inicialização ocorre quando o módulo é importado pela primeira vez. try: vae_ltx_aduc_pipeline = VaeLtxAducPipeline() except Exception as e: logging.critical("CRITICAL: Failed to initialize VaeLtxAducPipeline singleton.", exc_info=True) vae_ltx_aduc_pipeline = None