euiia commited on
Commit
4553f61
·
verified ·
1 Parent(s): 70b520b

Update ltx_manager_helpers.py

Browse files
Files changed (1) hide show
  1. ltx_manager_helpers.py +27 -58
ltx_manager_helpers.py CHANGED
@@ -105,87 +105,52 @@ class LtxPoolManager:
105
  self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
106
  return worker
107
 
108
- # --- [NOVO] Função centralizada para preparar parâmetros da pipeline ---
109
  def _prepare_pipeline_params(self, worker: LtxWorker, **kwargs) -> dict:
110
  """Prepara o dicionário de parâmetros para a pipeline, tratando casos especiais como modelos destilados."""
111
-
112
- # Parâmetros obrigatórios
113
- height = kwargs['height']
114
- width = kwargs['width']
115
- num_frames = kwargs['video_total_frames']
116
-
117
- # Parâmetros com valores padrão
118
- motion_prompt = kwargs.get('motion_prompt', "")
119
- negative_prompt = kwargs.get('negative_prompt', "blurry, distorted, static, bad quality")
120
- guidance_scale = kwargs.get('guidance_scale', 1.0)
121
- stg_scale = kwargs.get('stg_scale', 0.0)
122
- rescaling_scale = kwargs.get('rescaling_scale', 0.15)
123
- num_inference_steps = kwargs.get('num_inference_steps', 20)
124
-
125
- # Parâmetros opcionais (para geração ou refinamento)
126
- latents_input = kwargs.get('latents')
127
- strength = kwargs.get('strength')
128
- conditioning_data = kwargs.get('conditioning_items_data')
129
-
130
- # Prepara os itens de condicionamento se existirem
131
- final_conditioning_items = []
132
- if conditioning_data:
133
- for item in conditioning_data:
134
- item.latent_tensor = item.latent_tensor.to(worker.device)
135
- final_conditioning_items.append(item)
136
-
137
- # Constrói o dicionário base de parâmetros
138
  pipeline_params = {
139
- "height": height, "width": width, "num_frames": num_frames,
140
  "frame_rate": kwargs.get('video_fps', 24),
141
  "generator": torch.Generator(device=worker.device).manual_seed(int(time.time()) + kwargs.get('current_fragment_index', 0)),
142
  "is_video": True, "vae_per_channel_normalize": True,
143
- "prompt": motion_prompt, "negative_prompt": negative_prompt,
144
- "guidance_scale": guidance_scale, "stg_scale": stg_scale,
145
- "rescaling_scale": rescaling_scale, "num_inference_steps": num_inference_steps,
146
  "output_type": "latent"
147
  }
148
 
149
- # Adiciona parâmetros opcionais se eles foram fornecidos
150
- if latents_input is not None:
151
- pipeline_params["latents"] = latents_input.to(worker.device, dtype=worker.pipeline.transformer.dtype)
152
- if strength is not None:
153
- pipeline_params["strength"] = strength
154
- if final_conditioning_items:
 
 
 
155
  pipeline_params["conditioning_items"] = final_conditioning_items
156
 
157
- # --- LÓGICA CENTRALIZADA E À PROVA DE ERRO ---
158
- # Se o modelo for destilado, sobrescreve os passos com os timesteps fixos obrigatórios.
159
  if worker.is_distilled:
160
  logger.info(f"Worker {worker.device} está usando um modelo destilado. Usando timesteps fixos.")
161
  fixed_timesteps = worker.config.get("first_pass", {}).get("timesteps")
162
  pipeline_params["timesteps"] = fixed_timesteps
163
  if fixed_timesteps:
164
  pipeline_params["num_inference_steps"] = len(fixed_timesteps)
165
-
166
- # Log dos parâmetros para depuração
167
- log_params = {k: v for k, v in pipeline_params.items() if k not in ['generator', 'latents', 'conditioning_items']}
168
- logger.info(f"Parâmetros preparados para a pipeline em {worker.device}:\n{json.dumps(log_params, indent=2)}")
169
-
170
  return pipeline_params
171
 
172
- # --- [REATORADO] Função de Geração simplificada ---
173
  def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
174
  worker_to_use = self._get_next_worker()
175
  try:
176
- # Padding
177
  height, width = kwargs['height'], kwargs['width']
178
  padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
179
  padding_vals = calculate_padding(height, width, padded_h, padded_w)
180
-
181
- # Atualiza kwargs com as dimensões com padding
182
- kwargs['height'] = padded_h
183
- kwargs['width'] = padded_w
184
 
185
- # Prepara os parâmetros usando a função centralizada
186
  pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
187
 
188
- # Executa a geração
 
189
  if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline):
190
  result = worker_to_use.pipeline.video_pipeline(**pipeline_params).images
191
  else:
@@ -200,22 +165,26 @@ class LtxPoolManager:
200
  with torch.cuda.device(worker_to_use.device):
201
  gc.collect(); torch.cuda.empty_cache()
202
 
203
- # --- [REATORADO] Função de Refinamento simplificada ---
204
  def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
205
  worker_to_use = self._get_next_worker()
206
  try:
207
- # Adiciona os tensores e a força de denoise aos kwargs para a função auxiliar
 
 
 
 
 
 
 
208
  kwargs['latents'] = latents_to_refine
209
  kwargs['strength'] = kwargs.get('denoise_strength', 0.4)
210
  kwargs['num_inference_steps'] = int(kwargs.get('refine_steps', 10))
211
 
212
- # Prepara os parâmetros usando a mesma função centralizada
213
  pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
214
 
215
- logger.info("LTX POOL MANAGER: Iniciando passe de refinamento (denoise)...")
216
 
217
  pipeline_to_call = worker_to_use.pipeline.video_pipeline if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline) else worker_to_use.pipeline
218
-
219
  result = pipeline_to_call(**pipeline_params).images
220
  return result, None
221
 
 
105
  self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
106
  return worker
107
 
 
108
  def _prepare_pipeline_params(self, worker: LtxWorker, **kwargs) -> dict:
109
  """Prepara o dicionário de parâmetros para a pipeline, tratando casos especiais como modelos destilados."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  pipeline_params = {
111
+ "height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
112
  "frame_rate": kwargs.get('video_fps', 24),
113
  "generator": torch.Generator(device=worker.device).manual_seed(int(time.time()) + kwargs.get('current_fragment_index', 0)),
114
  "is_video": True, "vae_per_channel_normalize": True,
115
+ "prompt": kwargs.get('motion_prompt', ""), "negative_prompt": kwargs.get('negative_prompt', "blurry, distorted, static, bad quality"),
116
+ "guidance_scale": kwargs.get('guidance_scale', 1.0), "stg_scale": kwargs.get('stg_scale', 0.0),
117
+ "rescaling_scale": kwargs.get('rescaling_scale', 0.15), "num_inference_steps": kwargs.get('num_inference_steps', 20),
118
  "output_type": "latent"
119
  }
120
 
121
+ if 'latents' in kwargs:
122
+ pipeline_params["latents"] = kwargs['latents'].to(worker.device, dtype=worker.pipeline.transformer.dtype)
123
+ if 'strength' in kwargs:
124
+ pipeline_params["strength"] = kwargs['strength']
125
+ if 'conditioning_items_data' in kwargs:
126
+ final_conditioning_items = []
127
+ for item in kwargs['conditioning_items_data']:
128
+ item.latent_tensor = item.latent_tensor.to(worker.device)
129
+ final_conditioning_items.append(item)
130
  pipeline_params["conditioning_items"] = final_conditioning_items
131
 
 
 
132
  if worker.is_distilled:
133
  logger.info(f"Worker {worker.device} está usando um modelo destilado. Usando timesteps fixos.")
134
  fixed_timesteps = worker.config.get("first_pass", {}).get("timesteps")
135
  pipeline_params["timesteps"] = fixed_timesteps
136
  if fixed_timesteps:
137
  pipeline_params["num_inference_steps"] = len(fixed_timesteps)
138
+
 
 
 
 
139
  return pipeline_params
140
 
 
141
  def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
142
  worker_to_use = self._get_next_worker()
143
  try:
144
+ # [CORREÇÃO] A lógica de padding é específica para a geração do zero.
145
  height, width = kwargs['height'], kwargs['width']
146
  padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
147
  padding_vals = calculate_padding(height, width, padded_h, padded_w)
148
+ kwargs['height'], kwargs['width'] = padded_h, padded_w
 
 
 
149
 
 
150
  pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
151
 
152
+ logger.info(f"Iniciando GERAÇÃO em {worker_to_use.device} com shape {padded_w}x{padded_h}")
153
+
154
  if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline):
155
  result = worker_to_use.pipeline.video_pipeline(**pipeline_params).images
156
  else:
 
165
  with torch.cuda.device(worker_to_use.device):
166
  gc.collect(); torch.cuda.empty_cache()
167
 
 
168
  def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
169
  worker_to_use = self._get_next_worker()
170
  try:
171
+ # [CORREÇÃO] A lógica de dimensionamento para refinamento deriva da forma do latente.
172
+ _b, _c, _f, latent_h, latent_w = latents_to_refine.shape
173
+ vae_scale_factor = worker_to_use.pipeline.vae_scale_factor
174
+
175
+ # Garante que as dimensões correspondam EXATAMENTE ao latente fornecido.
176
+ kwargs['height'] = latent_h * vae_scale_factor
177
+ kwargs['width'] = latent_w * vae_scale_factor
178
+ kwargs['video_total_frames'] = kwargs.get('video_total_frames', _f * worker_to_use.pipeline.video_scale_factor)
179
  kwargs['latents'] = latents_to_refine
180
  kwargs['strength'] = kwargs.get('denoise_strength', 0.4)
181
  kwargs['num_inference_steps'] = int(kwargs.get('refine_steps', 10))
182
 
 
183
  pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
184
 
185
+ logger.info(f"Iniciando REFINAMENTO em {worker_to_use.device} com shape {kwargs['width']}x{kwargs['height']}")
186
 
187
  pipeline_to_call = worker_to_use.pipeline.video_pipeline if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline) else worker_to_use.pipeline
 
188
  result = pipeline_to_call(**pipeline_params).images
189
  return result, None
190