euiia commited on
Commit
1319e4b
·
verified ·
1 Parent(s): a0d2dcb

Update ltx_manager_helpers.py

Browse files
Files changed (1) hide show
  1. ltx_manager_helpers.py +57 -82
ltx_manager_helpers.py CHANGED
@@ -1,15 +1,6 @@
1
  # ltx_manager_helpers.py
2
  # Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
3
- #
4
- # ORIGINAL SOURCE: LTX-Video by Lightricks Ltd. & other open-source projects.
5
- # Licensed under the Apache License, Version 2.0
6
- # https://github.com/Lightricks/LTX-Video
7
- #
8
- # MODIFICATIONS FOR ADUC-SDR_Video:
9
- # This file is part of ADUC-SDR_Video, a derivative work based on LTX-Video.
10
- # It has been modified to manage pools of LTX workers, handle GPU memory,
11
- # and prepare parameters for the ADUC-SDR orchestration framework.
12
- # All modifications are also licensed under the Apache License, Version 2.0.
13
 
14
  import torch
15
  import gc
@@ -34,6 +25,7 @@ class LtxWorker:
34
  Gerencia o carregamento do modelo para a CPU e a movimentação de/para a GPU.
35
  """
36
  def __init__(self, device_id, ltx_config_file):
 
37
  self.cpu_device = torch.device('cpu')
38
  self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
39
  logger.info(f"LTX Worker ({self.device}): Inicializando com config '{ltx_config_file}'...")
@@ -60,20 +52,19 @@ class LtxWorker:
60
  )
61
  logger.info(f"LTX Worker ({self.device}): Modelo pronto na CPU. É um modelo destilado? {self.is_distilled}")
62
 
63
- if self.device.type == 'cuda' and can_optimize_fp8():
64
- logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Iniciando otimização...")
65
- self.pipeline.to(self.device)
66
- optimize_ltx_worker(self)
67
- self.pipeline.to(self.cpu_device)
68
- logger.info(f"LTX Worker ({self.device}): Otimização concluída. Modelo pronto.")
69
- elif self.device.type == 'cuda':
70
- logger.info(f"LTX Worker ({self.device}): Otimização FP8 não suportada ou desativada. Usando modelo padrão.")
71
-
72
  def to_gpu(self):
73
- """Move o pipeline para a GPU designada."""
74
  if self.device.type == 'cpu': return
75
  logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
76
  self.pipeline.to(self.device)
 
 
 
 
 
 
 
 
77
 
78
  def to_cpu(self):
79
  """Move o pipeline de volta para a CPU e libera a memória da GPU."""
@@ -89,23 +80,29 @@ class LtxWorker:
89
 
90
  class LtxPoolManager:
91
  """
92
- Gerencia um pool de LtxWorkers para otimizar o uso de múltiplas GPUs,
93
- alternando o worker ativo para permitir que o anterior descarregue da VRAM em segundo plano.
94
  """
95
  def __init__(self, device_ids, ltx_config_file):
96
  logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
97
  self.workers = [LtxWorker(dev_id, ltx_config_file) for dev_id in device_ids]
98
  self.current_worker_index = 0
99
  self.lock = threading.Lock()
100
- self.last_cleanup_thread = None
101
 
102
- def _cleanup_worker_thread(self, worker):
103
- """Thread para descarregar um worker da GPU em segundo plano."""
104
- logger.info(f"LTX CLEANUP THREAD: Iniciando limpeza de {worker.device} em background...")
105
- worker.to_cpu()
 
 
 
 
 
 
 
106
 
107
  def _prepare_and_log_params(self, worker_to_use, **kwargs):
108
- """Prepara e registra os parâmetros para a chamada da pipeline LTX."""
109
  target_device = worker_to_use.device
110
  height, width = kwargs['height'], kwargs['width']
111
 
@@ -113,7 +110,6 @@ class LtxPoolManager:
113
  final_conditioning_items = []
114
  conditioning_log_details = []
115
  for i, item in enumerate(conditioning_data):
116
- # Lida tanto com LatentConditioningItem quanto ConditioningItem (se usado no futuro)
117
  if hasattr(item, 'latent_tensor'):
118
  item.latent_tensor = item.latent_tensor.to(target_device)
119
  final_conditioning_items.append(item)
@@ -146,7 +142,6 @@ class LtxPoolManager:
146
  else:
147
  pipeline_params["num_inference_steps"] = int(kwargs.get('num_inference_steps', 20))
148
 
149
- # Log detalhado dos parâmetros para depuração.
150
  log_friendly_params = pipeline_params.copy()
151
  log_friendly_params.pop('generator', None)
152
  log_friendly_params.pop('conditioning_items', None)
@@ -162,80 +157,60 @@ class LtxPoolManager:
162
 
163
  return pipeline_params, padding_vals
164
 
165
- def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
166
  """
167
- Orquestra a geração de um novo fragmento de vídeo a partir do zero (ruído).
 
168
  """
169
  worker_to_use = None
170
  try:
171
  with self.lock:
172
- if self.last_cleanup_thread and self.last_cleanup_thread.is_alive():
173
- self.last_cleanup_thread.join()
174
  worker_to_use = self.workers[self.current_worker_index]
175
- previous_worker_index = (self.current_worker_index - 1 + len(self.workers)) % len(self.workers)
176
- worker_to_cleanup = self.workers[previous_worker_index]
177
- cleanup_thread = threading.Thread(target=self._cleanup_worker_thread, args=(worker_to_cleanup,))
178
- cleanup_thread.start()
179
- self.last_cleanup_thread = cleanup_thread
180
- worker_to_use.to_gpu()
181
  self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
182
 
183
  pipeline_params, padding_vals = self._prepare_and_log_params(worker_to_use, **kwargs)
184
- pipeline_params['output_type'] = "latent"
185
-
186
- with torch.no_grad():
187
- result_tensor = worker_to_use.generate_video_fragment_internal(**pipeline_params)
188
 
189
- return result_tensor, padding_vals
 
 
 
190
  except Exception as e:
191
- logger.error(f"LTX POOL MANAGER: Erro durante a geração de latentes: {e}", exc_info=True)
192
  raise e
193
  finally:
194
- if worker_to_use:
195
- logger.info(f"LTX POOL MANAGER: Executando limpeza final para {worker_to_use.device}...")
196
- worker_to_use.to_cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  def refine_latents(self, upscaled_latents: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
199
  """
200
- Orquestra um passe de difusão curto em latentes já existentes para refinar texturas.
201
- Usado na etapa de pós-produção de upscale.
202
  """
203
- worker_to_use = None
204
- try:
205
- with self.lock:
206
- if self.last_cleanup_thread and self.last_cleanup_thread.is_alive():
207
- self.last_cleanup_thread.join()
208
- worker_to_use = self.workers[self.current_worker_index]
209
- previous_worker_index = (self.current_worker_index - 1 + len(self.workers)) % len(self.workers)
210
- worker_to_cleanup = self.workers[previous_worker_index]
211
- cleanup_thread = threading.Thread(target=self._cleanup_worker_thread, args=(worker_to_cleanup,))
212
- cleanup_thread.start()
213
- self.last_cleanup_thread = cleanup_thread
214
- worker_to_use.to_gpu()
215
- self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
216
-
217
- pipeline_params, padding_vals = self._prepare_and_log_params(worker_to_use, **kwargs)
218
-
219
- # Parâmetros específicos para o passe de refinamento (denoise)
220
- pipeline_params['latents'] = upscaled_latents.to(worker_to_use.device, dtype=worker_to_use.pipeline.transformer.dtype)
221
- pipeline_params['strength'] = kwargs.get('denoise_strength', 0.4)
222
- pipeline_params['num_inference_steps'] = int(kwargs.get('refine_steps', 10))
223
- pipeline_params['output_type'] = "latent"
224
 
225
  logger.info("LTX POOL MANAGER: Iniciando passe de refinamento (denoise) em latentes de alta resolução.")
226
 
227
  with torch.no_grad():
228
- refined_tensor = worker_to_use.generate_video_fragment_internal(**pipeline_params)
229
-
230
- return refined_tensor, padding_vals
231
 
232
- except Exception as e:
233
- logger.error(f"LTX POOL MANAGER: Erro durante o refinamento de latentes: {e}", exc_info=True)
234
- raise e
235
- finally:
236
- if worker_to_use:
237
- logger.info(f"LTX POOL MANAGER: Executando limpeza final para {worker_to_use.device}...")
238
- worker_to_use.to_cpu()
239
 
240
  # --- Instanciação Singleton ---
241
  logger.info("Lendo config.yaml para inicializar o LTX Pool Manager...")
 
1
  # ltx_manager_helpers.py
2
  # Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
3
+ # (Licenciamento e cabeçalhos permanecem os mesmos)
 
 
 
 
 
 
 
 
 
4
 
5
  import torch
6
  import gc
 
25
  Gerencia o carregamento do modelo para a CPU e a movimentação de/para a GPU.
26
  """
27
  def __init__(self, device_id, ltx_config_file):
28
+ # ... (código do LtxWorker __init__ permanece o mesmo) ...
29
  self.cpu_device = torch.device('cpu')
30
  self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
31
  logger.info(f"LTX Worker ({self.device}): Inicializando com config '{ltx_config_file}'...")
 
52
  )
53
  logger.info(f"LTX Worker ({self.device}): Modelo pronto na CPU. É um modelo destilado? {self.is_distilled}")
54
 
 
 
 
 
 
 
 
 
 
55
  def to_gpu(self):
56
+ """Move o pipeline para a GPU designada E OTIMIZA SE POSSÍVEL."""
57
  if self.device.type == 'cpu': return
58
  logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
59
  self.pipeline.to(self.device)
60
+
61
+ # A otimização agora ocorre aqui, uma única vez, quando o modelo vai para a GPU.
62
+ if self.device.type == 'cuda' and can_optimize_fp8():
63
+ logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Iniciando otimização...")
64
+ optimize_ltx_worker(self)
65
+ logger.info(f"LTX Worker ({self.device}): Otimização concluída.")
66
+ elif self.device.type == 'cuda':
67
+ logger.info(f"LTX Worker ({self.device}): Otimização FP8 não suportada ou desativada.")
68
 
69
  def to_cpu(self):
70
  """Move o pipeline de volta para a CPU e libera a memória da GPU."""
 
80
 
81
  class LtxPoolManager:
82
  """
83
+ Gerencia um pool de LtxWorkers para otimizar o uso de múltiplas GPUs.
84
+ NOVO MODO "HOT START": Mantém todos os modelos carregados na VRAM para latência mínima.
85
  """
86
  def __init__(self, device_ids, ltx_config_file):
87
  logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
88
  self.workers = [LtxWorker(dev_id, ltx_config_file) for dev_id in device_ids]
89
  self.current_worker_index = 0
90
  self.lock = threading.Lock()
 
91
 
92
+ # ######################################################################
93
+ # ## MUDANÇA 1: PRÉ-AQUECIMENTO DAS GPUs ##
94
+ # ######################################################################
95
+ if all(w.device.type == 'cuda' for w in self.workers):
96
+ logger.info("LTX POOL MANAGER: MODO HOT START ATIVADO. Pré-aquecendo todas as GPUs...")
97
+ for worker in self.workers:
98
+ worker.to_gpu()
99
+ logger.info("LTX POOL MANAGER: Todas as GPUs estão quentes e prontas.")
100
+ else:
101
+ logger.info("LTX POOL MANAGER: Operando em modo CPU ou misto. O pré-aquecimento de GPU foi ignorado.")
102
+ # ######################################################################
103
 
104
  def _prepare_and_log_params(self, worker_to_use, **kwargs):
105
+ # ... (Esta função permanece exatamente a mesma) ...
106
  target_device = worker_to_use.device
107
  height, width = kwargs['height'], kwargs['width']
108
 
 
110
  final_conditioning_items = []
111
  conditioning_log_details = []
112
  for i, item in enumerate(conditioning_data):
 
113
  if hasattr(item, 'latent_tensor'):
114
  item.latent_tensor = item.latent_tensor.to(target_device)
115
  final_conditioning_items.append(item)
 
142
  else:
143
  pipeline_params["num_inference_steps"] = int(kwargs.get('num_inference_steps', 20))
144
 
 
145
  log_friendly_params = pipeline_params.copy()
146
  log_friendly_params.pop('generator', None)
147
  log_friendly_params.pop('conditioning_items', None)
 
157
 
158
  return pipeline_params, padding_vals
159
 
160
+ def _execute_on_worker(self, execution_fn, **kwargs):
161
  """
162
+ Função unificada para selecionar um worker e executar uma tarefa,
163
+ sem a lógica de carregar/descarregar.
164
  """
165
  worker_to_use = None
166
  try:
167
  with self.lock:
 
 
168
  worker_to_use = self.workers[self.current_worker_index]
 
 
 
 
 
 
169
  self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
170
 
171
  pipeline_params, padding_vals = self._prepare_and_log_params(worker_to_use, **kwargs)
 
 
 
 
172
 
173
+ result = execution_fn(worker_to_use, pipeline_params, **kwargs)
174
+
175
+ return result, padding_vals
176
+
177
  except Exception as e:
178
+ logger.error(f"LTX POOL MANAGER: Erro durante a execução em {worker_to_use.device if worker_to_use else 'N/A'}: {e}", exc_info=True)
179
  raise e
180
  finally:
181
+ # Apenas limpa o cache da GPU, não descarrega o modelo.
182
+ if worker_to_use and worker_to_use.device.type == 'cuda':
183
+ with torch.cuda.device(worker_to_use.device):
184
+ gc.collect()
185
+ torch.cuda.empty_cache()
186
+
187
+ def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
188
+ """
189
+ Orquestra a geração de um novo fragmento de vídeo a partir do ruído.
190
+ """
191
+ def execution_logic(worker, params, **inner_kwargs):
192
+ params['output_type'] = "latent"
193
+ with torch.no_grad():
194
+ return worker.generate_video_fragment_internal(**params)
195
+
196
+ return self._execute_on_worker(execution_logic, **kwargs)
197
 
198
  def refine_latents(self, upscaled_latents: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
199
  """
200
+ Orquestra um passe de difusão curto em latentes já existentes para refinamento.
 
201
  """
202
+ def execution_logic(worker, params, **inner_kwargs):
203
+ params['latents'] = upscaled_latents.to(worker.device, dtype=worker.pipeline.transformer.dtype)
204
+ params['strength'] = inner_kwargs.get('denoise_strength', 0.4)
205
+ params['num_inference_steps'] = int(inner_kwargs.get('refine_steps', 10))
206
+ params['output_type'] = "latent"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  logger.info("LTX POOL MANAGER: Iniciando passe de refinamento (denoise) em latentes de alta resolução.")
209
 
210
  with torch.no_grad():
211
+ return worker.generate_video_fragment_internal(**params)
 
 
212
 
213
+ return self._execute_on_worker(execution_logic, upscaled_latents=upscaled_latents, **kwargs)
 
 
 
 
 
 
214
 
215
  # --- Instanciação Singleton ---
216
  logger.info("Lendo config.yaml para inicializar o LTX Pool Manager...")