euiiiia commited on
Commit
f0f0810
·
verified ·
1 Parent(s): 4aa7f1b

Update api/ltx_server_refactored.py

Browse files
Files changed (1) hide show
  1. api/ltx_server_refactored.py +239 -19
api/ltx_server_refactored.py CHANGED
@@ -299,28 +299,53 @@ class VideoService:
299
  # --- Métodos Públicos (API do Serviço) ---
300
  # --------------------------------------------------------------------------
301
 
302
- def _prepare_condition_items(self, items_list: List[Tuple], height: int, width: int, num_frames: int) -> List[ConditioningItem]:
303
- """Prepara os tensores de condicionamento a partir de imagens ou tensores."""
 
 
 
 
 
 
 
 
304
  if not items_list:
305
  return []
306
-
307
- height, width = self._calculate_downscaled_dims(height, width)
308
-
309
- height_padded = ((height - 1) // 8 + 1) * 8
310
- width_padded = ((width - 1) // 8 + 1) * 8
311
- padding_values = calculate_padding(height, width, height_padded, width_padded)
312
-
313
- conditioning_items = []
 
314
  for media, frame_idx, weight in items_list:
 
315
  if isinstance(media, str):
316
- tensor = self._prepare_conditioning_tensor_from_path(media, height, width, padding_values)
317
- else: # Assume que é um tensor
318
- tensor = media.to(self.device, dtype=self.runtime_autocast_dtype)
319
-
320
- # Garante que o frame de condicionamento esteja dentro dos limites do vídeo
321
- safe_frame_idx = max(0, min(int(frame_idx), num_frames - 1))
322
- conditioning_items.append(ConditioningItem(tensor, safe_frame_idx, float(weight)))
323
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  return conditioning_items
325
 
326
  def generate_low_resolution(
@@ -457,4 +482,199 @@ class VideoService:
457
  pixel_chunk = vae_manager_singleton.decode(chunk.to(self.device), decode_timestep=float(self.config.get("decode_timestep", 0.05)))
458
  pixel_chunks.append(pixel_chunk)
459
 
460
- final_pixel_tensor = self._merge_chunks_with_ove
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  # --- Métodos Públicos (API do Serviço) ---
300
  # --------------------------------------------------------------------------
301
 
302
+ def _prepare_condition_items(
303
+ self,
304
+ items_list: List[Tuple[Union[str, Image.Image, torch.Tensor], int, float]],
305
+ height: int,
306
+ width: int,
307
+ num_frames: int,
308
+ ) -> List[ConditioningItem]:
309
+ """
310
+ Prepara ConditioningItem a partir de paths, PIL.Images ou tensores.
311
+ """
312
  if not items_list:
313
  return []
314
+
315
+ # calcula dims downscaled (múltiplo de patch temporal)
316
+ down_h, down_w = self._calculate_downscaled_dims(height, width)
317
+ # ajusta padding para múltiplos de 8
318
+ pad_h = ((down_h - 1) // 8 + 1) * 8
319
+ pad_w = ((down_w - 1) // 8 + 1) * 8
320
+ padding = calculate_padding(down_h, down_w, pad_h, pad_w)
321
+
322
+ conditioning_items: List[ConditioningItem] = []
323
  for media, frame_idx, weight in items_list:
324
+ # carrega raw_item como PIL.Image ou tensor
325
  if isinstance(media, str):
326
+ img = Image.open(media).convert("RGB")
327
+ raw_item = ImageOps.fit(img, (down_w, down_h), Image.LANCZOS)
328
+ elif isinstance(media, Image.Image):
329
+ raw_item = ImageOps.fit(media, (down_w, down_h), Image.LANCZOS)
330
+ elif isinstance(media, torch.Tensor):
331
+ raw_item = media.to(device=self.device, dtype=self.runtime_autocast_dtype)
332
+ else:
333
+ raise TypeError(f"Tipo de media não suportado: {type(media)}")
334
+
335
+ # garante frame index seguro
336
+ safe_frame = max(0, min(int(frame_idx), num_frames - 1))
337
+
338
+ # codifica raw_item em latentes via VAE (inclui expansão de frame se precisar)
339
+ cond_item = self.encode_conditioning_item(
340
+ raw_item,
341
+ frame_number=safe_frame,
342
+ strength=float(weight),
343
+ height=down_h,
344
+ width=down_w,
345
+ vae_per_channel_normalize=self.vae_per_channel_normalize,
346
+ )
347
+ conditioning_items.append(cond_item)
348
+
349
  return conditioning_items
350
 
351
  def generate_low_resolution(
 
482
  pixel_chunk = vae_manager_singleton.decode(chunk.to(self.device), decode_timestep=float(self.config.get("decode_timestep", 0.05)))
483
  pixel_chunks.append(pixel_chunk)
484
 
485
+ final_pixel_tensor = self._merge_chunks_with_overlap(pixel_chunks)
486
+ final_video_path = self._save_video_from_tensor(final_pixel_tensor, f"final_video_{seed}", seed, temp_dir, fps=fps)
487
+ return final_video_path
488
+
489
+ except Exception as e:
490
+ print(f"[ERROR] Falha ao encodar latentes para MP4: {e}")
491
+ traceback.print_exc()
492
+ raise
493
+ finally:
494
+ self._finalize()
495
+
496
+ # --------------------------------------------------------------------------
497
+ # --- Métodos Internos e Auxiliares ---
498
+ # --------------------------------------------------------------------------
499
+
500
+ def _finalize(self):
501
+ """Limpa a memória da GPU e os diretórios temporários."""
502
+ if LTXV_DEBUG:
503
+ print("[DEBUG] Finalize: iniciando limpeza...")
504
+
505
+ gc.collect()
506
+ if torch.cuda.is_available():
507
+ torch.cuda.empty_cache()
508
+ torch.cuda.ipc_collect()
509
+
510
+ # Limpa todos os diretórios temporários registrados
511
+ for d in list(self._tmp_dirs):
512
+ shutil.rmtree(d, ignore_errors=True)
513
+ self._tmp_dirs.remove(d)
514
+ if LTXV_DEBUG:
515
+ print(f"[DEBUG] Diretório temporário removido: {d}")
516
+
517
+ def _load_config(self, config_filename: str) -> Dict:
518
+ """Carrega o arquivo de configuração YAML."""
519
+ config_path = LTX_VIDEO_REPO_DIR / "configs" / config_filename
520
+ print(f"[INFO] Carregando configuração de: {config_path}")
521
+ with open(config_path, "r") as file:
522
+ return yaml.safe_load(file)
523
+
524
+ def _load_models_from_hub(self) -> Tuple[LTXMultiScalePipeline, Optional[torch.nn.Module]]:
525
+ """Baixa e cria as instâncias da pipeline e do upsampler."""
526
+ t0 = time.perf_counter()
527
+ LTX_REPO = "Lightricks/LTX-Video"
528
+
529
+ print("[INFO] Baixando checkpoint principal...")
530
+ self.config["checkpoint_path"] = hf_hub_download(
531
+ repo_id=LTX_REPO, filename=self.config["checkpoint_path"],
532
+ token=os.getenv("HF_TOKEN")
533
+ )
534
+ print(f"[INFO] Checkpoint principal em: {self.config['checkpoint_path']}")
535
+
536
+ print("[INFO] Construindo pipeline...")
537
+ pipeline = create_ltx_video_pipeline(
538
+ ckpt_path=self.config["checkpoint_path"],
539
+ precision=self.config["precision"],
540
+ text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
541
+ sampler=self.config["sampler"],
542
+ device="cpu", # Carrega em CPU primeiro
543
+ enhance_prompt=False
544
+ )
545
+ print("[INFO] Pipeline construída.")
546
+
547
+ latent_upsampler = None
548
+ if self.config.get("spatial_upscaler_model_path"):
549
+ print("[INFO] Baixando upscaler espacial...")
550
+ self.config["spatial_upscaler_model_path"] = hf_hub_download(
551
+ repo_id=LTX_REPO, filename=self.config["spatial_upscaler_model_path"],
552
+ token=os.getenv("HF_TOKEN")
553
+ )
554
+ print(f"[INFO] Upscaler em: {self.config['spatial_upscaler_model_path']}")
555
+
556
+ print("[INFO] Construindo latent_upsampler...")
557
+ latent_upsampler = create_latent_upsampler(self.config["spatial_upscaler_model_path"], device="cpu")
558
+ print("[INFO] Latent upsampler construído.")
559
+
560
+ print(f"[INFO] Carregamento de modelos concluído em {time.perf_counter()-t0:.2f}s")
561
+ return pipeline, latent_upsampler
562
+
563
+ def _move_models_to_device(self):
564
+ """Move os modelos carregados para o dispositivo de computação (GPU/CPU)."""
565
+ print(f"[INFO] Movendo modelos para o dispositivo: {self.device}")
566
+ self.pipeline.to(self.device)
567
+ if self.latent_upsampler:
568
+ self.latent_upsampler.to(self.device)
569
+
570
+ def _get_precision_dtype(self) -> torch.dtype:
571
+ """Determina o dtype para autocast com base na configuração de precisão."""
572
+ prec = str(self.config.get("precision", "")).lower()
573
+ if prec in ["float8_e4m3fn", "bfloat16"]:
574
+ return torch.bfloat16
575
+ elif prec == "mixed_precision":
576
+ return torch.float16
577
+ return torch.float32
578
+
579
+ @torch.no_grad()
580
+ def _upsample_and_filter_latents(self, latents: torch.Tensor) -> torch.Tensor:
581
+ """Aplica o upsample espacial e o filtro AdaIN aos latentes."""
582
+ if not self.latent_upsampler:
583
+ raise ValueError("Latent Upsampler não está carregado para a operação de upscale.")
584
+
585
+ latents_unnormalized = un_normalize_latents(latents, self.pipeline.vae, vae_per_channel_normalize=True)
586
+ upsampled_latents_unnormalized = self.latent_upsampler(latents_unnormalized)
587
+ upsampled_latents_normalized = normalize_latents(upsampled_latents_unnormalized, self.pipeline.vae, vae_per_channel_normalize=True)
588
+
589
+ # Filtro AdaIN para manter consistência de cor/estilo com o vídeo de baixa resolução
590
+ return adain_filter_latent(latents=upsampled_latents_normalized, reference_latents=latents)
591
+
592
+ def _prepare_conditioning_tensor_from_path(self, filepath: str, height: int, width: int, padding: Tuple) -> torch.Tensor:
593
+ """Carrega uma imagem, redimensiona, aplica padding e move para o dispositivo."""
594
+ tensor = load_image_to_tensor_with_resize_and_crop(filepath, height, width)
595
+ tensor = F.pad(tensor, padding)
596
+ return tensor.to(self.device, dtype=self.runtime_autocast_dtype)
597
+
598
+ def _calculate_downscaled_dims(self, height: int, width: int) -> Tuple[int, int]:
599
+ """Calcula as dimensões para o primeiro passo (baixa resolução)."""
600
+ height_padded = ((height - 1) // 8 + 1) * 8
601
+ width_padded = ((width - 1) // 8 + 1) * 8
602
+
603
+ downscale_factor = self.config.get("downscale_factor", 0.6666666)
604
+ vae_scale_factor = self.pipeline.vae_scale_factor
605
+
606
+ target_w = int(width_padded * downscale_factor)
607
+ downscaled_width = target_w - (target_w % vae_scale_factor)
608
+
609
+ target_h = int(height_padded * downscale_factor)
610
+ downscaled_height = target_h - (target_h % vae_scale_factor)
611
+
612
+ return downscaled_height, downscaled_width
613
+
614
+ def _split_latents_with_overlap(self, latents: torch.Tensor, overlap: int = 1) -> List[torch.Tensor]:
615
+ """Divide um tensor de latentes em dois chunks com sobreposição."""
616
+ total_frames = latents.shape[2]
617
+ if total_frames <= overlap:
618
+ return [latents]
619
+
620
+ mid_point = max(overlap, total_frames // 2)
621
+ chunk1 = latents[:, :, :mid_point, :, :]
622
+ # O segundo chunk começa 'overlap' frames antes para criar a sobreposição
623
+ chunk2 = latents[:, :, mid_point - overlap:, :, :]
624
+
625
+ return [c for c in [chunk1, chunk2] if c.shape[2] > 0]
626
+
627
+ def _merge_chunks_with_overlap(self, chunks: List[torch.Tensor], overlap: int = 1) -> torch.Tensor:
628
+ """Junta uma lista de chunks, removendo a sobreposição."""
629
+ if not chunks:
630
+ return torch.empty(0)
631
+ if len(chunks) == 1:
632
+ return chunks[0]
633
+
634
+ # Pega o primeiro chunk sem o frame de sobreposição final
635
+ merged_list = [chunks[0][:, :, :-overlap, :, :]]
636
+ # Adiciona os chunks restantes
637
+ merged_list.extend(chunks[1:])
638
+
639
+ return torch.cat(merged_list, dim=2)
640
+
641
+ def _save_latents_to_disk(self, latents_tensor: torch.Tensor, base_filename: str, seed: int) -> str:
642
+ """Salva um tensor de latentes em um arquivo .pt."""
643
+ latents_cpu = latents_tensor.detach().to("cpu")
644
+ tensor_path = RESULTS_DIR / f"{base_filename}_{seed}.pt"
645
+ torch.save(latents_cpu, tensor_path)
646
+ if LTXV_DEBUG:
647
+ print(f"[DEBUG] Latentes salvos em: {tensor_path}")
648
+ return str(tensor_path)
649
+
650
+ def _save_video_from_tensor(self, pixel_tensor: torch.Tensor, base_filename: str, seed: int, temp_dir: str, fps: int = int(DEFAULT_FPS)) -> str:
651
+ """Salva um tensor de pixels como um arquivo de vídeo MP4."""
652
+ temp_path = os.path.join(temp_dir, f"{base_filename}_{seed}.mp4")
653
+ video_encode_tool_singleton.save_video_from_tensor(pixel_tensor, temp_path, fps=fps)
654
+
655
+ final_path = RESULTS_DIR / f"{base_filename}_{seed}.mp4"
656
+ shutil.move(temp_path, final_path)
657
+ print(f"[INFO] Vídeo final salvo em: {final_path}")
658
+ return str(final_path)
659
+
660
+ def _register_tmp_dir(self, dir_path: str):
661
+ """Registra um diretório temporário para limpeza posterior."""
662
+ if dir_path and os.path.isdir(dir_path):
663
+ self._tmp_dirs.add(dir_path)
664
+ if LTXV_DEBUG:
665
+ print(f"[DEBUG] Diretório temporário registrado: {dir_path}")
666
+
667
+ def _seed_everething(self, seed: int):
668
+ random.seed(seed)
669
+ np.random.seed(seed)
670
+ torch.manual_seed(seed)
671
+ if torch.cuda.is_available():
672
+ torch.cuda.manual_seed(seed)
673
+ if torch.backends.mps.is_available():
674
+ torch.mps.manual_seed(seed)
675
+
676
+ # ==============================================================================
677
+ # 4. INSTANCIAÇÃO E PONTO DE ENTRADA (Exemplo)
678
+ # ==============================================================================
679
+ video_generation_service = VideoService()
680
+ print("Instância do VideoService pronta para uso.")