Spaces:
Paused
Paused
Update api/ltx/ltx_aduc_pipeline.py
Browse files- api/ltx/ltx_aduc_pipeline.py +106 -18
api/ltx/ltx_aduc_pipeline.py
CHANGED
|
@@ -222,17 +222,93 @@ class LtxAducPipeline:
|
|
| 222 |
# --- UNIDADES DE TRABALHO E HELPERS INTERNOS ---
|
| 223 |
# ==========================================================================
|
| 224 |
|
| 225 |
-
def _log_conditioning_items(self, items: List[
|
| 226 |
-
"""
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
@log_function_io
|
| 232 |
def _generate_single_chunk_low(self, **kwargs) -> Optional[torch.Tensor]:
|
| 233 |
"""[WORKER] Calls the pipeline to generate a single chunk of latents."""
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
@log_function_io
|
| 238 |
def _finalize_generation(self, final_latents: torch.Tensor, base_filename: str, seed: int) -> Tuple[str, str]:
|
|
@@ -246,30 +322,42 @@ class LtxAducPipeline:
|
|
| 246 |
final_latents, decode_timestep=float(self.config.get("decode_timestep", 0.05))
|
| 247 |
)
|
| 248 |
video_path = self._save_and_log_video(pixel_tensor, f"{base_filename}_{seed}")
|
| 249 |
-
return str(video_path), str(final_latents_path)
|
| 250 |
-
|
| 251 |
def _apply_ui_overrides(self, config_dict: Dict, overrides: Dict):
|
| 252 |
"""Applies advanced settings from the UI to a config dictionary."""
|
| 253 |
-
#
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
def _save_and_log_video(self, pixel_tensor: torch.Tensor, base_filename: str) -> Path:
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
def _apply_precision_policy(self):
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
def _align(self, dim: int, alignment: int = FRAMES_ALIGNMENT, alignment_rule: str = 'default') -> int:
|
| 266 |
-
"""Aligns a dimension
|
| 267 |
if alignment_rule == 'n*8+1':
|
| 268 |
return ((dim - 1) // alignment) * alignment + 1
|
| 269 |
return ((dim - 1) // alignment + 1) * alignment
|
| 270 |
|
| 271 |
def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int:
|
| 272 |
num_frames = int(round(duration_s * DEFAULT_FPS))
|
|
|
|
| 273 |
aligned_frames = self._align(num_frames, alignment=FRAMES_ALIGNMENT)
|
| 274 |
return max(aligned_frames, min_frames)
|
| 275 |
|
|
|
|
| 222 |
# --- UNIDADES DE TRABALHO E HELPERS INTERNOS ---
|
| 223 |
# ==========================================================================
|
| 224 |
|
| 225 |
+
def _log_conditioning_items(self, items: List[ConditioningItem]):
|
| 226 |
+
"""
|
| 227 |
+
Logs detailed information about a list of ConditioningItem objects.
|
| 228 |
+
This is a dedicated debug helper function.
|
| 229 |
+
"""
|
| 230 |
+
# Só imprime o log se o nível de logging for DEBUG
|
| 231 |
+
if logging.getLogger().isEnabledFor(logging.INFO):
|
| 232 |
+
log_str = ["\n" + "="*25 + " INFO: Conditioning Items " + "="*25]
|
| 233 |
+
if not items:
|
| 234 |
+
log_str.append(" -> Lista de conditioning_items está vazia.")
|
| 235 |
+
else:
|
| 236 |
+
for i, item in enumerate(items):
|
| 237 |
+
if hasattr(item, 'media_item') and isinstance(item.media_item, torch.Tensor):
|
| 238 |
+
t = item.media_item
|
| 239 |
+
log_str.append(
|
| 240 |
+
f" -> Item [{i}]: "
|
| 241 |
+
f"Tensor(shape={list(t.shape)}, "
|
| 242 |
+
f"device='{t.device}', "
|
| 243 |
+
f"dtype={t.dtype}), "
|
| 244 |
+
f"Target Frame = {item.media_frame_number}, "
|
| 245 |
+
f"Strength = {item.conditioning_strength:.2f}"
|
| 246 |
+
)
|
| 247 |
+
else:
|
| 248 |
+
log_str.append(f" -> Item [{i}]: Não contém um tensor válido.")
|
| 249 |
+
log_str.append("="*75 + "\n")
|
| 250 |
+
|
| 251 |
+
# Usa o logger de debug para imprimir a mensagem completa
|
| 252 |
+
logging.info("\n".join(log_str))
|
| 253 |
|
| 254 |
@log_function_io
|
| 255 |
def _generate_single_chunk_low(self, **kwargs) -> Optional[torch.Tensor]:
|
| 256 |
"""[WORKER] Calls the pipeline to generate a single chunk of latents."""
|
| 257 |
+
height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width']))
|
| 258 |
+
downscale_factor = self.config.get("downscale_factor", 0.6666666)
|
| 259 |
+
vae_scale_factor = self.pipeline.vae_scale_factor
|
| 260 |
+
downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor)
|
| 261 |
+
downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor)
|
| 262 |
+
|
| 263 |
+
# 1. Começa com a configuração padrão
|
| 264 |
+
first_pass_config = self.config.get("first_pass", {}).copy()
|
| 265 |
+
|
| 266 |
+
# 2. Aplica os overrides da UI, se existirem
|
| 267 |
+
if kwargs.get("ltx_configs_override"):
|
| 268 |
+
self._apply_ui_overrides(first_pass_config, kwargs.get("ltx_configs_override"))
|
| 269 |
+
|
| 270 |
+
# 3. Monta o dicionário de argumentos SEM conditioning_items primeiro
|
| 271 |
+
pipeline_kwargs = {
|
| 272 |
+
"prompt": kwargs['prompt'],
|
| 273 |
+
"negative_prompt": kwargs['negative_prompt'],
|
| 274 |
+
"height": downscaled_height,
|
| 275 |
+
"width": downscaled_width,
|
| 276 |
+
"num_frames": kwargs['num_frames'],
|
| 277 |
+
"frame_rate": int(DEFAULT_FPS),
|
| 278 |
+
"generator": torch.Generator(device=self.main_device).manual_seed(kwargs['seed']),
|
| 279 |
+
"output_type": "latent",
|
| 280 |
+
#"conditioning_items": conditioning_items if conditioning_items else None,
|
| 281 |
+
"media_items": None,
|
| 282 |
+
"decode_timestep": self.config["decode_timestep"],
|
| 283 |
+
"decode_noise_scale": self.config["decode_noise_scale"],
|
| 284 |
+
"stochastic_sampling": self.config["stochastic_sampling"],
|
| 285 |
+
"image_cond_noise_scale": 0.01,
|
| 286 |
+
"is_video": True,
|
| 287 |
+
"vae_per_channel_normalize": True,
|
| 288 |
+
"mixed_precision": (self.config["precision"] == "mixed_precision"),
|
| 289 |
+
"offload_to_cpu": False,
|
| 290 |
+
"enhance_prompt": False,
|
| 291 |
+
#"skip_layer_strategy": SkipLayerStrategy.AttentionValues,
|
| 292 |
+
**first_pass_config
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
# --- Bloco de Logging para Depuração ---
|
| 296 |
+
# 4. Loga os argumentos do pipeline (sem os tensores de condição)
|
| 297 |
+
logging.info(f"\n[Info] Pipeline Arguments (BASE):\n {json.dumps(pipeline_kwargs, indent=2, default=str)}\n")
|
| 298 |
+
|
| 299 |
+
# Loga os conditioning_items separadamente com a nossa função helper
|
| 300 |
+
conditioning_items_list = kwargs.get('conditioning_items')
|
| 301 |
+
self._log_conditioning_items(conditioning_items_list)
|
| 302 |
+
# --- Fim do Bloco de Logging ---
|
| 303 |
+
|
| 304 |
+
# 5. Adiciona os conditioning_items ao dicionário
|
| 305 |
+
pipeline_kwargs['conditioning_items'] = conditioning_items_list
|
| 306 |
+
|
| 307 |
+
# 6. Executa o pipeline com o dicionário completo
|
| 308 |
+
with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type):
|
| 309 |
+
latents_raw = self.pipeline(**pipeline_kwargs).images
|
| 310 |
+
|
| 311 |
+
return latents_raw.to(self.main_device)
|
| 312 |
|
| 313 |
@log_function_io
|
| 314 |
def _finalize_generation(self, final_latents: torch.Tensor, base_filename: str, seed: int) -> Tuple[str, str]:
|
|
|
|
| 322 |
final_latents, decode_timestep=float(self.config.get("decode_timestep", 0.05))
|
| 323 |
)
|
| 324 |
video_path = self._save_and_log_video(pixel_tensor, f"{base_filename}_{seed}")
|
| 325 |
+
return str(video_path), str(final_latents_path)
|
| 326 |
+
|
| 327 |
def _apply_ui_overrides(self, config_dict: Dict, overrides: Dict):
|
| 328 |
"""Applies advanced settings from the UI to a config dictionary."""
|
| 329 |
+
# Override step counts
|
| 330 |
+
for key in ["num_inference_steps", "skip_initial_inference_steps", "skip_final_inference_steps"]:
|
| 331 |
+
ui_value = overrides.get(key)
|
| 332 |
+
if ui_value and ui_value > 0:
|
| 333 |
+
config_dict[key] = ui_value
|
| 334 |
+
logging.info(f"Override: '{key}' set to {ui_value} by UI.")
|
| 335 |
+
|
| 336 |
def _save_and_log_video(self, pixel_tensor: torch.Tensor, base_filename: str) -> Path:
|
| 337 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 338 |
+
temp_path = os.path.join(temp_dir, f"{base_filename}.mp4")
|
| 339 |
+
video_encode_tool_singleton.save_video_from_tensor(pixel_tensor, temp_path, fps=DEFAULT_FPS)
|
| 340 |
+
final_path = RESULTS_DIR / f"{base_filename}.mp4"
|
| 341 |
+
shutil.move(temp_path, final_path)
|
| 342 |
+
logging.info(f"Video saved successfully to: {final_path}")
|
| 343 |
+
return final_path
|
| 344 |
|
| 345 |
def _apply_precision_policy(self):
|
| 346 |
+
precision = str(self.config.get("precision", "bfloat16")).lower()
|
| 347 |
+
if precision in ["float8_e4m3fn", "bfloat16"]: self.runtime_autocast_dtype = torch.bfloat16
|
| 348 |
+
elif precision == "mixed_precision": self.runtime_autocast_dtype = torch.float16
|
| 349 |
+
else: self.runtime_autocast_dtype = torch.float32
|
| 350 |
+
logging.info(f"Runtime precision policy set for autocast: {self.runtime_autocast_dtype}")
|
| 351 |
|
| 352 |
def _align(self, dim: int, alignment: int = FRAMES_ALIGNMENT, alignment_rule: str = 'default') -> int:
|
| 353 |
+
"""Aligns a dimension to the nearest multiple of `alignment`."""
|
| 354 |
if alignment_rule == 'n*8+1':
|
| 355 |
return ((dim - 1) // alignment) * alignment + 1
|
| 356 |
return ((dim - 1) // alignment + 1) * alignment
|
| 357 |
|
| 358 |
def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int:
|
| 359 |
num_frames = int(round(duration_s * DEFAULT_FPS))
|
| 360 |
+
# Para a duração total, sempre arredondamos para cima para o múltiplo de 8 mais próximo
|
| 361 |
aligned_frames = self._align(num_frames, alignment=FRAMES_ALIGNMENT)
|
| 362 |
return max(aligned_frames, min_frames)
|
| 363 |
|