Eueuiaa commited on
Commit
e27a641
·
verified ·
1 Parent(s): ed9c16e

Update api/ltx_server.py

Browse files
Files changed (1) hide show
  1. api/ltx_server.py +77 -5
api/ltx_server.py CHANGED
@@ -8,7 +8,7 @@ warnings.filterwarnings("ignore", category=UserWarning)
8
  warnings.filterwarnings("ignore", category=FutureWarning)
9
  warnings.filterwarnings("ignore", message=".*")
10
 
11
- from huggingface_hub import logging
12
 
13
  logging.set_verbosity_error()
14
  logging.set_verbosity_warning()
@@ -354,6 +354,80 @@ class VideoService:
354
  return yaml.safe_load(file)
355
 
356
  def _load_models(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  t0 = time.perf_counter()
358
  LTX_REPO = "Lightricks/LTX-Video"
359
  print("[DEBUG] Baixando checkpoint principal...")
@@ -445,7 +519,6 @@ class VideoService:
445
  return upsampled_latents
446
 
447
 
448
-
449
  def _apply_precision_policy(self):
450
  prec = str(self.config.get("precision", "")).lower()
451
  self.runtime_autocast_dtype = torch.float32
@@ -853,12 +926,11 @@ class VideoService:
853
  temp_dir = tempfile.mkdtemp(prefix="ltxv_"); self._register_tmp_dir(temp_dir)
854
  results_dir = "/app/output"; os.makedirs(results_dir, exist_ok=True)
855
 
856
- latents_parts_up = self._dividir_latentes_por_tamanho(latents_list,15,1)
857
-
858
  partes_mp4 = []
859
  par = 0
860
 
861
- for latents_vae in latents_parts_up:
862
 
863
  latents_cpu_vae = latents_vae.detach().to("cpu", non_blocking=True)
864
  torch.cuda.empty_cache()
 
8
  warnings.filterwarnings("ignore", category=FutureWarning)
9
  warnings.filterwarnings("ignore", message=".*")
10
 
11
+ from huggingface_hub import logging, hf_hub_download
12
 
13
  logging.set_verbosity_error()
14
  logging.set_verbosity_warning()
 
354
  return yaml.safe_load(file)
355
 
356
  def _load_models(self):
357
+ """
358
+ Carrega os modelos de forma inteligente:
359
+ 1. Tenta resolver o caminho do cache local (rápido, sem rede).
360
+ 2. Se o arquivo não for encontrado localmente, baixa como fallback.
361
+ Garante que o serviço possa iniciar mesmo que o setup.py não tenha sido executado.
362
+ """
363
+ t0 = time.perf_counter()
364
+ LTX_REPO = "Lightricks/LTX-Video"
365
+
366
+ print("[DEBUG] Resolvendo caminhos dos modelos de forma inteligente...")
367
+
368
+ # --- Função Auxiliar para Carregamento Inteligente ---
369
+ def get_or_download_model(repo_id, filename, description):
370
+ try:
371
+ # hf_hub_download é a ferramenta certa aqui. Ela verifica o cache PRIMEIRO.
372
+ # Se o arquivo estiver no cache, retorna o caminho instantaneamente (após uma verificação rápida de metadados).
373
+ # Se não estiver no cache, ela o baixa.
374
+ print(f"[DEBUG] Verificando {description}: {filename}...")
375
+ model_path = hf_hub_download(
376
+ repo_id=repo_id,
377
+ filename=filename,
378
+ # Forçar o uso de um cache específico se necessário
379
+ cache_dir=os.getenv("HF_HOME_CACHE"),
380
+ token=os.getenv("HF_TOKEN")
381
+ )
382
+ print(f"[DEBUG] Caminho do {description} resolvido com sucesso.")
383
+ return model_path
384
+ except Exception as e:
385
+ print("\n" + "="*80)
386
+ print(f"[ERRO CRÍTICO] Falha ao obter o modelo '{filename}'.")
387
+ print(f"Detalhe do erro: {e}")
388
+ print("Verifique sua conexão com a internet ou o estado do cache do Hugging Face.")
389
+ print("="*80 + "\n")
390
+ sys.exit(1)
391
+
392
+ # --- Checkpoint Principal ---
393
+ checkpoint_filename = self.config["checkpoint_path"]
394
+ distilled_model_path = get_or_download_model(
395
+ LTX_REPO, checkpoint_filename, "checkpoint principal"
396
+ )
397
+ self.config["checkpoint_path"] = distilled_model_path
398
+
399
+ # --- Upscaler Espacial ---
400
+ upscaler_filename = self.config["spatial_upscaler_model_path"]
401
+ spatial_upscaler_path = get_or_download_model(
402
+ LTX_REPO, upscaler_filename, "upscaler espacial"
403
+ )
404
+ self.config["spatial_upscaler_model_path"] = spatial_upscaler_path
405
+
406
+ # --- Construção dos Pipelines ---
407
+ print("\n[DEBUG] Construindo pipeline a partir dos caminhos resolvidos...")
408
+ pipeline = create_ltx_video_pipeline(
409
+ ckpt_path=self.config["checkpoint_path"],
410
+ precision=self.config["precision"],
411
+ text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
412
+ sampler=self.config["sampler"],
413
+ device="cpu",
414
+ enhance_prompt=False,
415
+ prompt_enhancer_image_caption_model_name_or_path=self.config["prompt_enhancer_image_caption_model_name_or_path"],
416
+ prompt_enhancer_llm_model_name_or_path=self.config["prompt_enhancer_llm_model_name_or_path"],
417
+ )
418
+ print("[DEBUG] Pipeline pronto.")
419
+
420
+ latent_upsampler = None
421
+ if self.config.get("spatial_upscaler_model_path"):
422
+ print("[DEBUG] Construindo latent_upsampler...")
423
+ latent_upsampler = create_latent_upsampler(self.config["spatial_upscaler_model_path"], device="cpu")
424
+ print("[DEBUG] Upsampler pronto.")
425
+
426
+ print(f"[DEBUG] _load_models() tempo total={time.perf_counter()-t0:.3f}s")
427
+ return pipeline, latent_upsampler```
428
+
429
+
430
+ def _load_models_old(self):
431
  t0 = time.perf_counter()
432
  LTX_REPO = "Lightricks/LTX-Video"
433
  print("[DEBUG] Baixando checkpoint principal...")
 
519
  return upsampled_latents
520
 
521
 
 
522
  def _apply_precision_policy(self):
523
  prec = str(self.config.get("precision", "")).lower()
524
  self.runtime_autocast_dtype = torch.float32
 
926
  temp_dir = tempfile.mkdtemp(prefix="ltxv_"); self._register_tmp_dir(temp_dir)
927
  results_dir = "/app/output"; os.makedirs(results_dir, exist_ok=True)
928
 
929
+
 
930
  partes_mp4 = []
931
  par = 0
932
 
933
+ for latents_vae in latents_list:
934
 
935
  latents_cpu_vae = latents_vae.detach().to("cpu", non_blocking=True)
936
  torch.cuda.empty_cache()