Update api/ltx_server.py
Browse files- api/ltx_server.py +25 -8
api/ltx_server.py
CHANGED
|
@@ -588,14 +588,31 @@ class VideoService:
|
|
| 588 |
second_pass_args = self.config.get("second_pass", {}).copy()
|
| 589 |
second_pass_args["guidance_scale"] = float(guidance_scale)
|
| 590 |
|
| 591 |
-
multi_scale_call_kwargs = call_kwargs.copy()
|
| 592 |
-
multi_scale_call_kwargs.update(
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 599 |
print("[DEBUG] Chamando multi_scale_pipeline...")
|
| 600 |
t_ms = time.perf_counter()
|
| 601 |
ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
|
|
|
|
| 588 |
second_pass_args = self.config.get("second_pass", {}).copy()
|
| 589 |
second_pass_args["guidance_scale"] = float(guidance_scale)
|
| 590 |
|
| 591 |
+
#multi_scale_call_kwargs = call_kwargs.copy()
|
| 592 |
+
#multi_scale_call_kwargs.update(
|
| 593 |
+
# {
|
| 594 |
+
# "downscale_factor": self.config["downscale_factor"],
|
| 595 |
+
# "first_pass": first_pass_args,
|
| 596 |
+
# "second_pass": second_pass_args,
|
| 597 |
+
# }
|
| 598 |
+
#)
|
| 599 |
+
|
| 600 |
+
first_pass_kwargs = call_kwargs.copy()
|
| 601 |
+
first_pass_kwargs.update(first_pass_args)
|
| 602 |
+
|
| 603 |
+
with ctx:
|
| 604 |
+
_apply_precision_policyresult_first = self.pipeline(**first_pass_kwargs)
|
| 605 |
+
|
| 606 |
+
latents_first = result_first.latents
|
| 607 |
+
|
| 608 |
+
with ctx:
|
| 609 |
+
result_second = self.latent_upsampler(
|
| 610 |
+
latents=latents_first,
|
| 611 |
+
**second_pass_args
|
| 612 |
+
)
|
| 613 |
+
latents_final = result_second.latents if hasattr(result_second, "latents") else result_second
|
| 614 |
+
|
| 615 |
+
|
| 616 |
print("[DEBUG] Chamando multi_scale_pipeline...")
|
| 617 |
t_ms = time.perf_counter()
|
| 618 |
ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
|