# vae_manager.py — versão simples (beta 1.0) # Responsável por decodificar latentes (B,C,T,H,W) → pixels (B,C,T,H',W') em [0,1]. import torch import contextlib class _SimpleVAEManager: def __init__(self, pipeline=None, device=None, autocast_dtype=torch.float32): """ pipeline: objeto do LTX que expõe decode_latents(...) ou .vae.decode(...) device: "cuda" ou "cpu" onde a decodificação deve ocorrer autocast_dtype: dtype de autocast quando em CUDA (bf16/fp16/fp32) """ self.pipeline = pipeline self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.autocast_dtype = autocast_dtype def attach_pipeline(self, pipeline, device=None, autocast_dtype=None): self.pipeline = pipeline if device is not None: self.device = device if autocast_dtype is not None: self.autocast_dtype = autocast_dtype @torch.no_grad() def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: # Garante device e dtype conforme runtime latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.autocast_dtype if self.device == "cuda" else latent_tensor.dtype) # Constrói o vetor de timesteps (um por item no batch B) num_items_in_batch = latent_tensor_gpu.shape[0] timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=latent_tensor_gpu.dtype) ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext() with ctx: pixels = vae_decode( latent_tensor_gpu, self.pipeline.vae if hasattr(self.pipeline, "vae") else self.pipeline, # compat is_video=True, timestep=timestep_tensor, vae_per_channel_normalize=True, ) # Normaliza para [0,1] se vier em [-1,1] if pixels.min() < 0: pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0 else: pixels = pixels.clamp(0, 1) return pixels # Singleton global de uso simples vae_manager_singleton = _SimpleVAEManager()