Update api/ltx_server.py
Browse files- api/ltx_server.py +23 -18
api/ltx_server.py
CHANGED
|
@@ -242,24 +242,6 @@ def log_tensor_info(tensor, name="Tensor"):
|
|
| 242 |
print("------------------------------------------\n")
|
| 243 |
|
| 244 |
|
| 245 |
-
@torch.no_grad()
|
| 246 |
-
def _upsample_latents_internal(self, latents: torch.Tensor) -> torch.Tensor:
|
| 247 |
-
"""
|
| 248 |
-
Lógica extraída diretamente da LTXMultiScalePipeline para upscale de latentes.
|
| 249 |
-
"""
|
| 250 |
-
if not self.latent_upsampler:
|
| 251 |
-
raise ValueError("Latent Upsampler não está carregado.")
|
| 252 |
-
|
| 253 |
-
# Garante que os modelos estejam no dispositivo correto
|
| 254 |
-
self.latent_upsampler.to(self.device)
|
| 255 |
-
self.pipeline.vae.to(self.device)
|
| 256 |
-
print(f"[DEBUG-UPSAMPLE] Shape de entrada: {tuple(latents.shape)}")
|
| 257 |
-
latents = un_normalize_latents(latents, self.pipeline.vae, vae_per_channel_normalize=True)
|
| 258 |
-
upsampled_latents = self.latent_upsampler(latents)
|
| 259 |
-
upsampled_latents = normalize_latents(upsampled_latents, self.pipeline.vae, vae_per_channel_normalize=True)
|
| 260 |
-
print(f"[DEBUG-UPSAMPLE] Shape de saída: {tuple(upsampled_latents.shape)}")
|
| 261 |
-
|
| 262 |
-
return upsampled_latents
|
| 263 |
|
| 264 |
|
| 265 |
|
|
@@ -453,6 +435,29 @@ class VideoService:
|
|
| 453 |
pass
|
| 454 |
print(f"[DEBUG] FP8→BF16: params_promoted={p_cnt}, buffers_promoted={b_cnt}")
|
| 455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
def _apply_precision_policy(self):
|
| 457 |
prec = str(self.config.get("precision", "")).lower()
|
| 458 |
self.runtime_autocast_dtype = torch.float32
|
|
|
|
| 242 |
print("------------------------------------------\n")
|
| 243 |
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
|
| 247 |
|
|
|
|
| 435 |
pass
|
| 436 |
print(f"[DEBUG] FP8→BF16: params_promoted={p_cnt}, buffers_promoted={b_cnt}")
|
| 437 |
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
@torch.no_grad()
|
| 441 |
+
def _upsample_latents_internal(self, latents: torch.Tensor) -> torch.Tensor:
|
| 442 |
+
"""
|
| 443 |
+
Lógica extraída diretamente da LTXMultiScalePipeline para upscale de latentes.
|
| 444 |
+
"""
|
| 445 |
+
if not self.latent_upsampler:
|
| 446 |
+
raise ValueError("Latent Upsampler não está carregado.")
|
| 447 |
+
|
| 448 |
+
# Garante que os modelos estejam no dispositivo correto
|
| 449 |
+
self.latent_upsampler.to(self.device)
|
| 450 |
+
self.pipeline.vae.to(self.device)
|
| 451 |
+
print(f"[DEBUG-UPSAMPLE] Shape de entrada: {tuple(latents.shape)}")
|
| 452 |
+
latents = un_normalize_latents(latents, self.pipeline.vae, vae_per_channel_normalize=True)
|
| 453 |
+
upsampled_latents = self.latent_upsampler(latents)
|
| 454 |
+
upsampled_latents = normalize_latents(upsampled_latents, self.pipeline.vae, vae_per_channel_normalize=True)
|
| 455 |
+
print(f"[DEBUG-UPSAMPLE] Shape de saída: {tuple(upsampled_latents.shape)}")
|
| 456 |
+
|
| 457 |
+
return upsampled_latents
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
|
| 461 |
def _apply_precision_policy(self):
|
| 462 |
prec = str(self.config.get("precision", "")).lower()
|
| 463 |
self.runtime_autocast_dtype = torch.float32
|