eeuuia commited on
Commit
0df7d55
·
verified ·
1 Parent(s): 5b467cd

Delete managers/vae_manager.py

Browse files
Files changed (1) hide show
  1. managers/vae_manager.py +0 -96
managers/vae_manager.py DELETED
@@ -1,96 +0,0 @@
1
- # FILE: managers/vae_manager.py (Versão Final com vae_decode corrigido)
2
-
3
- import torch
4
- import contextlib
5
- import logging
6
- import sys
7
- from pathlib import Path
8
- import os
9
- import io
10
-
11
- LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
12
-
13
- def add_deps_to_path():
14
- """
15
- Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
16
- bibliotecas possam ser importadas.
17
- """
18
- repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
19
- if repo_path not in sys.path:
20
- sys.path.insert(0, repo_path)
21
- logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
22
-
23
- # Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
24
- add_deps_to_path()
25
-
26
-
27
- # --- IMPORTAÇÃO CRÍTICA ---
28
- # Importa a função helper oficial da biblioteca LTX para decodificação.
29
- try:
30
- from ltx_video.models.autoencoders.vae_encode import vae_decode
31
- except ImportError:
32
- raise ImportError("Could not import 'vae_decode' from LTX-Video library. Check sys.path and repo integrity.")
33
-
34
-
35
- class _SimpleVAEManager:
36
- """
37
- Manages VAE decoding, now using the official 'vae_decode' helper function
38
- for maximum compatibility.
39
- """
40
- def __init__(self):
41
- self.pipeline = None
42
- self.device = torch.device("cpu")
43
- self.autocast_dtype = torch.float32
44
-
45
- def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
46
- self.pipeline = pipeline
47
- if device is not None:
48
- self.device = torch.device(device)
49
- logging.info(f"[VAEManager] VAE device successfully set to: {self.device}")
50
- if autocast_dtype is not None:
51
- self.autocast_dtype = autocast_dtype
52
-
53
- @torch.no_grad()
54
- def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
55
- """
56
- Decodes a latent tensor into a pixel tensor using the 'vae_decode' helper.
57
- """
58
- if self.pipeline is None:
59
- raise RuntimeError("VAEManager: No pipeline has been attached.")
60
-
61
- # Move os latentes para o dispositivo VAE dedicado.
62
- latent_tensor_on_vae_device = latent_tensor.to(self.device)
63
-
64
- # Prepara o tensor de timesteps no mesmo dispositivo.
65
- num_items_in_batch = latent_tensor_on_vae_device.shape[0]
66
- timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device)
67
-
68
- autocast_device_type = self.device.type
69
- ctx = torch.autocast(
70
- device_type=autocast_device_type,
71
- dtype=self.autocast_dtype,
72
- enabled=(autocast_device_type == 'cuda')
73
- )
74
-
75
- with ctx:
76
- logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.")
77
-
78
- # --- CORREÇÃO PRINCIPAL ---
79
- # Usa a função helper `vae_decode` em vez de chamar `vae.decode` diretamente.
80
- # Esta função sabe como lidar com o argumento 'timestep'.
81
- pixels = vae_decode(
82
- latents=latent_tensor_on_vae_device,
83
- vae=self.pipeline.vae,
84
- is_video=True,
85
- timestep=timestep_tensor,
86
- vae_per_channel_normalize=True, # Importante manter este parâmetro consistente
87
- )
88
-
89
- # A função vae_decode já retorna no intervalo [0, 1], mas um clamp extra não faz mal.
90
- pixels = pixels.clamp(0, 1)
91
-
92
- logging.debug("[VAEManager] Decoding complete. Moving pixel tensor to CPU.")
93
- return pixels.cpu()
94
-
95
- # Singleton global
96
- vae_manager_singleton = _SimpleVAEManager()