# FILE: api/ltx/vae_aduc_pipeline.py # DESCRIPTION: A dedicated, "hot" VAE service specialist. # It holds the VAE model on a dedicated GPU and provides high-level services # for encoding images/tensors into conditioning items and decoding latents back to pixels. import os import sys import time import logging import threading from pathlib import Path from typing import List, Union, Tuple, Dict, Optional import yaml import torch import numpy as np from PIL import Image from api.ltx.ltx_aduc_manager import LatentConditioningItem from managers.gpu_manager import gpu_manager from api.ltx.ltx_aduc_manager import ltx_aduc_manager # ============================================================================== # --- IMPORTAÇÕES DA ARQUITETURA E DO LTX --- # ============================================================================== # Adiciona o path para as bibliotecas do LTX LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path: sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve())) from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode # ============================================================================== # --- CLASSE DO SERVIÇO VAE --- # ============================================================================== class VaeAducPipeline: _instance = None _lock = threading.Lock() def __new__(cls, *args, **kwargs): with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return with self._lock: if self._initialized: return logging.info("⚙️ Initializing VaeServer Singleton...") t0 = time.time() # 1. Obter o dispositivo VAE dedicado do gerenciador central self.device = gpu_manager.get_ltx_vae_device() # 2. Obter o modelo VAE já carregado pelo LTXPoolManager # Isso garante consistência e evita carregar o modelo duas vezes. try: from api.ltx.ltx_aduc_manager import ltx_aduc_manager if ltx_aduc_manager is None or ltx_aduc_manager.get_pipeline() is None: raise RuntimeError("LTXPoolManager is not initialized yet. VaeServer must be initialized after.") self.vae = ltx_aduc_manager.get_pipeline().vae except Exception as e: logging.critical(f"Failed to get VAE from LTXPoolManager. Error: {e}", exc_info=True) raise # 3. Garante que o VAE está no dispositivo correto e em modo de avaliação self.vae.to(self.device) self.vae.eval() self.dtype = self.vae.dtype self._initialized = True logging.info(f"✅ VaeServer ready. VAE model is 'hot' on {self.device} with dtype {self.dtype}. Startup time: {time.time() - t0:.2f}s") def _cleanup_gpu(self): """Limpa a VRAM da GPU do VAE.""" if torch.cuda.is_available(): with torch.cuda.device(self.device): torch.cuda.empty_cache() def _preprocess_input(self, item: Union[Image.Image, torch.Tensor], target_resolution: Tuple[int, int]) -> torch.Tensor: """Prepara uma imagem PIL ou um tensor para o formato de pixel que o VAE espera para encodar.""" if isinstance(item, Image.Image): from PIL import ImageOps img = item.convert("RGB") processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS) image_np = np.array(processed_img).astype(np.float32) / 255.0 tensor = torch.from_numpy(image_np).permute(2, 0, 1) # HWC -> CHW elif isinstance(item, torch.Tensor): if item.ndim == 4 and item.shape[0] == 1: tensor = item.squeeze(0) elif item.ndim == 3: tensor = item else: raise ValueError(f"Input tensor must have 3 or 4 dimensions (CHW or BCHW), but got {item.ndim}") else: raise TypeError(f"Input must be a PIL Image or a torch.Tensor, but got {type(item)}") # Converte para 5D (B, C, F, H, W) e normaliza para [-1, 1] tensor_5d = tensor.unsqueeze(0).unsqueeze(2) return (tensor_5d * 2.0) - 1.0 @torch.no_grad() def generate_conditioning_items( self, media_items: List[Union[Image.Image, torch.Tensor]], target_frames: List[int], strengths: List[float], target_resolution: Tuple[int, int] ) -> List[LatentConditioningItem]: """ [FUNÇÃO PRINCIPAL] Converte uma lista de imagens (PIL ou tensores de pixel) em uma lista de LatentConditioningItem, pronta para a pipeline LTX corrigida. """ t0 = time.time() logging.info(f"VaeServer: Generating {len(media_items)} latent conditioning items...") if not (len(media_items) == len(target_frames) == len(strengths)): raise ValueError("Input lists for conditioning items must have the same length.") conditioning_items = [] try: for item, frame, strength in zip(media_items, target_frames, strengths): pixel_tensor = self._preprocess_input(item, target_resolution) pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.dtype) latents = vae_encode(pixel_tensor_gpu, self.vae, vae_per_channel_normalize=True) conditioning_items.append(LatentConditioningItem(latents.cpu(), frame, strength)) logging.info(f"VaeServer: Generated {len(conditioning_items)} items in {time.time() - t0:.2f}s.") return conditioning_items finally: self._cleanup_gpu() @torch.no_grad() def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: """Decodifica um tensor latente para um tensor de pixels, retornando na CPU.""" t0 = time.time() try: latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype) num_items_in_batch = latent_tensor_gpu.shape[0] timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=self.dtype) pixels = vae_decode( latent_tensor_gpu, self.vae, is_video=True, timestep=timestep_tensor, vae_per_channel_normalize=True ) logging.info(f"VaeServer: Decoded latents with shape {latent_tensor.shape} in {time.time() - t0:.2f}s.") return pixels.cpu() finally: self._cleanup_gpu() vae_aduc_pipeline = VaeAducPipeline()