eeuuia commited on
Commit
460fa35
·
verified ·
1 Parent(s): ae38dbc

Update api/ltx/ltx_aduc_manager.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_aduc_manager.py +215 -132
api/ltx/ltx_aduc_manager.py CHANGED
@@ -1,160 +1,243 @@
1
  # FILE: api/ltx/ltx_aduc_manager.py
2
- # DESCRIPTION: The "secret weapon". A pool manager for LTX that applies
3
- # runtime patches to the pipeline for full control and ADUC-SDR compatibility.
4
 
5
  import logging
6
- from typing import Dict, List, Optional, Tuple, Union
7
- from dataclasses import dataclass
8
  import torch
9
- from diffusers.utils.torch_utils import randn_tensor
10
  import sys
11
  from pathlib import Path
12
- import os
13
- import random
14
- import yaml
 
15
 
16
- LTX_REPO_ID = "Lightricks/LTX-Video"
17
- DEPS_DIR = Path("/data")
18
- LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
19
- RESULTS_DIR = Path("/app/output")
20
-
21
- # --- Importações da nossa arquitetura ---
22
  from managers.gpu_manager import gpu_manager
23
- from api.ltx.ltx_utils import build_ltx_pipeline_on_cpu
24
 
 
 
25
  def add_deps_to_path():
26
- """
27
- Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
28
- bibliotecas possam ser importadas.
29
- """
30
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
31
  if repo_path not in sys.path:
32
  sys.path.insert(0, repo_path)
33
- logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
34
-
35
- # Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
36
  add_deps_to_path()
37
- from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
38
-
39
- # --- Definição dos nossos Data Classes ---
40
- @dataclass
41
- class ConditioningItem:
42
- pixel_tensor: torch.Tensor # Sempre um tensor de pixel
43
- media_frame_number: int
44
- conditioning_strength: float
45
 
46
- @dataclass
47
- class LatentConditioningItem:
48
- latent_tensor: torch.Tensor # Sempre um tensor latente
49
- media_frame_number: int
50
- conditioning_strength: float
51
 
52
  # ==============================================================================
53
- # --- O MONKEY PATCH ---
54
- # Esta é a nossa versão customizada de `prepare_conditioning`
55
  # ==============================================================================
56
- def _aduc_prepare_conditioning_patch(
57
- self: "LTXVideoPipeline",
58
- conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]],
59
- init_latents: torch.Tensor,
60
- num_frames: int,
61
- height: int,
62
- width: int,
63
- vae_per_channel_normalize: bool = False,
64
- generator=None,
65
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
66
- if not conditioning_items:
67
- init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
68
- init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
69
- return init_latents, init_pixel_coords, None, 0
70
-
71
- init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
72
- extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
73
- extra_conditioning_num_latents = 0
74
-
75
- for item in conditioning_items:
76
- if not isinstance(item, LatentConditioningItem):
77
- logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.")
78
- continue
79
-
80
- media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
81
- media_frame_number, strength = item.media_frame_number, item.conditioning_strength
82
-
83
- if media_frame_number == 0:
84
- f_l, h_l, w_l = media_item_latents.shape[-3:]
85
- init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
86
- init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
87
- else:
88
- noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
89
- media_item_latents = torch.lerp(noise, media_item_latents, strength)
90
- patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
91
- pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
92
- pixel_coords[:, 0] += media_frame_number
93
- extra_conditioning_num_latents += patched_latents.shape[1]
94
- new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
95
- extra_conditioning_latents.append(patched_latents)
96
- extra_conditioning_pixel_coords.append(pixel_coords)
97
- extra_conditioning_mask.append(new_mask)
98
-
99
- init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
100
- init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
101
- init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
102
- init_conditioning_mask = init_conditioning_mask.squeeze(-1)
103
-
104
- if extra_conditioning_latents:
105
- init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
106
- init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
107
- init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
108
-
109
- return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
110
-
111
-
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  # ==============================================================================
115
- # --- LTX Worker e Pool Manager ---
116
  # ==============================================================================
 
 
 
117
 
118
- class LTXWorker:
119
- """Gerencia uma instância do LTX Pipeline em um par de GPUs (main + vae)."""
120
- def __init__(self, main_device: str, vae_device: str, config: dict):
121
- self.main_device = torch.device(main_device)
122
- self.vae_device = torch.device(vae_device)
123
- self.config = config
124
- self.pipeline: LTXVideoPipeline = None
125
-
126
- self._load_and_patch_pipeline()
127
 
128
- def _load_and_patch_pipeline(self):
129
- logging.info(f"[LTXWorker-{self.main_device}] Carregando pipeline LTX para a CPU...")
130
- self.pipeline, _ = build_ltx_pipeline_on_cpu(self.config)
131
 
132
- logging.info(f"[LTXWorker-{self.main_device}] Movendo pipeline para GPUs (Main: {self.main_device}, VAE: {self.vae_device})...")
133
- self.pipeline.to(self.main_device)
134
- self.pipeline.vae.to(self.vae_device)
135
-
136
- logging.info(f"[LTXWorker-{self.main_device}] Aplicando patch ADUC-SDR na função 'prepare_conditioning'...")
137
- # A "mágica" do monkey patching acontece aqui
138
- self.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(self.pipeline, LTXVideoPipeline)
139
- logging.info(f"[LTXWorker-{self.main_device}] ✅ Pipeline 'quente', corrigido e pronto.")
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
 
 
 
 
 
 
 
 
 
142
 
143
- class LTXAducManager:
144
- def __init__(self):
145
- main_device = gpu_manager.get_ltx_device()
146
- vae_device = gpu_manager.get_ltx_vae_device()
147
- # Em uma arquitetura futura, poderíamos ter múltiplos workers. Por enquanto, temos um.
148
- self.worker = LTXWorker(str(main_device), str(vae_device), self.load_config())
149
-
150
- def load_config(self) -> Dict:
151
- """Loads the YAML configuration file."""
152
- config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
153
- with open(config_path, "r") as file:
154
- return yaml.safe_load(file)
155
-
156
- def get_pipeline(self) -> LTXVideoPipeline:
157
- return self.worker.pipeline
158
-
159
- # Instância Singleton
160
- ltx_aduc_manager = LTXAducManager()
 
1
  # FILE: api/ltx/ltx_aduc_manager.py
2
+ # DESCRIPTION: An advanced, fault-tolerant pool manager for LTX and VAE workers.
3
+ # It handles job queuing, load balancing, and health monitoring for production-grade stability.
4
 
5
  import logging
 
 
6
  import torch
 
7
  import sys
8
  from pathlib import Path
9
+ import threading
10
+ import queue
11
+ import time
12
+ from typing import List, Optional, Callable, Any, Tuple
13
 
14
+ # Imports dos builders e do gpu_manager
15
+ from api.ltx.ltx_utils import get_main_ltx_pipeline, get_main_vae
 
 
 
 
16
  from managers.gpu_manager import gpu_manager
 
17
 
18
+ # --- Adiciona o path do LTX-Video para importação de tipos ---
19
+ LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
20
  def add_deps_to_path():
 
 
 
 
21
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
22
  if repo_path not in sys.path:
23
  sys.path.insert(0, repo_path)
 
 
 
24
  add_deps_to_path()
 
 
 
 
 
 
 
 
25
 
26
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
27
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
 
 
 
28
 
29
  # ==============================================================================
30
+ # --- CLASSES DE WORKER (Especialistas em Tarefas) ---
 
31
  # ==============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ class BaseWorker(threading.Thread):
34
+ """Classe base para nossos workers com gerenciamento de estado e saúde."""
35
+ def __init__(self, worker_id: int, device: torch.device):
36
+ super().__init__()
37
+ self.worker_id = worker_id
38
+ self.device = device
39
+ self.is_healthy = False
40
+ self.is_busy = False
41
+ self.daemon = True # Permite que o programa principal saia
42
+
43
+ def run(self):
44
+ """O loop de vida do worker, responsável por carregar os modelos."""
45
+ try:
46
+ self._load_models()
47
+ self.is_healthy = True
48
+ logging.info(f"✅ Worker {self.worker_id} ({self.__class__.__name__}) on {self.device} is healthy and ready.")
49
+ except Exception:
50
+ self.is_healthy = False
51
+ logging.error(f"❌ Worker {self.worker_id} on {self.device} FAILED to initialize!", exc_info=True)
52
+
53
+ def _load_models(self):
54
+ """Método a ser implementado pelas classes filhas."""
55
+ raise NotImplementedError
56
+
57
+ def get_status(self) -> Tuple[bool, bool]:
58
+ """Retorna (is_healthy, is_busy)."""
59
+ return self.is_healthy, self.is_busy
60
+
61
+ class LTXMainWorker(BaseWorker):
62
+ """Worker especialista para o pipeline principal do LTX."""
63
+ def __init__(self, worker_id: int, device: torch.device):
64
+ super().__init__(worker_id, device)
65
+ self.pipeline: Optional[LTXVideoPipeline] = None
66
+
67
+ def _load_models(self):
68
+ logging.info(f"[LTXWorker-{self.worker_id}] Loading models to CPU...")
69
+ self.pipeline = get_main_ltx_pipeline()
70
+ logging.info(f"[LTXWorker-{self.worker_id}] Moving pipeline to {self.device}...")
71
+ self.pipeline.to(self.device)
72
+
73
+ def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any:
74
+ """Executa um trabalho, gerenciando o estado 'busy'."""
75
+ self.is_busy = True
76
+ logging.info(f"Worker {self.worker_id} (LTX) starting job: {job_func.__name__}")
77
+ try:
78
+ result = job_func(self.pipeline, *args, **kwargs)
79
+ logging.info(f"Worker {self.worker_id} (LTX) finished job successfully.")
80
+ return result
81
+ except Exception as e:
82
+ logging.error(f"Worker {self.worker_id} (LTX) job failed!", exc_info=True)
83
+ self.is_healthy = False # Falha em um job marca o worker como não saudável
84
+ raise
85
+ finally:
86
+ self.is_busy = False
87
+
88
+ class VAEWorker(BaseWorker):
89
+ """Worker especialista para o modelo VAE."""
90
+ def __init__(self, worker_id: int, device: torch.device):
91
+ super().__init__(worker_id, device)
92
+ self.vae: Optional[CausalVideoAutoencoder] = None
93
+
94
+ def _load_models(self):
95
+ logging.info(f"[VAEWorker-{self.worker_id}] Loading VAE model to CPU...")
96
+ self.vae = get_main_vae()
97
+ logging.info(f"[VAEWorker-{self.worker_id}] Moving VAE to {self.device}...")
98
+ self.vae.to(self.device)
99
+ self.vae.eval()
100
+
101
+ def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any:
102
+ """Executa um trabalho, gerenciando o estado 'busy'."""
103
+ self.is_busy = True
104
+ logging.info(f"Worker {self.worker_id} (VAE) starting job: {job_func.__name__}")
105
+ try:
106
+ result = job_func(self.vae, *args, **kwargs)
107
+ logging.info(f"Worker {self.worker_id} (VAE) finished job successfully.")
108
+ return result
109
+ except Exception as e:
110
+ logging.error(f"Worker {self.worker_id} (VAE) job failed!", exc_info=True)
111
+ self.is_healthy = False
112
+ raise
113
+ finally:
114
+ self.is_busy = False
115
 
116
  # ==============================================================================
117
+ # --- O GERENCIADOR DE POOL AVANÇADO (SINGLETON) ---
118
  # ==============================================================================
119
+ class LTXAducManager:
120
+ _instance = None
121
+ _initialized = False
122
 
123
+ def __new__(cls, *args, **kwargs):
124
+ if cls._instance is None:
125
+ cls._instance = super().__new__(cls)
126
+ return cls._instance
 
 
 
 
 
127
 
128
+ def __init__(self):
129
+ if self._initialized: return
 
130
 
131
+ logging.info("🏭 Initializing Advanced Pool Manager for LTX...")
132
+
133
+ self.ltx_workers: List[LTXMainWorker] = []
134
+ self.vae_workers: List[VAEWorker] = []
135
+ self.ltx_job_queue = queue.Queue()
136
+ self.vae_job_queue = queue.Queue()
137
+ self.pool_lock = threading.Lock()
 
138
 
139
+ self._initialize_workers()
140
+
141
+ # Inicia threads consumidores para processar as filas
142
+ self.ltx_dispatcher = threading.Thread(target=self._dispatch_jobs, args=(self.ltx_job_queue, self.ltx_workers), daemon=True)
143
+ self.vae_dispatcher = threading.Thread(target=self._dispatch_jobs, args=(self.vae_job_queue, self.vae_workers), daemon=True)
144
+ self.health_monitor = threading.Thread(target=self._health_check_loop, daemon=True)
145
+
146
+ self.ltx_dispatcher.start()
147
+ self.vae_dispatcher.start()
148
+ self.health_monitor.start()
149
+
150
+ self._initialized = True
151
+ logging.info("✅ Advanced Pool Manager is running with all threads started.")
152
+
153
+ def _initialize_workers(self):
154
+ """Cria e inicia os workers com base nas GPUs alocadas."""
155
+ # Supondo que gpu_manager agora tenha get_ltx_devices() e get_seedvr_devices() que retornam listas
156
+ ltx_gpus = gpu_manager.get_ltx_device() # Ajuste se o nome for diferente
157
+ vae_gpus = gpu_manager.get_ltx_vae_device() # Ajuste se o nome for diferente
158
+
159
+ with self.pool_lock:
160
+ for i, device_id in enumerate([ltx_gpus]): # Assumindo que retorna uma lista
161
+ worker = LTXMainWorker(worker_id=i, device=torch.device(f"cuda:{device_id}"))
162
+ self.ltx_workers.append(worker)
163
+ worker.start()
164
+
165
+ for i, device_id in enumerate([vae_gpus]): # Assumindo que retorna uma lista
166
+ worker = VAEWorker(worker_id=i, device=torch.device(f"cuda:{device_id}"))
167
+ self.vae_workers.append(worker)
168
+ worker.start()
169
+
170
+ def _get_available_worker(self, worker_pool: List[BaseWorker]) -> Optional[BaseWorker]:
171
+ """Encontra um worker saudável e desocupado no pool."""
172
+ with self.pool_lock:
173
+ for worker in worker_pool:
174
+ healthy, busy = worker.get_status()
175
+ if healthy and not busy:
176
+ return worker
177
+ return None
178
+
179
+ def _dispatch_jobs(self, job_queue: queue.Queue, worker_pool: List[BaseWorker]):
180
+ """Loop do thread consumidor que pega trabalhos da fila e os despacha."""
181
+ while True:
182
+ job_func, args, kwargs, future = job_queue.get()
183
+ worker = None
184
+ while worker is None:
185
+ worker = self._get_available_worker(worker_pool)
186
+ if worker is None:
187
+ time.sleep(0.1) # Espera por um worker ficar livre
188
+
189
+ try:
190
+ result = worker.execute(job_func, args, kwargs)
191
+ future.put(result)
192
+ except Exception as e:
193
+ future.put(e)
194
+
195
+ def _health_check_loop(self):
196
+ """Thread que periodicamente verifica e reinicia workers não saudáveis."""
197
+ while True:
198
+ time.sleep(30)
199
+ logging.debug("Running health check on all workers...")
200
+ with self.pool_lock:
201
+ for i, worker in enumerate(self.ltx_workers):
202
+ if not worker.is_alive() or not worker.is_healthy:
203
+ logging.warning(f"LTX Worker {worker.worker_id} on {worker.device} is UNHEALTHY. Restarting...")
204
+ new_worker = LTXMainWorker(worker.worker_id, worker.device)
205
+ self.ltx_workers[i] = new_worker
206
+ new_worker.start()
207
+ # Repetir o laço para VAE workers
208
+ for i, worker in enumerate(self.vae_workers):
209
+ if not worker.is_alive() or not worker.is_healthy:
210
+ logging.warning(f"VAE Worker {worker.worker_id} on {worker.device} is UNHEALTHY. Restarting...")
211
+ new_worker = VAEWorker(worker.worker_id, worker.device)
212
+ self.vae_workers[i] = new_worker
213
+ new_worker.start()
214
+
215
+ def submit_job(self, job_type: str, job_func: Callable, *args, **kwargs) -> Any:
216
+ """
217
+ Ponto de entrada público para submeter um trabalho ao pool.
218
+ Esta função é síncrona: ela espera pelo resultado.
219
+ """
220
+ if job_type not in ['ltx', 'vae']:
221
+ raise ValueError("Invalid job_type. Must be 'ltx' or 'vae'.")
222
+
223
+ job_queue = self.ltx_job_queue if job_type == 'ltx' else self.vae_job_queue
224
+ future = queue.Queue() # Usamos uma fila como um 'future' para obter o resultado de volta
225
 
226
+ job_queue.put((job_func, args, kwargs, future))
227
+
228
+ # Bloqueia e espera pelo resultado ser colocado no 'future' pelo dispatcher
229
+ result = future.get()
230
+
231
+ if isinstance(result, Exception):
232
+ raise result # Se o job falhou, re-lança a exceção no thread principal
233
+
234
+ return result
235
 
236
+ # ==============================================================================
237
+ # --- INSTANCIAÇÃO GLOBAL ---
238
+ # ==============================================================================
239
+ try:
240
+ ltx_aduc_manager = LTXAducManager()
241
+ except Exception as e:
242
+ logging.critical("CRITICAL ERROR: Failed to initialize the LTXAducManager pool.", exc_info=True)
243
+ ltx_aduc_manager = None