Update api/ltx_server.py
Browse files- api/ltx_server.py +39 -17
api/ltx_server.py
CHANGED
|
@@ -544,6 +544,8 @@ class VideoService:
|
|
| 544 |
multi_scale_pipeline = None
|
| 545 |
|
| 546 |
try:
|
|
|
|
|
|
|
| 547 |
if improve_texture:
|
| 548 |
if not self.latent_upsampler:
|
| 549 |
raise ValueError("Upscaler espacial não carregado.")
|
|
@@ -556,28 +558,26 @@ class VideoService:
|
|
| 556 |
first_pass_args = self.config.get("first_pass", {}).copy()
|
| 557 |
first_pass_kwargs = call_kwargs.copy()
|
| 558 |
|
| 559 |
-
# Carrega os parâmetros do config, incluindo listas de timesteps e guidance
|
| 560 |
first_pass_kwargs.update({
|
| 561 |
"guidance_scale": first_pass_args.get("guidance_scale", guidance_scale),
|
| 562 |
"stg_scale": first_pass_args.get("stg_scale"),
|
| 563 |
"rescaling_scale": first_pass_args.get("rescaling_scale"),
|
| 564 |
"skip_block_list": first_pass_args.get("skip_block_list"),
|
| 565 |
"guidance_timesteps": first_pass_args.get("guidance_timesteps"),
|
| 566 |
-
"timesteps": first_pass_args.get("timesteps")
|
|
|
|
| 567 |
})
|
| 568 |
print(f"[DEBUG] Passo 1: Parâmetros do config carregados.")
|
| 569 |
|
| 570 |
-
# Calcula as dimensões de baixa resolução
|
| 571 |
downscale_factor = self.config.get("downscale_factor", 2)
|
| 572 |
original_height = first_pass_kwargs["height"]
|
| 573 |
original_width = first_pass_kwargs["width"]
|
| 574 |
divisor = 24
|
| 575 |
|
| 576 |
-
# Para downscale_factor < 1 (ex: 0.666), a lógica é multiplicar
|
| 577 |
if downscale_factor < 1.0:
|
| 578 |
target_height_p1 = original_height * downscale_factor
|
| 579 |
target_width_p1 = original_width * downscale_factor
|
| 580 |
-
else:
|
| 581 |
target_height_p1 = original_height // downscale_factor
|
| 582 |
target_width_p1 = original_width // downscale_factor
|
| 583 |
|
|
@@ -590,7 +590,10 @@ class VideoService:
|
|
| 590 |
first_pass_kwargs["width"] = width_p1
|
| 591 |
|
| 592 |
print(f"[DEBUG] Passo 1: Dimensões reduzidas e ajustadas para {height_p1}x{width_p1}")
|
| 593 |
-
|
|
|
|
|
|
|
|
|
|
| 594 |
with ctx:
|
| 595 |
first_pass_result = self.pipeline(**first_pass_kwargs)
|
| 596 |
|
|
@@ -601,6 +604,7 @@ class VideoService:
|
|
| 601 |
gc.collect()
|
| 602 |
if self.device == "cuda": torch.cuda.empty_cache()
|
| 603 |
|
|
|
|
| 604 |
# --- PASSO INTERMEDIÁRIO: UPSCALE DOS LATENTES ---
|
| 605 |
print("[DEBUG] Multi-escala: Fazendo upscale dos latentes com latent_upsampler.")
|
| 606 |
with ctx:
|
|
@@ -616,38 +620,52 @@ class VideoService:
|
|
| 616 |
second_pass_args = self.config.get("second_pass", {}).copy()
|
| 617 |
second_pass_kwargs = call_kwargs.copy()
|
| 618 |
|
| 619 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
second_pass_kwargs.update({
|
| 621 |
"guidance_scale": second_pass_args.get("guidance_scale", guidance_scale),
|
| 622 |
"stg_scale": second_pass_args.get("stg_scale"),
|
| 623 |
"rescaling_scale": second_pass_args.get("rescaling_scale"),
|
| 624 |
"skip_block_list": second_pass_args.get("skip_block_list"),
|
| 625 |
-
"
|
| 626 |
-
"timesteps": second_pass_args.get("timesteps")
|
| 627 |
})
|
| 628 |
-
print(f"[DEBUG] Passo 2: Parâmetros do config carregados.")
|
| 629 |
|
| 630 |
-
# Define as dimensões de alta resolução com base no upscale
|
| 631 |
-
# O upsampler espacial dobra a resolução, então multiplicamos por 2
|
| 632 |
height_p2 = height_p1 * 2
|
| 633 |
width_p2 = width_p1 * 2
|
| 634 |
second_pass_kwargs["height"] = height_p2
|
| 635 |
second_pass_kwargs["width"] = width_p2
|
| 636 |
print(f"[DEBUG] Passo 2: Dimensões definidas para {height_p2}x{width_p2}")
|
| 637 |
|
| 638 |
-
# A entrada para o refinamento são os latentes que sofreram upscale
|
| 639 |
second_pass_kwargs["latents"] = latents_high_res
|
| 640 |
-
|
| 641 |
-
# Garante que 'strength' não seja passado, pois estamos controlando via timesteps
|
| 642 |
-
if "strength" in second_pass_kwargs:
|
| 643 |
-
del second_pass_kwargs["strength"]
|
| 644 |
|
|
|
|
|
|
|
| 645 |
with ctx:
|
| 646 |
second_pass_result = self.pipeline(**second_pass_kwargs)
|
| 647 |
|
| 648 |
latents = second_pass_result.images
|
| 649 |
log_tensor_info(latents, "Latentes Finais (Passo 2)")
|
| 650 |
|
|
|
|
| 651 |
|
| 652 |
else:
|
| 653 |
single_pass_kwargs = call_kwargs.copy()
|
|
@@ -671,6 +689,10 @@ class VideoService:
|
|
| 671 |
print("\n[INFO] Executando pipeline de etapa única...")
|
| 672 |
t_sp = time.perf_counter()
|
| 673 |
ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
with ctx:
|
| 675 |
result = self.pipeline(**single_pass_kwargs)
|
| 676 |
print(f"[DEBUG] single-pass tempo={time.perf_counter()-t_sp:.3f}s")
|
|
|
|
| 544 |
multi_scale_pipeline = None
|
| 545 |
|
| 546 |
try:
|
| 547 |
+
# Em ltx_server.py, substitua o bloco 'if improve_texture:' por este:
|
| 548 |
+
|
| 549 |
if improve_texture:
|
| 550 |
if not self.latent_upsampler:
|
| 551 |
raise ValueError("Upscaler espacial não carregado.")
|
|
|
|
| 558 |
first_pass_args = self.config.get("first_pass", {}).copy()
|
| 559 |
first_pass_kwargs = call_kwargs.copy()
|
| 560 |
|
|
|
|
| 561 |
first_pass_kwargs.update({
|
| 562 |
"guidance_scale": first_pass_args.get("guidance_scale", guidance_scale),
|
| 563 |
"stg_scale": first_pass_args.get("stg_scale"),
|
| 564 |
"rescaling_scale": first_pass_args.get("rescaling_scale"),
|
| 565 |
"skip_block_list": first_pass_args.get("skip_block_list"),
|
| 566 |
"guidance_timesteps": first_pass_args.get("guidance_timesteps"),
|
| 567 |
+
"timesteps": first_pass_args.get("timesteps"),
|
| 568 |
+
"num_inference_steps": first_pass_args.get("num_inference_steps", 20)
|
| 569 |
})
|
| 570 |
print(f"[DEBUG] Passo 1: Parâmetros do config carregados.")
|
| 571 |
|
|
|
|
| 572 |
downscale_factor = self.config.get("downscale_factor", 2)
|
| 573 |
original_height = first_pass_kwargs["height"]
|
| 574 |
original_width = first_pass_kwargs["width"]
|
| 575 |
divisor = 24
|
| 576 |
|
|
|
|
| 577 |
if downscale_factor < 1.0:
|
| 578 |
target_height_p1 = original_height * downscale_factor
|
| 579 |
target_width_p1 = original_width * downscale_factor
|
| 580 |
+
else:
|
| 581 |
target_height_p1 = original_height // downscale_factor
|
| 582 |
target_width_p1 = original_width // downscale_factor
|
| 583 |
|
|
|
|
| 590 |
first_pass_kwargs["width"] = width_p1
|
| 591 |
|
| 592 |
print(f"[DEBUG] Passo 1: Dimensões reduzidas e ajustadas para {height_p1}x{width_p1}")
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
print(f"[DEBUG] first_pass_kwargs {first_pass_kwargs}")
|
| 596 |
+
|
| 597 |
with ctx:
|
| 598 |
first_pass_result = self.pipeline(**first_pass_kwargs)
|
| 599 |
|
|
|
|
| 604 |
gc.collect()
|
| 605 |
if self.device == "cuda": torch.cuda.empty_cache()
|
| 606 |
|
| 607 |
+
|
| 608 |
# --- PASSO INTERMEDIÁRIO: UPSCALE DOS LATENTES ---
|
| 609 |
print("[DEBUG] Multi-escala: Fazendo upscale dos latentes com latent_upsampler.")
|
| 610 |
with ctx:
|
|
|
|
| 620 |
second_pass_args = self.config.get("second_pass", {}).copy()
|
| 621 |
second_pass_kwargs = call_kwargs.copy()
|
| 622 |
|
| 623 |
+
# Lógica de refinamento robusta usando 'strength'
|
| 624 |
+
strength = second_pass_args.get("strength", second_pass_args.get("denoising_strength"))
|
| 625 |
+
if strength is None and "skip_initial_inference_steps" in second_pass_args:
|
| 626 |
+
total_steps = second_pass_args.get("num_inference_steps", 30)
|
| 627 |
+
skip_steps = second_pass_args.get("skip_initial_inference_steps", 0)
|
| 628 |
+
if total_steps > 0:
|
| 629 |
+
strength = 1.0 - (skip_steps / total_steps)
|
| 630 |
+
elif strength is None and "timesteps" in second_pass_args:
|
| 631 |
+
# Se temos timesteps explícitos, o strength é o primeiro valor da lista
|
| 632 |
+
# (já que a lista começa "tarde", ex: [0.9, 0.7...])
|
| 633 |
+
strength = second_pass_args["timesteps"][0]
|
| 634 |
+
elif strength is None:
|
| 635 |
+
strength = 0.5 # Fallback seguro
|
| 636 |
+
|
| 637 |
+
second_pass_kwargs["strength"] = strength
|
| 638 |
+
print(f"[DEBUG] Passo 2: Usando 'strength'={strength:.3f} para o refinamento.")
|
| 639 |
+
|
| 640 |
+
# Removemos timesteps para que a pipeline os calcule a partir do strength
|
| 641 |
+
if "timesteps" in second_pass_kwargs: del second_pass_kwargs["timesteps"]
|
| 642 |
+
if "guidance_timesteps" in second_pass_kwargs: del second_pass_kwargs["guidance_timesteps"]
|
| 643 |
+
|
| 644 |
second_pass_kwargs.update({
|
| 645 |
"guidance_scale": second_pass_args.get("guidance_scale", guidance_scale),
|
| 646 |
"stg_scale": second_pass_args.get("stg_scale"),
|
| 647 |
"rescaling_scale": second_pass_args.get("rescaling_scale"),
|
| 648 |
"skip_block_list": second_pass_args.get("skip_block_list"),
|
| 649 |
+
"num_inference_steps": second_pass_args.get("num_inference_steps", 20)
|
|
|
|
| 650 |
})
|
|
|
|
| 651 |
|
|
|
|
|
|
|
| 652 |
height_p2 = height_p1 * 2
|
| 653 |
width_p2 = width_p1 * 2
|
| 654 |
second_pass_kwargs["height"] = height_p2
|
| 655 |
second_pass_kwargs["width"] = width_p2
|
| 656 |
print(f"[DEBUG] Passo 2: Dimensões definidas para {height_p2}x{width_p2}")
|
| 657 |
|
|
|
|
| 658 |
second_pass_kwargs["latents"] = latents_high_res
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
|
| 660 |
+
print(f"[DEBUG] second_pass_kwargs {second_pass_kwargs}")
|
| 661 |
+
|
| 662 |
with ctx:
|
| 663 |
second_pass_result = self.pipeline(**second_pass_kwargs)
|
| 664 |
|
| 665 |
latents = second_pass_result.images
|
| 666 |
log_tensor_info(latents, "Latentes Finais (Passo 2)")
|
| 667 |
|
| 668 |
+
# --- FIM DA IMPLEMENTAÇÃO LIMPA ---
|
| 669 |
|
| 670 |
else:
|
| 671 |
single_pass_kwargs = call_kwargs.copy()
|
|
|
|
| 689 |
print("\n[INFO] Executando pipeline de etapa única...")
|
| 690 |
t_sp = time.perf_counter()
|
| 691 |
ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
|
| 692 |
+
|
| 693 |
+
print(f"[DEBUG] single_pass_kwargs {single_pass_kwargs}")
|
| 694 |
+
|
| 695 |
+
|
| 696 |
with ctx:
|
| 697 |
result = self.pipeline(**single_pass_kwargs)
|
| 698 |
print(f"[DEBUG] single-pass tempo={time.perf_counter()-t_sp:.3f}s")
|