aducsdr commited on
Commit
b5e7c3e
·
verified ·
1 Parent(s): 6493ca5

Update aduc_framework/managers/ltx_manager.py

Browse files
aduc_framework/managers/ltx_manager.py CHANGED
@@ -3,6 +3,11 @@
3
  # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
  #
5
  # Versão 2.3.2 (Com correção de manipulação de dataclass)
 
 
 
 
 
6
 
7
  import torch
8
  import gc
@@ -17,19 +22,108 @@ import subprocess
17
  from pathlib import Path
18
  from typing import Optional, List, Tuple, Union
19
 
 
20
  from ..types import LatentConditioningItem
21
  from ..tools.optimization import optimize_ltx_worker, can_optimize_fp8
22
  from ..tools.hardware_manager import hardware_manager
23
 
24
  logger = logging.getLogger(__name__)
25
 
26
- # --- Gerenciamento de Dependências e Placeholders (sem alteração) ---
27
  DEPS_DIR = Path("./deps")
28
- # ... (código sem alteração) ...
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  class LtxPoolManager:
31
- # ... (__init__ e outros métodos sem alteração) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def _prepare_pipeline_params(self, worker: 'LtxWorker', **kwargs) -> dict:
34
  pipeline_params = {
35
  "height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
@@ -46,16 +140,14 @@ class LtxPoolManager:
46
  if 'strength' in kwargs:
47
  pipeline_params["strength"] = kwargs['strength']
48
 
49
- # --- A CORREÇÃO ESTÁ AQUI ---
50
  if 'conditioning_items_data' in kwargs:
51
  final_conditioning_items = []
52
  for item in kwargs['conditioning_items_data']:
53
- # Como LatentConditioningItem é uma dataclass mutável,
54
- # nós modificamos o atributo diretamente.
55
  item.latent_tensor = item.latent_tensor.to(worker.device)
56
  final_conditioning_items.append(item)
57
  pipeline_params["conditioning_items"] = final_conditioning_items
58
- # --- FIM DA CORREÇÃO ---
59
 
60
  if worker.is_distilled:
61
  fixed_timesteps = worker.config.get("first_pass", {}).get("timesteps")
@@ -70,9 +162,139 @@ class LtxPoolManager:
70
 
71
  return pipeline_params
72
 
73
- # ... (resto da classe LtxPoolManager, LtxWorker e _aduc_prepare_conditioning_patch sem alteração) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # --- Instanciação Singleton (sem alteração) ---
76
  with open("config.yaml", 'r') as f:
77
  config = yaml.safe_load(f)
78
  ltx_gpus_required = config['specialists']['ltx']['gpus_required']
 
3
  # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
  #
5
  # Versão 2.3.2 (Com correção de manipulação de dataclass)
6
+ #
7
+ # Este manager é responsável por controlar a pipeline LTX-Video. Ele gerencia
8
+ # um pool de workers para otimizar o uso de múltiplas GPUs, lida com a inicialização
9
+ # e o setup de dependências complexas, e expõe uma interface de alto nível para a
10
+ # geração de fragmentos de vídeo no espaço latente.
11
 
12
  import torch
13
  import gc
 
22
  from pathlib import Path
23
  from typing import Optional, List, Tuple, Union
24
 
25
+ # --- Imports Relativos Corrigidos ---
26
  from ..types import LatentConditioningItem
27
  from ..tools.optimization import optimize_ltx_worker, can_optimize_fp8
28
  from ..tools.hardware_manager import hardware_manager
29
 
30
  logger = logging.getLogger(__name__)
31
 
32
+ # --- Gerenciamento de Dependências e Placeholders ---
33
  DEPS_DIR = Path("./deps")
34
+ LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
35
+ LTX_VIDEO_REPO_URL = "https://github.com/Lightricks/LTX-Video.git"
36
+
37
+ # Placeholders para módulos importados tardiamente (lazy-loaded)
38
+ create_ltx_video_pipeline = None
39
+ calculate_padding = None
40
+ LTXVideoPipeline = None
41
+ ConditioningItem = None
42
+ LTXMultiScalePipeline = None
43
+ vae_encode = None
44
+ latent_to_pixel_coords = None
45
+ randn_tensor = None
46
 
47
  class LtxPoolManager:
48
+ """
49
+ Gerencia um pool de LtxWorkers e expõe a pipeline de aprimoramento de prompt.
50
+ """
51
+ def __init__(self, device_ids: List[str], ltx_config_file_name: str):
52
+ logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
53
+ self._ltx_modules_loaded = False
54
+ self._setup_dependencies()
55
+ self._lazy_load_ltx_modules()
56
+
57
+ self.ltx_config_file = LTX_VIDEO_REPO_DIR / "configs" / ltx_config_file_name
58
+
59
+ self.workers = [LtxWorker(dev_id, self.ltx_config_file) for dev_id in device_ids]
60
+ self.current_worker_index = 0
61
+ self.lock = threading.Lock()
62
+
63
+ self.prompt_enhancement_pipeline = self.workers[0].pipeline if self.workers else None
64
+ if self.prompt_enhancement_pipeline:
65
+ logger.info("LTX POOL MANAGER: Pipeline de aprimoramento de prompt exposta para outros especialistas.")
66
+
67
+ self._apply_ltx_pipeline_patches()
68
+
69
+ if all(w.device.type == 'cuda' for w in self.workers):
70
+ logger.info("LTX POOL MANAGER: MODO HOT START ATIVADO. Pré-aquecendo todas as GPUs...")
71
+ for worker in self.workers:
72
+ worker.to_gpu()
73
+ logger.info("LTX POOL MANAGER: Todas as GPUs estão prontas.")
74
+ else:
75
+ logger.info("LTX POOL MANAGER: Operando em modo CPU ou misto. Pré-aquecimento de GPU pulado.")
76
+
77
+ def _setup_dependencies(self):
78
+ """Clona o repositório LTX-Video se não encontrado e o adiciona ao sys.path."""
79
+ if not LTX_VIDEO_REPO_DIR.exists():
80
+ logger.info(f"Repositório LTX-Video não encontrado em '{LTX_VIDEO_REPO_DIR}'. Clonando do GitHub...")
81
+ try:
82
+ DEPS_DIR.mkdir(exist_ok=True)
83
+ subprocess.run(
84
+ ["git", "clone", "--depth", "1", LTX_VIDEO_REPO_URL, str(LTX_VIDEO_REPO_DIR)],
85
+ check=True, capture_output=True, text=True
86
+ )
87
+ logger.info("Repositório LTX-Video clonado com sucesso.")
88
+ except subprocess.CalledProcessError as e:
89
+ logger.error(f"Falha ao clonar o repositório LTX-Video. Git stderr: {e.stderr}")
90
+ raise RuntimeError("Não foi possível clonar a dependência LTX-Video do GitHub.")
91
+ else:
92
+ logger.info("Repositório LTX-Video local encontrado.")
93
+
94
+ if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
95
+ sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve()))
96
+ logger.info(f"Adicionado '{LTX_VIDEO_REPO_DIR.resolve()}' ao sys.path.")
97
+
98
+ def _lazy_load_ltx_modules(self):
99
+ """Importa dinamicamente os módulos do LTX-Video após garantir que o repositório existe."""
100
+ if self._ltx_modules_loaded:
101
+ return
102
+
103
+ global create_ltx_video_pipeline, calculate_padding, LTXVideoPipeline, ConditioningItem, LTXMultiScalePipeline
104
+ global vae_encode, latent_to_pixel_coords, randn_tensor
105
+
106
+ from .ltx_pipeline_utils import create_ltx_video_pipeline, calculate_padding
107
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LTXMultiScalePipeline
108
+ from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
109
+ from diffusers.utils.torch_utils import randn_tensor
110
+
111
+ self._ltx_modules_loaded = True
112
+ logger.info("Módulos do LTX-Video foram carregados dinamicamente.")
113
 
114
+ def _apply_ltx_pipeline_patches(self):
115
+ """Aplica patches em tempo de execução na pipeline LTX para compatibilidade com ADUC-SDR."""
116
+ logger.info("LTX POOL MANAGER: Aplicando patches ADUC-SDR na pipeline LTX...")
117
+ for worker in self.workers:
118
+ worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
119
+ logger.info("LTX POOL MANAGER: Todas as instâncias da pipeline foram corrigidas com sucesso.")
120
+
121
+ def _get_next_worker(self) -> 'LtxWorker':
122
+ with self.lock:
123
+ worker = self.workers[self.current_worker_index]
124
+ self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
125
+ return worker
126
+
127
  def _prepare_pipeline_params(self, worker: 'LtxWorker', **kwargs) -> dict:
128
  pipeline_params = {
129
  "height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
 
140
  if 'strength' in kwargs:
141
  pipeline_params["strength"] = kwargs['strength']
142
 
 
143
  if 'conditioning_items_data' in kwargs:
144
  final_conditioning_items = []
145
  for item in kwargs['conditioning_items_data']:
146
+ # CORREÇÃO: Como LatentConditioningItem é uma dataclass mutável,
147
+ # nós modificamos o atributo diretamente no dispositivo do worker.
148
  item.latent_tensor = item.latent_tensor.to(worker.device)
149
  final_conditioning_items.append(item)
150
  pipeline_params["conditioning_items"] = final_conditioning_items
 
151
 
152
  if worker.is_distilled:
153
  fixed_timesteps = worker.config.get("first_pass", {}).get("timesteps")
 
162
 
163
  return pipeline_params
164
 
165
+ def generate_latent_fragment(self, **kwargs) -> Tuple[torch.Tensor, tuple]:
166
+ worker_to_use = self._get_next_worker()
167
+ try:
168
+ height, width = kwargs['height'], kwargs['width']
169
+ padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
170
+ padding_vals = calculate_padding(height, width, padded_h, padded_w)
171
+ kwargs['height'], kwargs['width'] = padded_h, padded_w
172
+
173
+ pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
174
+
175
+ logger.info(f"Iniciando GERAÇÃO em {worker_to_use.device} com shape {padded_w}x{padded_h}")
176
+
177
+ if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline):
178
+ result = worker_to_use.pipeline.video_pipeline(**pipeline_params).images
179
+ else:
180
+ result = worker_to_use.generate_video_fragment_internal(**pipeline_params)
181
+ return result, padding_vals
182
+ except Exception as e:
183
+ logger.error(f"LTX POOL MANAGER: Erro durante a geração em {worker_to_use.device}: {e}", exc_info=True)
184
+ raise e
185
+ finally:
186
+ if worker_to_use and worker_to_use.device.type == 'cuda':
187
+ with torch.cuda.device(worker_to_use.device):
188
+ gc.collect()
189
+ torch.cuda.empty_cache()
190
+
191
+ def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, tuple]:
192
+ pass # Placeholder
193
+
194
+ class LtxWorker:
195
+ """Representa uma única instância da pipeline LTX-Video em um dispositivo específico."""
196
+ def __init__(self, device_id, ltx_config_file):
197
+ self.cpu_device = torch.device('cpu')
198
+ self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
199
+ logger.info(f"LTX Worker ({self.device}): Inicializando com config '{ltx_config_file}'...")
200
+
201
+ with open(ltx_config_file, "r") as file:
202
+ self.config = yaml.safe_load(file)
203
+
204
+ self.is_distilled = "distilled" in self.config.get("checkpoint_path", "")
205
+ models_dir = LTX_VIDEO_REPO_DIR / "models_downloaded"
206
+
207
+ logger.info(f"LTX Worker ({self.device}): Preparando para carregar modelo...")
208
+ model_filename = self.config["checkpoint_path"]
209
+ model_path = huggingface_hub.hf_hub_download(
210
+ repo_id="Lightricks/LTX-Video", filename=model_filename,
211
+ local_dir=str(models_dir), local_dir_use_symlinks=False
212
+ )
213
+
214
+ self.pipeline = create_ltx_video_pipeline(
215
+ ckpt_path=model_path,
216
+ precision=self.config["precision"],
217
+ text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
218
+ sampler=self.config["sampler"],
219
+ device='cpu'
220
+ )
221
+ logger.info(f"LTX Worker ({self.device}): Modelo pronto na CPU. É um modelo distilled? {self.is_distilled}")
222
+
223
+ def to_gpu(self):
224
+ if self.device.type == 'cpu': return
225
+ logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
226
+ self.pipeline.to(self.device)
227
+ if self.device.type == 'cuda' and can_optimize_fp8():
228
+ logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Otimizando...")
229
+ optimize_ltx_worker(self)
230
+ logger.info(f"LTX Worker ({self.device}): Otimização completa.")
231
+
232
+ def to_cpu(self):
233
+ if self.device.type == 'cpu': return
234
+ logger.info(f"LTX Worker: Descarregando pipeline da GPU {self.device}...")
235
+ self.pipeline.to('cpu')
236
+ gc.collect()
237
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
238
+
239
+ def generate_video_fragment_internal(self, **kwargs):
240
+ return self.pipeline(**kwargs).images
241
+
242
+ def _aduc_prepare_conditioning_patch(
243
+ self: "LTXVideoPipeline",
244
+ conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]],
245
+ init_latents: torch.Tensor,
246
+ num_frames: int,
247
+ height: int,
248
+ width: int,
249
+ vae_per_channel_normalize: bool = False,
250
+ generator=None,
251
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
252
+ if not conditioning_items:
253
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
254
+ init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
255
+ return init_latents, init_pixel_coords, None, 0
256
+
257
+ init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
258
+ extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
259
+ extra_conditioning_num_latents = 0
260
+
261
+ for item in conditioning_items:
262
+ if not isinstance(item, LatentConditioningItem):
263
+ logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.")
264
+ continue
265
+
266
+ media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
267
+ media_frame_number, strength = item.media_frame_number, item.conditioning_strength
268
+
269
+ if media_frame_number == 0:
270
+ f_l, h_l, w_l = media_item_latents.shape[-3:]
271
+ init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
272
+ init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
273
+ else:
274
+ noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
275
+ media_item_latents = torch.lerp(noise, media_item_latents, strength)
276
+ patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
277
+ pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
278
+ pixel_coords[:, 0] += media_frame_number
279
+ extra_conditioning_num_latents += patched_latents.shape[1]
280
+ new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
281
+ extra_conditioning_latents.append(patched_latents)
282
+ extra_conditioning_pixel_coords.append(pixel_coords)
283
+ extra_conditioning_mask.append(new_mask)
284
+
285
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
286
+ init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
287
+ init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
288
+ init_conditioning_mask = init_conditioning_mask.squeeze(-1)
289
+
290
+ if extra_conditioning_latents:
291
+ init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
292
+ init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
293
+ init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
294
+
295
+ return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
296
 
297
+ # --- Instanciação Singleton ---
298
  with open("config.yaml", 'r') as f:
299
  config = yaml.safe_load(f)
300
  ltx_gpus_required = config['specialists']['ltx']['gpus_required']