Spaces:
Paused
Paused
| # FILE: api/ltx/vae_aduc_pipeline.py | |
| # DESCRIPTION: A dedicated, "hot" VAE service specialist. | |
| # It holds the VAE model on a dedicated GPU and provides high-level services | |
| # for encoding images/tensors into conditioning items and decoding latents back to pixels. | |
| import os | |
| import sys | |
| import time | |
| import threading | |
| from pathlib import Path | |
| from typing import List, Union, Tuple, Dict, Optional | |
| import yaml | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from api.ltx.ltx_aduc_manager import LatentConditioningItem | |
| from managers.gpu_manager import gpu_manager | |
| from api.ltx.ltx_aduc_manager import ltx_aduc_manager | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", message=".*") | |
| from huggingface_hub import logging | |
| logging.set_verbosity_error() | |
| logging.set_verbosity_warning() | |
| logging.set_verbosity_info() | |
| logging.set_verbosity_debug() | |
| # ============================================================================== | |
| # --- IMPORTAÇÕES DA ARQUITETURA E DO LTX --- | |
| # ============================================================================== | |
| # Adiciona o path para as bibliotecas do LTX | |
| LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") | |
| if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path: | |
| sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve())) | |
| from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
| from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode | |
| # ============================================================================== | |
| # --- CLASSE DO SERVIÇO VAE --- | |
| # ============================================================================== | |
| class VaeAducPipeline: | |
| _instance = None | |
| _lock = threading.Lock() | |
| def __new__(cls, *args, **kwargs): | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: return | |
| with self._lock: | |
| if self._initialized: return | |
| logging.info("⚙️ Initializing VaeServer Singleton...") | |
| t0 = time.time() | |
| # 1. Obter o dispositivo VAE dedicado do gerenciador central | |
| self.device = gpu_manager.get_ltx_vae_device() | |
| # 2. Obter o modelo VAE já carregado pelo LTXPoolManager | |
| # Isso garante consistência e evita carregar o modelo duas vezes. | |
| try: | |
| from api.ltx.ltx_aduc_manager import ltx_aduc_manager | |
| if ltx_aduc_manager is None or ltx_aduc_manager.get_pipeline() is None: | |
| raise RuntimeError("LTXPoolManager is not initialized yet. VaeServer must be initialized after.") | |
| self.vae = ltx_aduc_manager.get_pipeline().vae | |
| except Exception as e: | |
| logging.critical(f"Failed to get VAE from LTXPoolManager. Error: {e}", exc_info=True) | |
| raise | |
| # 3. Garante que o VAE está no dispositivo correto e em modo de avaliação | |
| self.vae.to(self.device) | |
| self.vae.eval() | |
| self.dtype = self.vae.dtype | |
| self._initialized = True | |
| logging.info(f"✅ VaeServer ready. VAE model is 'hot' on {self.device} with dtype {self.dtype}. Startup time: {time.time() - t0:.2f}s") | |
| def _cleanup_gpu(self): | |
| """Limpa a VRAM da GPU do VAE.""" | |
| if torch.cuda.is_available(): | |
| with torch.cuda.device(self.device): | |
| torch.cuda.empty_cache() | |
| def _preprocess_input(self, item: Union[Image.Image, torch.Tensor], target_resolution: Tuple[int, int]) -> torch.Tensor: | |
| """Prepara uma imagem PIL ou um tensor para o formato de pixel que o VAE espera para encodar.""" | |
| if isinstance(item, Image.Image): | |
| from PIL import ImageOps | |
| img = item.convert("RGB") | |
| processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS) | |
| image_np = np.array(processed_img).astype(np.float32) / 255.0 | |
| tensor = torch.from_numpy(image_np).permute(2, 0, 1) # HWC -> CHW | |
| elif isinstance(item, torch.Tensor): | |
| if item.ndim == 4 and item.shape[0] == 1: tensor = item.squeeze(0) | |
| elif item.ndim == 3: tensor = item | |
| else: raise ValueError(f"Input tensor must have 3 or 4 dimensions (CHW or BCHW), but got {item.ndim}") | |
| else: | |
| raise TypeError(f"Input must be a PIL Image or a torch.Tensor, but got {type(item)}") | |
| # Converte para 5D (B, C, F, H, W) e normaliza para [-1, 1] | |
| tensor_5d = tensor.unsqueeze(0).unsqueeze(2) | |
| return (tensor_5d * 2.0) - 1.0 | |
| def generate_conditioning_items( | |
| self, | |
| media_items: List[Union[Image.Image, torch.Tensor]], | |
| target_frames: List[int], | |
| strengths: List[float], | |
| target_resolution: Tuple[int, int] | |
| ) -> List[LatentConditioningItem]: | |
| """ | |
| [FUNÇÃO PRINCIPAL] Converte uma lista de imagens (PIL ou tensores de pixel) | |
| em uma lista de LatentConditioningItem, pronta para a pipeline LTX corrigida. | |
| """ | |
| t0 = time.time() | |
| logging.info(f"VaeServer: Generating {len(media_items)} latent conditioning items...") | |
| if not (len(media_items) == len(target_frames) == len(strengths)): | |
| raise ValueError("Input lists for conditioning items must have the same length.") | |
| conditioning_items = [] | |
| try: | |
| for item, frame, strength in zip(media_items, target_frames, strengths): | |
| pixel_tensor = self._preprocess_input(item, target_resolution) | |
| pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.dtype) | |
| latents = vae_encode(pixel_tensor_gpu, self.vae, vae_per_channel_normalize=True) | |
| conditioning_items.append(LatentConditioningItem(latents.cpu(), frame, strength)) | |
| logging.info(f"VaeServer: Generated {len(conditioning_items)} items in {time.time() - t0:.2f}s.") | |
| return conditioning_items | |
| finally: | |
| self._cleanup_gpu() | |
| def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: | |
| """Decodifica um tensor latente para um tensor de pixels, retornando na CPU.""" | |
| t0 = time.time() | |
| try: | |
| latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype) | |
| num_items_in_batch = latent_tensor_gpu.shape[0] | |
| timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=self.dtype) | |
| pixels = vae_decode( | |
| latent_tensor_gpu, self.vae, is_video=True, | |
| timestep=timestep_tensor, vae_per_channel_normalize=True | |
| ) | |
| logging.info(f"VaeServer: Decoded latents with shape {latent_tensor.shape} in {time.time() - t0:.2f}s.") | |
| return pixels.cpu() | |
| finally: | |
| self._cleanup_gpu() | |
| vae_aduc_pipeline = VaeAducPipeline() |