Update api/ltx_server.py
Browse files- api/ltx_server.py +77 -5
api/ltx_server.py
CHANGED
|
@@ -8,7 +8,7 @@ warnings.filterwarnings("ignore", category=UserWarning)
|
|
| 8 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 9 |
warnings.filterwarnings("ignore", message=".*")
|
| 10 |
|
| 11 |
-
from huggingface_hub import logging
|
| 12 |
|
| 13 |
logging.set_verbosity_error()
|
| 14 |
logging.set_verbosity_warning()
|
|
@@ -354,6 +354,80 @@ class VideoService:
|
|
| 354 |
return yaml.safe_load(file)
|
| 355 |
|
| 356 |
def _load_models(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
t0 = time.perf_counter()
|
| 358 |
LTX_REPO = "Lightricks/LTX-Video"
|
| 359 |
print("[DEBUG] Baixando checkpoint principal...")
|
|
@@ -445,7 +519,6 @@ class VideoService:
|
|
| 445 |
return upsampled_latents
|
| 446 |
|
| 447 |
|
| 448 |
-
|
| 449 |
def _apply_precision_policy(self):
|
| 450 |
prec = str(self.config.get("precision", "")).lower()
|
| 451 |
self.runtime_autocast_dtype = torch.float32
|
|
@@ -853,12 +926,11 @@ class VideoService:
|
|
| 853 |
temp_dir = tempfile.mkdtemp(prefix="ltxv_"); self._register_tmp_dir(temp_dir)
|
| 854 |
results_dir = "/app/output"; os.makedirs(results_dir, exist_ok=True)
|
| 855 |
|
| 856 |
-
|
| 857 |
-
|
| 858 |
partes_mp4 = []
|
| 859 |
par = 0
|
| 860 |
|
| 861 |
-
for latents_vae in
|
| 862 |
|
| 863 |
latents_cpu_vae = latents_vae.detach().to("cpu", non_blocking=True)
|
| 864 |
torch.cuda.empty_cache()
|
|
|
|
| 8 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 9 |
warnings.filterwarnings("ignore", message=".*")
|
| 10 |
|
| 11 |
+
from huggingface_hub import logging, hf_hub_download
|
| 12 |
|
| 13 |
logging.set_verbosity_error()
|
| 14 |
logging.set_verbosity_warning()
|
|
|
|
| 354 |
return yaml.safe_load(file)
|
| 355 |
|
| 356 |
def _load_models(self):
|
| 357 |
+
"""
|
| 358 |
+
Carrega os modelos de forma inteligente:
|
| 359 |
+
1. Tenta resolver o caminho do cache local (rápido, sem rede).
|
| 360 |
+
2. Se o arquivo não for encontrado localmente, baixa como fallback.
|
| 361 |
+
Garante que o serviço possa iniciar mesmo que o setup.py não tenha sido executado.
|
| 362 |
+
"""
|
| 363 |
+
t0 = time.perf_counter()
|
| 364 |
+
LTX_REPO = "Lightricks/LTX-Video"
|
| 365 |
+
|
| 366 |
+
print("[DEBUG] Resolvendo caminhos dos modelos de forma inteligente...")
|
| 367 |
+
|
| 368 |
+
# --- Função Auxiliar para Carregamento Inteligente ---
|
| 369 |
+
def get_or_download_model(repo_id, filename, description):
|
| 370 |
+
try:
|
| 371 |
+
# hf_hub_download é a ferramenta certa aqui. Ela verifica o cache PRIMEIRO.
|
| 372 |
+
# Se o arquivo estiver no cache, retorna o caminho instantaneamente (após uma verificação rápida de metadados).
|
| 373 |
+
# Se não estiver no cache, ela o baixa.
|
| 374 |
+
print(f"[DEBUG] Verificando {description}: {filename}...")
|
| 375 |
+
model_path = hf_hub_download(
|
| 376 |
+
repo_id=repo_id,
|
| 377 |
+
filename=filename,
|
| 378 |
+
# Forçar o uso de um cache específico se necessário
|
| 379 |
+
cache_dir=os.getenv("HF_HOME_CACHE"),
|
| 380 |
+
token=os.getenv("HF_TOKEN")
|
| 381 |
+
)
|
| 382 |
+
print(f"[DEBUG] Caminho do {description} resolvido com sucesso.")
|
| 383 |
+
return model_path
|
| 384 |
+
except Exception as e:
|
| 385 |
+
print("\n" + "="*80)
|
| 386 |
+
print(f"[ERRO CRÍTICO] Falha ao obter o modelo '{filename}'.")
|
| 387 |
+
print(f"Detalhe do erro: {e}")
|
| 388 |
+
print("Verifique sua conexão com a internet ou o estado do cache do Hugging Face.")
|
| 389 |
+
print("="*80 + "\n")
|
| 390 |
+
sys.exit(1)
|
| 391 |
+
|
| 392 |
+
# --- Checkpoint Principal ---
|
| 393 |
+
checkpoint_filename = self.config["checkpoint_path"]
|
| 394 |
+
distilled_model_path = get_or_download_model(
|
| 395 |
+
LTX_REPO, checkpoint_filename, "checkpoint principal"
|
| 396 |
+
)
|
| 397 |
+
self.config["checkpoint_path"] = distilled_model_path
|
| 398 |
+
|
| 399 |
+
# --- Upscaler Espacial ---
|
| 400 |
+
upscaler_filename = self.config["spatial_upscaler_model_path"]
|
| 401 |
+
spatial_upscaler_path = get_or_download_model(
|
| 402 |
+
LTX_REPO, upscaler_filename, "upscaler espacial"
|
| 403 |
+
)
|
| 404 |
+
self.config["spatial_upscaler_model_path"] = spatial_upscaler_path
|
| 405 |
+
|
| 406 |
+
# --- Construção dos Pipelines ---
|
| 407 |
+
print("\n[DEBUG] Construindo pipeline a partir dos caminhos resolvidos...")
|
| 408 |
+
pipeline = create_ltx_video_pipeline(
|
| 409 |
+
ckpt_path=self.config["checkpoint_path"],
|
| 410 |
+
precision=self.config["precision"],
|
| 411 |
+
text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
|
| 412 |
+
sampler=self.config["sampler"],
|
| 413 |
+
device="cpu",
|
| 414 |
+
enhance_prompt=False,
|
| 415 |
+
prompt_enhancer_image_caption_model_name_or_path=self.config["prompt_enhancer_image_caption_model_name_or_path"],
|
| 416 |
+
prompt_enhancer_llm_model_name_or_path=self.config["prompt_enhancer_llm_model_name_or_path"],
|
| 417 |
+
)
|
| 418 |
+
print("[DEBUG] Pipeline pronto.")
|
| 419 |
+
|
| 420 |
+
latent_upsampler = None
|
| 421 |
+
if self.config.get("spatial_upscaler_model_path"):
|
| 422 |
+
print("[DEBUG] Construindo latent_upsampler...")
|
| 423 |
+
latent_upsampler = create_latent_upsampler(self.config["spatial_upscaler_model_path"], device="cpu")
|
| 424 |
+
print("[DEBUG] Upsampler pronto.")
|
| 425 |
+
|
| 426 |
+
print(f"[DEBUG] _load_models() tempo total={time.perf_counter()-t0:.3f}s")
|
| 427 |
+
return pipeline, latent_upsampler```
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def _load_models_old(self):
|
| 431 |
t0 = time.perf_counter()
|
| 432 |
LTX_REPO = "Lightricks/LTX-Video"
|
| 433 |
print("[DEBUG] Baixando checkpoint principal...")
|
|
|
|
| 519 |
return upsampled_latents
|
| 520 |
|
| 521 |
|
|
|
|
| 522 |
def _apply_precision_policy(self):
|
| 523 |
prec = str(self.config.get("precision", "")).lower()
|
| 524 |
self.runtime_autocast_dtype = torch.float32
|
|
|
|
| 926 |
temp_dir = tempfile.mkdtemp(prefix="ltxv_"); self._register_tmp_dir(temp_dir)
|
| 927 |
results_dir = "/app/output"; os.makedirs(results_dir, exist_ok=True)
|
| 928 |
|
| 929 |
+
|
|
|
|
| 930 |
partes_mp4 = []
|
| 931 |
par = 0
|
| 932 |
|
| 933 |
+
for latents_vae in latents_list:
|
| 934 |
|
| 935 |
latents_cpu_vae = latents_vae.detach().to("cpu", non_blocking=True)
|
| 936 |
torch.cuda.empty_cache()
|