eeuuia commited on
Commit
52c58b6
·
verified ·
1 Parent(s): 24228a2

Update api/ltx/ltx_aduc_manager.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_aduc_manager.py +90 -142
api/ltx/ltx_aduc_manager.py CHANGED
@@ -1,6 +1,7 @@
1
  # FILE: api/ltx/ltx_aduc_manager.py
2
- # DESCRIPTION: An advanced, fault-tolerant pool manager for LTX and VAE workers.
3
- # It is self-contained, orchestrating the construction, health, and job dispatching for its workers.
 
4
 
5
  import logging
6
  import torch
@@ -10,12 +11,13 @@ import threading
10
  import queue
11
  import time
12
  import yaml
 
13
  from huggingface_hub import hf_hub_download
14
  from typing import List, Optional, Callable, Any, Tuple, Dict
15
- import os
16
  # --- Importa o gerenciador de GPUs e o builder de baixo nível ---
17
  from managers.gpu_manager import gpu_manager
18
- from api.ltx.ltx_utils import build_components_on_cpu
19
 
20
  # --- Adiciona o path do LTX-Video para importação de tipos ---
21
  LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
@@ -26,174 +28,138 @@ def add_deps_to_path():
26
  add_deps_to_path()
27
 
28
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
29
- from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
30
 
31
  # ==============================================================================
32
- # --- CLASSES DE WORKER (Especialistas em Tarefas) ---
33
  # ==============================================================================
34
 
35
- class BaseWorker(threading.Thread):
36
- """Classe base para nossos workers com gerenciamento de estado e saúde."""
37
- def __init__(self, worker_id: int, device: torch.device, model: torch.nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  super().__init__()
39
  self.worker_id = worker_id
40
- self.device = device
41
- self.model = model
42
  self.is_healthy = False
43
  self.is_busy = False
44
  self.daemon = True
 
45
 
46
  def run(self):
47
- """O loop de vida do worker, responsável por mover o modelo para a GPU."""
48
- if True:
49
- logging.info(f"Worker {self.worker_id} ({self.__class__.__name__}) moving model to {self.device}...")
50
- self.model.to(self.device)
51
- self._post_load_hook()
 
 
 
 
 
 
 
52
  self.is_healthy = True
53
- logging.info(f"✅ Worker {self.worker_id} ({self.__class__.__name__}) on {self.device} is healthy and ready.")
54
- #except Exception:
55
  self.is_healthy = False
56
- logging.error(f"❌ Worker {self.worker_id} on {self.device} FAILED to initialize!", exc_info=True)
57
-
58
- def _post_load_hook(self):
59
- """Gancho para ações pós-carregamento, como chamar .eval()."""
60
- pass
61
-
62
- def get_status(self) -> Tuple[bool, bool]:
63
- return self.is_healthy, self.is_busy
64
-
65
- class LTXMainWorker(BaseWorker):
66
- """Worker especialista para o pipeline principal do LTX."""
67
- def __init__(self, worker_id: int, device: torch.device, pipeline: LTXVideoPipeline):
68
- super().__init__(worker_id, device, pipeline)
69
- self.pipeline = self.model
70
- self.autocast_dtype: torch.dtype = torch.float32
71
-
72
- def _post_load_hook(self):
73
- self._set_precision_policy()
74
 
75
  def _set_precision_policy(self):
76
- if True: #try:
 
77
  config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
78
- with open(config_path, "r") as file:
79
- config = yaml.safe_load(file)
80
  precision = str(config.get("precision", "bfloat16")).lower()
81
  if precision in ["float8_e4m3fn", "bfloat16"]: self.autocast_dtype = torch.bfloat16
82
  elif precision == "mixed_precision": self.autocast_dtype = torch.float16
83
- logging.info(f"[LTXWorker-{self.worker_id}] Autocast precision policy set to {self.autocast_dtype}")
84
- #except Exception as e:
85
- #logging.warning(f"[LTXWorker-{self.worker_id}] Could not set precision policy from config. Defaulting to float32. Error: {e}")
86
 
87
  def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any:
88
  self.is_busy = True
89
- if True: #try:
 
90
  result = job_func(self.pipeline, self.autocast_dtype, *args, **kwargs)
91
  return result
92
- #except Exception:
93
- # self.is_healthy = False
94
- # raise
95
- #finally:
96
- # self.is_busy = False
97
-
98
- class VAEWorker(BaseWorker):
99
- """Worker especialista para o modelo VAE."""
100
- def __init__(self, worker_id: int, device: torch.device, vae: CausalVideoAutoencoder):
101
- super().__init__(worker_id, device, vae)
102
- self.vae = self.model
103
-
104
- def _post_load_hook(self):
105
- self.vae.eval()
106
-
107
- def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any:
108
- self.is_busy = True
109
- if True: #try:
110
- result = job_func(self.vae, *args, **kwargs)
111
- return result
112
- #except Exception:
113
- # self.is_healthy = False
114
- # raise
115
- #finally:
116
- # self.is_busy = False
117
 
118
  # ==============================================================================
119
- # --- O GERENCIADOR DE POOL AVANÇADO (SINGLETON) ---
120
  # ==============================================================================
121
  class LTXAducManager:
122
  _instance = None
123
  _initialized = False
124
 
125
  def __new__(cls, *args, **kwargs):
126
- if cls._instance is None:
127
- cls._instance = super().__new__(cls)
128
  return cls._instance
129
 
130
  def __init__(self):
131
  if self._initialized: return
132
 
133
- logging.info("🏭 Initializing Advanced Pool Manager for LTX...")
134
 
135
- self.ltx_workers: List[LTXMainWorker] = []
136
- self.vae_workers: List[VAEWorker] = []
137
- self.ltx_job_queue = queue.Queue()
138
- self.vae_job_queue = queue.Queue()
139
  self.pool_lock = threading.Lock()
140
 
141
- # Carrega os modelos na CPU antes de criar os workers
142
- self.main_pipeline, self.main_vae = self._load_components_once()
143
-
144
  self._initialize_workers()
145
 
146
- self.ltx_dispatcher = threading.Thread(target=self._dispatch_jobs, args=(self.ltx_job_queue, self.ltx_workers), daemon=True)
147
- self.vae_dispatcher = threading.Thread(target=self._dispatch_jobs, args=(self.vae_job_queue, self.vae_workers), daemon=True)
148
  self.health_monitor = threading.Thread(target=self._health_check_loop, daemon=True)
149
-
150
- self.ltx_dispatcher.start()
151
- self.vae_dispatcher.start()
152
  self.health_monitor.start()
153
 
154
  self._initialized = True
155
- logging.info("✅ Advanced Pool Manager is running with all threads started.")
156
-
157
- def _load_components_once(self) -> Tuple[LTXVideoPipeline, CausalVideoAutoencoder]:
158
- """Orquestra a construção de TODOS os componentes na CPU uma única vez."""
159
- logging.info("Manager loading all components onto CPU...")
160
- config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
161
- with open(config_path, "r") as file:
162
- config = yaml.safe_load(file)
163
-
164
- ckpt_path = hf_hub_download(repo_id="Lightricks/LTX-Video", filename=config["checkpoint_path"], cache_dir=os.environ.get("HF_HOME"))
165
- pipeline, vae = build_components_on_cpu(ckpt_path, config)
166
- logging.info("✅ All components loaded to CPU successfully.")
167
- return pipeline, vae
168
 
169
  def _initialize_workers(self):
170
- """Cria e inicia os workers, injetando os modelos já carregados."""
171
- ltx_device = gpu_manager.get_ltx_device()
172
- vae_device = gpu_manager.get_ltx_vae_device()
173
-
174
  with self.pool_lock:
175
- ltx_worker = LTXMainWorker(worker_id=0, device=ltx_device, pipeline=self.main_pipeline)
176
- self.ltx_workers.append(ltx_worker)
177
- ltx_worker.start()
178
-
179
- vae_worker = VAEWorker(worker_id=0, device=vae_device, vae=self.main_vae)
180
- self.vae_workers.append(vae_worker)
181
- vae_worker.start()
182
 
183
- def _get_available_worker(self, worker_pool: List[BaseWorker]) -> Optional[BaseWorker]:
184
  with self.pool_lock:
185
- for worker in worker_pool:
186
- healthy, busy = worker.get_status()
187
- if healthy and not busy: return worker
188
  return None
189
 
190
- def _dispatch_jobs(self, job_queue: queue.Queue, worker_pool: List[BaseWorker]):
191
  while True:
192
- job_func, args, kwargs, future = job_queue.get()
193
  worker = None
194
  while worker is None:
195
- worker = self._get_available_worker(worker_pool)
196
- if worker is None: time.sleep(0.1)
197
  try:
198
  result = worker.execute(job_func, args, kwargs)
199
  future.put(result)
@@ -204,37 +170,19 @@ class LTXAducManager:
204
  while True:
205
  time.sleep(30)
206
  with self.pool_lock:
207
- for i, worker in enumerate(self.ltx_workers):
208
- if not worker.is_alive() or not worker.is_healthy:
209
- logging.warning(f"LTX Worker {worker.worker_id} on {worker.device} is UNHEALTHY. Restarting...")
210
- new_worker = LTXMainWorker(worker.worker_id, worker.device, self.main_pipeline)
211
- self.ltx_workers[i] = new_worker
212
- new_worker.start()
213
-
214
- for i, worker in enumerate(self.vae_workers):
215
  if not worker.is_alive() or not worker.is_healthy:
216
- logging.warning(f"VAE Worker {worker.worker_id} on {worker.device} is UNHEALTHY. Restarting...")
217
- new_worker = VAEWorker(worker.worker_id, worker.device, self.main_vae)
218
- self.vae_workers[i] = new_worker
219
  new_worker.start()
220
 
221
- def submit_job(self, job_type: str, job_func: Callable, *args, **kwargs) -> Any:
222
- if job_type not in ['ltx', 'vae']:
223
- raise ValueError("Invalid job_type. Must be 'ltx' or 'vae'.")
224
-
225
- job_queue = self.ltx_job_queue if job_type == 'ltx' else self.vae_job_queue
226
  future = queue.Queue(1)
227
- job_queue.put((job_func, args, kwargs, future))
228
  result = future.get()
229
-
230
- if isinstance(result, Exception):
231
- raise result
232
-
233
  return result
234
 
235
  # --- INSTANCIAÇÃO GLOBAL ---
236
- #try:
237
- ltx_aduc_manager = LTXAducManager()
238
- #except Exception:
239
- # logging.critical("CRITICAL ERROR: Failed to initialize the LTXAducManager pool.", exc_info=True)
240
- # ltx_aduc_manager = None
 
1
  # FILE: api/ltx/ltx_aduc_manager.py
2
+ # DESCRIPTION: A simplified, robust pool manager for a unified LTX worker.
3
+ # This worker handles all tasks, including Transformer generation and VAE operations,
4
+ # while still respecting the GPU separation defined by the GPUManager.
5
 
6
  import logging
7
  import torch
 
11
  import queue
12
  import time
13
  import yaml
14
+ import os
15
  from huggingface_hub import hf_hub_download
16
  from typing import List, Optional, Callable, Any, Tuple, Dict
17
+
18
  # --- Importa o gerenciador de GPUs e o builder de baixo nível ---
19
  from managers.gpu_manager import gpu_manager
20
+ from api.ltx.ltx_utils import build_complete_pipeline_on_cpu, create_transformer
21
 
22
  # --- Adiciona o path do LTX-Video para importação de tipos ---
23
  LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
 
28
  add_deps_to_path()
29
 
30
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
 
31
 
32
  # ==============================================================================
33
+ # --- FUNÇÃO DE ORQUESTRAÇÃO DA CONSTRUÇÃO (Interna ao Manager) ---
34
  # ==============================================================================
35
 
36
+ def get_complete_pipeline() -> LTXVideoPipeline:
37
+ """
38
+ Orquestra a construção do pipeline LTX COMPLETO, incluindo o VAE, na CPU.
39
+ """
40
+ config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
41
+ with open(config_path, "r") as file:
42
+ config = yaml.safe_load(file)
43
+
44
+ ckpt_path = hf_hub_download(
45
+ repo_id="Lightricks/LTX-Video",
46
+ filename=config["checkpoint_path"],
47
+ cache_dir=os.environ.get("HF_HOME")
48
+ )
49
+ return build_complete_pipeline_on_cpu(ckpt_path, config)
50
+
51
+ # ==============================================================================
52
+ # --- CLASSE DE WORKER UNIFICADO ---
53
+ # ==============================================================================
54
+
55
+ class LTXWorker(threading.Thread):
56
+ """
57
+ Um worker unificado que gerencia uma instância completa do pipeline LTX.
58
+ Ele carrega o modelo e distribui seus componentes (Transformer/VAE) para as GPUs corretas.
59
+ """
60
+ def __init__(self, worker_id: int):
61
  super().__init__()
62
  self.worker_id = worker_id
63
+ self.pipeline: Optional[LTXVideoPipeline] = None
 
64
  self.is_healthy = False
65
  self.is_busy = False
66
  self.daemon = True
67
+ self.autocast_dtype: torch.dtype = torch.float32
68
 
69
  def run(self):
70
+ """Inicializa o worker: carrega o pipeline e o move para as GPUs."""
71
+ try:
72
+ self.pipeline = get_complete_pipeline()
73
+ self._set_precision_policy()
74
+
75
+ main_device = gpu_manager.get_ltx_device()
76
+ vae_device = gpu_manager.get_ltx_vae_device()
77
+
78
+ logging.info(f"[LTXWorker-{self.worker_id}] Moving components -> Main: {main_device}, VAE: {vae_device}")
79
+ self.pipeline.to(main_device) # Move tudo para a GPU principal primeiro
80
+ self.pipeline.vae.to(vae_device) # Move especificamente o VAE para sua GPU dedicada
81
+
82
  self.is_healthy = True
83
+ logging.info(f"✅ LTXWorker {self.worker_id} is healthy. Main on {main_device}, VAE on {vae_device}.")
84
+ except Exception:
85
  self.is_healthy = False
86
+ logging.error(f"❌ LTXWorker {self.worker_id} FAILED to initialize!", exc_info=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def _set_precision_policy(self):
89
+ """Define a política de precisão para operações de autocast."""
90
+ try:
91
  config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
92
+ with open(config_path, "r") as file: config = yaml.safe_load(file)
 
93
  precision = str(config.get("precision", "bfloat16")).lower()
94
  if precision in ["float8_e4m3fn", "bfloat16"]: self.autocast_dtype = torch.bfloat16
95
  elif precision == "mixed_precision": self.autocast_dtype = torch.float16
96
+ except Exception:
97
+ logging.warning(f"[LTXWorker-{self.worker_id}] Could not set precision policy, defaulting to float32.", exc_info=True)
 
98
 
99
  def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any:
100
  self.is_busy = True
101
+ try:
102
+ # O job recebe o pipeline completo e o dtype para o autocast
103
  result = job_func(self.pipeline, self.autocast_dtype, *args, **kwargs)
104
  return result
105
+ except Exception:
106
+ self.is_healthy = False
107
+ raise
108
+ finally:
109
+ self.is_busy = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  # ==============================================================================
112
+ # --- O GERENCIADOR DE POOL (SINGLETON) ---
113
  # ==============================================================================
114
  class LTXAducManager:
115
  _instance = None
116
  _initialized = False
117
 
118
  def __new__(cls, *args, **kwargs):
119
+ if cls._instance is None: cls._instance = super().__new__(cls)
 
120
  return cls._instance
121
 
122
  def __init__(self):
123
  if self._initialized: return
124
 
125
+ logging.info("🏭 Initializing Simplified Pool Manager for LTX...")
126
 
127
+ self.workers: List[LTXWorker] = []
128
+ self.job_queue = queue.Queue()
 
 
129
  self.pool_lock = threading.Lock()
130
 
 
 
 
131
  self._initialize_workers()
132
 
133
+ self.dispatcher = threading.Thread(target=self._dispatch_jobs, daemon=True)
 
134
  self.health_monitor = threading.Thread(target=self._health_check_loop, daemon=True)
135
+ self.dispatcher.start()
 
 
136
  self.health_monitor.start()
137
 
138
  self._initialized = True
139
+ logging.info("✅ Simplified Pool Manager is running.")
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  def _initialize_workers(self):
 
 
 
 
142
  with self.pool_lock:
143
+ # Por enquanto, criamos um único worker unificado.
144
+ # No futuro, este loop pode criar múltiplos workers se houver mais GPUs.
145
+ worker = LTXWorker(worker_id=0)
146
+ self.workers.append(worker)
147
+ worker.start()
 
 
148
 
149
+ def _get_available_worker(self) -> Optional[LTXWorker]:
150
  with self.pool_lock:
151
+ for worker in self.workers:
152
+ if worker.is_healthy and not worker.is_busy:
153
+ return worker
154
  return None
155
 
156
+ def _dispatch_jobs(self):
157
  while True:
158
+ job_func, args, kwargs, future = self.job_queue.get()
159
  worker = None
160
  while worker is None:
161
+ worker = self._get_available_worker()
162
+ if worker is None: time.sleep(0.1)
163
  try:
164
  result = worker.execute(job_func, args, kwargs)
165
  future.put(result)
 
170
  while True:
171
  time.sleep(30)
172
  with self.pool_lock:
173
+ for i, worker in enumerate(self.workers):
 
 
 
 
 
 
 
174
  if not worker.is_alive() or not worker.is_healthy:
175
+ logging.warning(f"LTX Worker {worker.worker_id} is UNHEALTHY. Restarting...")
176
+ new_worker = LTXWorker(worker_id=worker.worker_id)
177
+ self.workers[i] = new_worker
178
  new_worker.start()
179
 
180
+ def submit_job(self, job_func: Callable, *args, **kwargs) -> Any:
 
 
 
 
181
  future = queue.Queue(1)
182
+ self.job_queue.put((job_func, args, kwargs, future))
183
  result = future.get()
184
+ if isinstance(result, Exception): raise result
 
 
 
185
  return result
186
 
187
  # --- INSTANCIAÇÃO GLOBAL ---
188
+ ltx_aduc_manager = LTXAducManager()