Spaces:
Paused
Paused
| # vae_manager.py — versão simples (beta 1.0) | |
| # Responsável por decodificar latentes (B,C,T,H,W) → pixels (B,C,T,H',W') em [0,1]. | |
| import torch | |
| import contextlib | |
| import os | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| from huggingface_hub import logging | |
| logging.set_verbosity_error() | |
| logging.set_verbosity_warning() | |
| logging.set_verbosity_info() | |
| logging.set_verbosity_debug() | |
| DEPS_DIR = Path("/data") | |
| LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video" | |
| if not LTX_VIDEO_REPO_DIR.exists(): | |
| print(f"[DEBUG] Repositório não encontrado em {LTX_VIDEO_REPO_DIR}. Rodando setup...") | |
| run_setup() | |
| def add_deps_to_path(): | |
| repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path: | |
| sys.path.insert(0, repo_path) | |
| print(f"[DEBUG] Repo adicionado ao sys.path: {repo_path}") | |
| add_deps_to_path() | |
| from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode | |
| class _SimpleVAEManager: | |
| def __init__(self, pipeline=None, device=None, autocast_dtype=torch.float32): | |
| """ | |
| pipeline: objeto do LTX que expõe decode_latents(...) ou .vae.decode(...) | |
| device: "cuda" ou "cpu" onde a decodificação deve ocorrer | |
| autocast_dtype: dtype de autocast quando em CUDA (bf16/fp16/fp32) | |
| """ | |
| self.pipeline = pipeline | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.autocast_dtype = autocast_dtype | |
| def attach_pipeline(self, pipeline, device=None, autocast_dtype=None): | |
| self.pipeline = pipeline | |
| if device is not None: | |
| self.device = device | |
| if autocast_dtype is not None: | |
| self.autocast_dtype = autocast_dtype | |
| def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: | |
| # Garante device e dtype conforme runtime | |
| latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.autocast_dtype if self.device == "cuda" else latent_tensor.dtype) | |
| # Constrói o vetor de timesteps (um por item no batch B) | |
| num_items_in_batch = latent_tensor_gpu.shape[0] | |
| timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=latent_tensor_gpu.dtype) | |
| ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext() | |
| with ctx: | |
| pixels = vae_decode( | |
| latent_tensor_gpu, | |
| self.pipeline.vae if hasattr(self.pipeline, "vae") else self.pipeline, # compat | |
| is_video=True, | |
| timestep=timestep_tensor, | |
| vae_per_channel_normalize=True, | |
| ) | |
| # Normaliza para [0,1] se vier em [-1,1] | |
| if pixels.min() < 0: | |
| pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0 | |
| else: | |
| pixels = pixels.clamp(0, 1) | |
| return pixels | |
| # Singleton global de uso simples | |
| vae_manager_singleton = _SimpleVAEManager() | |