File size: 15,537 Bytes
ac82132
12e0b76
ac82132
12e0b76
6493ca5
b5e7c3e
 
 
 
 
12e0b76
 
 
 
 
 
 
 
 
 
 
 
 
 
b5e7c3e
ac82132
 
 
12e0b76
 
 
b5e7c3e
12e0b76
b5e7c3e
 
 
 
 
 
 
 
 
 
 
 
12e0b76
 
b5e7c3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12e0b76
b5e7c3e
 
 
 
 
 
 
 
 
 
 
 
 
12e0b76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6493ca5
12e0b76
6493ca5
 
b5e7c3e
 
6493ca5
 
 
 
12e0b76
 
 
ac82132
12e0b76
ac82132
 
 
 
 
 
12e0b76
 
b5e7c3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12e0b76
b5e7c3e
12e0b76
 
 
 
 
 
ac82132
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
# aduc_framework/managers/ltx_manager.py
#
# Copyright (C) August 4, 2025  Carlos Rodrigues dos Santos
#
# Versão 2.3.2 (Com correção de manipulação de dataclass)
#
# Este manager é responsável por controlar a pipeline LTX-Video. Ele gerencia
# um pool de workers para otimizar o uso de múltiplas GPUs, lida com a inicialização
# e o setup de dependências complexas, e expõe uma interface de alto nível para a
# geração de fragmentos de vídeo no espaço latente.

import torch
import gc
import os
import sys
import yaml
import logging
import huggingface_hub
import time
import threading
import subprocess
from pathlib import Path
from typing import Optional, List, Tuple, Union

# --- Imports Relativos Corrigidos ---
from ..types import LatentConditioningItem
from ..tools.optimization import optimize_ltx_worker, can_optimize_fp8
from ..tools.hardware_manager import hardware_manager

logger = logging.getLogger(__name__)

# --- Gerenciamento de Dependências e Placeholders ---
DEPS_DIR = Path("./deps")
LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
LTX_VIDEO_REPO_URL = "https://github.com/Lightricks/LTX-Video.git"

# Placeholders para módulos importados tardiamente (lazy-loaded)
create_ltx_video_pipeline = None
calculate_padding = None
LTXVideoPipeline = None
ConditioningItem = None
LTXMultiScalePipeline = None
vae_encode = None
latent_to_pixel_coords = None
randn_tensor = None

class LtxPoolManager:
    """
    Gerencia um pool de LtxWorkers e expõe a pipeline de aprimoramento de prompt.
    """
    def __init__(self, device_ids: List[str], ltx_config_file_name: str):
        logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
        self._ltx_modules_loaded = False
        self._setup_dependencies()
        self._lazy_load_ltx_modules()

        self.ltx_config_file = LTX_VIDEO_REPO_DIR / "configs" / ltx_config_file_name

        self.workers = [LtxWorker(dev_id, self.ltx_config_file) for dev_id in device_ids]
        self.current_worker_index = 0
        self.lock = threading.Lock()

        self.prompt_enhancement_pipeline = self.workers[0].pipeline if self.workers else None
        if self.prompt_enhancement_pipeline:
            logger.info("LTX POOL MANAGER: Pipeline de aprimoramento de prompt exposta para outros especialistas.")

        self._apply_ltx_pipeline_patches()

        if all(w.device.type == 'cuda' for w in self.workers):
            logger.info("LTX POOL MANAGER: MODO HOT START ATIVADO. Pré-aquecendo todas as GPUs...")
            for worker in self.workers:
                worker.to_gpu()
            logger.info("LTX POOL MANAGER: Todas as GPUs estão prontas.")
        else:
            logger.info("LTX POOL MANAGER: Operando em modo CPU ou misto. Pré-aquecimento de GPU pulado.")

    def _setup_dependencies(self):
        """Clona o repositório LTX-Video se não encontrado e o adiciona ao sys.path."""
        if not LTX_VIDEO_REPO_DIR.exists():
            logger.info(f"Repositório LTX-Video não encontrado em '{LTX_VIDEO_REPO_DIR}'. Clonando do GitHub...")
            try:
                DEPS_DIR.mkdir(exist_ok=True)
                subprocess.run(
                    ["git", "clone", "--depth", "1", LTX_VIDEO_REPO_URL, str(LTX_VIDEO_REPO_DIR)],
                    check=True, capture_output=True, text=True
                )
                logger.info("Repositório LTX-Video clonado com sucesso.")
            except subprocess.CalledProcessError as e:
                logger.error(f"Falha ao clonar o repositório LTX-Video. Git stderr: {e.stderr}")
                raise RuntimeError("Não foi possível clonar a dependência LTX-Video do GitHub.")
        else:
            logger.info("Repositório LTX-Video local encontrado.")

        if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
            sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve()))
            logger.info(f"Adicionado '{LTX_VIDEO_REPO_DIR.resolve()}' ao sys.path.")

    def _lazy_load_ltx_modules(self):
        """Importa dinamicamente os módulos do LTX-Video após garantir que o repositório existe."""
        if self._ltx_modules_loaded:
            return

        global create_ltx_video_pipeline, calculate_padding, LTXVideoPipeline, ConditioningItem, LTXMultiScalePipeline
        global vae_encode, latent_to_pixel_coords, randn_tensor
        
        from .ltx_pipeline_utils import create_ltx_video_pipeline, calculate_padding
        from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LTXMultiScalePipeline
        from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
        from diffusers.utils.torch_utils import randn_tensor
        
        self._ltx_modules_loaded = True
        logger.info("Módulos do LTX-Video foram carregados dinamicamente.")

    def _apply_ltx_pipeline_patches(self):
        """Aplica patches em tempo de execução na pipeline LTX para compatibilidade com ADUC-SDR."""
        logger.info("LTX POOL MANAGER: Aplicando patches ADUC-SDR na pipeline LTX...")
        for worker in self.workers:
            worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
        logger.info("LTX POOL MANAGER: Todas as instâncias da pipeline foram corrigidas com sucesso.")

    def _get_next_worker(self) -> 'LtxWorker':
        with self.lock:
            worker = self.workers[self.current_worker_index]
            self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
            return worker
    
    def _prepare_pipeline_params(self, worker: 'LtxWorker', **kwargs) -> dict:
        pipeline_params = {
            "height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
            "frame_rate": kwargs.get('video_fps', 24),
            "generator": torch.Generator(device=worker.device).manual_seed(int(time.time()) + kwargs.get('current_fragment_index', 0)),
            "is_video": True, "vae_per_channel_normalize": True,
            "prompt": kwargs.get('motion_prompt', ""), "negative_prompt": kwargs.get('negative_prompt', "blurry, distorted, static, bad quality"),
            "guidance_scale": kwargs.get('guidance_scale', 1.0), "stg_scale": kwargs.get('stg_scale', 0.0),
            "rescaling_scale": kwargs.get('rescaling_scale', 0.15), "num_inference_steps": kwargs.get('num_inference_steps', 20),
            "output_type": "latent"
        }
        if 'latents' in kwargs:
            pipeline_params["latents"] = kwargs['latents'].to(worker.device, dtype=worker.pipeline.transformer.dtype)
        if 'strength' in kwargs:
            pipeline_params["strength"] = kwargs['strength']
        
        if 'conditioning_items_data' in kwargs:
            final_conditioning_items = []
            for item in kwargs['conditioning_items_data']:
                # CORREÇÃO: Como LatentConditioningItem é uma dataclass mutável,
                # nós modificamos o atributo diretamente no dispositivo do worker.
                item.latent_tensor = item.latent_tensor.to(worker.device)
                final_conditioning_items.append(item)
            pipeline_params["conditioning_items"] = final_conditioning_items

        if worker.is_distilled:
            fixed_timesteps = worker.config.get("first_pass", {}).get("timesteps")
            if fixed_timesteps:
                pipeline_params["timesteps"] = fixed_timesteps
                pipeline_params["num_inference_steps"] = len(fixed_timesteps)
        
        callback = kwargs.get('callback')
        if callback:
            pipeline_params["callback_on_step_end"] = callback
            pipeline_params["callback_on_step_end_tensor_inputs"] = ["latents"]
        
        return pipeline_params

    def generate_latent_fragment(self, **kwargs) -> Tuple[torch.Tensor, tuple]:
        worker_to_use = self._get_next_worker()
        try:
            height, width = kwargs['height'], kwargs['width']
            padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
            padding_vals = calculate_padding(height, width, padded_h, padded_w)
            kwargs['height'], kwargs['width'] = padded_h, padded_w
            
            pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
            
            logger.info(f"Iniciando GERAÇÃO em {worker_to_use.device} com shape {padded_w}x{padded_h}")
            
            if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline):
                result = worker_to_use.pipeline.video_pipeline(**pipeline_params).images
            else:
                result = worker_to_use.generate_video_fragment_internal(**pipeline_params)
            return result, padding_vals
        except Exception as e:
            logger.error(f"LTX POOL MANAGER: Erro durante a geração em {worker_to_use.device}: {e}", exc_info=True)
            raise e
        finally:
            if worker_to_use and worker_to_use.device.type == 'cuda':
                with torch.cuda.device(worker_to_use.device):
                    gc.collect()
                    torch.cuda.empty_cache()

    def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, tuple]:
        pass # Placeholder

class LtxWorker:
    """Representa uma única instância da pipeline LTX-Video em um dispositivo específico."""
    def __init__(self, device_id, ltx_config_file):
        self.cpu_device = torch.device('cpu')
        self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
        logger.info(f"LTX Worker ({self.device}): Inicializando com config '{ltx_config_file}'...")
        
        with open(ltx_config_file, "r") as file:
            self.config = yaml.safe_load(file)
        
        self.is_distilled = "distilled" in self.config.get("checkpoint_path", "")
        models_dir = LTX_VIDEO_REPO_DIR / "models_downloaded"
        
        logger.info(f"LTX Worker ({self.device}): Preparando para carregar modelo...")
        model_filename = self.config["checkpoint_path"]
        model_path = huggingface_hub.hf_hub_download(
            repo_id="Lightricks/LTX-Video", filename=model_filename,
            local_dir=str(models_dir), local_dir_use_symlinks=False
        )
        
        self.pipeline = create_ltx_video_pipeline(
            ckpt_path=model_path, 
            precision=self.config["precision"],
            text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
            sampler=self.config["sampler"], 
            device='cpu'
        )
        logger.info(f"LTX Worker ({self.device}): Modelo pronto na CPU. É um modelo distilled? {self.is_distilled}")
    
    def to_gpu(self):
        if self.device.type == 'cpu': return
        logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
        self.pipeline.to(self.device)
        if self.device.type == 'cuda' and can_optimize_fp8():
            logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Otimizando...")
            optimize_ltx_worker(self)
            logger.info(f"LTX Worker ({self.device}): Otimização completa.")
    
    def to_cpu(self):
        if self.device.type == 'cpu': return
        logger.info(f"LTX Worker: Descarregando pipeline da GPU {self.device}...")
        self.pipeline.to('cpu')
        gc.collect()
        if torch.cuda.is_available(): torch.cuda.empty_cache()

    def generate_video_fragment_internal(self, **kwargs):
        return self.pipeline(**kwargs).images

def _aduc_prepare_conditioning_patch(
    self: "LTXVideoPipeline",
    conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]],
    init_latents: torch.Tensor,
    num_frames: int,
    height: int,
    width: int,
    vae_per_channel_normalize: bool = False,
    generator=None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
    if not conditioning_items:
        init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
        init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
        return init_latents, init_pixel_coords, None, 0
    
    init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
    extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
    extra_conditioning_num_latents = 0
    
    for item in conditioning_items:
        if not isinstance(item, LatentConditioningItem):
            logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.")
            continue
        
        media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
        media_frame_number, strength = item.media_frame_number, item.conditioning_strength
        
        if media_frame_number == 0:
            f_l, h_l, w_l = media_item_latents.shape[-3:]
            init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
            init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
        else:
            noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
            media_item_latents = torch.lerp(noise, media_item_latents, strength)
            patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
            pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
            pixel_coords[:, 0] += media_frame_number
            extra_conditioning_num_latents += patched_latents.shape[1]
            new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
            extra_conditioning_latents.append(patched_latents)
            extra_conditioning_pixel_coords.append(pixel_coords)
            extra_conditioning_mask.append(new_mask)

    init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
    init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
    init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
    init_conditioning_mask = init_conditioning_mask.squeeze(-1)
    
    if extra_conditioning_latents:
        init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
        init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
        init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
    
    return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents

# --- Instanciação Singleton ---
with open("config.yaml", 'r') as f:
    config = yaml.safe_load(f)
ltx_gpus_required = config['specialists']['ltx']['gpus_required']
ltx_device_ids = hardware_manager.allocate_gpus('LTX', ltx_gpus_required)
ltx_config_filename = config['specialists']['ltx']['config_file']
ltx_manager_singleton = LtxPoolManager(device_ids=ltx_device_ids, ltx_config_file_name=ltx_config_filename)
logger.info("Especialista de Vídeo (LTX) pronto.")