# FILE: api/ltx/vae_aduc_pipeline.py # DESCRIPTION: A dedicated, "hot" VAE service specialist. # It holds the VAE model on a dedicated GPU and provides high-level services # for encoding images/tensors into conditioning items and decoding latents back to pixels. import os import sys import time import copy import threading from pathlib import Path from typing import List, Union, Tuple, Optional from dataclasses import dataclass import torch import numpy as np from PIL import Image from einops import rearrange import torch.nn.functional as F from managers.gpu_manager import gpu_manager from utils.debug_utils import log_function_io from diffusers.utils.torch_utils import randn_tensor import logging import warnings # --- Configuração de Logging e Warnings --- warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", message=".*") try: from huggingface_hub import logging as hf_logging hf_logging.set_verbosity_error() except ImportError: pass # ============================================================================== # --- IMPORTAÇÕES E DEFINIÇÕES DE TIPO --- # ============================================================================== LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path: sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve())) from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode, latent_to_pixel_coords from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline @dataclass class LatentConditioningItem: latent_tensor: torch.Tensor media_frame_number: int conditioning_strength: float # ============================================================================== # --- CLASSE PRINCIPAL DO SERVIÇO VAE --- # ============================================================================== class VaeAducPipeline: _instance = None _lock = threading.Lock() def __new__(cls, *args, **kwargs): with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if hasattr(self, '_initialized') and self._initialized: return with self._lock: if hasattr(self, '_initialized') and self._initialized: return logging.info("⚙️ Initializing VaeAducPipeline Singleton...") t0 = time.time() self.device = gpu_manager.get_ltx_vae_device() try: from api.ltx.ltx_aduc_manager import ltx_aduc_manager main_pipeline = ltx_aduc_manager.get_pipeline() if main_pipeline is None: raise RuntimeError("LTXPoolManager must be initialized before VaeAducPipeline.") self.vae: CausalVideoAutoencoder = main_pipeline.vae self.patchifier = main_pipeline.patchifier self.transformer = main_pipeline.transformer self.vae_scale_factor = main_pipeline.vae_scale_factor except Exception as e: logging.critical(f"Failed to get components from LTXPoolManager. Error: {e}", exc_info=True) raise self.vae.to(self.device).eval() self.dtype = self.vae.dtype self._initialized = True logging.info(f"✅ VaeAducPipeline ready. Components are 'hot' on {self.device}. Startup time: {time.time() - t0:.2f}s") # --- MÉTODOS PÚBLICOS DE SERVIÇO --- @log_function_io def encode_video(self, video_tensor: torch.Tensor, vae_per_channel_normalize: bool = True) -> torch.Tensor: logging.info(f"VaeAducPipeline: Encoding video with shape {video_tensor.shape}") if not (video_tensor.ndim == 5): raise ValueError(f"Input video tensor must be 5D (B, C, F, H, W), but got shape {video_tensor.shape}") video_tensor_normalized = (video_tensor * 2.0) - 1.0 try: video_gpu = video_tensor_normalized.to(self.device, dtype=self.dtype) with torch.no_grad(): latents = vae_encode(video_gpu, self.vae, vae_per_channel_normalize=vae_per_channel_normalize) logging.info(f"VaeAducPipeline: Successfully encoded video to latents of shape {latents.shape}") return latents.cpu() finally: self._cleanup_gpu() @log_function_io def decode_and_resize_video(self, latent_tensor: torch.Tensor, target_height: int, target_width: int, decode_timestep: float = 0.05) -> torch.Tensor: logging.info(f"VaeAducPipeline: Decoding latents {latent_tensor.shape} and resizing to {target_height}x{target_width}") pixel_video = self.decode_to_pixels(latent_tensor, decode_timestep) num_frames = pixel_video.shape[2] current_height, current_width = pixel_video.shape[3:] if current_height == target_height and current_width == target_width: logging.info("VaeAducPipeline: Resizing skipped, already at target resolution.") return pixel_video videos_flat = rearrange(pixel_video, "b c f h w -> (b f) c h w") videos_resized = F.interpolate(videos_flat, size=(target_height, target_width), mode="bilinear", align_corners=False) final_video = rearrange(videos_resized, "(b f) c h w -> b c f h w", f=num_frames) logging.info(f"VaeAducPipeline: Resized video to final shape {final_video.shape}") return final_video @log_function_io def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: t0 = time.time() try: latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype) num_items = latent_tensor_gpu.shape[0] timestep_tensor = torch.tensor([decode_timestep] * num_items, device=self.device, dtype=self.dtype) with torch.no_grad(): pixels = vae_decode(latent_tensor_gpu, self.vae, is_video=True, timestep=timestep_tensor, vae_per_channel_normalize=True) logging.info(f"VaeAducPipeline: Decoded latents {latent_tensor.shape} in {time.time() - t0:.2f}s.") return pixels.cpu() finally: self._cleanup_gpu() @log_function_io def prepare_conditioning( self, conditioning_items: Optional[List[Union[ConditioningItem, LatentConditioningItem]]], init_latents: torch.Tensor, num_frames: int, height: int, width: int, vae_per_channel_normalize: bool = True, generator: Optional[torch.Generator] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: init_latents = init_latents.to(self.device, dtype=self.dtype) if not conditioning_items: latents_p, coords_l = self.patchifier.patchify(latents=init_latents) coords_p = self._latent_to_pixel_coords(coords_l) return latents_p.cpu(), coords_p.cpu(), None, 0 mask = torch.zeros(init_latents.shape[0], *init_latents.shape[2:], dtype=torch.float32, device=self.device) extra_latents, extra_coords, extra_masks = [], [], [] num_extra_latents = 0 is_latent_mode = isinstance(conditioning_items[0], LatentConditioningItem) with torch.no_grad(): if is_latent_mode: for item in conditioning_items: latents = item.latent_tensor.to(device=self.device, dtype=self.dtype) if item.media_frame_number == 0: f, h, w = latents.shape[-3:] init_latents[..., :f, :h, :w] = torch.lerp(init_latents[..., :f, :h, :w], latents, item.conditioning_strength) mask[..., :f, :h, :w] = item.conditioning_strength else: if latents.shape[2] > 1: init_latents, mask, latents = self._handle_non_first_sequence( init_latents, mask, latents, item.media_frame_number, item.conditioning_strength ) if latents is not None: latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator) extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask) num_extra_latents += num_new else: for item in conditioning_items: item_resized = self._resize_conditioning_item(item, height, width) media_item = item_resized.media_item.to(self.device, dtype=self.dtype) latents = vae_encode(media_item, self.vae, vae_per_channel_normalize=vae_per_channel_normalize) if item.media_frame_number == 0: latents_pos, lx, ly = self._get_latent_spatial_position(latents, item_resized, height, width) f, h, w = latents_pos.shape[-3:] init_latents[..., :f, ly:ly+h, lx:lx+w] = torch.lerp(init_latents[..., :f, ly:ly+h, lx:lx+w], latents_pos, item.conditioning_strength) mask[..., :f, ly:ly+h, lx:lx+w] = item.conditioning_strength else: if media_item.shape[2] > 1: init_latents, mask, latents = self._handle_non_first_sequence( init_latents, mask, latents, item.media_frame_number, item.conditioning_strength ) if latents is not None: latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator) extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask) num_extra_latents += num_new # --- Consolidação final --- latents_p, coords_l = self.patchifier.patchify(latents=init_latents) coords_p = self._latent_to_pixel_coords(coords_l) mask_p, _ = self.patchifier.patchify(latents=mask.unsqueeze(1)) mask_p = mask_p.squeeze(-1) if extra_latents: latents_p = torch.cat([*extra_latents, latents_p], dim=1) coords_p = torch.cat([*extra_coords, coords_p], dim=2) mask_p = torch.cat([*extra_masks, mask_p], dim=1) use_flash = getattr(self.transformer.config, 'use_tpu_flash_attention', False) if use_flash: latents_p = latents_p[:, :-num_extra_latents] coords_p = coords_p[:, :, :-num_extra_latents] mask_p = mask_p[:, :-num_extra_latents] return latents_p.cpu(), coords_p.cpu(), mask_p.cpu(), num_extra_latents # --- MÉTODOS PRIVADOS AUXILIARES --- def _cleanup_gpu(self): if torch.cuda.is_available(): with torch.cuda.device(self.device): torch.cuda.empty_cache() def _latent_to_pixel_coords(self, c): return latent_to_pixel_coords(c, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) @staticmethod def _resize_tensor(m, h, w): if m.shape[-2:] != (h, w): n = m.shape[2] flat = rearrange(m, "b c n h w -> (b n) c h w") resized = F.interpolate(flat, (h, w), mode="bilinear", align_corners=False) return rearrange(resized, "(b n) c h w -> b c n h w", n=n) return m def _resize_conditioning_item(self, i, h, w): n = copy.copy(i); n.media_item = self._resize_tensor(i.media_item, h, w); return n def _get_latent_spatial_position(self, l, i, h, w, strip=True): s, hi, wi = self.vae_scale_factor, i.media_item.shape[-2], i.media_item.shape[-1] xs = (w - wi) // 2 if i.media_x is None else i.media_x ys = (h - hi) // 2 if i.media_y is None else i.media_y if strip: if xs > 0: xs += s; l = l[..., :, 1:] if ys > 0: ys += s; l = l[..., 1:, :] if (xs + wi) < w: l = l[..., :, :-1] if (ys + hi) < h: l = l[..., :-1, :] return l, xs // s, ys // s def _handle_non_first_sequence( self, init_latents: torch.Tensor, mask: torch.Tensor, latents: torch.Tensor, media_frame_number: int, conditioning_strength: float, num_prefix=2, mode="concat" ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: fl, flp = latents.shape[2], num_prefix if fl > flp: start = media_frame_number // 8 + flp end = start + fl - flp init_latents[..., start:end, :, :] = torch.lerp(init_latents[..., start:end, :, :], latents[..., flp:, :, :], conditioning_strength) mask[..., start:end, :, :] = conditioning_strength if mode == "concat": latents = latents[..., :flp, :, :] else: latents = None return init_latents, mask, latents def _process_extra_item(self, l, i, g): n = randn_tensor(l.shape, generator=g, device=self.device, dtype=self.dtype) l = torch.lerp(n, l, i.conditioning_strength) lp, cl = self.patchifier.patchify(l) cp = self._latent_to_pixel_coords(cl); cp[:, 0] += i.media_frame_number nl = lp.shape[1] nm = torch.full(lp.shape[:2], i.conditioning_strength, dtype=torch.float32, device=self.device) return lp, cp, nm, nl # --- Instânciação do Singleton --- vae_aduc_pipeline = VaeAducPipeline()