euiiiia commited on
Commit
e8cfb14
·
verified ·
1 Parent(s): 3018b1d

Update api/aduc_ltx_latent_patch.py

Browse files
Files changed (1) hide show
  1. api/aduc_ltx_latent_patch.py +36 -44
api/aduc_ltx_latent_patch.py CHANGED
@@ -12,25 +12,20 @@ from typing import Optional, List, Tuple
12
  from pathlib import Path
13
  import os
14
  import sys
 
15
 
16
- DEPS_DIR = Path("/data")
17
- LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
18
- def add_deps_to_path(repo_path: Path):
19
- """Adiciona o diretório do repositório ao sys.path para importações locais."""
20
- resolved_path = str(repo_path.resolve())
21
- if resolved_path not in sys.path:
22
- sys.path.insert(0, resolved_path)
23
- if LTXV_DEBUG:
24
- print(f"[DEBUG] Adicionado ao sys.path: {resolved_path}")
25
-
26
- # --- Execução da configuração inicial ---
27
- if not LTX_VIDEO_REPO_DIR.exists():
28
- _run_setup_script()
29
- add_deps_to_path(LTX_VIDEO_REPO_DIR)
30
 
31
 
32
  # Tenta importar as dependências necessárias do módulo original que será modificado.
33
- # Isso requer que o ambiente Python tenha o pacote `ltx_video` acessível em seu sys.path.
34
  try:
35
  from ltx_video.pipelines.pipeline_ltx_video import (
36
  LTXVideoPipeline,
@@ -42,17 +37,14 @@ try:
42
  except ImportError as e:
43
  print(f"FATAL ERROR: Could not import dependencies from 'ltx_video'. "
44
  f"Please ensure the environment is correctly set up. Error: {e}")
45
- # Interrompe a execução se as dependências essenciais não puderem ser encontradas.
46
  raise
47
 
48
  print("[INFO] Patch module 'aduc_ltx_latent_patch' loaded successfully.")
49
 
50
  # ==============================================================================
51
- # 1. NOVA DEFINIÇÃO DA DATACLASS `ConditioningItem`
52
  # ==============================================================================
53
 
54
- from dataclasses import dataclass
55
-
56
  @dataclass
57
  class PatchedConditioningItem:
58
  """
@@ -103,7 +95,6 @@ def prepare_conditioning_with_latents(
103
  assert isinstance(self, LTXVideoPipeline), "This function must be called as a method of LTXVideoPipeline."
104
  assert isinstance(self.vae, CausalVideoAutoencoder), "VAE must be of type CausalVideoAutoencoder."
105
 
106
- # Se não há itens de condicionamento, apenas patchifica os latentes e retorna.
107
  if not conditioning_items:
108
  init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
109
  init_pixel_coords = latent_to_pixel_coords(
@@ -112,7 +103,6 @@ def prepare_conditioning_with_latents(
112
  )
113
  return init_latents, init_pixel_coords, None, 0
114
 
115
- # Inicializa tensores para acumular resultados
116
  init_conditioning_mask = torch.zeros(
117
  init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device
118
  )
@@ -124,36 +114,52 @@ def prepare_conditioning_with_latents(
124
  for item in conditioning_items:
125
  item_latents: Tensor
126
 
127
- # --- LÓGICA CENTRAL DO PATCH ---
128
  if item.latents is not None:
129
- # 1. Se latentes pré-calculados existem, use-os diretamente.
130
  item_latents = item.latents.to(dtype=init_latents.dtype, device=init_latents.device)
131
  if item_latents.ndim != 5:
132
  raise ValueError(f"Latents must have 5 dimensions (b, c, f, h, w), but got {item_latents.ndim}")
133
  elif item.media_item is not None:
134
- # 2. Caso contrário, volte para o fluxo original de codificação da VAE.
135
  resized_item = self._resize_conditioning_item(item, height, width)
136
  media_item = resized_item.media_item
137
  assert media_item.ndim == 5, f"media_item must have 5 dims, but got {media_item.ndim}"
138
-
139
  item_latents = vae_encode(
140
  media_item.to(dtype=self.vae.dtype, device=self.vae.device),
141
  self.vae,
142
  vae_per_channel_normalize=vae_per_channel_normalize,
143
  ).to(dtype=init_latents.dtype)
144
  else:
145
- # Este caso é prevenido pelo __post_init__ do dataclass, mas é bom ter uma checagem.
146
  raise ValueError("ConditioningItem is invalid: it has neither 'latents' nor 'media_item'.")
147
- # --- FIM DA LÓGICA DO PATCH ---
148
 
149
  media_frame_number = item.media_frame_number
150
  strength = item.conditioning_strength
151
 
152
- # O resto da lógica da função original é aplicado sobre `item_latents`.
153
  if media_frame_number == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  item_latents, l_x, l_y = self._get_latent_spatial_position(
155
- item_latents, item, height, width, strip_latent_border=True
156
  )
 
 
157
  _, _, f_l, h_l, w_l = item_latents.shape
158
  init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = torch.lerp(
159
  init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], item_latents, strength
@@ -186,7 +192,6 @@ def prepare_conditioning_with_latents(
186
  extra_conditioning_pixel_coords.append(pixel_coords)
187
  extra_conditioning_mask.append(conditioning_mask)
188
 
189
- # Patchifica os latentes principais e a máscara de condicionamento
190
  init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
191
  init_pixel_coords = latent_to_pixel_coords(
192
  init_latent_coords, self.vae,
@@ -195,7 +200,6 @@ def prepare_conditioning_with_latents(
195
  init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
196
  init_conditioning_mask = init_conditioning_mask.squeeze(-1)
197
 
198
- # Concatena os latentes extras (se houver)
199
  if extra_conditioning_latents:
200
  init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
201
  init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
@@ -208,7 +212,6 @@ def prepare_conditioning_with_latents(
208
 
209
  return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
210
 
211
-
212
  # ==============================================================================
213
  # 3. CLASSE DO MONKEY PATCHER
214
  # ==============================================================================
@@ -216,10 +219,6 @@ def prepare_conditioning_with_latents(
216
  class LTXLatentConditioningPatch:
217
  """
218
  Classe estática para aplicar e reverter o monkey patch na pipeline LTX-Video.
219
-
220
- Esta classe substitui o método `prepare_conditioning` da `LTXVideoPipeline`
221
- pela versão otimizada que suporta latentes pré-calculados, e implicitamente
222
- requer o uso da `PatchedConditioningItem`.
223
  """
224
  _original_prepare_conditioning = None
225
  _is_patched = False
@@ -228,21 +227,14 @@ class LTXLatentConditioningPatch:
228
  def apply():
229
  """
230
  Aplica o monkey patch à classe `LTXVideoPipeline`.
231
-
232
- Guarda o método original e o substitui pela nova implementação.
233
- É idempotente; aplicar múltiplas vezes não causa efeito adicional.
234
  """
235
  if LTXLatentConditioningPatch._is_patched:
236
  print("[WARNING] LTXLatentConditioningPatch has already been applied. Ignoring.")
237
  return
238
 
239
  print("[INFO] Applying monkey patch for latent-based conditioning...")
240
-
241
- # Guarda a implementação original para permitir a reversão.
242
  LTXLatentConditioningPatch._original_prepare_conditioning = LTXVideoPipeline.prepare_conditioning
243
-
244
- # Substitui o método na classe LTXVideoPipeline.
245
- # Todas as instâncias futuras e existentes da classe usarão este novo método.
246
  LTXVideoPipeline.prepare_conditioning = prepare_conditioning_with_latents
247
 
248
  LTXLatentConditioningPatch._is_patched = True
 
12
  from pathlib import Path
13
  import os
14
  import sys
15
+ from dataclasses import dataclass, replace
16
 
17
+ # --- CONFIGURAÇÃO DE PATH (Assume que LTXV_DEBUG e _run_setup_script existem no escopo que carrega este módulo) ---
18
+ # DEPS_DIR = Path("/data")
19
+ # LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
20
+ # def add_deps_to_path(repo_path: Path):
21
+ # """Adiciona o diretório do repositório ao sys.path para importações locais."""
22
+ # resolved_path = str(repo_path.resolve())
23
+ # if resolved_path not in sys.path:
24
+ # sys.path.insert(0, resolved_path)
25
+ # add_deps_to_path(LTX_VIDEO_REPO_DIR)
 
 
 
 
 
26
 
27
 
28
  # Tenta importar as dependências necessárias do módulo original que será modificado.
 
29
  try:
30
  from ltx_video.pipelines.pipeline_ltx_video import (
31
  LTXVideoPipeline,
 
37
  except ImportError as e:
38
  print(f"FATAL ERROR: Could not import dependencies from 'ltx_video'. "
39
  f"Please ensure the environment is correctly set up. Error: {e}")
 
40
  raise
41
 
42
  print("[INFO] Patch module 'aduc_ltx_latent_patch' loaded successfully.")
43
 
44
  # ==============================================================================
45
+ # 1. NOVA DEFINIÇÃO DA DATACLASS `PatchedConditioningItem`
46
  # ==============================================================================
47
 
 
 
48
  @dataclass
49
  class PatchedConditioningItem:
50
  """
 
95
  assert isinstance(self, LTXVideoPipeline), "This function must be called as a method of LTXVideoPipeline."
96
  assert isinstance(self.vae, CausalVideoAutoencoder), "VAE must be of type CausalVideoAutoencoder."
97
 
 
98
  if not conditioning_items:
99
  init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
100
  init_pixel_coords = latent_to_pixel_coords(
 
103
  )
104
  return init_latents, init_pixel_coords, None, 0
105
 
 
106
  init_conditioning_mask = torch.zeros(
107
  init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device
108
  )
 
114
  for item in conditioning_items:
115
  item_latents: Tensor
116
 
 
117
  if item.latents is not None:
 
118
  item_latents = item.latents.to(dtype=init_latents.dtype, device=init_latents.device)
119
  if item_latents.ndim != 5:
120
  raise ValueError(f"Latents must have 5 dimensions (b, c, f, h, w), but got {item_latents.ndim}")
121
  elif item.media_item is not None:
 
122
  resized_item = self._resize_conditioning_item(item, height, width)
123
  media_item = resized_item.media_item
124
  assert media_item.ndim == 5, f"media_item must have 5 dims, but got {media_item.ndim}"
 
125
  item_latents = vae_encode(
126
  media_item.to(dtype=self.vae.dtype, device=self.vae.device),
127
  self.vae,
128
  vae_per_channel_normalize=vae_per_channel_normalize,
129
  ).to(dtype=init_latents.dtype)
130
  else:
 
131
  raise ValueError("ConditioningItem is invalid: it has neither 'latents' nor 'media_item'.")
 
132
 
133
  media_frame_number = item.media_frame_number
134
  strength = item.conditioning_strength
135
 
 
136
  if media_frame_number == 0:
137
+ # --- INÍCIO DA MODIFICAÇÃO ---
138
+ # Se `item.media_item` for None (nosso caso de uso otimizado), a função original `_get_latent_spatial_position`
139
+ # quebraria. Para evitar isso, criamos um item temporário com um tensor de placeholder que contém
140
+ # as informações de dimensão corretas, inferidas a partir dos próprios latentes.
141
+
142
+ item_for_spatial_position = item
143
+ if item.media_item is None:
144
+ # Infere as dimensões em pixels a partir da forma dos latentes
145
+ latent_h, latent_w = item_latents.shape[-2:]
146
+ pixel_h = latent_h * self.vae_scale_factor
147
+ pixel_w = latent_w * self.vae_scale_factor
148
+
149
+ # Cria um tensor de placeholder com o shape esperado (o conteúdo não importa)
150
+ placeholder_media_item = torch.empty(
151
+ (1, 1, 1, pixel_h, pixel_w), device=item_latents.device, dtype=item_latents.dtype
152
+ )
153
+
154
+ # Usa `dataclasses.replace` para criar uma cópia temporária do item com o placeholder
155
+ item_for_spatial_position = replace(item, media_item=placeholder_media_item)
156
+
157
+ # Chama a função original com um item que ela pode processar sem erro
158
  item_latents, l_x, l_y = self._get_latent_spatial_position(
159
+ item_latents, item_for_spatial_position, height, width, strip_latent_border=True
160
  )
161
+ # --- FIM DA MODIFICAÇÃO ---
162
+
163
  _, _, f_l, h_l, w_l = item_latents.shape
164
  init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = torch.lerp(
165
  init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], item_latents, strength
 
192
  extra_conditioning_pixel_coords.append(pixel_coords)
193
  extra_conditioning_mask.append(conditioning_mask)
194
 
 
195
  init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
196
  init_pixel_coords = latent_to_pixel_coords(
197
  init_latent_coords, self.vae,
 
200
  init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
201
  init_conditioning_mask = init_conditioning_mask.squeeze(-1)
202
 
 
203
  if extra_conditioning_latents:
204
  init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
205
  init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
 
212
 
213
  return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
214
 
 
215
  # ==============================================================================
216
  # 3. CLASSE DO MONKEY PATCHER
217
  # ==============================================================================
 
219
  class LTXLatentConditioningPatch:
220
  """
221
  Classe estática para aplicar e reverter o monkey patch na pipeline LTX-Video.
 
 
 
 
222
  """
223
  _original_prepare_conditioning = None
224
  _is_patched = False
 
227
  def apply():
228
  """
229
  Aplica o monkey patch à classe `LTXVideoPipeline`.
 
 
 
230
  """
231
  if LTXLatentConditioningPatch._is_patched:
232
  print("[WARNING] LTXLatentConditioningPatch has already been applied. Ignoring.")
233
  return
234
 
235
  print("[INFO] Applying monkey patch for latent-based conditioning...")
236
+
 
237
  LTXLatentConditioningPatch._original_prepare_conditioning = LTXVideoPipeline.prepare_conditioning
 
 
 
238
  LTXVideoPipeline.prepare_conditioning = prepare_conditioning_with_latents
239
 
240
  LTXLatentConditioningPatch._is_patched = True