eeuuia commited on
Commit
7ea6441
·
verified ·
1 Parent(s): 66af513

Update api/ltx/ltx_aduc_manager.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_aduc_manager.py +90 -110
api/ltx/ltx_aduc_manager.py CHANGED
@@ -1,148 +1,107 @@
1
  # FILE: api/ltx/ltx_aduc_manager.py
2
- # DESCRIPTION: A singleton pool manager for the LTX-Video pipeline.
3
- # This module is the "secret weapon": it handles loading, device placement,
4
- # and applies a runtime monkey patch to the LTX pipeline for full control
5
- # and compatibility with the ADUC-SDR architecture, especially for latent conditioning.
6
-
7
 
8
  import time
9
- import os
10
  import yaml
11
- import json
12
  from pathlib import Path
13
  from typing import List, Optional, Tuple, Union, Dict
14
- from dataclasses import dataclass
15
  import threading
16
  import sys
17
- from pathlib import Path
18
  import torch
19
- from diffusers.utils.torch_utils import randn_tensor
20
- from huggingface_hub import hf_hub_download
21
 
22
- # --- Importações da nossa arquitetura ---
23
  from managers.gpu_manager import gpu_manager
24
  from api.ltx.ltx_utils import build_ltx_pipeline_on_cpu
25
  from utils.debug_utils import log_function_io
26
-
27
- LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
28
- LTX_REPO_ID = "Lightricks/LTX-Video"
29
- CACHE_DIR = os.environ.get("HF_HOME")
30
 
31
  # --- Importações da biblioteca LTX-Video ---
 
32
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
33
  if repo_path not in sys.path:
34
  sys.path.insert(0, repo_path)
35
 
36
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
37
- from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
 
38
  import logging
39
 
40
  import warnings
41
  warnings.filterwarnings("ignore", category=UserWarning)
42
  warnings.filterwarnings("ignore", category=FutureWarning)
43
  warnings.filterwarnings("ignore", message=".*")
44
- from huggingface_hub import logging as ll
45
- ll.set_verbosity_error()
46
- ll.set_verbosity_warning()
47
- ll.set_verbosity_info()
48
- ll.set_verbosity_debug()
 
49
 
50
  logger = logging.getLogger("AducDebug")
51
  logging.basicConfig(level=logging.DEBUG)
52
  logger.setLevel(logging.DEBUG)
53
 
54
- # ==============================================================================
55
- # --- DEFINIÇÃO DOS DATACLASSES DE CONDICIONAMENTO ADUC-SDR ---
56
- # ==============================================================================
57
-
58
- @dataclass
59
- class ConditioningItem:
60
- """Nosso Data Class para condicionamento com TENSORES DE PIXEL (de imagens)."""
61
- pixel_tensor: torch.Tensor
62
- media_frame_number: int
63
- conditioning_strength: float
64
-
65
- @dataclass
66
- class LatentConditioningItem:
67
- """Nossa "arma secreta": um Data Class para condicionamento com TENSORES LATENTES (de overlap)."""
68
- latent_tensor: torch.Tensor
69
- media_frame_number: int
70
- conditioning_strength: float
71
 
72
  # ==============================================================================
73
- # --- O MONKEY PATCH ---
74
- # Nossa versão customizada de `prepare_conditioning` que entende ambos os Data Classes.
75
  # ==============================================================================
76
 
77
  @log_function_io
78
  def _aduc_prepare_conditioning_patch(
79
  self: "LTXVideoPipeline",
80
- conditioning_items: Optional[List[Union[ConditioningItem, LatentConditioningItem]]],
81
  init_latents: torch.Tensor,
82
- num_frames: int, height: int, width: int, # Assinatura mantida para compatibilidade
 
 
83
  vae_per_channel_normalize: bool = False,
84
  generator=None,
85
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
86
-
87
- if not conditioning_items:
88
- latents, latent_coords = self.patchifier.patchify(latents=init_latents)
89
- pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
90
- return latents, pixel_coords, None, 0
91
-
92
- init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
93
- extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
94
- extra_conditioning_num_latents = 0
95
-
96
- for item in conditioning_items:
97
- strength = item.conditioning_strength
98
- media_frame_number = item.media_frame_number
99
-
100
- if isinstance(item, ConditioningItem):
101
- logging.debug("Patch ADUC: Processando ConditioningItem (pixels).")
102
- pixel_tensor_on_vae_device = item.pixel_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
103
- media_item_latents = vae_encode(pixel_tensor_on_vae_device, self.vae, vae_per_channel_normalize=vae_per_channel_normalize)
104
- media_item_latents = media_item_latents.to(device=init_latents.device, dtype=init_latents.dtype)
105
- elif isinstance(item, LatentConditioningItem):
106
- logging.debug("Patch ADUC: Processando LatentConditioningItem (latentes).")
107
- media_item_latents = item.latent_tensor.to(device=init_latents.device, dtype=init_latents.dtype)
108
- else:
109
- logging.warning(f"Patch ADUC: Item de condicionamento de tipo desconhecido '{type(item)}' será ignorado.")
110
- continue
111
-
112
- if media_frame_number == 0:
113
- f_l, h_l, w_l = media_item_latents.shape[-3:]
114
- init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
115
- init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
116
- else:
117
- noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
118
- media_item_latents = torch.lerp(noise, media_item_latents, strength)
119
- patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
120
- pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
121
- pixel_coords[:, 0] += media_frame_number
122
- extra_conditioning_num_latents += patched_latents.shape[1]
123
- new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
124
- extra_conditioning_latents.append(patched_latents)
125
- extra_conditioning_pixel_coords.append(pixel_coords)
126
- extra_conditioning_mask.append(new_mask)
127
-
128
- init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
129
- init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
130
- init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
131
- init_conditioning_mask = init_conditioning_mask.squeeze(-1)
132
-
133
- if extra_conditioning_latents:
134
- init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
135
- init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
136
- init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
137
-
138
- return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
139
 
140
  # ==============================================================================
141
  # --- LTX WORKER E POOL MANAGER ---
142
  # ==============================================================================
143
 
144
  class LTXWorker:
145
- """Gerencia uma instância do LTX Pipeline em um par de GPUs (main + vae)."""
 
 
 
146
  def __init__(self, main_device_str: str, vae_device_str: str, config: dict):
147
  self.main_device = torch.device(main_device_str)
148
  self.vae_device = torch.device(vae_device_str)
@@ -152,16 +111,27 @@ class LTXWorker:
152
 
153
  @log_function_io
154
  def _load_and_patch_pipeline(self):
 
 
 
155
  logging.info(f"[LTXWorker-{self.main_device}] Carregando pipeline LTX para a CPU...")
156
  self.pipeline, _ = build_ltx_pipeline_on_cpu(self.config)
 
157
  logging.info(f"[LTXWorker-{self.main_device}] Movendo pipeline para GPUs (Main: {self.main_device}, VAE: {self.vae_device})...")
158
- self.pipeline.to(self.main_device)
159
- self.pipeline.vae.to(self.vae_device)
160
- logging.info(f"[LTXWorker-{self.main_device}] Aplicando patch ADUC-SDR na função 'prepare_conditioning'...")
 
 
161
  self.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(self.pipeline, LTXVideoPipeline)
162
- logging.info(f"[LTXWorker-{self.main_device}] ✅ Pipeline 'quente', corrigido e pronto para uso.")
 
163
 
164
  class LtxAducManager:
 
 
 
 
165
  _instance = None
166
  _lock = threading.Lock()
167
 
@@ -173,28 +143,38 @@ class LtxAducManager:
173
  return cls._instance
174
 
175
  def __init__(self):
176
- if self._initialized: return
 
177
  with self._lock:
178
- if self._initialized: return
179
- logging.info("⚙️ Inicializando LTXPoolManager Singleton...")
 
180
  self.config = self._load_config()
181
  main_device_str = str(gpu_manager.get_ltx_device())
182
  vae_device_str = str(gpu_manager.get_ltx_vae_device())
 
 
183
  self.worker = LTXWorker(main_device_str, vae_device_str, self.config)
 
184
  self._initialized = True
185
- logging.info("✅ LTXPoolManager pronto.")
186
 
187
- @log_function_io
188
  def _load_config(self) -> Dict:
189
  """Carrega a configuração YAML principal do LTX."""
 
190
  config_path = Path("/data/LTX-Video/configs/ltxv-13b-0.9.8-dev-fp8.yaml")
191
  with open(config_path, "r") as file:
192
  return yaml.safe_load(file)
193
 
194
- @log_function_io
195
  def get_pipeline(self) -> LTXVideoPipeline:
196
- """Retorna a instância do pipeline, já carregada e corrigida."""
 
 
 
 
 
197
  return self.worker.pipeline
198
 
199
  # --- Instância Singleton Global ---
200
- ltx_aduc_manager = LtxAducManager()
 
 
1
  # FILE: api/ltx/ltx_aduc_manager.py
2
+ # DESCRIPTION: A singleton manager for the LTX-Video pipeline.
3
+ # This module loads the pipeline, places it on the correct devices, and applies a
4
+ # targeted runtime monkey patch to delegate conditioning tasks to the specialized
5
+ # VaeAducPipeline service, enabling full control for the ADUC-SDR architecture.
 
6
 
7
  import time
 
8
  import yaml
 
9
  from pathlib import Path
10
  from typing import List, Optional, Tuple, Union, Dict
 
11
  import threading
12
  import sys
 
13
  import torch
 
 
14
 
15
+ # --- Importações da arquitetura ADUC-SDR ---
16
  from managers.gpu_manager import gpu_manager
17
  from api.ltx.ltx_utils import build_ltx_pipeline_on_cpu
18
  from utils.debug_utils import log_function_io
19
+ # Importa o serviço VAE que fará o trabalho real
20
+ from api.ltx.vae_aduc_pipeline import vae_aduc_pipeline, LatentConditioningItem
 
 
21
 
22
  # --- Importações da biblioteca LTX-Video ---
23
+ LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
24
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
25
  if repo_path not in sys.path:
26
  sys.path.insert(0, repo_path)
27
 
28
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
29
+ # Importa o tipo original de conditioning item para type hinting
30
+ from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem as PipelineConditioningItem
31
  import logging
32
 
33
  import warnings
34
  warnings.filterwarnings("ignore", category=UserWarning)
35
  warnings.filterwarnings("ignore", category=FutureWarning)
36
  warnings.filterwarnings("ignore", message=".*")
37
+
38
+ try:
39
+ from huggingface_hub import logging as hf_logging
40
+ hf_logging.set_verbosity_error()
41
+ except ImportError:
42
+ pass
43
 
44
  logger = logging.getLogger("AducDebug")
45
  logging.basicConfig(level=logging.DEBUG)
46
  logger.setLevel(logging.DEBUG)
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # ==============================================================================
50
+ # --- O MONKEY PATCH DIRECIONADO E SIMPLES ---
 
51
  # ==============================================================================
52
 
53
  @log_function_io
54
  def _aduc_prepare_conditioning_patch(
55
  self: "LTXVideoPipeline",
56
+ conditioning_items: Optional[List[Union[PipelineConditioningItem, LatentConditioningItem]]],
57
  init_latents: torch.Tensor,
58
+ num_frames: int,
59
+ height: int,
60
+ width: int,
61
  vae_per_channel_normalize: bool = False,
62
  generator=None,
63
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
64
+ """
65
+ [PATCH] Substitui o método `prepare_conditioning` original da LTXVideoPipeline.
66
+
67
+ Esta função atua como um proxy (intermediário). Ela não contém lógica de processamento.
68
+ Em vez disso, ela delega 100% do trabalho para o `vae_aduc_pipeline`, que é o nosso
69
+ serviço especializado e otimizado para essa tarefa.
70
+ """
71
+ logging.debug(f"Patch ADUC: Interceptado 'prepare_conditioning'. Delegando para o serviço VaeAducPipeline.")
72
+
73
+ # 1. Chama o serviço especializado para fazer todo o trabalho pesado.
74
+ # O serviço VAE processa na sua própria GPU dedicada e retorna os tensores na CPU.
75
+ latents_cpu, coords_cpu, mask_cpu, num_latents = vae_aduc_pipeline.prepare_conditioning(
76
+ conditioning_items=conditioning_items,
77
+ init_latents=init_latents,
78
+ num_frames=num_frames,
79
+ height=height,
80
+ width=width,
81
+ vae_per_channel_normalize=vae_per_channel_normalize,
82
+ generator=generator,
83
+ )
84
+
85
+ # 2. Move os resultados da CPU para o dispositivo correto que a pipeline principal espera.
86
+ # O `init_latents.device` garante que estamos usando o dispositivo principal da pipeline (ex: 'cuda:0').
87
+ device = init_latents.device
88
+ latents = latents_cpu.to(device)
89
+ pixel_coords = coords_cpu.to(device)
90
+ conditioning_mask = mask_cpu.to(device) if mask_cpu is not None else None
91
+
92
+ # 3. Retorna os tensores prontos. A pipeline principal continua sua execução normalmente,
93
+ # sem saber que a lógica de condicionamento foi executada por um serviço externo.
94
+ return latents, pixel_coords, conditioning_mask, num_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  # ==============================================================================
97
  # --- LTX WORKER E POOL MANAGER ---
98
  # ==============================================================================
99
 
100
  class LTXWorker:
101
+ """
102
+ Gerencia uma instância única da LTXVideoPipeline, aplicando o patch
103
+ necessário durante a inicialização.
104
+ """
105
  def __init__(self, main_device_str: str, vae_device_str: str, config: dict):
106
  self.main_device = torch.device(main_device_str)
107
  self.vae_device = torch.device(vae_device_str)
 
111
 
112
  @log_function_io
113
  def _load_and_patch_pipeline(self):
114
+ """
115
+ Orquestra o carregamento da pipeline e a aplicação do monkey patch.
116
+ """
117
  logging.info(f"[LTXWorker-{self.main_device}] Carregando pipeline LTX para a CPU...")
118
  self.pipeline, _ = build_ltx_pipeline_on_cpu(self.config)
119
+
120
  logging.info(f"[LTXWorker-{self.main_device}] Movendo pipeline para GPUs (Main: {self.main_device}, VAE: {self.vae_device})...")
121
+ self.pipeline.to(self.main_device) # Move a maioria dos componentes
122
+ self.pipeline.vae.to(self.vae_device) # Move o VAE para sua GPU dedicada
123
+
124
+ logging.info(f"[LTXWorker-{self.main_device}] Aplicando patch ADUC-SDR em 'prepare_conditioning'...")
125
+ # A "mágica" simples e eficaz acontece aqui:
126
  self.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(self.pipeline, LTXVideoPipeline)
127
+
128
+ logging.info(f"[LTXWorker-{self.main_device}] ✅ Pipeline 'quente', corrigida e pronta para uso.")
129
 
130
  class LtxAducManager:
131
+ """
132
+ Implementa o padrão Singleton para garantir que a pipeline LTX seja
133
+ carregada e corrigida apenas uma vez durante a vida útil da aplicação.
134
+ """
135
  _instance = None
136
  _lock = threading.Lock()
137
 
 
143
  return cls._instance
144
 
145
  def __init__(self):
146
+ if hasattr(self, '_initialized') and self._initialized:
147
+ return
148
  with self._lock:
149
+ if hasattr(self, '_initialized') and self._initialized:
150
+ return
151
+ logging.info("⚙️ Inicializando LtxAducManager Singleton...")
152
  self.config = self._load_config()
153
  main_device_str = str(gpu_manager.get_ltx_device())
154
  vae_device_str = str(gpu_manager.get_ltx_vae_device())
155
+
156
+ # Cria o worker que irá carregar e patchear a pipeline
157
  self.worker = LTXWorker(main_device_str, vae_device_str, self.config)
158
+
159
  self._initialized = True
160
+ logging.info("✅ LtxAducManager pronto.")
161
 
 
162
  def _load_config(self) -> Dict:
163
  """Carrega a configuração YAML principal do LTX."""
164
+ # TODO: Considerar mover o path da configuração para uma variável de ambiente ou config central
165
  config_path = Path("/data/LTX-Video/configs/ltxv-13b-0.9.8-dev-fp8.yaml")
166
  with open(config_path, "r") as file:
167
  return yaml.safe_load(file)
168
 
 
169
  def get_pipeline(self) -> LTXVideoPipeline:
170
+ """
171
+ Ponto de acesso principal para obter a instância da pipeline.
172
+
173
+ Returns:
174
+ LTXVideoPipeline: A instância única, carregada e já corrigida.
175
+ """
176
  return self.worker.pipeline
177
 
178
  # --- Instância Singleton Global ---
179
+ # Outras partes do código importarão esta instância para interagir com a pipeline.
180
+ ltx_aduc_manager = LtxAducManager()