Test4 / managers /vae_manager.py
EuuIia's picture
Update managers/vae_manager.py
90eb72d verified
# vae_manager.py — versão simples (beta 1.0)
# Responsável por decodificar latentes (B,C,T,H,W) → pixels (B,C,T,H',W') em [0,1].
import torch
import contextlib
import os
import subprocess
import sys
from pathlib import Path
from huggingface_hub import logging
logging.set_verbosity_error()
logging.set_verbosity_warning()
logging.set_verbosity_info()
logging.set_verbosity_debug()
DEPS_DIR = Path("/data")
LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
if not LTX_VIDEO_REPO_DIR.exists():
print(f"[DEBUG] Repositório não encontrado em {LTX_VIDEO_REPO_DIR}. Rodando setup...")
run_setup()
def add_deps_to_path():
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
sys.path.insert(0, repo_path)
print(f"[DEBUG] Repo adicionado ao sys.path: {repo_path}")
add_deps_to_path()
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
class _SimpleVAEManager:
def __init__(self, pipeline=None, device=None, autocast_dtype=torch.float32):
"""
pipeline: objeto do LTX que expõe decode_latents(...) ou .vae.decode(...)
device: "cuda" ou "cpu" onde a decodificação deve ocorrer
autocast_dtype: dtype de autocast quando em CUDA (bf16/fp16/fp32)
"""
self.pipeline = pipeline
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.autocast_dtype = autocast_dtype
def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
self.pipeline = pipeline
if device is not None:
self.device = device
if autocast_dtype is not None:
self.autocast_dtype = autocast_dtype
@torch.no_grad()
def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
# Garante device e dtype conforme runtime
latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.autocast_dtype if self.device == "cuda" else latent_tensor.dtype)
# Constrói o vetor de timesteps (um por item no batch B)
num_items_in_batch = latent_tensor_gpu.shape[0]
timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=latent_tensor_gpu.dtype)
ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
with ctx:
pixels = vae_decode(
latent_tensor_gpu,
self.pipeline.vae if hasattr(self.pipeline, "vae") else self.pipeline, # compat
is_video=True,
timestep=timestep_tensor,
vae_per_channel_normalize=True,
)
# Normaliza para [0,1] se vier em [-1,1]
if pixels.min() < 0:
pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0
else:
pixels = pixels.clamp(0, 1)
return pixels
# Singleton global de uso simples
vae_manager_singleton = _SimpleVAEManager()