Test / api /ltx /ltx_aduc_manager.py
eeuuia's picture
Update api/ltx/ltx_aduc_manager.py
52c58b6 verified
raw
history blame
7.65 kB
# FILE: api/ltx/ltx_aduc_manager.py
# DESCRIPTION: A simplified, robust pool manager for a unified LTX worker.
# This worker handles all tasks, including Transformer generation and VAE operations,
# while still respecting the GPU separation defined by the GPUManager.
import logging
import torch
import sys
from pathlib import Path
import threading
import queue
import time
import yaml
import os
from huggingface_hub import hf_hub_download
from typing import List, Optional, Callable, Any, Tuple, Dict
# --- Importa o gerenciador de GPUs e o builder de baixo nível ---
from managers.gpu_manager import gpu_manager
from api.ltx.ltx_utils import build_complete_pipeline_on_cpu, create_transformer
# --- Adiciona o path do LTX-Video para importação de tipos ---
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
def add_deps_to_path():
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
if repo_path not in sys.path:
sys.path.insert(0, repo_path)
add_deps_to_path()
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
# ==============================================================================
# --- FUNÇÃO DE ORQUESTRAÇÃO DA CONSTRUÇÃO (Interna ao Manager) ---
# ==============================================================================
def get_complete_pipeline() -> LTXVideoPipeline:
"""
Orquestra a construção do pipeline LTX COMPLETO, incluindo o VAE, na CPU.
"""
config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
with open(config_path, "r") as file:
config = yaml.safe_load(file)
ckpt_path = hf_hub_download(
repo_id="Lightricks/LTX-Video",
filename=config["checkpoint_path"],
cache_dir=os.environ.get("HF_HOME")
)
return build_complete_pipeline_on_cpu(ckpt_path, config)
# ==============================================================================
# --- CLASSE DE WORKER UNIFICADO ---
# ==============================================================================
class LTXWorker(threading.Thread):
"""
Um worker unificado que gerencia uma instância completa do pipeline LTX.
Ele carrega o modelo e distribui seus componentes (Transformer/VAE) para as GPUs corretas.
"""
def __init__(self, worker_id: int):
super().__init__()
self.worker_id = worker_id
self.pipeline: Optional[LTXVideoPipeline] = None
self.is_healthy = False
self.is_busy = False
self.daemon = True
self.autocast_dtype: torch.dtype = torch.float32
def run(self):
"""Inicializa o worker: carrega o pipeline e o move para as GPUs."""
try:
self.pipeline = get_complete_pipeline()
self._set_precision_policy()
main_device = gpu_manager.get_ltx_device()
vae_device = gpu_manager.get_ltx_vae_device()
logging.info(f"[LTXWorker-{self.worker_id}] Moving components -> Main: {main_device}, VAE: {vae_device}")
self.pipeline.to(main_device) # Move tudo para a GPU principal primeiro
self.pipeline.vae.to(vae_device) # Move especificamente o VAE para sua GPU dedicada
self.is_healthy = True
logging.info(f"✅ LTXWorker {self.worker_id} is healthy. Main on {main_device}, VAE on {vae_device}.")
except Exception:
self.is_healthy = False
logging.error(f"❌ LTXWorker {self.worker_id} FAILED to initialize!", exc_info=True)
def _set_precision_policy(self):
"""Define a política de precisão para operações de autocast."""
try:
config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
with open(config_path, "r") as file: config = yaml.safe_load(file)
precision = str(config.get("precision", "bfloat16")).lower()
if precision in ["float8_e4m3fn", "bfloat16"]: self.autocast_dtype = torch.bfloat16
elif precision == "mixed_precision": self.autocast_dtype = torch.float16
except Exception:
logging.warning(f"[LTXWorker-{self.worker_id}] Could not set precision policy, defaulting to float32.", exc_info=True)
def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any:
self.is_busy = True
try:
# O job recebe o pipeline completo e o dtype para o autocast
result = job_func(self.pipeline, self.autocast_dtype, *args, **kwargs)
return result
except Exception:
self.is_healthy = False
raise
finally:
self.is_busy = False
# ==============================================================================
# --- O GERENCIADOR DE POOL (SINGLETON) ---
# ==============================================================================
class LTXAducManager:
_instance = None
_initialized = False
def __new__(cls, *args, **kwargs):
if cls._instance is None: cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if self._initialized: return
logging.info("🏭 Initializing Simplified Pool Manager for LTX...")
self.workers: List[LTXWorker] = []
self.job_queue = queue.Queue()
self.pool_lock = threading.Lock()
self._initialize_workers()
self.dispatcher = threading.Thread(target=self._dispatch_jobs, daemon=True)
self.health_monitor = threading.Thread(target=self._health_check_loop, daemon=True)
self.dispatcher.start()
self.health_monitor.start()
self._initialized = True
logging.info("✅ Simplified Pool Manager is running.")
def _initialize_workers(self):
with self.pool_lock:
# Por enquanto, criamos um único worker unificado.
# No futuro, este loop pode criar múltiplos workers se houver mais GPUs.
worker = LTXWorker(worker_id=0)
self.workers.append(worker)
worker.start()
def _get_available_worker(self) -> Optional[LTXWorker]:
with self.pool_lock:
for worker in self.workers:
if worker.is_healthy and not worker.is_busy:
return worker
return None
def _dispatch_jobs(self):
while True:
job_func, args, kwargs, future = self.job_queue.get()
worker = None
while worker is None:
worker = self._get_available_worker()
if worker is None: time.sleep(0.1)
try:
result = worker.execute(job_func, args, kwargs)
future.put(result)
except Exception as e:
future.put(e)
def _health_check_loop(self):
while True:
time.sleep(30)
with self.pool_lock:
for i, worker in enumerate(self.workers):
if not worker.is_alive() or not worker.is_healthy:
logging.warning(f"LTX Worker {worker.worker_id} is UNHEALTHY. Restarting...")
new_worker = LTXWorker(worker_id=worker.worker_id)
self.workers[i] = new_worker
new_worker.start()
def submit_job(self, job_func: Callable, *args, **kwargs) -> Any:
future = queue.Queue(1)
self.job_queue.put((job_func, args, kwargs, future))
result = future.get()
if isinstance(result, Exception): raise result
return result
# --- INSTANCIAÇÃO GLOBAL ---
ltx_aduc_manager = LTXAducManager()