File size: 3,617 Bytes
a6c6f33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# FILE: managers/vae_manager.py (Versão Final com vae_decode corrigido)

import torch
import contextlib
import logging
import sys
from pathlib import Path
import os
import io

LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")

def add_deps_to_path():
    """
    Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
    bibliotecas possam ser importadas.
    """
    repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
    if repo_path not in sys.path:
        sys.path.insert(0, repo_path)
        logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")

# Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
add_deps_to_path()


# --- IMPORTAÇÃO CRÍTICA ---
# Importa a função helper oficial da biblioteca LTX para decodificação.
try:
    from ltx_video.models.autoencoders.vae_encode import vae_decode
except ImportError:
    raise ImportError("Could not import 'vae_decode' from LTX-Video library. Check sys.path and repo integrity.")


class _SimpleVAEManager:
    """
    Manages VAE decoding, now using the official 'vae_decode' helper function
    for maximum compatibility.
    """
    def __init__(self):
        self.pipeline = None
        self.device = torch.device("cpu")
        self.autocast_dtype = torch.float32

    def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
        self.pipeline = pipeline
        if device is not None:
            self.device = torch.device(device)
            logging.info(f"[VAEManager] VAE device successfully set to: {self.device}")
        if autocast_dtype is not None:
            self.autocast_dtype = autocast_dtype

    @torch.no_grad()
    def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
        """
        Decodes a latent tensor into a pixel tensor using the 'vae_decode' helper.
        """
        if self.pipeline is None:
            raise RuntimeError("VAEManager: No pipeline has been attached.")
        
        # Move os latentes para o dispositivo VAE dedicado.
        latent_tensor_on_vae_device = latent_tensor.to(self.device)

        # Prepara o tensor de timesteps no mesmo dispositivo.
        num_items_in_batch = latent_tensor_on_vae_device.shape[0]
        timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device)
    
        autocast_device_type = self.device.type
        ctx = torch.autocast(
            device_type=autocast_device_type,
            dtype=self.autocast_dtype,
            enabled=(autocast_device_type == 'cuda')
        )
        
        with ctx:
            logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.")
            
            # --- CORREÇÃO PRINCIPAL ---
            # Usa a função helper `vae_decode` em vez de chamar `vae.decode` diretamente.
            # Esta função sabe como lidar com o argumento 'timestep'.
            pixels = vae_decode(
                latents=latent_tensor_on_vae_device,
                vae=self.pipeline.vae,
                is_video=True,
                timestep=timestep_tensor,
                vae_per_channel_normalize=True, # Importante manter este parâmetro consistente
            )
    
        # A função vae_decode já retorna no intervalo [0, 1], mas um clamp extra não faz mal.
        pixels = pixels.clamp(0, 1)
        
        logging.debug("[VAEManager] Decoding complete. Moving pixel tensor to CPU.")
        return pixels.cpu()

# Singleton global
vae_manager_singleton = _SimpleVAEManager()