File size: 3,053 Bytes
8815ceb
 
 
 
 
ff5ac95
 
 
 
 
b93af84
 
db05291
b93af84
 
 
 
 
90eb72d
b93af84
 
ff5ac95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8815ceb
829e1b9
 
 
8815ceb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441491f
 
 
 
 
 
 
 
 
 
8815ceb
 
441491f
 
 
 
 
 
 
 
8815ceb
441491f
 
8815ceb
441491f
 
 
8815ceb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# vae_manager.py — versão simples (beta 1.0)
# Responsável por decodificar latentes (B,C,T,H,W) → pixels (B,C,T,H',W') em [0,1].

import torch
import contextlib
import os
import subprocess
import sys
from pathlib import Path

from huggingface_hub import logging


logging.set_verbosity_error()
logging.set_verbosity_warning()
logging.set_verbosity_info()
logging.set_verbosity_debug()




DEPS_DIR = Path("/data")
LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
if not LTX_VIDEO_REPO_DIR.exists():
    print(f"[DEBUG] Repositório não encontrado em {LTX_VIDEO_REPO_DIR}. Rodando setup...")
    run_setup()

def add_deps_to_path():
    repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
    if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
        sys.path.insert(0, repo_path)
        print(f"[DEBUG] Repo adicionado ao sys.path: {repo_path}")

add_deps_to_path()



from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode


class _SimpleVAEManager:
    def __init__(self, pipeline=None, device=None, autocast_dtype=torch.float32):
        """
        pipeline: objeto do LTX que expõe decode_latents(...) ou .vae.decode(...)
        device: "cuda" ou "cpu" onde a decodificação deve ocorrer
        autocast_dtype: dtype de autocast quando em CUDA (bf16/fp16/fp32)
        """
        self.pipeline = pipeline
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.autocast_dtype = autocast_dtype

    def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
        self.pipeline = pipeline
        if device is not None:
            self.device = device
        if autocast_dtype is not None:
            self.autocast_dtype = autocast_dtype



    @torch.no_grad()
    def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
        
        # Garante device e dtype conforme runtime
        latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.autocast_dtype if self.device == "cuda" else latent_tensor.dtype)
        
        # Constrói o vetor de timesteps (um por item no batch B)
        num_items_in_batch = latent_tensor_gpu.shape[0]
        timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=latent_tensor_gpu.dtype)
    
        ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
        with ctx:
            pixels = vae_decode(
                latent_tensor_gpu,
                self.pipeline.vae if hasattr(self.pipeline, "vae") else self.pipeline,  # compat
                is_video=True,
                timestep=timestep_tensor,
                vae_per_channel_normalize=True,
            )
    
        # Normaliza para [0,1] se vier em [-1,1]
        if pixels.min() < 0:
            pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0
        else:
            pixels = pixels.clamp(0, 1)
        return pixels
    

# Singleton global de uso simples
vae_manager_singleton = _SimpleVAEManager()