# FILE: api/ltx/ltx_aduc_manager.py # DESCRIPTION: A simplified, robust pool manager for a unified LTX worker. # This worker handles all tasks, including Transformer generation and VAE operations, # while still respecting the GPU separation defined by the GPUManager. import logging import torch import sys from pathlib import Path import threading import queue import time import yaml import os from huggingface_hub import hf_hub_download from typing import List, Optional, Callable, Any, Tuple, Dict # --- Importa o gerenciador de GPUs e o builder de baixo nível --- from managers.gpu_manager import gpu_manager from api.ltx.ltx_utils import build_complete_pipeline_on_cpu, create_transformer # --- Adiciona o path do LTX-Video para importação de tipos --- LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") def add_deps_to_path(): repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) if repo_path not in sys.path: sys.path.insert(0, repo_path) add_deps_to_path() from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline # ============================================================================== # --- FUNÇÃO DE ORQUESTRAÇÃO DA CONSTRUÇÃO (Interna ao Manager) --- # ============================================================================== def get_complete_pipeline() -> LTXVideoPipeline: """ Orquestra a construção do pipeline LTX COMPLETO, incluindo o VAE, na CPU. """ config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml" with open(config_path, "r") as file: config = yaml.safe_load(file) ckpt_path = hf_hub_download( repo_id="Lightricks/LTX-Video", filename=config["checkpoint_path"], cache_dir=os.environ.get("HF_HOME") ) return build_complete_pipeline_on_cpu(ckpt_path, config) # ============================================================================== # --- CLASSE DE WORKER UNIFICADO --- # ============================================================================== class LTXWorker(threading.Thread): """ Um worker unificado que gerencia uma instância completa do pipeline LTX. Ele carrega o modelo e distribui seus componentes (Transformer/VAE) para as GPUs corretas. """ def __init__(self, worker_id: int): super().__init__() self.worker_id = worker_id self.pipeline: Optional[LTXVideoPipeline] = None self.is_healthy = False self.is_busy = False self.daemon = True self.autocast_dtype: torch.dtype = torch.float32 def run(self): """Inicializa o worker: carrega o pipeline e o move para as GPUs.""" try: self.pipeline = get_complete_pipeline() self._set_precision_policy() main_device = gpu_manager.get_ltx_device() vae_device = gpu_manager.get_ltx_vae_device() logging.info(f"[LTXWorker-{self.worker_id}] Moving components -> Main: {main_device}, VAE: {vae_device}") self.pipeline.to(main_device) # Move tudo para a GPU principal primeiro self.pipeline.vae.to(vae_device) # Move especificamente o VAE para sua GPU dedicada self.is_healthy = True logging.info(f"✅ LTXWorker {self.worker_id} is healthy. Main on {main_device}, VAE on {vae_device}.") except Exception: self.is_healthy = False logging.error(f"❌ LTXWorker {self.worker_id} FAILED to initialize!", exc_info=True) def _set_precision_policy(self): """Define a política de precisão para operações de autocast.""" try: config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml" with open(config_path, "r") as file: config = yaml.safe_load(file) precision = str(config.get("precision", "bfloat16")).lower() if precision in ["float8_e4m3fn", "bfloat16"]: self.autocast_dtype = torch.bfloat16 elif precision == "mixed_precision": self.autocast_dtype = torch.float16 except Exception: logging.warning(f"[LTXWorker-{self.worker_id}] Could not set precision policy, defaulting to float32.", exc_info=True) def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any: self.is_busy = True try: # O job recebe o pipeline completo e o dtype para o autocast result = job_func(self.pipeline, self.autocast_dtype, *args, **kwargs) return result except Exception: self.is_healthy = False raise finally: self.is_busy = False # ============================================================================== # --- O GERENCIADOR DE POOL (SINGLETON) --- # ============================================================================== class LTXAducManager: _instance = None _initialized = False def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): if self._initialized: return logging.info("🏭 Initializing Simplified Pool Manager for LTX...") self.workers: List[LTXWorker] = [] self.job_queue = queue.Queue() self.pool_lock = threading.Lock() self._initialize_workers() self.dispatcher = threading.Thread(target=self._dispatch_jobs, daemon=True) self.health_monitor = threading.Thread(target=self._health_check_loop, daemon=True) self.dispatcher.start() self.health_monitor.start() self._initialized = True logging.info("✅ Simplified Pool Manager is running.") def _initialize_workers(self): with self.pool_lock: # Por enquanto, criamos um único worker unificado. # No futuro, este loop pode criar múltiplos workers se houver mais GPUs. worker = LTXWorker(worker_id=0) self.workers.append(worker) worker.start() def _get_available_worker(self) -> Optional[LTXWorker]: with self.pool_lock: for worker in self.workers: if worker.is_healthy and not worker.is_busy: return worker return None def _dispatch_jobs(self): while True: job_func, args, kwargs, future = self.job_queue.get() worker = None while worker is None: worker = self._get_available_worker() if worker is None: time.sleep(0.1) try: result = worker.execute(job_func, args, kwargs) future.put(result) except Exception as e: future.put(e) def _health_check_loop(self): while True: time.sleep(30) with self.pool_lock: for i, worker in enumerate(self.workers): if not worker.is_alive() or not worker.is_healthy: logging.warning(f"LTX Worker {worker.worker_id} is UNHEALTHY. Restarting...") new_worker = LTXWorker(worker_id=worker.worker_id) self.workers[i] = new_worker new_worker.start() def submit_job(self, job_func: Callable, *args, **kwargs) -> Any: future = queue.Queue(1) self.job_queue.put((job_func, args, kwargs, future)) result = future.get() if isinstance(result, Exception): raise result return result # --- INSTANCIAÇÃO GLOBAL --- ltx_aduc_manager = LTXAducManager()