EuuIia commited on
Commit
6077505
·
verified ·
1 Parent(s): c344381

Update api/ltx_server.py

Browse files
Files changed (1) hide show
  1. api/ltx_server.py +89 -25
api/ltx_server.py CHANGED
@@ -577,39 +577,103 @@ class VideoService:
577
  latents = None
578
  multi_scale_pipeline = None
579
 
 
 
580
  try:
581
  if improve_texture:
582
  if not self.latent_upsampler:
583
  raise ValueError("Upscaler espacial não carregado.")
584
- print("[DEBUG] Multi-escala: construindo pipeline...")
585
- multi_scale_pipeline = LTXMultiScalePipeline(self.pipeline, self.latent_upsampler)
 
 
586
  first_pass_args = self.config.get("first_pass", {}).copy()
587
- first_pass_args["guidance_scale"] = float(guidance_scale)
588
- second_pass_args = self.config.get("second_pass", {}).copy()
589
- second_pass_args["guidance_scale"] = float(guidance_scale)
590
-
591
- multi_scale_call_kwargs = call_kwargs.copy()
592
- multi_scale_call_kwargs.update(
593
- {
594
- "downscale_factor": self.config["downscale_factor"],
595
- "first_pass": first_pass_args,
596
- "second_pass": second_pass_args,
597
- }
598
- )
599
- print("[DEBUG] Chamando multi_scale_pipeline...")
600
- t_ms = time.perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
  ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
602
  with ctx:
603
- result = multi_scale_pipeline(**multi_scale_call_kwargs)
604
- print(f"[DEBUG] multi_scale_pipeline tempo={time.perf_counter()-t_ms:.3f}s")
 
 
 
 
 
 
605
 
606
- if hasattr(result, "latents"):
607
- latents = result.latents
608
- elif hasattr(result, "images") and isinstance(result.images, torch.Tensor):
609
- latents = result.images
610
- else:
611
- latents = result
612
- print(f"[DEBUG] Latentes (multi-escala): shape={tuple(latents.shape)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  else:
614
  single_pass_kwargs = call_kwargs.copy()
615
  first_pass_config = self.config.get("first_pass", {})
 
577
  latents = None
578
  multi_scale_pipeline = None
579
 
580
+ # ltx_server.py (dentro da função generate)
581
+
582
  try:
583
  if improve_texture:
584
  if not self.latent_upsampler:
585
  raise ValueError("Upscaler espacial não carregado.")
586
+
587
+ print("[DEBUG] Multi-escala: Iniciando Passo 1 (geração de latentes base).")
588
+
589
+ # 1. Configurar e executar o primeiro passo
590
  first_pass_args = self.config.get("first_pass", {}).copy()
591
+ first_pass_kwargs = call_kwargs.copy()
592
+ first_pass_kwargs.update({
593
+ "guidance_scale": float(guidance_scale),
594
+ "stg_scale": first_pass_args.get("stg_scale"),
595
+ "rescaling_scale": first_pass_args.get("rescaling_scale"),
596
+ "skip_block_list": first_pass_args.get("skip_block_list"),
597
+ })
598
+ schedule = first_pass_args.get("timesteps") or first_pass_args.get("guidance_timesteps")
599
+ if schedule:
600
+ first_pass_kwargs["timesteps"] = schedule
601
+ first_pass_kwargs["guidance_timesteps"] = schedule
602
+
603
+ downscale_factor = self.config.get("downscale_factor", 2)
604
+ original_height = first_pass_kwargs["height"]
605
+ original_width = first_pass_kwargs["width"]
606
+ divisor = 24
607
+
608
+ target_height_p1 = original_height // downscale_factor
609
+ height_p1 = round(target_height_p1 / divisor) * divisor
610
+ if height_p1 == 0: height_p1 = divisor
611
+ first_pass_kwargs["height"] = height_p1
612
+
613
+ target_width_p1 = original_width // downscale_factor
614
+ width_p1 = round(target_width_p1 / divisor) * divisor
615
+ if width_p1 == 0: width_p1 = divisor
616
+ first_pass_kwargs["width"] = width_p1
617
+
618
+ print(f"[DEBUG] Passo 1: Dimensões reduzidas e ajustadas para {height_p1}x{width_p1}")
619
+
620
  ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
621
  with ctx:
622
+ first_pass_result = self.pipeline(**first_pass_kwargs)
623
+
624
+ latents_low_res = first_pass_result.latents if hasattr(first_pass_result, "latents") else first_pass_result
625
+ log_tensor_info(latents_low_res, "Latentes (Passo 1)")
626
+
627
+ del first_pass_result
628
+ gc.collect()
629
+ if self.device == "cuda": torch.cuda.empty_cache()
630
 
631
+ # 2. Upscale dos latentes
632
+ print("[DEBUG] Multi-escala: Fazendo upscale dos latentes com latent_upsampler.")
633
+ with ctx:
634
+ # Chamada posicional confirmada pelo código-fonte
635
+ latents_high_res = self.latent_upsampler(latents_low_res)
636
+
637
+ log_tensor_info(latents_high_res, "Latentes (Pós-Upscale)")
638
+ del latents_low_res
639
+ gc.collect()
640
+ if self.device == "cuda": torch.cuda.empty_cache()
641
+
642
+ # 3. Configurar e executar o segundo passo
643
+ print("[DEBUG] Multi-escala: Iniciando Passo 2 (refinamento em alta resolução).")
644
+ second_pass_args = self.config.get("second_pass", {}).copy()
645
+ second_pass_kwargs = call_kwargs.copy()
646
+
647
+ # ==================== LÓGICA DE DIMENSÃO FINAL ====================
648
+ # As dimensões do Passo 2 DEVEM ser o dobro das dimensões do Passo 1,
649
+ # para corresponder à saída do upsampler.
650
+ height_p2 = height_p1 * 2
651
+ width_p2 = width_p1 * 2
652
+ second_pass_kwargs["height"] = height_p2
653
+ second_pass_kwargs["width"] = width_p2
654
+ print(f"[DEBUG] Passo 2: Dimensões definidas para {height_p2}x{width_p2} para corresponder ao upscale.")
655
+ # =================================================================
656
+
657
+ second_pass_kwargs.update({
658
+ "guidance_scale": float(guidance_scale),
659
+ "stg_scale": second_pass_args.get("stg_scale"),
660
+ "rescaling_scale": second_pass_args.get("rescaling_scale"),
661
+ "skip_block_list": second_pass_args.get("skip_block_list"),
662
+ })
663
+
664
+ schedule_p2 = second_pass_args.get("timesteps") or second_pass_args.get("guidance_timesteps")
665
+ if schedule_p2:
666
+ second_pass_kwargs["timesteps"] = schedule_p2
667
+ second_pass_kwargs["guidance_timesteps"] = schedule_p2
668
+
669
+ second_pass_kwargs["latents"] = latents_high_res
670
+
671
+ with ctx:
672
+ second_pass_result = self.pipeline(**second_pass_kwargs)
673
+
674
+ latents = second_pass_result.latents if hasattr(second_pass_result, "latents") else second_pass_result
675
+ log_tensor_info(latents, "Latentes Finais (Passo 2)")
676
+
677
  else:
678
  single_pass_kwargs = call_kwargs.copy()
679
  first_pass_config = self.config.get("first_pass", {})