euiia commited on
Commit
a0d2dcb
·
verified ·
1 Parent(s): 795e89c

Update ltx_manager_helpers.py

Browse files
Files changed (1) hide show
  1. ltx_manager_helpers.py +66 -16
ltx_manager_helpers.py CHANGED
@@ -25,11 +25,14 @@ from optimization import optimize_ltx_worker, can_optimize_fp8
25
  from hardware_manager import hardware_manager
26
  from inference import create_ltx_video_pipeline, calculate_padding
27
  from ltx_video.pipelines.pipeline_ltx_video import LatentConditioningItem
28
- from ltx_video.models.autoencoders.vae_encode import vae_decode
29
 
30
  logger = logging.getLogger(__name__)
31
 
32
  class LtxWorker:
 
 
 
 
33
  def __init__(self, device_id, ltx_config_file):
34
  self.cpu_device = torch.device('cpu')
35
  self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
@@ -67,11 +70,13 @@ class LtxWorker:
67
  logger.info(f"LTX Worker ({self.device}): Otimização FP8 não suportada ou desativada. Usando modelo padrão.")
68
 
69
  def to_gpu(self):
 
70
  if self.device.type == 'cpu': return
71
  logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
72
  self.pipeline.to(self.device)
73
 
74
  def to_cpu(self):
 
75
  if self.device.type == 'cpu': return
76
  logger.info(f"LTX Worker: Descarregando pipeline da GPU {self.device}...")
77
  self.pipeline.to('cpu')
@@ -79,9 +84,14 @@ class LtxWorker:
79
  if torch.cuda.is_available(): torch.cuda.empty_cache()
80
 
81
  def generate_video_fragment_internal(self, **kwargs):
 
82
  return self.pipeline(**kwargs).images
83
 
84
  class LtxPoolManager:
 
 
 
 
85
  def __init__(self, device_ids, ltx_config_file):
86
  logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
87
  self.workers = [LtxWorker(dev_id, ltx_config_file) for dev_id in device_ids]
@@ -90,19 +100,20 @@ class LtxPoolManager:
90
  self.last_cleanup_thread = None
91
 
92
  def _cleanup_worker_thread(self, worker):
 
93
  logger.info(f"LTX CLEANUP THREAD: Iniciando limpeza de {worker.device} em background...")
94
  worker.to_cpu()
95
 
96
  def _prepare_and_log_params(self, worker_to_use, **kwargs):
 
97
  target_device = worker_to_use.device
98
  height, width = kwargs['height'], kwargs['width']
99
 
100
  conditioning_data = kwargs.get('conditioning_items_data', [])
101
  final_conditioning_items = []
102
-
103
- # --- LOG ADICIONADO: Detalhes dos tensores de condicionamento ---
104
  conditioning_log_details = []
105
  for i, item in enumerate(conditioning_data):
 
106
  if hasattr(item, 'latent_tensor'):
107
  item.latent_tensor = item.latent_tensor.to(target_device)
108
  final_conditioning_items.append(item)
@@ -121,23 +132,21 @@ class LtxPoolManager:
121
  "conditioning_items": final_conditioning_items,
122
  "is_video": True, "vae_per_channel_normalize": True,
123
  "decode_timestep": float(kwargs.get('decode_timestep', worker_to_use.config.get("decode_timestep", 0.05))),
124
- "decode_noise_scale": float(kwargs.get('decode_noise_scale', worker_to_use.config.get("decode_noise_scale", 0.025))),
125
  "image_cond_noise_scale": float(kwargs.get('image_cond_noise_scale', 0.0)),
126
- "stochastic_sampling": bool(kwargs.get('stochastic_sampling', worker_to_use.config.get("stochastic_sampling", False))),
127
  "prompt": kwargs['motion_prompt'],
128
  "negative_prompt": kwargs.get('negative_prompt', "blurry, distorted, static, bad quality, artifacts"),
129
- "guidance_scale": float(kwargs.get('guidance_scale', 1.0)),
130
- "stg_scale": float(kwargs.get('stg_scale', 0.0)),
131
- "rescaling_scale": float(kwargs.get('rescaling_scale', 1.0)),
132
  }
133
 
134
  if worker_to_use.is_distilled:
135
  pipeline_params["timesteps"] = first_pass_config.get("timesteps")
136
- pipeline_params["num_inference_steps"] = len(pipeline_params["timesteps"]) if "timesteps" in first_pass_config else 8
137
  else:
138
- pipeline_params["num_inference_steps"] = int(kwargs.get('num_inference_steps', 7))
139
 
140
- # --- LOG ADICIONADO: Exibição completa dos parâmetros da pipeline ---
141
  log_friendly_params = pipeline_params.copy()
142
  log_friendly_params.pop('generator', None)
143
  log_friendly_params.pop('conditioning_items', None)
@@ -148,15 +157,16 @@ class LtxPoolManager:
148
  logger.info("-" * 20 + " PARÂMETROS DA PIPELINE " + "-" * 20)
149
  logger.info(json.dumps(log_friendly_params, indent=2))
150
  logger.info("-" * 20 + " ITENS DE CONDICIONAMENTO " + "-" * 19)
151
- logger.info("\n".join(conditioning_log_details))
152
  logger.info("="*60)
153
- # --- FIM DO LOG ADICIONADO ---
154
 
155
  return pipeline_params, padding_vals
156
 
157
  def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
 
 
 
158
  worker_to_use = None
159
- progress = kwargs.get('progress')
160
  try:
161
  with self.lock:
162
  if self.last_cleanup_thread and self.last_cleanup_thread.is_alive():
@@ -173,8 +183,6 @@ class LtxPoolManager:
173
  pipeline_params, padding_vals = self._prepare_and_log_params(worker_to_use, **kwargs)
174
  pipeline_params['output_type'] = "latent"
175
 
176
- if progress: progress(0.1, desc=f"[Especialista LTX em {worker_to_use.device}] Gerando latentes...")
177
-
178
  with torch.no_grad():
179
  result_tensor = worker_to_use.generate_video_fragment_internal(**pipeline_params)
180
 
@@ -187,7 +195,49 @@ class LtxPoolManager:
187
  logger.info(f"LTX POOL MANAGER: Executando limpeza final para {worker_to_use.device}...")
188
  worker_to_use.to_cpu()
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
 
191
  logger.info("Lendo config.yaml para inicializar o LTX Pool Manager...")
192
  with open("config.yaml", 'r') as f:
193
  config = yaml.safe_load(f)
 
25
  from hardware_manager import hardware_manager
26
  from inference import create_ltx_video_pipeline, calculate_padding
27
  from ltx_video.pipelines.pipeline_ltx_video import LatentConditioningItem
 
28
 
29
  logger = logging.getLogger(__name__)
30
 
31
  class LtxWorker:
32
+ """
33
+ Representa uma única instância da pipeline LTX-Video em um dispositivo específico.
34
+ Gerencia o carregamento do modelo para a CPU e a movimentação de/para a GPU.
35
+ """
36
  def __init__(self, device_id, ltx_config_file):
37
  self.cpu_device = torch.device('cpu')
38
  self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
 
70
  logger.info(f"LTX Worker ({self.device}): Otimização FP8 não suportada ou desativada. Usando modelo padrão.")
71
 
72
  def to_gpu(self):
73
+ """Move o pipeline para a GPU designada."""
74
  if self.device.type == 'cpu': return
75
  logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
76
  self.pipeline.to(self.device)
77
 
78
  def to_cpu(self):
79
+ """Move o pipeline de volta para a CPU e libera a memória da GPU."""
80
  if self.device.type == 'cpu': return
81
  logger.info(f"LTX Worker: Descarregando pipeline da GPU {self.device}...")
82
  self.pipeline.to('cpu')
 
84
  if torch.cuda.is_available(): torch.cuda.empty_cache()
85
 
86
  def generate_video_fragment_internal(self, **kwargs):
87
+ """Invoca a pipeline de geração."""
88
  return self.pipeline(**kwargs).images
89
 
90
  class LtxPoolManager:
91
+ """
92
+ Gerencia um pool de LtxWorkers para otimizar o uso de múltiplas GPUs,
93
+ alternando o worker ativo para permitir que o anterior descarregue da VRAM em segundo plano.
94
+ """
95
  def __init__(self, device_ids, ltx_config_file):
96
  logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
97
  self.workers = [LtxWorker(dev_id, ltx_config_file) for dev_id in device_ids]
 
100
  self.last_cleanup_thread = None
101
 
102
  def _cleanup_worker_thread(self, worker):
103
+ """Thread para descarregar um worker da GPU em segundo plano."""
104
  logger.info(f"LTX CLEANUP THREAD: Iniciando limpeza de {worker.device} em background...")
105
  worker.to_cpu()
106
 
107
  def _prepare_and_log_params(self, worker_to_use, **kwargs):
108
+ """Prepara e registra os parâmetros para a chamada da pipeline LTX."""
109
  target_device = worker_to_use.device
110
  height, width = kwargs['height'], kwargs['width']
111
 
112
  conditioning_data = kwargs.get('conditioning_items_data', [])
113
  final_conditioning_items = []
 
 
114
  conditioning_log_details = []
115
  for i, item in enumerate(conditioning_data):
116
+ # Lida tanto com LatentConditioningItem quanto ConditioningItem (se usado no futuro)
117
  if hasattr(item, 'latent_tensor'):
118
  item.latent_tensor = item.latent_tensor.to(target_device)
119
  final_conditioning_items.append(item)
 
132
  "conditioning_items": final_conditioning_items,
133
  "is_video": True, "vae_per_channel_normalize": True,
134
  "decode_timestep": float(kwargs.get('decode_timestep', worker_to_use.config.get("decode_timestep", 0.05))),
 
135
  "image_cond_noise_scale": float(kwargs.get('image_cond_noise_scale', 0.0)),
 
136
  "prompt": kwargs['motion_prompt'],
137
  "negative_prompt": kwargs.get('negative_prompt', "blurry, distorted, static, bad quality, artifacts"),
138
+ "guidance_scale": float(kwargs.get('guidance_scale', 2.0)),
139
+ "stg_scale": float(kwargs.get('stg_scale', 0.025)),
140
+ "rescaling_scale": float(kwargs.get('rescaling_scale', 0.15)),
141
  }
142
 
143
  if worker_to_use.is_distilled:
144
  pipeline_params["timesteps"] = first_pass_config.get("timesteps")
145
+ pipeline_params["num_inference_steps"] = len(pipeline_params["timesteps"]) if "timesteps" in first_pass_config else 20
146
  else:
147
+ pipeline_params["num_inference_steps"] = int(kwargs.get('num_inference_steps', 20))
148
 
149
+ # Log detalhado dos parâmetros para depuração.
150
  log_friendly_params = pipeline_params.copy()
151
  log_friendly_params.pop('generator', None)
152
  log_friendly_params.pop('conditioning_items', None)
 
157
  logger.info("-" * 20 + " PARÂMETROS DA PIPELINE " + "-" * 20)
158
  logger.info(json.dumps(log_friendly_params, indent=2))
159
  logger.info("-" * 20 + " ITENS DE CONDICIONAMENTO " + "-" * 19)
160
+ logger.info("\n".join(conditioning_log_details) if conditioning_log_details else " - Nenhum")
161
  logger.info("="*60)
 
162
 
163
  return pipeline_params, padding_vals
164
 
165
  def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
166
+ """
167
+ Orquestra a geração de um novo fragmento de vídeo a partir do zero (ruído).
168
+ """
169
  worker_to_use = None
 
170
  try:
171
  with self.lock:
172
  if self.last_cleanup_thread and self.last_cleanup_thread.is_alive():
 
183
  pipeline_params, padding_vals = self._prepare_and_log_params(worker_to_use, **kwargs)
184
  pipeline_params['output_type'] = "latent"
185
 
 
 
186
  with torch.no_grad():
187
  result_tensor = worker_to_use.generate_video_fragment_internal(**pipeline_params)
188
 
 
195
  logger.info(f"LTX POOL MANAGER: Executando limpeza final para {worker_to_use.device}...")
196
  worker_to_use.to_cpu()
197
 
198
+ def refine_latents(self, upscaled_latents: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
199
+ """
200
+ Orquestra um passe de difusão curto em latentes já existentes para refinar texturas.
201
+ Usado na etapa de pós-produção de upscale.
202
+ """
203
+ worker_to_use = None
204
+ try:
205
+ with self.lock:
206
+ if self.last_cleanup_thread and self.last_cleanup_thread.is_alive():
207
+ self.last_cleanup_thread.join()
208
+ worker_to_use = self.workers[self.current_worker_index]
209
+ previous_worker_index = (self.current_worker_index - 1 + len(self.workers)) % len(self.workers)
210
+ worker_to_cleanup = self.workers[previous_worker_index]
211
+ cleanup_thread = threading.Thread(target=self._cleanup_worker_thread, args=(worker_to_cleanup,))
212
+ cleanup_thread.start()
213
+ self.last_cleanup_thread = cleanup_thread
214
+ worker_to_use.to_gpu()
215
+ self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
216
+
217
+ pipeline_params, padding_vals = self._prepare_and_log_params(worker_to_use, **kwargs)
218
+
219
+ # Parâmetros específicos para o passe de refinamento (denoise)
220
+ pipeline_params['latents'] = upscaled_latents.to(worker_to_use.device, dtype=worker_to_use.pipeline.transformer.dtype)
221
+ pipeline_params['strength'] = kwargs.get('denoise_strength', 0.4)
222
+ pipeline_params['num_inference_steps'] = int(kwargs.get('refine_steps', 10))
223
+ pipeline_params['output_type'] = "latent"
224
+
225
+ logger.info("LTX POOL MANAGER: Iniciando passe de refinamento (denoise) em latentes de alta resolução.")
226
+
227
+ with torch.no_grad():
228
+ refined_tensor = worker_to_use.generate_video_fragment_internal(**pipeline_params)
229
+
230
+ return refined_tensor, padding_vals
231
+
232
+ except Exception as e:
233
+ logger.error(f"LTX POOL MANAGER: Erro durante o refinamento de latentes: {e}", exc_info=True)
234
+ raise e
235
+ finally:
236
+ if worker_to_use:
237
+ logger.info(f"LTX POOL MANAGER: Executando limpeza final para {worker_to_use.device}...")
238
+ worker_to_use.to_cpu()
239
 
240
+ # --- Instanciação Singleton ---
241
  logger.info("Lendo config.yaml para inicializar o LTX Pool Manager...")
242
  with open("config.yaml", 'r') as f:
243
  config = yaml.safe_load(f)