File size: 7,652 Bytes
5105909
52c58b6
 
 
42998d3
 
 
6e680c2
 
460fa35
 
 
ecd2b0d
52c58b6
eaa00b5
 
52c58b6
2869224
5105909
52c58b6
8504d9f
460fa35
 
8504d9f
 
 
 
 
42998d3
460fa35
42998d3
 
52c58b6
42998d3
 
52c58b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460fa35
 
52c58b6
460fa35
 
eaa00b5
52c58b6
460fa35
 
52c58b6
 
 
 
 
 
 
 
 
 
 
 
460fa35
52c58b6
 
460fa35
52c58b6
460fa35
ecd2b0d
52c58b6
 
2869224
52c58b6
ecd2b0d
eaa00b5
 
52c58b6
 
ecd2b0d
460fa35
 
52c58b6
 
ecd2b0d
460fa35
52c58b6
 
 
 
 
42998d3
 
52c58b6
42998d3
460fa35
 
 
42998d3
460fa35
52c58b6
460fa35
42998d3
460fa35
 
42998d3
52c58b6
460fa35
52c58b6
 
460fa35
42998d3
460fa35
 
52c58b6
460fa35
52c58b6
460fa35
 
 
52c58b6
2869224
460fa35
 
52c58b6
 
 
 
 
460fa35
52c58b6
460fa35
52c58b6
 
 
460fa35
 
52c58b6
460fa35
52c58b6
460fa35
 
52c58b6
 
460fa35
 
 
 
 
 
 
 
 
 
52c58b6
460fa35
52c58b6
 
 
460fa35
 
52c58b6
ecd2b0d
52c58b6
460fa35
52c58b6
460fa35
c8f8a7f
460fa35
52c58b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# FILE: api/ltx/ltx_aduc_manager.py
# DESCRIPTION: A simplified, robust pool manager for a unified LTX worker.
# This worker handles all tasks, including Transformer generation and VAE operations,
# while still respecting the GPU separation defined by the GPUManager.

import logging
import torch
import sys
from pathlib import Path
import threading
import queue
import time
import yaml
import os
from huggingface_hub import hf_hub_download
from typing import List, Optional, Callable, Any, Tuple, Dict

# --- Importa o gerenciador de GPUs e o builder de baixo nível ---
from managers.gpu_manager import gpu_manager
from api.ltx.ltx_utils import build_complete_pipeline_on_cpu, create_transformer

# --- Adiciona o path do LTX-Video para importação de tipos ---
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
def add_deps_to_path():
    repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
    if repo_path not in sys.path:
        sys.path.insert(0, repo_path)
add_deps_to_path()

from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline

# ==============================================================================
# --- FUNÇÃO DE ORQUESTRAÇÃO DA CONSTRUÇÃO (Interna ao Manager) ---
# ==============================================================================

def get_complete_pipeline() -> LTXVideoPipeline:
    """
    Orquestra a construção do pipeline LTX COMPLETO, incluindo o VAE, na CPU.
    """
    config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    
    ckpt_path = hf_hub_download(
        repo_id="Lightricks/LTX-Video", 
        filename=config["checkpoint_path"], 
        cache_dir=os.environ.get("HF_HOME")
    )
    return build_complete_pipeline_on_cpu(ckpt_path, config)

# ==============================================================================
# --- CLASSE DE WORKER UNIFICADO ---
# ==============================================================================

class LTXWorker(threading.Thread):
    """
    Um worker unificado que gerencia uma instância completa do pipeline LTX.
    Ele carrega o modelo e distribui seus componentes (Transformer/VAE) para as GPUs corretas.
    """
    def __init__(self, worker_id: int):
        super().__init__()
        self.worker_id = worker_id
        self.pipeline: Optional[LTXVideoPipeline] = None
        self.is_healthy = False
        self.is_busy = False
        self.daemon = True
        self.autocast_dtype: torch.dtype = torch.float32

    def run(self):
        """Inicializa o worker: carrega o pipeline e o move para as GPUs."""
        try:
            self.pipeline = get_complete_pipeline()
            self._set_precision_policy()
            
            main_device = gpu_manager.get_ltx_device()
            vae_device = gpu_manager.get_ltx_vae_device()
            
            logging.info(f"[LTXWorker-{self.worker_id}] Moving components -> Main: {main_device}, VAE: {vae_device}")
            self.pipeline.to(main_device)      # Move tudo para a GPU principal primeiro
            self.pipeline.vae.to(vae_device)   # Move especificamente o VAE para sua GPU dedicada
            
            self.is_healthy = True
            logging.info(f"✅ LTXWorker {self.worker_id} is healthy. Main on {main_device}, VAE on {vae_device}.")
        except Exception:
            self.is_healthy = False
            logging.error(f"❌ LTXWorker {self.worker_id} FAILED to initialize!", exc_info=True)

    def _set_precision_policy(self):
        """Define a política de precisão para operações de autocast."""
        try:
            config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
            with open(config_path, "r") as file: config = yaml.safe_load(file)
            precision = str(config.get("precision", "bfloat16")).lower()
            if precision in ["float8_e4m3fn", "bfloat16"]: self.autocast_dtype = torch.bfloat16
            elif precision == "mixed_precision": self.autocast_dtype = torch.float16
        except Exception:
            logging.warning(f"[LTXWorker-{self.worker_id}] Could not set precision policy, defaulting to float32.", exc_info=True)

    def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any:
        self.is_busy = True
        try:
            # O job recebe o pipeline completo e o dtype para o autocast
            result = job_func(self.pipeline, self.autocast_dtype, *args, **kwargs)
            return result
        except Exception:
            self.is_healthy = False
            raise
        finally:
            self.is_busy = False

# ==============================================================================
# --- O GERENCIADOR DE POOL (SINGLETON) ---
# ==============================================================================
class LTXAducManager:
    _instance = None
    _initialized = False

    def __new__(cls, *args, **kwargs):
        if cls._instance is None: cls._instance = super().__new__(cls)
        return cls._instance

    def __init__(self):
        if self._initialized: return
        
        logging.info("🏭 Initializing Simplified Pool Manager for LTX...")
        
        self.workers: List[LTXWorker] = []
        self.job_queue = queue.Queue()
        self.pool_lock = threading.Lock()

        self._initialize_workers()
        
        self.dispatcher = threading.Thread(target=self._dispatch_jobs, daemon=True)
        self.health_monitor = threading.Thread(target=self._health_check_loop, daemon=True)
        self.dispatcher.start()
        self.health_monitor.start()
        
        self._initialized = True
        logging.info("✅ Simplified Pool Manager is running.")

    def _initialize_workers(self):
        with self.pool_lock:
            # Por enquanto, criamos um único worker unificado.
            # No futuro, este loop pode criar múltiplos workers se houver mais GPUs.
            worker = LTXWorker(worker_id=0)
            self.workers.append(worker)
            worker.start()

    def _get_available_worker(self) -> Optional[LTXWorker]:
        with self.pool_lock:
            for worker in self.workers:
                if worker.is_healthy and not worker.is_busy:
                    return worker
        return None

    def _dispatch_jobs(self):
        while True:
            job_func, args, kwargs, future = self.job_queue.get()
            worker = None
            while worker is None:
                worker = self._get_available_worker()
                if worker is None: time.sleep(0.1)
            try:
                result = worker.execute(job_func, args, kwargs)
                future.put(result)
            except Exception as e:
                future.put(e)

    def _health_check_loop(self):
        while True:
            time.sleep(30)
            with self.pool_lock:
                for i, worker in enumerate(self.workers):
                    if not worker.is_alive() or not worker.is_healthy:
                        logging.warning(f"LTX Worker {worker.worker_id} is UNHEALTHY. Restarting...")
                        new_worker = LTXWorker(worker_id=worker.worker_id)
                        self.workers[i] = new_worker
                        new_worker.start()

    def submit_job(self, job_func: Callable, *args, **kwargs) -> Any:
        future = queue.Queue(1)
        self.job_queue.put((job_func, args, kwargs, future))
        result = future.get()
        if isinstance(result, Exception): raise result
        return result

# --- INSTANCIAÇÃO GLOBAL ---
ltx_aduc_manager = LTXAducManager()