Update deformes4D_engine.py
Browse files- deformes4D_engine.py +22 -12
deformes4D_engine.py
CHANGED
|
@@ -292,29 +292,39 @@ class Deformes4DEngine:
|
|
| 292 |
return refined_latents_tensor
|
| 293 |
|
| 294 |
|
|
|
|
|
|
|
| 295 |
def refine_latents(self, latents: torch.Tensor,
|
| 296 |
fps: int = 24,
|
| 297 |
denoise_strength: float = 0.35,
|
| 298 |
refine_steps: int = 12,
|
| 299 |
motion_prompt: str = "refining video, improving details, cinematic quality") -> torch.Tensor:
|
| 300 |
-
"""
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
| 302 |
|
|
|
|
| 303 |
_, _, num_latent_frames, latent_h, latent_w = latents.shape
|
| 304 |
|
| 305 |
-
#
|
| 306 |
-
#
|
| 307 |
-
|
| 308 |
-
vae_scale_factor =
|
| 309 |
|
|
|
|
| 310 |
pixel_height = latent_h * vae_scale_factor
|
| 311 |
pixel_width = latent_w * vae_scale_factor
|
| 312 |
-
# --- [INÍCIO DA CORREÇÃO] ---
|
| 313 |
-
# Converte o número de frames latentes para frames de pixel.
|
| 314 |
-
pixel_frames = num_latent_frames * vae_scale_factor
|
| 315 |
-
# --- [FIM DA CORREÇÃO] ---
|
| 316 |
|
| 317 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
refined_latents_tensor, _ = self.ltx_manager.refine_latents(
|
| 319 |
latents,
|
| 320 |
height=pixel_height,
|
|
@@ -329,7 +339,7 @@ class Deformes4DEngine:
|
|
| 329 |
|
| 330 |
logger.info(f"Retornando tensor latente refinado com shape: {refined_latents_tensor.shape}")
|
| 331 |
return refined_latents_tensor
|
| 332 |
-
|
| 333 |
|
| 334 |
|
| 335 |
def upscale_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 292 |
return refined_latents_tensor
|
| 293 |
|
| 294 |
|
| 295 |
+
|
| 296 |
+
|
| 297 |
def refine_latents(self, latents: torch.Tensor,
|
| 298 |
fps: int = 24,
|
| 299 |
denoise_strength: float = 0.35,
|
| 300 |
refine_steps: int = 12,
|
| 301 |
motion_prompt: str = "refining video, improving details, cinematic quality") -> torch.Tensor:
|
| 302 |
+
"""
|
| 303 |
+
Aplica um passe de refinamento (denoise) em um tensor latente.
|
| 304 |
+
[CORRIGIDO] Calcula os frames de pixel de forma a alinhar com a lógica do VAE causal.
|
| 305 |
+
"""
|
| 306 |
+
logger.info(f"Refinando tensor latente com shape {latents.shape} para refinamento.")
|
| 307 |
|
| 308 |
+
# Extrai as dimensões do tensor latente de ENTRADA.
|
| 309 |
_, _, num_latent_frames, latent_h, latent_w = latents.shape
|
| 310 |
|
| 311 |
+
# Busca os fatores de escala do VAE. Assumimos que o fator temporal e espacial são iguais.
|
| 312 |
+
# Esta é uma suposição segura para o LTX-Video.
|
| 313 |
+
video_scale_factor = getattr(self.vae, 'temporal_downscale_factor', 8)
|
| 314 |
+
vae_scale_factor = getattr(self.vae, 'spatial_downscale_factor', 8)
|
| 315 |
|
| 316 |
+
# Converte as dimensões latentes para as dimensões de pixel correspondentes.
|
| 317 |
pixel_height = latent_h * vae_scale_factor
|
| 318 |
pixel_width = latent_w * vae_scale_factor
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
+
# --- [A CORREÇÃO PRINCIPAL ESTÁ AQUI] ---
|
| 321 |
+
# Para que a pipeline espere um latente com 'num_latent_frames', precisamos
|
| 322 |
+
# fornecer um número de frames de pixel que, após a divisão e a adição de 1
|
| 323 |
+
# (devido ao VAE causal), resulte no número original de frames latentes.
|
| 324 |
+
# A fórmula inversa é: (num_latent_frames - 1) * video_scale_factor
|
| 325 |
+
pixel_frames = (num_latent_frames - 1) * video_scale_factor
|
| 326 |
+
|
| 327 |
+
# Chama o ltx_manager com os parâmetros corretos.
|
| 328 |
refined_latents_tensor, _ = self.ltx_manager.refine_latents(
|
| 329 |
latents,
|
| 330 |
height=pixel_height,
|
|
|
|
| 339 |
|
| 340 |
logger.info(f"Retornando tensor latente refinado com shape: {refined_latents_tensor.shape}")
|
| 341 |
return refined_latents_tensor
|
| 342 |
+
|
| 343 |
|
| 344 |
|
| 345 |
def upscale_latents(self, latents: torch.Tensor) -> torch.Tensor:
|