EuuIia commited on
Commit
2e1e83b
·
verified ·
1 Parent(s): f0b5401

Update api/ltx_server.py

Browse files
Files changed (1) hide show
  1. api/ltx_server.py +91 -36
api/ltx_server.py CHANGED
@@ -1,6 +1,3 @@
1
- # video_service.py
2
-
3
- # --- 1. IMPORTAÇÕES ---
4
  import torch
5
  import numpy as np
6
  import random
@@ -63,9 +60,9 @@ def _query_gpu_processes_via_nvidiasmi(device_index: int) -> List[Dict]:
63
  parts = [p.strip() for p in line.split(",")]
64
  if len(parts) >= 3:
65
  try:
66
- pid = int(parts[0])
67
- name = parts[1]
68
- used_mb = int(parts[2])
69
  user = "unknown"
70
  try:
71
  import psutil
@@ -349,6 +346,41 @@ class VideoService:
349
  return tensor.to(self.device, dtype=self.runtime_autocast_dtype)
350
  return tensor.to(self.device)
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  def generate(
353
  self,
354
  prompt,
@@ -370,6 +402,7 @@ class VideoService:
370
  guidance_scale=3.0,
371
  improve_texture=True,
372
  progress_callback=None,
 
373
  ):
374
  if self.device == "cuda":
375
  torch.cuda.empty_cache()
@@ -417,7 +450,7 @@ class VideoService:
417
  "num_frames": actual_num_frames,
418
  "frame_rate": int(FPS),
419
  "generator": generator,
420
- "output_type": "pt",
421
  "conditioning_items": conditioning_items if conditioning_items else None,
422
  "media_items": None,
423
  "decode_timestep": self.config["decode_timestep"],
@@ -441,6 +474,7 @@ class VideoService:
441
  padding=padding_values,
442
  ).to(self.device)
443
 
 
444
  result_tensor = None
445
  multi_scale_pipeline = None
446
 
@@ -452,7 +486,6 @@ class VideoService:
452
  first_pass_args["guidance_scale"] = float(guidance_scale)
453
  second_pass_args = self.config.get("second_pass", {}).copy()
454
  second_pass_args["guidance_scale"] = float(guidance_scale)
455
-
456
  multi_scale_call_kwargs = call_kwargs.copy()
457
  multi_scale_call_kwargs.update(
458
  {
@@ -461,13 +494,18 @@ class VideoService:
461
  "second_pass": second_pass_args,
462
  }
463
  )
464
-
465
  ctx = contextlib.nullcontext()
466
  if self.device == "cuda":
467
  ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype)
468
  with ctx:
469
- result_tensor = multi_scale_pipeline(**multi_scale_call_kwargs).images
470
- log_tensor_info(result_tensor, "Resultado da Etapa 2 (Saída do Pipeline Multi-Scale)")
 
 
 
 
 
 
471
  else:
472
  single_pass_kwargs = call_kwargs.copy()
473
  first_pass_config = self.config.get("first_pass", {})
@@ -479,8 +517,7 @@ class VideoService:
479
  "skip_block_list": first_pass_config.get("skip_block_list"),
480
  }
481
  )
482
-
483
- # Escolha de schedule única para garantir guidance_mapping definido e consistente
484
  schedule = first_pass_config.get("timesteps")
485
  if schedule is None:
486
  schedule = first_pass_config.get("guidance_timesteps")
@@ -489,20 +526,18 @@ class VideoService:
489
  print("[INFO] Modo video-to-video (etapa única): definindo timesteps (força) para [0.7]")
490
  if isinstance(schedule, (list, tuple)) and len(schedule) > 0:
491
  single_pass_kwargs["timesteps"] = schedule
492
- single_pass_kwargs["guidance_timesteps"] = schedule # garante criação de guidance_mapping
493
 
494
  print("\n[INFO] Executando pipeline de etapa única...")
495
  ctx = contextlib.nullcontext()
496
  if self.device == "cuda":
497
  ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype)
498
  with ctx:
499
- result_tensor = self.pipeline(**single_pass_kwargs).images
500
-
501
- pad_left, pad_right, pad_top, pad_bottom = padding_values
502
- slice_h_end = -pad_bottom if pad_bottom > 0 else None
503
- slice_w_end = -pad_right if pad_right > 0 else None
504
- result_tensor = result_tensor[:, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end]
505
- log_tensor_info(result_tensor, "Tensor Final (Após Pós-processamento, Antes de Salvar)")
506
 
507
  # Staging seguro em tmp e move para diretório persistente
508
  temp_dir = tempfile.mkdtemp(prefix="ltxv_")
@@ -513,20 +548,36 @@ class VideoService:
513
  final_output_path = None
514
  output_video_path = os.path.join(temp_dir, f"output_{used_seed}.mp4")
515
  try:
516
- # Escrita quadro a quadro para evitar array 4D gigante em RAM
517
- with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], codec="libx264", quality=8) as writer:
518
- T = result_tensor.shape[2] # (B, C, T, H, W)
519
- for i in range(T):
520
- frame_chw = result_tensor[0, :, i] # (C,H,W) no device
521
- frame_hwc_u8 = (frame_chw.permute(1, 2, 0) # (H,W,C)
522
- .clamp(0, 1)
523
- .mul(255)
524
- .to(torch.uint8)
525
- .cpu()
526
- .numpy())
527
- writer.append_data(frame_hwc_u8)
528
- if progress_callback:
529
- progress_callback(i + 1, T)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
  candidate_final = os.path.join(results_dir, f"output_{used_seed}.mp4")
532
  try:
@@ -539,6 +590,10 @@ class VideoService:
539
  self._log_gpu_memory("Fim da Geração")
540
  return final_output_path, used_seed
541
  finally:
 
 
 
 
542
  try:
543
  del result_tensor
544
  except Exception:
@@ -565,4 +620,4 @@ class VideoService:
565
  pass
566
 
567
  print("Criando instância do VideoService. O carregamento do modelo começará agora...")
568
- video_generation_service = VideoService()
 
 
 
 
1
  import torch
2
  import numpy as np
3
  import random
 
60
  parts = [p.strip() for p in line.split(",")]
61
  if len(parts) >= 3:
62
  try:
63
+ pid = int(parts[^20_0])
64
+ name = parts[^20_1]
65
+ used_mb = int(parts[^20_2])
66
  user = "unknown"
67
  try:
68
  import psutil
 
346
  return tensor.to(self.device, dtype=self.runtime_autocast_dtype)
347
  return tensor.to(self.device)
348
 
349
+ # Nova: decodificação de latentes fora da pipeline com VAE e escrita incremental
350
+ def _decode_latents_to_video(self, latents: torch.Tensor, output_video_path: str, frame_rate: int,
351
+ padding_values, progress_callback=None):
352
+ pad_left, pad_right, pad_top, pad_bottom = padding_values
353
+ with imageio.get_writer(output_video_path, fps=frame_rate, codec="libx264", quality=8) as writer:
354
+ T = latents.shape[^20_2]
355
+ for i in range(T):
356
+ latent_chw = latents[0, :, i].to(self.device)
357
+ with torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext():
358
+ pixel_bchw = None
359
+ if hasattr(self.pipeline, "decode_latents"):
360
+ pixel_bchw = self.pipeline.decode_latents(latent_chw.unsqueeze(0))
361
+ elif hasattr(self.pipeline, "vae") and hasattr(self.pipeline.vae, "decode"):
362
+ pixel_bchw = self.pipeline.vae.decode(latent_chw.unsqueeze(0))
363
+ else:
364
+ raise RuntimeError("Pipeline não expõe decode_latents nem vae.decode para decodificar latentes.")
365
+ pixel_chw = pixel_bchw[^20_0]
366
+ if pixel_chw.min() < 0:
367
+ pixel_chw = (pixel_chw.clamp(-1, 1) + 1.0) / 2.0
368
+ else:
369
+ pixel_chw = pixel_chw.clamp(0, 1)
370
+ H = pixel_chw.shape[^20_1]
371
+ W = pixel_chw.shape[^20_2]
372
+ h_end = H - pad_bottom if pad_bottom > 0 else H
373
+ w_end = W - pad_right if pad_right > 0 else W
374
+ pixel_chw = pixel_chw[:, pad_top:h_end, pad_left:w_end]
375
+ frame_hwc_u8 = (pixel_chw.permute(1, 2, 0)
376
+ .mul(255)
377
+ .to(torch.uint8)
378
+ .cpu()
379
+ .numpy())
380
+ writer.append_data(frame_hwc_u8)
381
+ if progress_callback:
382
+ progress_callback(i + 1, T)
383
+
384
  def generate(
385
  self,
386
  prompt,
 
402
  guidance_scale=3.0,
403
  improve_texture=True,
404
  progress_callback=None,
405
+ external_decode=True, # NOVO: decodificar fora da pipeline
406
  ):
407
  if self.device == "cuda":
408
  torch.cuda.empty_cache()
 
450
  "num_frames": actual_num_frames,
451
  "frame_rate": int(FPS),
452
  "generator": generator,
453
+ "output_type": "latent" if external_decode else "pt", # aqui alternamos o tipo de saída
454
  "conditioning_items": conditioning_items if conditioning_items else None,
455
  "media_items": None,
456
  "decode_timestep": self.config["decode_timestep"],
 
474
  padding=padding_values,
475
  ).to(self.device)
476
 
477
+ latents = None
478
  result_tensor = None
479
  multi_scale_pipeline = None
480
 
 
486
  first_pass_args["guidance_scale"] = float(guidance_scale)
487
  second_pass_args = self.config.get("second_pass", {}).copy()
488
  second_pass_args["guidance_scale"] = float(guidance_scale)
 
489
  multi_scale_call_kwargs = call_kwargs.copy()
490
  multi_scale_call_kwargs.update(
491
  {
 
494
  "second_pass": second_pass_args,
495
  }
496
  )
 
497
  ctx = contextlib.nullcontext()
498
  if self.device == "cuda":
499
  ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype)
500
  with ctx:
501
+ result = multi_scale_pipeline(**multi_scale_call_kwargs)
502
+ # Captura latentes ou imagens conforme o output_type
503
+ if external_decode:
504
+ latents = getattr(result, "latents", None) or getattr(result, "images", None) or result
505
+ else:
506
+ result_tensor = getattr(result, "images", None) or result
507
+ if not external_decode:
508
+ log_tensor_info(result_tensor, "Resultado da Etapa 2 (Saída do Pipeline Multi-Scale)")
509
  else:
510
  single_pass_kwargs = call_kwargs.copy()
511
  first_pass_config = self.config.get("first_pass", {})
 
517
  "skip_block_list": first_pass_config.get("skip_block_list"),
518
  }
519
  )
520
+ # Agenda única para guidance_mapping consistente
 
521
  schedule = first_pass_config.get("timesteps")
522
  if schedule is None:
523
  schedule = first_pass_config.get("guidance_timesteps")
 
526
  print("[INFO] Modo video-to-video (etapa única): definindo timesteps (força) para [0.7]")
527
  if isinstance(schedule, (list, tuple)) and len(schedule) > 0:
528
  single_pass_kwargs["timesteps"] = schedule
529
+ single_pass_kwargs["guidance_timesteps"] = schedule
530
 
531
  print("\n[INFO] Executando pipeline de etapa única...")
532
  ctx = contextlib.nullcontext()
533
  if self.device == "cuda":
534
  ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype)
535
  with ctx:
536
+ result = self.pipeline(**single_pass_kwargs)
537
+ if external_decode:
538
+ latents = getattr(result, "latents", None) or getattr(result, "images", None) or result
539
+ else:
540
+ result_tensor = getattr(result, "images", None) or result
 
 
541
 
542
  # Staging seguro em tmp e move para diretório persistente
543
  temp_dir = tempfile.mkdtemp(prefix="ltxv_")
 
548
  final_output_path = None
549
  output_video_path = os.path.join(temp_dir, f"output_{used_seed}.mp4")
550
  try:
551
+ if external_decode:
552
+ # Decodifica latentes -> MP4, quadro a quadro
553
+ self._decode_latents_to_video(
554
+ latents=latents,
555
+ output_video_path=output_video_path,
556
+ frame_rate=call_kwargs["frame_rate"],
557
+ padding_values=padding_values,
558
+ progress_callback=progress_callback,
559
+ )
560
+ else:
561
+ # Caminho antigo: tensor já em espaço de pixels -> escrever quadro a quadro
562
+ # Aplicar corte de padding antes de escrever
563
+ pad_left, pad_right, pad_top, pad_bottom = padding_values
564
+ slice_h_end = -pad_bottom if pad_bottom > 0 else None
565
+ slice_w_end = -pad_right if pad_right > 0 else None
566
+ result_tensor = result_tensor[:, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end]
567
+ log_tensor_info(result_tensor, "Tensor Final (Após Pós-processamento, Antes de Salvar)")
568
+ with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], codec="libx264", quality=8) as writer:
569
+ T = result_tensor.shape[^20_2]
570
+ for i in range(T):
571
+ frame_chw = result_tensor[0, :, i]
572
+ frame_hwc_u8 = (frame_chw.permute(1, 2, 0)
573
+ .clamp(0, 1)
574
+ .mul(255)
575
+ .to(torch.uint8)
576
+ .cpu()
577
+ .numpy())
578
+ writer.append_data(frame_hwc_u8)
579
+ if progress_callback:
580
+ progress_callback(i + 1, T)
581
 
582
  candidate_final = os.path.join(results_dir, f"output_{used_seed}.mp4")
583
  try:
 
590
  self._log_gpu_memory("Fim da Geração")
591
  return final_output_path, used_seed
592
  finally:
593
+ try:
594
+ del latents
595
+ except Exception:
596
+ pass
597
  try:
598
  del result_tensor
599
  except Exception:
 
620
  pass
621
 
622
  print("Criando instância do VideoService. O carregamento do modelo começará agora...")
623
+ video_generation_service = VideoService()