eeuuia commited on
Commit
a6c6f33
·
verified ·
1 Parent(s): 9916d25

Upload vae_manager (1).py

Browse files
Files changed (1) hide show
  1. managers/vae_manager (1).py +96 -0
managers/vae_manager (1).py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FILE: managers/vae_manager.py (Versão Final com vae_decode corrigido)
2
+
3
+ import torch
4
+ import contextlib
5
+ import logging
6
+ import sys
7
+ from pathlib import Path
8
+ import os
9
+ import io
10
+
11
+ LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
12
+
13
+ def add_deps_to_path():
14
+ """
15
+ Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
16
+ bibliotecas possam ser importadas.
17
+ """
18
+ repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
19
+ if repo_path not in sys.path:
20
+ sys.path.insert(0, repo_path)
21
+ logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
22
+
23
+ # Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
24
+ add_deps_to_path()
25
+
26
+
27
+ # --- IMPORTAÇÃO CRÍTICA ---
28
+ # Importa a função helper oficial da biblioteca LTX para decodificação.
29
+ try:
30
+ from ltx_video.models.autoencoders.vae_encode import vae_decode
31
+ except ImportError:
32
+ raise ImportError("Could not import 'vae_decode' from LTX-Video library. Check sys.path and repo integrity.")
33
+
34
+
35
+ class _SimpleVAEManager:
36
+ """
37
+ Manages VAE decoding, now using the official 'vae_decode' helper function
38
+ for maximum compatibility.
39
+ """
40
+ def __init__(self):
41
+ self.pipeline = None
42
+ self.device = torch.device("cpu")
43
+ self.autocast_dtype = torch.float32
44
+
45
+ def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
46
+ self.pipeline = pipeline
47
+ if device is not None:
48
+ self.device = torch.device(device)
49
+ logging.info(f"[VAEManager] VAE device successfully set to: {self.device}")
50
+ if autocast_dtype is not None:
51
+ self.autocast_dtype = autocast_dtype
52
+
53
+ @torch.no_grad()
54
+ def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
55
+ """
56
+ Decodes a latent tensor into a pixel tensor using the 'vae_decode' helper.
57
+ """
58
+ if self.pipeline is None:
59
+ raise RuntimeError("VAEManager: No pipeline has been attached.")
60
+
61
+ # Move os latentes para o dispositivo VAE dedicado.
62
+ latent_tensor_on_vae_device = latent_tensor.to(self.device)
63
+
64
+ # Prepara o tensor de timesteps no mesmo dispositivo.
65
+ num_items_in_batch = latent_tensor_on_vae_device.shape[0]
66
+ timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device)
67
+
68
+ autocast_device_type = self.device.type
69
+ ctx = torch.autocast(
70
+ device_type=autocast_device_type,
71
+ dtype=self.autocast_dtype,
72
+ enabled=(autocast_device_type == 'cuda')
73
+ )
74
+
75
+ with ctx:
76
+ logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.")
77
+
78
+ # --- CORREÇÃO PRINCIPAL ---
79
+ # Usa a função helper `vae_decode` em vez de chamar `vae.decode` diretamente.
80
+ # Esta função sabe como lidar com o argumento 'timestep'.
81
+ pixels = vae_decode(
82
+ latents=latent_tensor_on_vae_device,
83
+ vae=self.pipeline.vae,
84
+ is_video=True,
85
+ timestep=timestep_tensor,
86
+ vae_per_channel_normalize=True, # Importante manter este parâmetro consistente
87
+ )
88
+
89
+ # A função vae_decode já retorna no intervalo [0, 1], mas um clamp extra não faz mal.
90
+ pixels = pixels.clamp(0, 1)
91
+
92
+ logging.debug("[VAEManager] Decoding complete. Moving pixel tensor to CPU.")
93
+ return pixels.cpu()
94
+
95
+ # Singleton global
96
+ vae_manager_singleton = _SimpleVAEManager()