Eueuiaa commited on
Commit
cf3d948
·
verified ·
1 Parent(s): ad120b9

Update LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py

Browse files
LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py CHANGED
@@ -24,7 +24,6 @@ from transformers import (
24
  AutoTokenizer,
25
  )
26
 
27
- from ltx_video.models.pipelines.spy_latent import SpyLatent
28
 
29
  from ltx_video.models.autoencoders.causal_video_autoencoder import (
30
  CausalVideoAutoencoder,
@@ -62,6 +61,116 @@ logging.set_verbosity_debug()
62
  #logger = logging.get_logger(__name__) # pylint: disable=invalid-name
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  ASPECT_RATIO_1024_BIN = {
66
  "0.25": [512.0, 2048.0],
67
  "0.28": [512.0, 1856.0],
 
24
  AutoTokenizer,
25
  )
26
 
 
27
 
28
  from ltx_video.models.autoencoders.causal_video_autoencoder import (
29
  CausalVideoAutoencoder,
 
61
  #logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
 
63
 
64
+ class SpyLatent:
65
+
66
+ """
67
+ Uma classe para inspecionar tensores latentes em vários estágios de um pipeline.
68
+ Imprime estatísticas e pode salvar visualizações decodificadas por um VAE.
69
+ """
70
+
71
+ import torch
72
+ import os
73
+ import traceback
74
+ from einops import rearrange
75
+ from torchvision.utils import save_image
76
+
77
+ def __init__(self, vae=None, output_dir: str = "/app/output"):
78
+ """
79
+ Inicializa o espião.
80
+
81
+ Args:
82
+ vae: A instância do modelo VAE para decodificar os latentes. Se for None,
83
+ a visualização será desativada.
84
+ output_dir (str): O diretório padrão para salvar as imagens de visualização.
85
+ """
86
+ self.vae = vae
87
+ self.output_dir = output_dir
88
+ self.device = vae.device if hasattr(vae, 'device') else torch.device("cpu")
89
+
90
+ if self.vae is None:
91
+ print("[SpyLatent] AVISO: VAE não fornecido. A funcionalidade de visualização de imagem está desativada.")
92
+
93
+ def inspect(
94
+ self,
95
+ tensor: torch.Tensor,
96
+ tag: str,
97
+ reference_shape_5d: tuple = None,
98
+ save_visual: bool = True,
99
+ ):
100
+ """
101
+ Inspeciona um tensor latente.
102
+
103
+ Args:
104
+ tensor (torch.Tensor): O tensor a ser inspecionado.
105
+ tag (str): Um rótulo para identificar o ponto de inspeção nos logs.
106
+ reference_shape_5d (tuple, optional): A forma 5D de referência (B, C, F, H, W)
107
+ necessária se o tensor de entrada for 3D.
108
+ save_visual (bool): Se True, decodifica com o VAE e salva uma imagem.
109
+ """
110
+ print(f"\n--- [INSPEÇÃO DE LATENTE: {tag}] ---")
111
+ if not isinstance(tensor, torch.Tensor):
112
+ print(f" AVISO: O objeto fornecido para '{tag}' não é um tensor.")
113
+ print("--- [FIM DA INSPEÇÃO] ---\n")
114
+ return
115
+
116
+ try:
117
+ # --- Imprime Estatísticas do Tensor Original ---
118
+ self._print_stats("Tensor Original", tensor)
119
+
120
+ # --- Converte para 5D se necessário ---
121
+ tensor_5d = self._to_5d(tensor, reference_shape_5d)
122
+ if tensor_5d is not None and tensor.ndim == 3:
123
+ self._print_stats("Convertido para 5D", tensor_5d)
124
+
125
+ save_visual = False
126
+ # --- Visualização com VAE ---
127
+ if save_visual and self.vae is not None and tensor_5d is not None:
128
+ os.makedirs(self.output_dir, exist_ok=True)
129
+ print(f" VISUALIZAÇÃO (VAE): Salvando imagem em {self.output_dir}...")
130
+
131
+ frame_idx_to_viz = min(1, tensor_5d.shape[2] - 1)
132
+ if frame_idx_to_viz < 0:
133
+ print(" VISUALIZAÇÃO (VAE): Tensor não tem frames para visualizar.")
134
+ else:
135
+ print(f" VISUALIZAÇÃO (VAE): Usando frame de índice {frame_idx_to_viz}.")
136
+ latent_slice = tensor_5d[:, :, frame_idx_to_viz:frame_idx_to_viz+1, :, :]
137
+
138
+ with torch.no_grad(), torch.autocast(device_type=self.device.type):
139
+ pixel_slice = self.vae.decode(latent_slice / self.vae.config.scaling_factor).sample
140
+
141
+ save_image((pixel_slice / 2 + 0.5).clamp(0, 1), os.path.join(self.output_dir, f"inspect_{tag.lower()}.png"))
142
+ print(" VISUALIZAÇÃO (VAE): Imagem salva.")
143
+
144
+ except Exception as e:
145
+ print(f" ERRO na inspeção: {e}")
146
+ traceback.print_exc()
147
+ finally:
148
+ print("--- [FIM DA INSPEÇÃO] ---\n")
149
+
150
+ def _to_5d(self, tensor: torch.Tensor, shape_5d: tuple) -> torch.Tensor:
151
+ """Converte um tensor 3D patchificado de volta para 5D."""
152
+ if tensor.ndim == 5:
153
+ return tensor
154
+ if tensor.ndim == 3 and shape_5d:
155
+ try:
156
+ b, c, f, h, w = shape_5d
157
+ return rearrange(tensor, "b (f h w) c -> b c f h w", c=c, f=f, h=h, w=w)
158
+ except Exception as e:
159
+ print(f" AVISO: Erro ao rearranjar tensor 3D para 5D: {e}. A visualização pode falhar.")
160
+ return None
161
+ return None
162
+
163
+ def _print_stats(self, prefix: str, tensor: torch.Tensor):
164
+ """Helper para imprimir estatísticas de um tensor."""
165
+ mean = tensor.mean().item()
166
+ std = tensor.std().item()
167
+ min_val = tensor.min().item()
168
+ max_val = tensor.max().item()
169
+ print(f" {prefix}: Shape={list(tensor.shape)}, Mean={mean:.4f}, Std={std:.4f}, Min={min_val:.4f}, Max={max_val:.4f}")
170
+
171
+
172
+
173
+
174
  ASPECT_RATIO_1024_BIN = {
175
  "0.25": [512.0, 2048.0],
176
  "0.28": [512.0, 1856.0],