Update api/ltx_server.py
Browse files- api/ltx_server.py +72 -6
api/ltx_server.py
CHANGED
|
@@ -385,6 +385,72 @@ class VideoService:
|
|
| 385 |
print(f"[DEBUG] Cond shape={tuple(out.shape)} dtype={out.dtype} device={out.device}")
|
| 386 |
return out
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
def _decode_latents_to_video(self, latents: torch.Tensor, output_video_path: str, frame_rate: int,
|
| 389 |
padding_values, progress_callback=None):
|
| 390 |
print(f"[DEBUG] Decodificando latentes → vídeo: {output_video_path}")
|
|
@@ -620,14 +686,14 @@ class VideoService:
|
|
| 620 |
output_video_path = os.path.join(temp_dir, f"output_{used_seed}.mp4")
|
| 621 |
final_output_path = None
|
| 622 |
|
| 623 |
-
if external_decode
|
| 624 |
-
print("[DEBUG]
|
| 625 |
-
self.
|
| 626 |
latents=latents,
|
| 627 |
-
|
| 628 |
-
|
| 629 |
padding_values=padding_values,
|
| 630 |
-
progress_callback=progress_callback
|
| 631 |
)
|
| 632 |
else:
|
| 633 |
print("[DEBUG] Escrevendo vídeo a partir de pixels (sem latentes)...")
|
|
|
|
| 385 |
print(f"[DEBUG] Cond shape={tuple(out.shape)} dtype={out.dtype} device={out.device}")
|
| 386 |
return out
|
| 387 |
|
| 388 |
+
|
| 389 |
+
def _decode_one_latent_to_pixel(self, latent_chw: torch.Tensor) -> torch.Tensor:
|
| 390 |
+
"""
|
| 391 |
+
Decodifica um latente (C,H,W) para pixel (C,H,W) no intervalo [0,1].
|
| 392 |
+
Usa pipeline.decode_latents se existir, senão pipeline.vae.decode.
|
| 393 |
+
"""
|
| 394 |
+
if self.device == "cuda":
|
| 395 |
+
ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype)
|
| 396 |
+
else:
|
| 397 |
+
ctx = contextlib.nullcontext()
|
| 398 |
+
with ctx:
|
| 399 |
+
if hasattr(self.pipeline, "decode_latents"):
|
| 400 |
+
img_bchw = self.pipeline.decode_latents(latent_chw.unsqueeze(0))
|
| 401 |
+
elif hasattr(self.pipeline, "vae") and hasattr(self.pipeline.vae, "decode"):
|
| 402 |
+
img_bchw = self.pipeline.vae.decode(latent_chw.unsqueeze(0))
|
| 403 |
+
else:
|
| 404 |
+
raise RuntimeError("Nenhum decoder encontrado (decode_latents/vae.decode).")
|
| 405 |
+
img_chw = img_bchw[0]
|
| 406 |
+
# Normaliza para [0,1] caso venha em [-1,1]
|
| 407 |
+
if img_chw.min() < 0:
|
| 408 |
+
img_chw = (img_chw.clamp(-1, 1) + 1.0) / 2.0
|
| 409 |
+
else:
|
| 410 |
+
img_chw = img_chw.clamp(0, 1)
|
| 411 |
+
return img_chw
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def _pixels_to_uint8_np(self, pixel_chw: torch.Tensor, padding_values) -> np.ndarray:
|
| 415 |
+
"""
|
| 416 |
+
Converte (C,H,W) float [0,1] em (H,W,C) uint8, aplicando crop do padding.
|
| 417 |
+
"""
|
| 418 |
+
pad_left, pad_right, pad_top, pad_bottom = padding_values
|
| 419 |
+
H, W = pixel_chw.shape[1], pixel_chw.shape[2]
|
| 420 |
+
h_end = H - pad_bottom if pad_bottom > 0 else H
|
| 421 |
+
w_end = W - pad_right if pad_right > 0 else W
|
| 422 |
+
pixel_chw = pixel_chw[:, pad_top:h_end, pad_left:w_end]
|
| 423 |
+
frame_hwc_u8 = (pixel_chw.permute(1, 2, 0)
|
| 424 |
+
.mul(255)
|
| 425 |
+
.to(torch.uint8)
|
| 426 |
+
.cpu()
|
| 427 |
+
.numpy())
|
| 428 |
+
return frame_hwc_u8
|
| 429 |
+
|
| 430 |
+
def encode_latents_to_mp4(self, latents: torch.Tensor, output_path: str, fps: int, padding_values,
|
| 431 |
+
progress_callback=None):
|
| 432 |
+
"""
|
| 433 |
+
Pipeline final: latentes (B,C,T,H,W) -> decodifica cada quadro -> escreve MP4 incremental.
|
| 434 |
+
Segue o padrão do encoder no outro app (frame a frame sem array 4D gigante).
|
| 435 |
+
"""
|
| 436 |
+
T = latents.shape[2]
|
| 437 |
+
print(f"[DEBUG] encode_latents_to_mp4: frames={T} out={output_path}")
|
| 438 |
+
with imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8) as writer:
|
| 439 |
+
for i in range(T):
|
| 440 |
+
latent_chw = latents[0, :, i].to(self.device)
|
| 441 |
+
pixel_chw = self._decode_one_latent_to_pixel(latent_chw)
|
| 442 |
+
frame_hwc_u8 = self._pixels_to_uint8_np(pixel_chw, padding_values)
|
| 443 |
+
writer.append_data(frame_hwc_u8)
|
| 444 |
+
if progress_callback:
|
| 445 |
+
progress_callback(i + 1, T)
|
| 446 |
+
if i % getattr(self, "frame_log_every", 8) == 0:
|
| 447 |
+
print(f"[DEBUG] encode frame {i}/{T}")
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
|
| 454 |
def _decode_latents_to_video(self, latents: torch.Tensor, output_video_path: str, frame_rate: int,
|
| 455 |
padding_values, progress_callback=None):
|
| 456 |
print(f"[DEBUG] Decodificando latentes → vídeo: {output_video_path}")
|
|
|
|
| 686 |
output_video_path = os.path.join(temp_dir, f"output_{used_seed}.mp4")
|
| 687 |
final_output_path = None
|
| 688 |
|
| 689 |
+
if external_decode:*
|
| 690 |
+
print("[DEBUG] Codificando a partir dos latentes (VAE externo) → MP4...")
|
| 691 |
+
self.encode_latents_to_mp4(
|
| 692 |
latents=latents,
|
| 693 |
+
output_path=output_video_path,
|
| 694 |
+
fps=call_kwargs["frame_rate"],
|
| 695 |
padding_values=padding_values,
|
| 696 |
+
progress_callback=progress_callback
|
| 697 |
)
|
| 698 |
else:
|
| 699 |
print("[DEBUG] Escrevendo vídeo a partir de pixels (sem latentes)...")
|