Spaces:
Paused
Paused
File size: 3,028 Bytes
1dc8d3d 8815ceb b8a0748 829e1b9 1dc8d3d 8815ceb b8a0748 1dc8d3d b8a0748 1dc8d3d b8a0748 8815ceb b8a0748 8815ceb 441491f b8a0748 1dc8d3d b8a0748 1dc8d3d b8a0748 1dc8d3d b8a0748 441491f b8a0748 8815ceb b8a0748 1dc8d3d 441491f 1dc8d3d b8a0748 1dc8d3d b8a0748 8815ceb 1dc8d3d b8a0748 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# FILE: managers/vae_manager.py (Versão Final com vae_decode corrigido)
import torch
import contextlib
import logging
# --- IMPORTAÇÃO CRÍTICA ---
# Importa a função helper oficial da biblioteca LTX para decodificação.
try:
from ltx_video.models.autoencoders.vae_encode import vae_decode
except ImportError:
raise ImportError("Could not import 'vae_decode' from LTX-Video library. Check sys.path and repo integrity.")
class _SimpleVAEManager:
"""
Manages VAE decoding, now using the official 'vae_decode' helper function
for maximum compatibility.
"""
def __init__(self):
self.pipeline = None
self.device = torch.device("cpu")
self.autocast_dtype = torch.float32
def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
self.pipeline = pipeline
if device is not None:
self.device = torch.device(device)
logging.info(f"[VAEManager] VAE device successfully set to: {self.device}")
if autocast_dtype is not None:
self.autocast_dtype = autocast_dtype
@torch.no_grad()
def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
"""
Decodes a latent tensor into a pixel tensor using the 'vae_decode' helper.
"""
if self.pipeline is None:
raise RuntimeError("VAEManager: No pipeline has been attached.")
# Move os latentes para o dispositivo VAE dedicado.
latent_tensor_on_vae_device = latent_tensor.to(self.device)
# Prepara o tensor de timesteps no mesmo dispositivo.
num_items_in_batch = latent_tensor_on_vae_device.shape[0]
timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device)
autocast_device_type = self.device.type
ctx = torch.autocast(
device_type=autocast_device_type,
dtype=self.autocast_dtype,
enabled=(autocast_device_type == 'cuda')
)
with ctx:
logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.")
# --- CORREÇÃO PRINCIPAL ---
# Usa a função helper `vae_decode` em vez de chamar `vae.decode` diretamente.
# Esta função sabe como lidar com o argumento 'timestep'.
pixels = vae_decode(
latents=latent_tensor_on_vae_device,
vae=self.pipeline.vae,
is_video=True,
timestep=timestep_tensor,
vae_per_channel_normalize=True, # Importante manter este parâmetro consistente
)
# A função vae_decode já retorna no intervalo [0, 1], mas um clamp extra não faz mal.
pixels = pixels.clamp(0, 1)
logging.debug("[VAEManager] Decoding complete. Moving pixel tensor to CPU.")
return pixels.cpu()
# Singleton global
vae_manager_singleton = _SimpleVAEManager() |