Spaces:
Paused
Paused
| # FILE: api/ltx/vae_aduc_pipeline.py | |
| # DESCRIPTION: A dedicated, "hot" VAE service specialist. | |
| # It holds the VAE model on a dedicated GPU and provides high-level services | |
| # for encoding images/tensors into conditioning items and decoding latents back to pixels. | |
| import os | |
| import sys | |
| import time | |
| import copy | |
| import threading | |
| from pathlib import Path | |
| from typing import List, Union, Tuple, Optional | |
| from dataclasses import dataclass | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from einops import rearrange | |
| import torch.nn.functional as F | |
| from managers.gpu_manager import gpu_manager | |
| from utils.debug_utils import log_function_io | |
| from diffusers.utils.torch_utils import randn_tensor | |
| import logging | |
| import warnings | |
| # --- Configuração de Logging e Warnings --- | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", message=".*") | |
| try: | |
| from huggingface_hub import logging as hf_logging | |
| hf_logging.set_verbosity_error() | |
| except ImportError: | |
| pass | |
| # ============================================================================== | |
| # --- IMPORTAÇÕES E DEFINIÇÕES DE TIPO --- | |
| # ============================================================================== | |
| LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") | |
| if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path: | |
| sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve())) | |
| from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode, latent_to_pixel_coords | |
| from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
| from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline | |
| class LatentConditioningItem: | |
| latent_tensor: torch.Tensor | |
| media_frame_number: int | |
| conditioning_strength: float | |
| # ============================================================================== | |
| # --- CLASSE PRINCIPAL DO SERVIÇO VAE --- | |
| # ============================================================================== | |
| class VaeAducPipeline: | |
| _instance = None | |
| _lock = threading.Lock() | |
| def __new__(cls, *args, **kwargs): | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if hasattr(self, '_initialized') and self._initialized: return | |
| with self._lock: | |
| if hasattr(self, '_initialized') and self._initialized: return | |
| logging.info("⚙️ Initializing VaeAducPipeline Singleton...") | |
| t0 = time.time() | |
| self.device = gpu_manager.get_ltx_vae_device() | |
| try: | |
| from api.ltx.ltx_aduc_manager import ltx_aduc_manager | |
| main_pipeline = ltx_aduc_manager.get_pipeline() | |
| if main_pipeline is None: | |
| raise RuntimeError("LTXPoolManager must be initialized before VaeAducPipeline.") | |
| self.vae: CausalVideoAutoencoder = main_pipeline.vae | |
| self.patchifier = main_pipeline.patchifier | |
| self.transformer = main_pipeline.transformer | |
| self.vae_scale_factor = main_pipeline.vae_scale_factor | |
| except Exception as e: | |
| logging.critical(f"Failed to get components from LTXPoolManager. Error: {e}", exc_info=True) | |
| raise | |
| self.vae.to(self.device).eval() | |
| self.dtype = self.vae.dtype | |
| self._initialized = True | |
| logging.info(f"✅ VaeAducPipeline ready. Components are 'hot' on {self.device}. Startup time: {time.time() - t0:.2f}s") | |
| # --- MÉTODOS PÚBLICOS DE SERVIÇO --- | |
| def encode_video(self, video_tensor: torch.Tensor, vae_per_channel_normalize: bool = True) -> torch.Tensor: | |
| logging.info(f"VaeAducPipeline: Encoding video with shape {video_tensor.shape}") | |
| if not (video_tensor.ndim == 5): | |
| raise ValueError(f"Input video tensor must be 5D (B, C, F, H, W), but got shape {video_tensor.shape}") | |
| video_tensor_normalized = (video_tensor * 2.0) - 1.0 | |
| try: | |
| video_gpu = video_tensor_normalized.to(self.device, dtype=self.dtype) | |
| with torch.no_grad(): | |
| latents = vae_encode(video_gpu, self.vae, vae_per_channel_normalize=vae_per_channel_normalize) | |
| logging.info(f"VaeAducPipeline: Successfully encoded video to latents of shape {latents.shape}") | |
| return latents.cpu() | |
| finally: | |
| self._cleanup_gpu() | |
| def decode_and_resize_video(self, latent_tensor: torch.Tensor, target_height: int, target_width: int, decode_timestep: float = 0.05) -> torch.Tensor: | |
| logging.info(f"VaeAducPipeline: Decoding latents {latent_tensor.shape} and resizing to {target_height}x{target_width}") | |
| pixel_video = self.decode_to_pixels(latent_tensor, decode_timestep) | |
| num_frames = pixel_video.shape[2] | |
| current_height, current_width = pixel_video.shape[3:] | |
| if current_height == target_height and current_width == target_width: | |
| logging.info("VaeAducPipeline: Resizing skipped, already at target resolution.") | |
| return pixel_video | |
| videos_flat = rearrange(pixel_video, "b c f h w -> (b f) c h w") | |
| videos_resized = F.interpolate(videos_flat, size=(target_height, target_width), mode="bilinear", align_corners=False) | |
| final_video = rearrange(videos_resized, "(b f) c h w -> b c f h w", f=num_frames) | |
| logging.info(f"VaeAducPipeline: Resized video to final shape {final_video.shape}") | |
| return final_video | |
| def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: | |
| t0 = time.time() | |
| try: | |
| latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype) | |
| num_items = latent_tensor_gpu.shape[0] | |
| timestep_tensor = torch.tensor([decode_timestep] * num_items, device=self.device, dtype=self.dtype) | |
| with torch.no_grad(): | |
| pixels = vae_decode(latent_tensor_gpu, self.vae, is_video=True, timestep=timestep_tensor, vae_per_channel_normalize=True) | |
| logging.info(f"VaeAducPipeline: Decoded latents {latent_tensor.shape} in {time.time() - t0:.2f}s.") | |
| return pixels.cpu() | |
| finally: | |
| self._cleanup_gpu() | |
| def prepare_conditioning( | |
| self, | |
| conditioning_items: Optional[List[Union[ConditioningItem, LatentConditioningItem]]], | |
| init_latents: torch.Tensor, | |
| num_frames: int, | |
| height: int, | |
| width: int, | |
| vae_per_channel_normalize: bool = True, | |
| generator: Optional[torch.Generator] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: | |
| init_latents = init_latents.to(self.device, dtype=self.dtype) | |
| if not conditioning_items: | |
| latents_p, coords_l = self.patchifier.patchify(latents=init_latents) | |
| coords_p = self._latent_to_pixel_coords(coords_l) | |
| return latents_p.cpu(), coords_p.cpu(), None, 0 | |
| mask = torch.zeros(init_latents.shape[0], *init_latents.shape[2:], dtype=torch.float32, device=self.device) | |
| extra_latents, extra_coords, extra_masks = [], [], [] | |
| num_extra_latents = 0 | |
| is_latent_mode = isinstance(conditioning_items[0], LatentConditioningItem) | |
| with torch.no_grad(): | |
| if is_latent_mode: | |
| for item in conditioning_items: | |
| latents = item.latent_tensor.to(device=self.device, dtype=self.dtype) | |
| if item.media_frame_number == 0: | |
| f, h, w = latents.shape[-3:] | |
| init_latents[..., :f, :h, :w] = torch.lerp(init_latents[..., :f, :h, :w], latents, item.conditioning_strength) | |
| mask[..., :f, :h, :w] = item.conditioning_strength | |
| else: | |
| if latents.shape[2] > 1: | |
| init_latents, mask, latents = self._handle_non_first_sequence( | |
| init_latents, mask, latents, item.media_frame_number, item.conditioning_strength | |
| ) | |
| if latents is not None: | |
| latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator) | |
| extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask) | |
| num_extra_latents += num_new | |
| else: | |
| for item in conditioning_items: | |
| item_resized = self._resize_conditioning_item(item, height, width) | |
| media_item = item_resized.media_item.to(self.device, dtype=self.dtype) | |
| latents = vae_encode(media_item, self.vae, vae_per_channel_normalize=vae_per_channel_normalize) | |
| if item.media_frame_number == 0: | |
| latents_pos, lx, ly = self._get_latent_spatial_position(latents, item_resized, height, width) | |
| f, h, w = latents_pos.shape[-3:] | |
| init_latents[..., :f, ly:ly+h, lx:lx+w] = torch.lerp(init_latents[..., :f, ly:ly+h, lx:lx+w], latents_pos, item.conditioning_strength) | |
| mask[..., :f, ly:ly+h, lx:lx+w] = item.conditioning_strength | |
| else: | |
| if media_item.shape[2] > 1: | |
| init_latents, mask, latents = self._handle_non_first_sequence( | |
| init_latents, mask, latents, item.media_frame_number, item.conditioning_strength | |
| ) | |
| if latents is not None: | |
| latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator) | |
| extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask) | |
| num_extra_latents += num_new | |
| # --- Consolidação final --- | |
| latents_p, coords_l = self.patchifier.patchify(latents=init_latents) | |
| coords_p = self._latent_to_pixel_coords(coords_l) | |
| mask_p, _ = self.patchifier.patchify(latents=mask.unsqueeze(1)) | |
| mask_p = mask_p.squeeze(-1) | |
| if extra_latents: | |
| latents_p = torch.cat([*extra_latents, latents_p], dim=1) | |
| coords_p = torch.cat([*extra_coords, coords_p], dim=2) | |
| mask_p = torch.cat([*extra_masks, mask_p], dim=1) | |
| use_flash = getattr(self.transformer.config, 'use_tpu_flash_attention', False) | |
| if use_flash: | |
| latents_p = latents_p[:, :-num_extra_latents] | |
| coords_p = coords_p[:, :, :-num_extra_latents] | |
| mask_p = mask_p[:, :-num_extra_latents] | |
| return latents_p.cpu(), coords_p.cpu(), mask_p.cpu(), num_extra_latents | |
| # --- MÉTODOS PRIVADOS AUXILIARES --- | |
| def _cleanup_gpu(self): | |
| if torch.cuda.is_available(): | |
| with torch.cuda.device(self.device): torch.cuda.empty_cache() | |
| def _latent_to_pixel_coords(self, c): return latent_to_pixel_coords(c, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| def _resize_tensor(m, h, w): | |
| if m.shape[-2:] != (h, w): | |
| n = m.shape[2] | |
| flat = rearrange(m, "b c n h w -> (b n) c h w") | |
| resized = F.interpolate(flat, (h, w), mode="bilinear", align_corners=False) | |
| return rearrange(resized, "(b n) c h w -> b c n h w", n=n) | |
| return m | |
| def _resize_conditioning_item(self, i, h, w): | |
| n = copy.copy(i); n.media_item = self._resize_tensor(i.media_item, h, w); return n | |
| def _get_latent_spatial_position(self, l, i, h, w, strip=True): | |
| s, hi, wi = self.vae_scale_factor, i.media_item.shape[-2], i.media_item.shape[-1] | |
| xs = (w - wi) // 2 if i.media_x is None else i.media_x | |
| ys = (h - hi) // 2 if i.media_y is None else i.media_y | |
| if strip: | |
| if xs > 0: xs += s; l = l[..., :, 1:] | |
| if ys > 0: ys += s; l = l[..., 1:, :] | |
| if (xs + wi) < w: l = l[..., :, :-1] | |
| if (ys + hi) < h: l = l[..., :-1, :] | |
| return l, xs // s, ys // s | |
| def _handle_non_first_sequence( | |
| self, | |
| init_latents: torch.Tensor, | |
| mask: torch.Tensor, | |
| latents: torch.Tensor, | |
| media_frame_number: int, | |
| conditioning_strength: float, | |
| num_prefix=2, | |
| mode="concat" | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| fl, flp = latents.shape[2], num_prefix | |
| if fl > flp: | |
| start = media_frame_number // 8 + flp | |
| end = start + fl - flp | |
| init_latents[..., start:end, :, :] = torch.lerp(init_latents[..., start:end, :, :], latents[..., flp:, :, :], conditioning_strength) | |
| mask[..., start:end, :, :] = conditioning_strength | |
| if mode == "concat": | |
| latents = latents[..., :flp, :, :] | |
| else: | |
| latents = None | |
| return init_latents, mask, latents | |
| def _process_extra_item(self, l, i, g): | |
| n = randn_tensor(l.shape, generator=g, device=self.device, dtype=self.dtype) | |
| l = torch.lerp(n, l, i.conditioning_strength) | |
| lp, cl = self.patchifier.patchify(l) | |
| cp = self._latent_to_pixel_coords(cl); cp[:, 0] += i.media_frame_number | |
| nl = lp.shape[1] | |
| nm = torch.full(lp.shape[:2], i.conditioning_strength, dtype=torch.float32, device=self.device) | |
| return lp, cp, nm, nl | |
| # --- Instânciação do Singleton --- | |
| vae_aduc_pipeline = VaeAducPipeline() |