Spaces:
Paused
Paused
| # FILE: api/ltx/ltx_aduc_manager.py | |
| # DESCRIPTION: A simplified, robust pool manager for a unified LTX worker. | |
| # This worker handles all tasks, including Transformer generation and VAE operations, | |
| # while still respecting the GPU separation defined by the GPUManager. | |
| import logging | |
| import torch | |
| import sys | |
| from pathlib import Path | |
| import threading | |
| import queue | |
| import time | |
| import yaml | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from typing import List, Optional, Callable, Any, Tuple, Dict | |
| # --- Importa o gerenciador de GPUs e o builder de baixo nível --- | |
| from managers.gpu_manager import gpu_manager | |
| from api.ltx.ltx_utils import build_complete_pipeline_on_cpu, create_transformer | |
| # --- Adiciona o path do LTX-Video para importação de tipos --- | |
| LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") | |
| def add_deps_to_path(): | |
| repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if repo_path not in sys.path: | |
| sys.path.insert(0, repo_path) | |
| add_deps_to_path() | |
| from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline | |
| # ============================================================================== | |
| # --- FUNÇÃO DE ORQUESTRAÇÃO DA CONSTRUÇÃO (Interna ao Manager) --- | |
| # ============================================================================== | |
| def get_complete_pipeline() -> LTXVideoPipeline: | |
| """ | |
| Orquestra a construção do pipeline LTX COMPLETO, incluindo o VAE, na CPU. | |
| """ | |
| config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml" | |
| with open(config_path, "r") as file: | |
| config = yaml.safe_load(file) | |
| ckpt_path = hf_hub_download( | |
| repo_id="Lightricks/LTX-Video", | |
| filename=config["checkpoint_path"], | |
| cache_dir=os.environ.get("HF_HOME") | |
| ) | |
| return build_complete_pipeline_on_cpu(ckpt_path, config) | |
| # ============================================================================== | |
| # --- CLASSE DE WORKER UNIFICADO --- | |
| # ============================================================================== | |
| class LTXWorker(threading.Thread): | |
| """ | |
| Um worker unificado que gerencia uma instância completa do pipeline LTX. | |
| Ele carrega o modelo e distribui seus componentes (Transformer/VAE) para as GPUs corretas. | |
| """ | |
| def __init__(self, worker_id: int): | |
| super().__init__() | |
| self.worker_id = worker_id | |
| self.pipeline: Optional[LTXVideoPipeline] = None | |
| self.is_healthy = False | |
| self.is_busy = False | |
| self.daemon = True | |
| self.autocast_dtype: torch.dtype = torch.float32 | |
| def run(self): | |
| """Inicializa o worker: carrega o pipeline e o move para as GPUs.""" | |
| try: | |
| self.pipeline = get_complete_pipeline() | |
| self._set_precision_policy() | |
| main_device = gpu_manager.get_ltx_device() | |
| vae_device = gpu_manager.get_ltx_vae_device() | |
| logging.info(f"[LTXWorker-{self.worker_id}] Moving components -> Main: {main_device}, VAE: {vae_device}") | |
| self.pipeline.to(main_device) # Move tudo para a GPU principal primeiro | |
| self.pipeline.vae.to(vae_device) # Move especificamente o VAE para sua GPU dedicada | |
| self.is_healthy = True | |
| logging.info(f"✅ LTXWorker {self.worker_id} is healthy. Main on {main_device}, VAE on {vae_device}.") | |
| except Exception: | |
| self.is_healthy = False | |
| logging.error(f"❌ LTXWorker {self.worker_id} FAILED to initialize!", exc_info=True) | |
| def _set_precision_policy(self): | |
| """Define a política de precisão para operações de autocast.""" | |
| try: | |
| config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml" | |
| with open(config_path, "r") as file: config = yaml.safe_load(file) | |
| precision = str(config.get("precision", "bfloat16")).lower() | |
| if precision in ["float8_e4m3fn", "bfloat16"]: self.autocast_dtype = torch.bfloat16 | |
| elif precision == "mixed_precision": self.autocast_dtype = torch.float16 | |
| except Exception: | |
| logging.warning(f"[LTXWorker-{self.worker_id}] Could not set precision policy, defaulting to float32.", exc_info=True) | |
| def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any: | |
| self.is_busy = True | |
| try: | |
| # O job recebe o pipeline completo e o dtype para o autocast | |
| result = job_func(self.pipeline, self.autocast_dtype, *args, **kwargs) | |
| return result | |
| except Exception: | |
| self.is_healthy = False | |
| raise | |
| finally: | |
| self.is_busy = False | |
| # ============================================================================== | |
| # --- O GERENCIADOR DE POOL (SINGLETON) --- | |
| # ============================================================================== | |
| class LTXAducManager: | |
| _instance = None | |
| _initialized = False | |
| def __new__(cls, *args, **kwargs): | |
| if cls._instance is None: cls._instance = super().__new__(cls) | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: return | |
| logging.info("🏭 Initializing Simplified Pool Manager for LTX...") | |
| self.workers: List[LTXWorker] = [] | |
| self.job_queue = queue.Queue() | |
| self.pool_lock = threading.Lock() | |
| self._initialize_workers() | |
| self.dispatcher = threading.Thread(target=self._dispatch_jobs, daemon=True) | |
| self.health_monitor = threading.Thread(target=self._health_check_loop, daemon=True) | |
| self.dispatcher.start() | |
| self.health_monitor.start() | |
| self._initialized = True | |
| logging.info("✅ Simplified Pool Manager is running.") | |
| def _initialize_workers(self): | |
| with self.pool_lock: | |
| # Por enquanto, criamos um único worker unificado. | |
| # No futuro, este loop pode criar múltiplos workers se houver mais GPUs. | |
| worker = LTXWorker(worker_id=0) | |
| self.workers.append(worker) | |
| worker.start() | |
| def _get_available_worker(self) -> Optional[LTXWorker]: | |
| with self.pool_lock: | |
| for worker in self.workers: | |
| if worker.is_healthy and not worker.is_busy: | |
| return worker | |
| return None | |
| def _dispatch_jobs(self): | |
| while True: | |
| job_func, args, kwargs, future = self.job_queue.get() | |
| worker = None | |
| while worker is None: | |
| worker = self._get_available_worker() | |
| if worker is None: time.sleep(0.1) | |
| try: | |
| result = worker.execute(job_func, args, kwargs) | |
| future.put(result) | |
| except Exception as e: | |
| future.put(e) | |
| def _health_check_loop(self): | |
| while True: | |
| time.sleep(30) | |
| with self.pool_lock: | |
| for i, worker in enumerate(self.workers): | |
| if not worker.is_alive() or not worker.is_healthy: | |
| logging.warning(f"LTX Worker {worker.worker_id} is UNHEALTHY. Restarting...") | |
| new_worker = LTXWorker(worker_id=worker.worker_id) | |
| self.workers[i] = new_worker | |
| new_worker.start() | |
| def submit_job(self, job_func: Callable, *args, **kwargs) -> Any: | |
| future = queue.Queue(1) | |
| self.job_queue.put((job_func, args, kwargs, future)) | |
| result = future.get() | |
| if isinstance(result, Exception): raise result | |
| return result | |
| # --- INSTANCIAÇÃO GLOBAL --- | |
| ltx_aduc_manager = LTXAducManager() |