EuuIia commited on
Commit
1f0a06a
·
verified ·
1 Parent(s): b9b0da3

Update api/ltx_server.py

Browse files
Files changed (1) hide show
  1. api/ltx_server.py +122 -59
api/ltx_server.py CHANGED
@@ -396,40 +396,6 @@ class VideoService:
396
  return out
397
 
398
 
399
- def _dividir_latentes_em_partes(self, latents_brutos, quantidade: int):
400
- """
401
- Divide um tensor de latentes em `quantidade` partes e retorna uma lista de clones.
402
-
403
- Args:
404
- latents_brutos: tensor [B, C, T, H, W]
405
- quantidade: número de partes que queremos dividir
406
-
407
- Returns:
408
- List[Tensor]: lista de `quantidade` partes, cada uma cloneada
409
- """
410
- total = latents_brutos.shape[2] # dimensão temporal
411
- partes = []
412
-
413
- if quantidade <= 1 or quantidade > total:
414
- return [latents_brutos.clone()]
415
-
416
- # calcular tamanho aproximado de cada parte
417
- step = total // quantidade
418
- overlap = 0 # sobreposição mínima de 1 frame entre partes
419
-
420
- for i in range(quantidade):
421
- start = i * step
422
- end = start + step
423
- if i == quantidade - 1:
424
- end = total # última parte vai até o final
425
- else:
426
- end += overlap # sobreposição
427
- parte = latents_brutos[:, :, start-1:end+1, :, :].clone()
428
- partes.append(parte)
429
-
430
- return partes
431
-
432
-
433
  def _dividir_latentes(self, latents_brutos):
434
  total = latents_brutos.shape[2] # dimensão temporal (número de latentes)
435
 
@@ -578,38 +544,131 @@ class VideoService:
578
  multi_scale_pipeline = None
579
 
580
  try:
 
 
581
  if improve_texture:
582
  if not self.latent_upsampler:
583
  raise ValueError("Upscaler espacial não carregado.")
584
- print("[DEBUG] Multi-escala: construindo pipeline...")
585
- multi_scale_pipeline = LTXMultiScalePipeline(self.pipeline, self.latent_upsampler)
586
- first_pass_args = self.config.get("first_pass", {}).copy()
587
- first_pass_args["guidance_scale"] = float(guidance_scale)
588
- second_pass_args = self.config.get("second_pass", {}).copy()
589
- second_pass_args["guidance_scale"] = float(guidance_scale)
590
 
591
- multi_scale_call_kwargs = call_kwargs.copy()
592
- multi_scale_call_kwargs.update(
593
- {
594
- "downscale_factor": self.config["downscale_factor"],
595
- "first_pass": first_pass_args,
596
- "second_pass": second_pass_args,
597
- }
598
- )
599
- print("[DEBUG] Chamando multi_scale_pipeline...")
600
- t_ms = time.perf_counter()
601
  ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  with ctx:
603
- result = multi_scale_pipeline(**multi_scale_call_kwargs)
604
- print(f"[DEBUG] multi_scale_pipeline tempo={time.perf_counter()-t_ms:.3f}s")
 
 
 
 
 
 
605
 
606
- if hasattr(result, "latents"):
607
- latents = result.latents
608
- elif hasattr(result, "images") and isinstance(result.images, torch.Tensor):
609
- latents = result.images
610
- else:
611
- latents = result
612
- print(f"[DEBUG] Latentes (multi-escala): shape={tuple(latents.shape)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  else:
614
  single_pass_kwargs = call_kwargs.copy()
615
  first_pass_config = self.config.get("first_pass", {})
@@ -632,6 +691,10 @@ class VideoService:
632
  print("\n[INFO] Executando pipeline de etapa única...")
633
  t_sp = time.perf_counter()
634
  ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
 
 
 
 
635
  with ctx:
636
  result = self.pipeline(**single_pass_kwargs)
637
  print(f"[DEBUG] single-pass tempo={time.perf_counter()-t_sp:.3f}s")
 
396
  return out
397
 
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  def _dividir_latentes(self, latents_brutos):
400
  total = latents_brutos.shape[2] # dimensão temporal (número de latentes)
401
 
 
544
  multi_scale_pipeline = None
545
 
546
  try:
547
+ # Em ltx_server.py, substitua o bloco 'if improve_texture:' por este:
548
+
549
  if improve_texture:
550
  if not self.latent_upsampler:
551
  raise ValueError("Upscaler espacial não carregado.")
552
+
553
+ # --- INÍCIO DA IMPLEMENTAÇÃO LIMPA DOS 3 PASSOS ---
 
 
 
 
554
 
 
 
 
 
 
 
 
 
 
 
555
  ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
556
+
557
+ # --- PASSO 1: GERAÇÃO DE LATENTES EM BAIXA RESOLUÇÃO ---
558
+ print("[DEBUG] Multi-escala: Iniciando Passo 1 (geração de latentes base).")
559
+
560
+ first_pass_args = self.config.get("first_pass", {}).copy()
561
+ first_pass_kwargs = call_kwargs.copy()
562
+
563
+ first_pass_kwargs.update({
564
+ "guidance_scale": first_pass_args.get("guidance_scale", guidance_scale),
565
+ "stg_scale": first_pass_args.get("stg_scale"),
566
+ "rescaling_scale": first_pass_args.get("rescaling_scale"),
567
+ "skip_block_list": first_pass_args.get("skip_block_list"),
568
+ "guidance_timesteps": first_pass_args.get("guidance_timesteps"),
569
+ "timesteps": first_pass_args.get("timesteps"),
570
+ "num_inference_steps": first_pass_args.get("num_inference_steps", 20)
571
+ })
572
+ print(f"[DEBUG] Passo 1: Parâmetros do config carregados.")
573
+
574
+ downscale_factor = self.config.get("downscale_factor", 2)
575
+ original_height = first_pass_kwargs["height"]
576
+ original_width = first_pass_kwargs["width"]
577
+ divisor = 24
578
+
579
+ if downscale_factor < 1.0:
580
+ target_height_p1 = original_height * downscale_factor
581
+ target_width_p1 = original_width * downscale_factor
582
+ else:
583
+ target_height_p1 = original_height // downscale_factor
584
+ target_width_p1 = original_width // downscale_factor
585
+
586
+ height_p1 = round(target_height_p1 / divisor) * divisor
587
+ if height_p1 == 0: height_p1 = divisor
588
+ first_pass_kwargs["height"] = height_p1
589
+
590
+ width_p1 = round(target_width_p1 / divisor) * divisor
591
+ if width_p1 == 0: width_p1 = divisor
592
+ first_pass_kwargs["width"] = width_p1
593
+
594
+ print(f"[DEBUG] Passo 1: Dimensões reduzidas e ajustadas para {height_p1}x{width_p1}")
595
+
596
+
597
+ print(f"[DEBUG] first_pass_kwargs {first_pass_kwargs}")
598
+
599
  with ctx:
600
+ first_pass_result = self.pipeline(**first_pass_kwargs)
601
+
602
+ latents_low_res = first_pass_result.images
603
+ log_tensor_info(latents_low_res, "Latentes (Passo 1)")
604
+
605
+ del first_pass_result, first_pass_kwargs
606
+ gc.collect()
607
+ if self.device == "cuda": torch.cuda.empty_cache()
608
 
609
+
610
+ # --- PASSO INTERMEDIÁRIO: UPSCALE DOS LATENTES ---
611
+ print("[DEBUG] Multi-escala: Fazendo upscale dos latentes com latent_upsampler.")
612
+ with ctx:
613
+ latents_high_res = self.latent_upsampler(latents_low_res)
614
+
615
+ log_tensor_info(latents_high_res, "Latentes (Pós-Upscale)")
616
+ del latents_low_res
617
+ gc.collect()
618
+ if self.device == "cuda": torch.cuda.empty_cache()
619
+
620
+ # --- PASSO 2: REFINAMENTO EM ALTA RESOLUÇÃO ---
621
+ print("[DEBUG] Multi-escala: Iniciando Passo 2 (refinamento em alta resolução).")
622
+ second_pass_args = self.config.get("second_pass", {}).copy()
623
+ second_pass_kwargs = call_kwargs.copy()
624
+
625
+ # Lógica de refinamento robusta usando 'strength'
626
+ strength = second_pass_args.get("strength", second_pass_args.get("denoising_strength"))
627
+ if strength is None and "skip_initial_inference_steps" in second_pass_args:
628
+ total_steps = second_pass_args.get("num_inference_steps", 30)
629
+ skip_steps = second_pass_args.get("skip_initial_inference_steps", 0)
630
+ if total_steps > 0:
631
+ strength = 1.0 - (skip_steps / total_steps)
632
+ elif strength is None and "timesteps" in second_pass_args:
633
+ # Se temos timesteps explícitos, o strength é o primeiro valor da lista
634
+ # (já que a lista começa "tarde", ex: [0.9, 0.7...])
635
+ strength = second_pass_args["timesteps"][0]
636
+ elif strength is None:
637
+ strength = 0.5 # Fallback seguro
638
+
639
+ second_pass_kwargs["strength"] = strength
640
+ print(f"[DEBUG] Passo 2: Usando 'strength'={strength:.3f} para o refinamento.")
641
+
642
+ # Removemos timesteps para que a pipeline os calcule a partir do strength
643
+ if "timesteps" in second_pass_kwargs: del second_pass_kwargs["timesteps"]
644
+ if "guidance_timesteps" in second_pass_kwargs: del second_pass_kwargs["guidance_timesteps"]
645
+
646
+ second_pass_kwargs.update({
647
+ "guidance_scale": second_pass_args.get("guidance_scale", guidance_scale),
648
+ "stg_scale": second_pass_args.get("stg_scale"),
649
+ "rescaling_scale": second_pass_args.get("rescaling_scale"),
650
+ "skip_block_list": second_pass_args.get("skip_block_list"),
651
+ "num_inference_steps": second_pass_args.get("num_inference_steps", 20)
652
+ })
653
+
654
+ height_p2 = height_p1 * 2
655
+ width_p2 = width_p1 * 2
656
+ second_pass_kwargs["height"] = height_p2
657
+ second_pass_kwargs["width"] = width_p2
658
+ print(f"[DEBUG] Passo 2: Dimensões definidas para {height_p2}x{width_p2}")
659
+
660
+ second_pass_kwargs["latents"] = latents_high_res
661
+
662
+ print(f"[DEBUG] second_pass_kwargs {second_pass_kwargs}")
663
+
664
+ with ctx:
665
+ second_pass_result = self.pipeline(**second_pass_kwargs)
666
+
667
+ latents = second_pass_result.images
668
+ log_tensor_info(latents, "Latentes Finais (Passo 2)")
669
+
670
+ # --- FIM DA IMPLEMENTAÇÃO LIMPA ---
671
+
672
  else:
673
  single_pass_kwargs = call_kwargs.copy()
674
  first_pass_config = self.config.get("first_pass", {})
 
691
  print("\n[INFO] Executando pipeline de etapa única...")
692
  t_sp = time.perf_counter()
693
  ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
694
+
695
+ print(f"[DEBUG] single_pass_kwargs {single_pass_kwargs}")
696
+
697
+
698
  with ctx:
699
  result = self.pipeline(**single_pass_kwargs)
700
  print(f"[DEBUG] single-pass tempo={time.perf_counter()-t_sp:.3f}s")