euiia commited on
Commit
3526526
·
verified ·
1 Parent(s): d1515a1

Update ltx_manager_helpers.py

Browse files
Files changed (1) hide show
  1. ltx_manager_helpers.py +89 -38
ltx_manager_helpers.py CHANGED
@@ -14,6 +14,7 @@ import huggingface_hub
14
  import time
15
  import threading
16
  import json
 
17
 
18
  from optimization import optimize_ltx_worker, can_optimize_fp8
19
  from hardware_manager import hardware_manager
@@ -103,33 +104,88 @@ class LtxPoolManager:
103
  worker = self.workers[self.current_worker_index]
104
  self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
105
  return worker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
108
  worker_to_use = self._get_next_worker()
109
  try:
 
110
  height, width = kwargs['height'], kwargs['width']
111
  padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
112
  padding_vals = calculate_padding(height, width, padded_h, padded_w)
113
 
114
- conditioning_data = kwargs.get('conditioning_items_data', [])
115
- final_conditioning_items = []
116
- for item in conditioning_data:
117
- item.latent_tensor = item.latent_tensor.to(worker_to_use.device)
118
- final_conditioning_items.append(item)
119
-
120
- pipeline_params = {
121
- "height": padded_h, "width": padded_w, "num_frames": kwargs['video_total_frames'],
122
- "frame_rate": kwargs['video_fps'], "generator": torch.Generator(device=worker_to_use.device).manual_seed(int(time.time()) + kwargs['current_fragment_index']),
123
- "conditioning_items": final_conditioning_items, "is_video": True, "vae_per_channel_normalize": True,
124
- "prompt": kwargs['motion_prompt'], "negative_prompt": "blurry, distorted, static, bad quality",
125
- "guidance_scale": kwargs['guidance_scale'], "stg_scale": kwargs['stg_scale'],
126
- "rescaling_scale": kwargs['rescaling_scale'], "num_inference_steps": kwargs['num_inference_steps'],
127
- "output_type": "latent"
128
- }
129
- if worker_to_use.is_distilled:
130
- pipeline_params["timesteps"] = worker_to_use.config.get("first_pass", {}).get("timesteps")
131
- pipeline_params["num_inference_steps"] = len(pipeline_params["timesteps"]) if pipeline_params["timesteps"] else 20
132
 
 
 
 
 
133
  if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline):
134
  result = worker_to_use.pipeline.video_pipeline(**pipeline_params).images
135
  else:
@@ -144,35 +200,30 @@ class LtxPoolManager:
144
  with torch.cuda.device(worker_to_use.device):
145
  gc.collect(); torch.cuda.empty_cache()
146
 
 
147
  def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
148
  worker_to_use = self._get_next_worker()
149
  try:
150
- height, width, num_frames = kwargs['height'], kwargs['width'], kwargs['video_total_frames']
 
 
 
151
 
152
- pipeline_params = {
153
- "latents": latents_to_refine.to(worker_to_use.device, dtype=worker_to_use.pipeline.transformer.dtype),
154
- "height": height, "width": width, "num_frames": num_frames, "frame_rate": kwargs['video_fps'],
155
- "generator": torch.Generator(device=worker_to_use.device).manual_seed(int(time.time()) + kwargs['current_fragment_index']),
156
- "is_video": True, "vae_per_channel_normalize": True,
157
- "prompt": kwargs['motion_prompt'], "negative_prompt": "blurry, distorted, static, bad quality",
158
- "guidance_scale": kwargs.get('guidance_scale', 1.0),
159
- "num_inference_steps": int(kwargs.get('refine_steps', 10)),
160
- "strength": kwargs.get('denoise_strength', 0.4),
161
- "output_type": "latent"
162
- }
163
 
164
  logger.info("LTX POOL MANAGER: Iniciando passe de refinamento (denoise)...")
165
 
166
  pipeline_to_call = worker_to_use.pipeline.video_pipeline if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline) else worker_to_use.pipeline
167
 
168
- try:
169
- result = pipeline_to_call(**pipeline_params).images
170
- return result, None
171
- except torch.cuda.OutOfMemoryError as e:
172
- logger.error(f"FALHA DE MEMÓRIA DURANTE O REFINAMENTO em {worker_to_use.device}: {e}")
173
- logger.warning("Limpando VRAM e retornando None para sinalizar a falha.")
174
- gc.collect(); torch.cuda.empty_cache()
175
- return None, None
176
  except Exception as e:
177
  logger.error(f"LTX POOL MANAGER: Erro inesperado durante o refinamento em {worker_to_use.device}: {e}", exc_info=True)
178
  raise e
 
14
  import time
15
  import threading
16
  import json
17
+ from typing import Optional, List
18
 
19
  from optimization import optimize_ltx_worker, can_optimize_fp8
20
  from hardware_manager import hardware_manager
 
104
  worker = self.workers[self.current_worker_index]
105
  self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
106
  return worker
107
+
108
+ # --- [NOVO] Função centralizada para preparar parâmetros da pipeline ---
109
+ def _prepare_pipeline_params(self, worker: LtxWorker, **kwargs) -> dict:
110
+ """Prepara o dicionário de parâmetros para a pipeline, tratando casos especiais como modelos destilados."""
111
+
112
+ # Parâmetros obrigatórios
113
+ height = kwargs['height']
114
+ width = kwargs['width']
115
+ num_frames = kwargs['video_total_frames']
116
+
117
+ # Parâmetros com valores padrão
118
+ motion_prompt = kwargs.get('motion_prompt', "")
119
+ negative_prompt = kwargs.get('negative_prompt', "blurry, distorted, static, bad quality")
120
+ guidance_scale = kwargs.get('guidance_scale', 1.0)
121
+ stg_scale = kwargs.get('stg_scale', 0.0)
122
+ rescaling_scale = kwargs.get('rescaling_scale', 0.15)
123
+ num_inference_steps = kwargs.get('num_inference_steps', 20)
124
+
125
+ # Parâmetros opcionais (para geração ou refinamento)
126
+ latents_input = kwargs.get('latents')
127
+ strength = kwargs.get('strength')
128
+ conditioning_data = kwargs.get('conditioning_items_data')
129
+
130
+ # Prepara os itens de condicionamento se existirem
131
+ final_conditioning_items = []
132
+ if conditioning_data:
133
+ for item in conditioning_data:
134
+ item.latent_tensor = item.latent_tensor.to(worker.device)
135
+ final_conditioning_items.append(item)
136
 
137
+ # Constrói o dicionário base de parâmetros
138
+ pipeline_params = {
139
+ "height": height, "width": width, "num_frames": num_frames,
140
+ "frame_rate": kwargs.get('video_fps', 24),
141
+ "generator": torch.Generator(device=worker.device).manual_seed(int(time.time()) + kwargs.get('current_fragment_index', 0)),
142
+ "is_video": True, "vae_per_channel_normalize": True,
143
+ "prompt": motion_prompt, "negative_prompt": negative_prompt,
144
+ "guidance_scale": guidance_scale, "stg_scale": stg_scale,
145
+ "rescaling_scale": rescaling_scale, "num_inference_steps": num_inference_steps,
146
+ "output_type": "latent"
147
+ }
148
+
149
+ # Adiciona parâmetros opcionais se eles foram fornecidos
150
+ if latents_input is not None:
151
+ pipeline_params["latents"] = latents_input.to(worker.device, dtype=worker.pipeline.transformer.dtype)
152
+ if strength is not None:
153
+ pipeline_params["strength"] = strength
154
+ if final_conditioning_items:
155
+ pipeline_params["conditioning_items"] = final_conditioning_items
156
+
157
+ # --- LÓGICA CENTRALIZADA E À PROVA DE ERRO ---
158
+ # Se o modelo for destilado, sobrescreve os passos com os timesteps fixos obrigatórios.
159
+ if worker.is_distilled:
160
+ logger.info(f"Worker {worker.device} está usando um modelo destilado. Usando timesteps fixos.")
161
+ fixed_timesteps = worker.config.get("first_pass", {}).get("timesteps")
162
+ pipeline_params["timesteps"] = fixed_timesteps
163
+ if fixed_timesteps:
164
+ pipeline_params["num_inference_steps"] = len(fixed_timesteps)
165
+
166
+ # Log dos parâmetros para depuração
167
+ log_params = {k: v for k, v in pipeline_params.items() if k not in ['generator', 'latents', 'conditioning_items']}
168
+ logger.info(f"Parâmetros preparados para a pipeline em {worker.device}:\n{json.dumps(log_params, indent=2)}")
169
+
170
+ return pipeline_params
171
+
172
+ # --- [REATORADO] Função de Geração simplificada ---
173
  def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
174
  worker_to_use = self._get_next_worker()
175
  try:
176
+ # Padding
177
  height, width = kwargs['height'], kwargs['width']
178
  padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
179
  padding_vals = calculate_padding(height, width, padded_h, padded_w)
180
 
181
+ # Atualiza kwargs com as dimensões com padding
182
+ kwargs['height'] = padded_h
183
+ kwargs['width'] = padded_w
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ # Prepara os parâmetros usando a função centralizada
186
+ pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
187
+
188
+ # Executa a geração
189
  if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline):
190
  result = worker_to_use.pipeline.video_pipeline(**pipeline_params).images
191
  else:
 
200
  with torch.cuda.device(worker_to_use.device):
201
  gc.collect(); torch.cuda.empty_cache()
202
 
203
+ # --- [REATORADO] Função de Refinamento simplificada ---
204
  def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
205
  worker_to_use = self._get_next_worker()
206
  try:
207
+ # Adiciona os tensores e a força de denoise aos kwargs para a função auxiliar
208
+ kwargs['latents'] = latents_to_refine
209
+ kwargs['strength'] = kwargs.get('denoise_strength', 0.4)
210
+ kwargs['num_inference_steps'] = int(kwargs.get('refine_steps', 10))
211
 
212
+ # Prepara os parâmetros usando a mesma função centralizada
213
+ pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
 
 
 
 
 
 
 
 
 
214
 
215
  logger.info("LTX POOL MANAGER: Iniciando passe de refinamento (denoise)...")
216
 
217
  pipeline_to_call = worker_to_use.pipeline.video_pipeline if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline) else worker_to_use.pipeline
218
 
219
+ result = pipeline_to_call(**pipeline_params).images
220
+ return result, None
221
+
222
+ except torch.cuda.OutOfMemoryError as e:
223
+ logger.error(f"FALHA DE MEMÓRIA DURANTE O REFINAMENTO em {worker_to_use.device}: {e}")
224
+ logger.warning("Limpando VRAM e retornando None para sinalizar a falha.")
225
+ gc.collect(); torch.cuda.empty_cache()
226
+ return None, None
227
  except Exception as e:
228
  logger.error(f"LTX POOL MANAGER: Erro inesperado durante o refinamento em {worker_to_use.device}: {e}", exc_info=True)
229
  raise e