Test / api /ltx /vae_aduc_pipeline.py
eeuuia's picture
Update api/ltx/vae_aduc_pipeline.py
5969983 verified
raw
history blame
7.27 kB
# FILE: api/ltx/vae_aduc_pipeline.py
# DESCRIPTION: A dedicated, "hot" VAE service specialist.
# It holds the VAE model on a dedicated GPU and provides high-level services
# for encoding images/tensors into conditioning items and decoding latents back to pixels.
import os
import sys
import time
import threading
from pathlib import Path
from typing import List, Union, Tuple, Dict, Optional
import yaml
import torch
import numpy as np
from PIL import Image
from api.ltx.ltx_aduc_manager import LatentConditioningItem
from managers.gpu_manager import gpu_manager
from api.ltx.ltx_aduc_manager import ltx_aduc_manager
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*")
from huggingface_hub import logging
logging.set_verbosity_error()
logging.set_verbosity_warning()
logging.set_verbosity_info()
logging.set_verbosity_debug()
# ==============================================================================
# --- IMPORTAÇÕES DA ARQUITETURA E DO LTX ---
# ==============================================================================
# Adiciona o path para as bibliotecas do LTX
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve()))
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
# ==============================================================================
# --- CLASSE DO SERVIÇO VAE ---
# ==============================================================================
class VaeAducPipeline:
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized: return
with self._lock:
if self._initialized: return
logging.info("⚙️ Initializing VaeServer Singleton...")
t0 = time.time()
# 1. Obter o dispositivo VAE dedicado do gerenciador central
self.device = gpu_manager.get_ltx_vae_device()
# 2. Obter o modelo VAE já carregado pelo LTXPoolManager
# Isso garante consistência e evita carregar o modelo duas vezes.
try:
from api.ltx.ltx_aduc_manager import ltx_aduc_manager
if ltx_aduc_manager is None or ltx_aduc_manager.get_pipeline() is None:
raise RuntimeError("LTXPoolManager is not initialized yet. VaeServer must be initialized after.")
self.vae = ltx_aduc_manager.get_pipeline().vae
except Exception as e:
logging.critical(f"Failed to get VAE from LTXPoolManager. Error: {e}", exc_info=True)
raise
# 3. Garante que o VAE está no dispositivo correto e em modo de avaliação
self.vae.to(self.device)
self.vae.eval()
self.dtype = self.vae.dtype
self._initialized = True
logging.info(f"✅ VaeServer ready. VAE model is 'hot' on {self.device} with dtype {self.dtype}. Startup time: {time.time() - t0:.2f}s")
def _cleanup_gpu(self):
"""Limpa a VRAM da GPU do VAE."""
if torch.cuda.is_available():
with torch.cuda.device(self.device):
torch.cuda.empty_cache()
def _preprocess_input(self, item: Union[Image.Image, torch.Tensor], target_resolution: Tuple[int, int]) -> torch.Tensor:
"""Prepara uma imagem PIL ou um tensor para o formato de pixel que o VAE espera para encodar."""
if isinstance(item, Image.Image):
from PIL import ImageOps
img = item.convert("RGB")
processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
image_np = np.array(processed_img).astype(np.float32) / 255.0
tensor = torch.from_numpy(image_np).permute(2, 0, 1) # HWC -> CHW
elif isinstance(item, torch.Tensor):
if item.ndim == 4 and item.shape[0] == 1: tensor = item.squeeze(0)
elif item.ndim == 3: tensor = item
else: raise ValueError(f"Input tensor must have 3 or 4 dimensions (CHW or BCHW), but got {item.ndim}")
else:
raise TypeError(f"Input must be a PIL Image or a torch.Tensor, but got {type(item)}")
# Converte para 5D (B, C, F, H, W) e normaliza para [-1, 1]
tensor_5d = tensor.unsqueeze(0).unsqueeze(2)
return (tensor_5d * 2.0) - 1.0
@torch.no_grad()
def generate_conditioning_items(
self,
media_items: List[Union[Image.Image, torch.Tensor]],
target_frames: List[int],
strengths: List[float],
target_resolution: Tuple[int, int]
) -> List[LatentConditioningItem]:
"""
[FUNÇÃO PRINCIPAL] Converte uma lista de imagens (PIL ou tensores de pixel)
em uma lista de LatentConditioningItem, pronta para a pipeline LTX corrigida.
"""
t0 = time.time()
logging.info(f"VaeServer: Generating {len(media_items)} latent conditioning items...")
if not (len(media_items) == len(target_frames) == len(strengths)):
raise ValueError("Input lists for conditioning items must have the same length.")
conditioning_items = []
try:
for item, frame, strength in zip(media_items, target_frames, strengths):
pixel_tensor = self._preprocess_input(item, target_resolution)
pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.dtype)
latents = vae_encode(pixel_tensor_gpu, self.vae, vae_per_channel_normalize=True)
conditioning_items.append(LatentConditioningItem(latents.cpu(), frame, strength))
logging.info(f"VaeServer: Generated {len(conditioning_items)} items in {time.time() - t0:.2f}s.")
return conditioning_items
finally:
self._cleanup_gpu()
@torch.no_grad()
def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
"""Decodifica um tensor latente para um tensor de pixels, retornando na CPU."""
t0 = time.time()
try:
latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
num_items_in_batch = latent_tensor_gpu.shape[0]
timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=self.dtype)
pixels = vae_decode(
latent_tensor_gpu, self.vae, is_video=True,
timestep=timestep_tensor, vae_per_channel_normalize=True
)
logging.info(f"VaeServer: Decoded latents with shape {latent_tensor.shape} in {time.time() - t0:.2f}s.")
return pixels.cpu()
finally:
self._cleanup_gpu()
vae_aduc_pipeline = VaeAducPipeline()