Spaces:
Paused
Paused
Update api/ltx_server_refactored.py
Browse files- api/ltx_server_refactored.py +239 -19
api/ltx_server_refactored.py
CHANGED
|
@@ -299,28 +299,53 @@ class VideoService:
|
|
| 299 |
# --- Métodos Públicos (API do Serviço) ---
|
| 300 |
# --------------------------------------------------------------------------
|
| 301 |
|
| 302 |
-
def _prepare_condition_items(
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
if not items_list:
|
| 305 |
return []
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
|
|
|
| 314 |
for media, frame_idx, weight in items_list:
|
|
|
|
| 315 |
if isinstance(media, str):
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
return conditioning_items
|
| 325 |
|
| 326 |
def generate_low_resolution(
|
|
@@ -457,4 +482,199 @@ class VideoService:
|
|
| 457 |
pixel_chunk = vae_manager_singleton.decode(chunk.to(self.device), decode_timestep=float(self.config.get("decode_timestep", 0.05)))
|
| 458 |
pixel_chunks.append(pixel_chunk)
|
| 459 |
|
| 460 |
-
final_pixel_tensor = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
# --- Métodos Públicos (API do Serviço) ---
|
| 300 |
# --------------------------------------------------------------------------
|
| 301 |
|
| 302 |
+
def _prepare_condition_items(
|
| 303 |
+
self,
|
| 304 |
+
items_list: List[Tuple[Union[str, Image.Image, torch.Tensor], int, float]],
|
| 305 |
+
height: int,
|
| 306 |
+
width: int,
|
| 307 |
+
num_frames: int,
|
| 308 |
+
) -> List[ConditioningItem]:
|
| 309 |
+
"""
|
| 310 |
+
Prepara ConditioningItem a partir de paths, PIL.Images ou tensores.
|
| 311 |
+
"""
|
| 312 |
if not items_list:
|
| 313 |
return []
|
| 314 |
+
|
| 315 |
+
# calcula dims downscaled (múltiplo de patch temporal)
|
| 316 |
+
down_h, down_w = self._calculate_downscaled_dims(height, width)
|
| 317 |
+
# ajusta padding para múltiplos de 8
|
| 318 |
+
pad_h = ((down_h - 1) // 8 + 1) * 8
|
| 319 |
+
pad_w = ((down_w - 1) // 8 + 1) * 8
|
| 320 |
+
padding = calculate_padding(down_h, down_w, pad_h, pad_w)
|
| 321 |
+
|
| 322 |
+
conditioning_items: List[ConditioningItem] = []
|
| 323 |
for media, frame_idx, weight in items_list:
|
| 324 |
+
# carrega raw_item como PIL.Image ou tensor
|
| 325 |
if isinstance(media, str):
|
| 326 |
+
img = Image.open(media).convert("RGB")
|
| 327 |
+
raw_item = ImageOps.fit(img, (down_w, down_h), Image.LANCZOS)
|
| 328 |
+
elif isinstance(media, Image.Image):
|
| 329 |
+
raw_item = ImageOps.fit(media, (down_w, down_h), Image.LANCZOS)
|
| 330 |
+
elif isinstance(media, torch.Tensor):
|
| 331 |
+
raw_item = media.to(device=self.device, dtype=self.runtime_autocast_dtype)
|
| 332 |
+
else:
|
| 333 |
+
raise TypeError(f"Tipo de media não suportado: {type(media)}")
|
| 334 |
+
|
| 335 |
+
# garante frame index seguro
|
| 336 |
+
safe_frame = max(0, min(int(frame_idx), num_frames - 1))
|
| 337 |
+
|
| 338 |
+
# codifica raw_item em latentes via VAE (inclui expansão de frame se precisar)
|
| 339 |
+
cond_item = self.encode_conditioning_item(
|
| 340 |
+
raw_item,
|
| 341 |
+
frame_number=safe_frame,
|
| 342 |
+
strength=float(weight),
|
| 343 |
+
height=down_h,
|
| 344 |
+
width=down_w,
|
| 345 |
+
vae_per_channel_normalize=self.vae_per_channel_normalize,
|
| 346 |
+
)
|
| 347 |
+
conditioning_items.append(cond_item)
|
| 348 |
+
|
| 349 |
return conditioning_items
|
| 350 |
|
| 351 |
def generate_low_resolution(
|
|
|
|
| 482 |
pixel_chunk = vae_manager_singleton.decode(chunk.to(self.device), decode_timestep=float(self.config.get("decode_timestep", 0.05)))
|
| 483 |
pixel_chunks.append(pixel_chunk)
|
| 484 |
|
| 485 |
+
final_pixel_tensor = self._merge_chunks_with_overlap(pixel_chunks)
|
| 486 |
+
final_video_path = self._save_video_from_tensor(final_pixel_tensor, f"final_video_{seed}", seed, temp_dir, fps=fps)
|
| 487 |
+
return final_video_path
|
| 488 |
+
|
| 489 |
+
except Exception as e:
|
| 490 |
+
print(f"[ERROR] Falha ao encodar latentes para MP4: {e}")
|
| 491 |
+
traceback.print_exc()
|
| 492 |
+
raise
|
| 493 |
+
finally:
|
| 494 |
+
self._finalize()
|
| 495 |
+
|
| 496 |
+
# --------------------------------------------------------------------------
|
| 497 |
+
# --- Métodos Internos e Auxiliares ---
|
| 498 |
+
# --------------------------------------------------------------------------
|
| 499 |
+
|
| 500 |
+
def _finalize(self):
|
| 501 |
+
"""Limpa a memória da GPU e os diretórios temporários."""
|
| 502 |
+
if LTXV_DEBUG:
|
| 503 |
+
print("[DEBUG] Finalize: iniciando limpeza...")
|
| 504 |
+
|
| 505 |
+
gc.collect()
|
| 506 |
+
if torch.cuda.is_available():
|
| 507 |
+
torch.cuda.empty_cache()
|
| 508 |
+
torch.cuda.ipc_collect()
|
| 509 |
+
|
| 510 |
+
# Limpa todos os diretórios temporários registrados
|
| 511 |
+
for d in list(self._tmp_dirs):
|
| 512 |
+
shutil.rmtree(d, ignore_errors=True)
|
| 513 |
+
self._tmp_dirs.remove(d)
|
| 514 |
+
if LTXV_DEBUG:
|
| 515 |
+
print(f"[DEBUG] Diretório temporário removido: {d}")
|
| 516 |
+
|
| 517 |
+
def _load_config(self, config_filename: str) -> Dict:
|
| 518 |
+
"""Carrega o arquivo de configuração YAML."""
|
| 519 |
+
config_path = LTX_VIDEO_REPO_DIR / "configs" / config_filename
|
| 520 |
+
print(f"[INFO] Carregando configuração de: {config_path}")
|
| 521 |
+
with open(config_path, "r") as file:
|
| 522 |
+
return yaml.safe_load(file)
|
| 523 |
+
|
| 524 |
+
def _load_models_from_hub(self) -> Tuple[LTXMultiScalePipeline, Optional[torch.nn.Module]]:
|
| 525 |
+
"""Baixa e cria as instâncias da pipeline e do upsampler."""
|
| 526 |
+
t0 = time.perf_counter()
|
| 527 |
+
LTX_REPO = "Lightricks/LTX-Video"
|
| 528 |
+
|
| 529 |
+
print("[INFO] Baixando checkpoint principal...")
|
| 530 |
+
self.config["checkpoint_path"] = hf_hub_download(
|
| 531 |
+
repo_id=LTX_REPO, filename=self.config["checkpoint_path"],
|
| 532 |
+
token=os.getenv("HF_TOKEN")
|
| 533 |
+
)
|
| 534 |
+
print(f"[INFO] Checkpoint principal em: {self.config['checkpoint_path']}")
|
| 535 |
+
|
| 536 |
+
print("[INFO] Construindo pipeline...")
|
| 537 |
+
pipeline = create_ltx_video_pipeline(
|
| 538 |
+
ckpt_path=self.config["checkpoint_path"],
|
| 539 |
+
precision=self.config["precision"],
|
| 540 |
+
text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
|
| 541 |
+
sampler=self.config["sampler"],
|
| 542 |
+
device="cpu", # Carrega em CPU primeiro
|
| 543 |
+
enhance_prompt=False
|
| 544 |
+
)
|
| 545 |
+
print("[INFO] Pipeline construída.")
|
| 546 |
+
|
| 547 |
+
latent_upsampler = None
|
| 548 |
+
if self.config.get("spatial_upscaler_model_path"):
|
| 549 |
+
print("[INFO] Baixando upscaler espacial...")
|
| 550 |
+
self.config["spatial_upscaler_model_path"] = hf_hub_download(
|
| 551 |
+
repo_id=LTX_REPO, filename=self.config["spatial_upscaler_model_path"],
|
| 552 |
+
token=os.getenv("HF_TOKEN")
|
| 553 |
+
)
|
| 554 |
+
print(f"[INFO] Upscaler em: {self.config['spatial_upscaler_model_path']}")
|
| 555 |
+
|
| 556 |
+
print("[INFO] Construindo latent_upsampler...")
|
| 557 |
+
latent_upsampler = create_latent_upsampler(self.config["spatial_upscaler_model_path"], device="cpu")
|
| 558 |
+
print("[INFO] Latent upsampler construído.")
|
| 559 |
+
|
| 560 |
+
print(f"[INFO] Carregamento de modelos concluído em {time.perf_counter()-t0:.2f}s")
|
| 561 |
+
return pipeline, latent_upsampler
|
| 562 |
+
|
| 563 |
+
def _move_models_to_device(self):
|
| 564 |
+
"""Move os modelos carregados para o dispositivo de computação (GPU/CPU)."""
|
| 565 |
+
print(f"[INFO] Movendo modelos para o dispositivo: {self.device}")
|
| 566 |
+
self.pipeline.to(self.device)
|
| 567 |
+
if self.latent_upsampler:
|
| 568 |
+
self.latent_upsampler.to(self.device)
|
| 569 |
+
|
| 570 |
+
def _get_precision_dtype(self) -> torch.dtype:
|
| 571 |
+
"""Determina o dtype para autocast com base na configuração de precisão."""
|
| 572 |
+
prec = str(self.config.get("precision", "")).lower()
|
| 573 |
+
if prec in ["float8_e4m3fn", "bfloat16"]:
|
| 574 |
+
return torch.bfloat16
|
| 575 |
+
elif prec == "mixed_precision":
|
| 576 |
+
return torch.float16
|
| 577 |
+
return torch.float32
|
| 578 |
+
|
| 579 |
+
@torch.no_grad()
|
| 580 |
+
def _upsample_and_filter_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 581 |
+
"""Aplica o upsample espacial e o filtro AdaIN aos latentes."""
|
| 582 |
+
if not self.latent_upsampler:
|
| 583 |
+
raise ValueError("Latent Upsampler não está carregado para a operação de upscale.")
|
| 584 |
+
|
| 585 |
+
latents_unnormalized = un_normalize_latents(latents, self.pipeline.vae, vae_per_channel_normalize=True)
|
| 586 |
+
upsampled_latents_unnormalized = self.latent_upsampler(latents_unnormalized)
|
| 587 |
+
upsampled_latents_normalized = normalize_latents(upsampled_latents_unnormalized, self.pipeline.vae, vae_per_channel_normalize=True)
|
| 588 |
+
|
| 589 |
+
# Filtro AdaIN para manter consistência de cor/estilo com o vídeo de baixa resolução
|
| 590 |
+
return adain_filter_latent(latents=upsampled_latents_normalized, reference_latents=latents)
|
| 591 |
+
|
| 592 |
+
def _prepare_conditioning_tensor_from_path(self, filepath: str, height: int, width: int, padding: Tuple) -> torch.Tensor:
|
| 593 |
+
"""Carrega uma imagem, redimensiona, aplica padding e move para o dispositivo."""
|
| 594 |
+
tensor = load_image_to_tensor_with_resize_and_crop(filepath, height, width)
|
| 595 |
+
tensor = F.pad(tensor, padding)
|
| 596 |
+
return tensor.to(self.device, dtype=self.runtime_autocast_dtype)
|
| 597 |
+
|
| 598 |
+
def _calculate_downscaled_dims(self, height: int, width: int) -> Tuple[int, int]:
|
| 599 |
+
"""Calcula as dimensões para o primeiro passo (baixa resolução)."""
|
| 600 |
+
height_padded = ((height - 1) // 8 + 1) * 8
|
| 601 |
+
width_padded = ((width - 1) // 8 + 1) * 8
|
| 602 |
+
|
| 603 |
+
downscale_factor = self.config.get("downscale_factor", 0.6666666)
|
| 604 |
+
vae_scale_factor = self.pipeline.vae_scale_factor
|
| 605 |
+
|
| 606 |
+
target_w = int(width_padded * downscale_factor)
|
| 607 |
+
downscaled_width = target_w - (target_w % vae_scale_factor)
|
| 608 |
+
|
| 609 |
+
target_h = int(height_padded * downscale_factor)
|
| 610 |
+
downscaled_height = target_h - (target_h % vae_scale_factor)
|
| 611 |
+
|
| 612 |
+
return downscaled_height, downscaled_width
|
| 613 |
+
|
| 614 |
+
def _split_latents_with_overlap(self, latents: torch.Tensor, overlap: int = 1) -> List[torch.Tensor]:
|
| 615 |
+
"""Divide um tensor de latentes em dois chunks com sobreposição."""
|
| 616 |
+
total_frames = latents.shape[2]
|
| 617 |
+
if total_frames <= overlap:
|
| 618 |
+
return [latents]
|
| 619 |
+
|
| 620 |
+
mid_point = max(overlap, total_frames // 2)
|
| 621 |
+
chunk1 = latents[:, :, :mid_point, :, :]
|
| 622 |
+
# O segundo chunk começa 'overlap' frames antes para criar a sobreposição
|
| 623 |
+
chunk2 = latents[:, :, mid_point - overlap:, :, :]
|
| 624 |
+
|
| 625 |
+
return [c for c in [chunk1, chunk2] if c.shape[2] > 0]
|
| 626 |
+
|
| 627 |
+
def _merge_chunks_with_overlap(self, chunks: List[torch.Tensor], overlap: int = 1) -> torch.Tensor:
|
| 628 |
+
"""Junta uma lista de chunks, removendo a sobreposição."""
|
| 629 |
+
if not chunks:
|
| 630 |
+
return torch.empty(0)
|
| 631 |
+
if len(chunks) == 1:
|
| 632 |
+
return chunks[0]
|
| 633 |
+
|
| 634 |
+
# Pega o primeiro chunk sem o frame de sobreposição final
|
| 635 |
+
merged_list = [chunks[0][:, :, :-overlap, :, :]]
|
| 636 |
+
# Adiciona os chunks restantes
|
| 637 |
+
merged_list.extend(chunks[1:])
|
| 638 |
+
|
| 639 |
+
return torch.cat(merged_list, dim=2)
|
| 640 |
+
|
| 641 |
+
def _save_latents_to_disk(self, latents_tensor: torch.Tensor, base_filename: str, seed: int) -> str:
|
| 642 |
+
"""Salva um tensor de latentes em um arquivo .pt."""
|
| 643 |
+
latents_cpu = latents_tensor.detach().to("cpu")
|
| 644 |
+
tensor_path = RESULTS_DIR / f"{base_filename}_{seed}.pt"
|
| 645 |
+
torch.save(latents_cpu, tensor_path)
|
| 646 |
+
if LTXV_DEBUG:
|
| 647 |
+
print(f"[DEBUG] Latentes salvos em: {tensor_path}")
|
| 648 |
+
return str(tensor_path)
|
| 649 |
+
|
| 650 |
+
def _save_video_from_tensor(self, pixel_tensor: torch.Tensor, base_filename: str, seed: int, temp_dir: str, fps: int = int(DEFAULT_FPS)) -> str:
|
| 651 |
+
"""Salva um tensor de pixels como um arquivo de vídeo MP4."""
|
| 652 |
+
temp_path = os.path.join(temp_dir, f"{base_filename}_{seed}.mp4")
|
| 653 |
+
video_encode_tool_singleton.save_video_from_tensor(pixel_tensor, temp_path, fps=fps)
|
| 654 |
+
|
| 655 |
+
final_path = RESULTS_DIR / f"{base_filename}_{seed}.mp4"
|
| 656 |
+
shutil.move(temp_path, final_path)
|
| 657 |
+
print(f"[INFO] Vídeo final salvo em: {final_path}")
|
| 658 |
+
return str(final_path)
|
| 659 |
+
|
| 660 |
+
def _register_tmp_dir(self, dir_path: str):
|
| 661 |
+
"""Registra um diretório temporário para limpeza posterior."""
|
| 662 |
+
if dir_path and os.path.isdir(dir_path):
|
| 663 |
+
self._tmp_dirs.add(dir_path)
|
| 664 |
+
if LTXV_DEBUG:
|
| 665 |
+
print(f"[DEBUG] Diretório temporário registrado: {dir_path}")
|
| 666 |
+
|
| 667 |
+
def _seed_everething(self, seed: int):
|
| 668 |
+
random.seed(seed)
|
| 669 |
+
np.random.seed(seed)
|
| 670 |
+
torch.manual_seed(seed)
|
| 671 |
+
if torch.cuda.is_available():
|
| 672 |
+
torch.cuda.manual_seed(seed)
|
| 673 |
+
if torch.backends.mps.is_available():
|
| 674 |
+
torch.mps.manual_seed(seed)
|
| 675 |
+
|
| 676 |
+
# ==============================================================================
|
| 677 |
+
# 4. INSTANCIAÇÃO E PONTO DE ENTRADA (Exemplo)
|
| 678 |
+
# ==============================================================================
|
| 679 |
+
video_generation_service = VideoService()
|
| 680 |
+
print("Instância do VideoService pronta para uso.")
|