Eueuiaa commited on
Commit
bdfe4dc
·
verified ·
1 Parent(s): 481129e

Upload 2 files

Browse files
LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py CHANGED
@@ -24,6 +24,8 @@ from transformers import (
24
  AutoTokenizer,
25
  )
26
 
 
 
27
  from ltx_video.models.autoencoders.causal_video_autoencoder import (
28
  CausalVideoAutoencoder,
29
  )
@@ -250,6 +252,8 @@ class LTXVideoPipeline(DiffusionPipeline):
250
  scheduler ([`SchedulerMixin`]):
251
  A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
252
  """
 
 
253
 
254
  bad_punct_regex = re.compile(
255
  r"["
@@ -312,6 +316,8 @@ class LTXVideoPipeline(DiffusionPipeline):
312
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
313
 
314
  self.allowed_inference_steps = allowed_inference_steps
 
 
315
 
316
  def mask_text_embeddings(self, emb, mask):
317
  if emb.shape[0] == 1:
@@ -1116,12 +1122,14 @@ class LTXVideoPipeline(DiffusionPipeline):
1116
  vae_per_channel_normalize=vae_per_channel_normalize,
1117
  )
1118
 
1119
-
1120
  try:
1121
  print(f"[LTX4]LATENTS {latents.shape}")
 
1122
  except Exception:
1123
- pass
1124
-
 
 
1125
  # Update the latents with the conditioning items and patchify them into (b, n, c)
1126
  latents, pixel_coords, conditioning_mask, num_cond_latents = (
1127
  self.prepare_conditioning(
@@ -1136,6 +1144,21 @@ class LTXVideoPipeline(DiffusionPipeline):
1136
  )
1137
  init_latents = latents.clone() # Used for image_cond_noise_update
1138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1139
 
1140
 
1141
 
@@ -1209,8 +1232,11 @@ class LTXVideoPipeline(DiffusionPipeline):
1209
 
1210
  try:
1211
  print(f"[LTX6]LATENTS {latents.shape}")
 
1212
  except Exception:
1213
  pass
 
 
1214
 
1215
  latent_model_input = (
1216
  torch.cat([latents] * num_conds) if num_conds > 1 else latents
@@ -1221,6 +1247,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1221
 
1222
  try:
1223
  print(f"[LTX7]LATENTS {latent_model_input.shape}")
 
1224
  except Exception:
1225
  pass
1226
 
@@ -1341,6 +1368,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1341
 
1342
  try:
1343
  print(f"[LTX8]LATENTS {latents.shape}")
 
1344
  except Exception:
1345
  pass
1346
 
@@ -1357,9 +1385,12 @@ class LTXVideoPipeline(DiffusionPipeline):
1357
 
1358
  try:
1359
  print(f"[LTX9]LATENTS {latents.shape}")
 
 
1360
  except Exception:
1361
  pass
1362
 
 
1363
  if offload_to_cpu:
1364
  self.transformer = self.transformer.cpu()
1365
  if self._execution_device == "cuda":
@@ -1371,6 +1402,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1371
 
1372
  try:
1373
  print(f"[LTX10]LATENTS {latents.shape}")
 
1374
  except Exception:
1375
  pass
1376
 
@@ -1402,7 +1434,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1402
  decode_timestep = None
1403
  latents = self.tone_map_latents(latents, tone_map_compression_ratio)
1404
  image = vae_decode(
1405
- latents,
1406
  self.vae,
1407
  is_video,
1408
  vae_per_channel_normalize=kwargs["vae_per_channel_normalize"],
 
24
  AutoTokenizer,
25
  )
26
 
27
+ from ltx_video.models.pipelines.spy_latent import SpyLatent
28
+
29
  from ltx_video.models.autoencoders.causal_video_autoencoder import (
30
  CausalVideoAutoencoder,
31
  )
 
252
  scheduler ([`SchedulerMixin`]):
253
  A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
254
  """
255
+
256
+
257
 
258
  bad_punct_regex = re.compile(
259
  r"["
 
316
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
317
 
318
  self.allowed_inference_steps = allowed_inference_steps
319
+
320
+ self.spy = SpyLatent(vae=self.pipeline.vae)
321
 
322
  def mask_text_embeddings(self, emb, mask):
323
  if emb.shape[0] == 1:
 
1122
  vae_per_channel_normalize=vae_per_channel_normalize,
1123
  )
1124
 
 
1125
  try:
1126
  print(f"[LTX4]LATENTS {latents.shape}")
1127
+ original_shape = latents
1128
  except Exception:
1129
+ pass
1130
+
1131
+
1132
+
1133
  # Update the latents with the conditioning items and patchify them into (b, n, c)
1134
  latents, pixel_coords, conditioning_mask, num_cond_latents = (
1135
  self.prepare_conditioning(
 
1144
  )
1145
  init_latents = latents.clone() # Used for image_cond_noise_update
1146
 
1147
+ try:
1148
+ print(f"[LTXCond]conditioning_mask {conditioning_mask.shape}")
1149
+ except Exception:
1150
+ pass
1151
+
1152
+ try:
1153
+ print(f"[LTXCond]pixel_coords {pixel_coords.shape}")
1154
+ except Exception:
1155
+ pass
1156
+
1157
+ try:
1158
+ print(f"[LTXCond]pixel_coords {pixel_coords.shape}")
1159
+ except Exception:
1160
+ pass
1161
+
1162
 
1163
 
1164
 
 
1232
 
1233
  try:
1234
  print(f"[LTX6]LATENTS {latents.shape}")
1235
+ self.spy.inspect(latents, "LTX6_After_Patchify", reference_shape_5d=original_shape)
1236
  except Exception:
1237
  pass
1238
+
1239
+
1240
 
1241
  latent_model_input = (
1242
  torch.cat([latents] * num_conds) if num_conds > 1 else latents
 
1247
 
1248
  try:
1249
  print(f"[LTX7]LATENTS {latent_model_input.shape}")
1250
+ self.spy.inspect(latents, "LTX7_After_Patchify", reference_shape_5d=original_shape)
1251
  except Exception:
1252
  pass
1253
 
 
1368
 
1369
  try:
1370
  print(f"[LTX8]LATENTS {latents.shape}")
1371
+ self.spy.inspect(latents, "LTX8_After_Patchify", reference_shape_5d=original_shape)
1372
  except Exception:
1373
  pass
1374
 
 
1385
 
1386
  try:
1387
  print(f"[LTX9]LATENTS {latents.shape}")
1388
+ self.spy.inspect(latents, "LTX9_After_Patchify", reference_shape_5d=original_shape)
1389
+
1390
  except Exception:
1391
  pass
1392
 
1393
+
1394
  if offload_to_cpu:
1395
  self.transformer = self.transformer.cpu()
1396
  if self._execution_device == "cuda":
 
1402
 
1403
  try:
1404
  print(f"[LTX10]LATENTS {latents.shape}")
1405
+ self.spy.inspect(latents, "LTX10_After_Patchify", reference_shape_5d=original_shape)
1406
  except Exception:
1407
  pass
1408
 
 
1434
  decode_timestep = None
1435
  latents = self.tone_map_latents(latents, tone_map_compression_ratio)
1436
  image = vae_decode(
1437
+ latents,,
1438
  self.vae,
1439
  is_video,
1440
  vae_per_channel_normalize=kwargs["vae_per_channel_normalize"],
LTX-Video/ltx_video/pipelines/spy_late.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # spy_latent.py
2
+
3
+ import torch
4
+ import os
5
+ import traceback
6
+ from einops import rearrange
7
+ from torchvision.utils import save_image
8
+
9
+ # Tenta importar o VAE do pipeline. Se não conseguir, a visualização será desativada.
10
+ try:
11
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
12
+ except ImportError:
13
+ CausalVideoAutoencoder = None
14
+
15
+ class SpyLatent:
16
+ """
17
+ Uma classe para inspecionar tensores latentes em vários estágios de um pipeline.
18
+ Imprime estatísticas e pode salvar visualizações decodificadas por um VAE.
19
+ """
20
+ def __init__(self, vae=None, output_dir: str = "/app/output"):
21
+ """
22
+ Inicializa o espião.
23
+
24
+ Args:
25
+ vae: A instância do modelo VAE para decodificar os latentes. Se for None,
26
+ a visualização será desativada.
27
+ output_dir (str): O diretório padrão para salvar as imagens de visualização.
28
+ """
29
+ self.vae = vae
30
+ self.output_dir = output_dir
31
+ self.device = vae.device if hasattr(vae, 'device') else torch.device("cpu")
32
+
33
+ if self.vae is None:
34
+ print("[SpyLatent] AVISO: VAE não fornecido. A funcionalidade de visualização de imagem está desativada.")
35
+
36
+ def inspect(
37
+ self,
38
+ tensor: torch.Tensor,
39
+ tag: str,
40
+ reference_shape_5d: tuple = None,
41
+ save_visual: bool = True,
42
+ ):
43
+ """
44
+ Inspeciona um tensor latente.
45
+
46
+ Args:
47
+ tensor (torch.Tensor): O tensor a ser inspecionado.
48
+ tag (str): Um rótulo para identificar o ponto de inspeção nos logs.
49
+ reference_shape_5d (tuple, optional): A forma 5D de referência (B, C, F, H, W)
50
+ necessária se o tensor de entrada for 3D.
51
+ save_visual (bool): Se True, decodifica com o VAE e salva uma imagem.
52
+ """
53
+ print(f"\n--- [INSPEÇÃO DE LATENTE: {tag}] ---")
54
+ if not isinstance(tensor, torch.Tensor):
55
+ print(f" AVISO: O objeto fornecido para '{tag}' não é um tensor.")
56
+ print("--- [FIM DA INSPEÇÃO] ---\n")
57
+ return
58
+
59
+ try:
60
+ # --- Imprime Estatísticas do Tensor Original ---
61
+ self._print_stats("Tensor Original", tensor)
62
+
63
+ # --- Converte para 5D se necessário ---
64
+ tensor_5d = self._to_5d(tensor, reference_shape_5d)
65
+ if tensor_5d is not None and tensor.ndim == 3:
66
+ self._print_stats("Convertido para 5D", tensor_5d)
67
+
68
+ # --- Visualização com VAE ---
69
+ if save_visual and self.vae is not None and tensor_5d is not None:
70
+ os.makedirs(self.output_dir, exist_ok=True)
71
+ print(f" VISUALIZAÇÃO (VAE): Salvando imagem em {self.output_dir}...")
72
+
73
+ frame_idx_to_viz = min(1, tensor_5d.shape[2] - 1)
74
+ if frame_idx_to_viz < 0:
75
+ print(" VISUALIZAÇÃO (VAE): Tensor não tem frames para visualizar.")
76
+ else:
77
+ print(f" VISUALIZAÇÃO (VAE): Usando frame de índice {frame_idx_to_viz}.")
78
+ latent_slice = tensor_5d[:, :, frame_idx_to_viz:frame_idx_to_viz+1, :, :]
79
+
80
+ with torch.no_grad(), torch.autocast(device_type=self.device.type):
81
+ pixel_slice = self.vae.decode(latent_slice / self.vae.config.scaling_factor).sample
82
+
83
+ save_image((pixel_slice / 2 + 0.5).clamp(0, 1), os.path.join(self.output_dir, f"inspect_{tag.lower()}.png"))
84
+ print(" VISUALIZAÇÃO (VAE): Imagem salva.")
85
+
86
+ except Exception as e:
87
+ print(f" ERRO na inspeção: {e}")
88
+ traceback.print_exc()
89
+ finally:
90
+ print("--- [FIM DA INSPEÇÃO] ---\n")
91
+
92
+ def _to_5d(self, tensor: torch.Tensor, shape_5d: tuple) -> torch.Tensor:
93
+ """Converte um tensor 3D patchificado de volta para 5D."""
94
+ if tensor.ndim == 5:
95
+ return tensor
96
+ if tensor.ndim == 3 and shape_5d:
97
+ try:
98
+ b, c, f, h, w = shape_5d
99
+ return rearrange(tensor, "b (f h w) c -> b c f h w", c=c, f=f, h=h, w=w)
100
+ except Exception as e:
101
+ print(f" AVISO: Erro ao rearranjar tensor 3D para 5D: {e}. A visualização pode falhar.")
102
+ return None
103
+ return None
104
+
105
+ def _print_stats(self, prefix: str, tensor: torch.Tensor):
106
+ """Helper para imprimir estatísticas de um tensor."""
107
+ mean = tensor.mean().item()
108
+ std = tensor.std().item()
109
+ min_val = tensor.min().item()
110
+ max_val = tensor.max().item()
111
+ print(f" {prefix}: Shape={list(tensor.shape)}, Mean={mean:.4f}, Std={std:.4f}, Min={min_val:.4f}, Max={max_val:.4f}")
112
+
113
+ # Exemplo de como instanciar globalmente (se desejado)
114
+ # spy = SpyLatent()
115
+ # A melhor prática é instanciar dentro da sua classe principal, passando o VAE.