euiia commited on
Commit
6ad3f18
·
verified ·
1 Parent(s): a6f24ba

Update latent_enhancer_specialist.py

Browse files
Files changed (1) hide show
  1. latent_enhancer_specialist.py +96 -0
latent_enhancer_specialist.py CHANGED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # latent_enhancer_specialist.py
2
+ # Copyright (C) 2025 Carlos Rodrigues
3
+ # Especialista ADUC para pós-produção e melhoria de tensores latentes.
4
+
5
+ import torch
6
+ import logging
7
+ from diffusers import LTXLatentUpsamplePipeline
8
+ from ltx_manager_helpers import ltx_manager_singleton
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class LatentEnhancerSpecialist:
13
+ """
14
+ Especialista responsável por melhorar a qualidade de tensores latentes,
15
+ incluindo upscaling espacial e refinamento por denoise.
16
+ """
17
+ def __init__(self):
18
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ self.pipe_upsample = None
20
+ self.base_vae = None # VAE para o upscaler
21
+
22
+ def _lazy_init_upscaler(self):
23
+ """Inicializa a pipeline de upscaling apenas quando for usada."""
24
+ if self.pipe_upsample is not None:
25
+ return
26
+ try:
27
+ # A pipeline de upscale requer um VAE específico
28
+ from diffusers.models.autoencoders import AutoencoderKLLTXVideo
29
+ self.base_vae = AutoencoderKLLTXVideo.from_pretrained(
30
+ "linoyts/LTX-Video-spatial-upscaler-0.9.8",
31
+ subfolder="vae",
32
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
33
+ ).to(self.device)
34
+
35
+ self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
36
+ "linoyts/LTX-Video-spatial-upscaler-0.9.8",
37
+ vae=self.base_vae,
38
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
39
+ ).to(self.device)
40
+ logger.info("[Enhancer] Pipeline de Upscale carregada com sucesso.")
41
+ except Exception as e:
42
+ logger.error(f"[Enhancer] Falha ao carregar pipeline de Upscale: {e}")
43
+ self.pipe_upsample = None
44
+
45
+ @torch.no_grad()
46
+ def upscale(self, latents: torch.Tensor) -> torch.Tensor:
47
+ """Aplica o upscaling 2x nos tensores latentes fornecidos."""
48
+ self._lazy_init_upscaler()
49
+ if self.pipe_upsample is None:
50
+ logger.warning("[Enhancer] Pipeline de Upscale indisponível. Retornando latentes originais.")
51
+ return latents
52
+ try:
53
+ logger.info(f"[Enhancer] Recebido shape {latents.shape} para Upscale.")
54
+ result = self.pipe_upsample(latents=latents, output_type="latent")
55
+ output_tensor = result.frames
56
+ logger.info(f"[Enhancer] Upscale concluído. Novo shape: {output_tensor.shape}")
57
+ return output_tensor
58
+ except Exception as e:
59
+ logger.error(f"[Enhancer] Erro durante upscale: {e}", exc_info=True)
60
+ return latents
61
+
62
+ @torch.no_grad()
63
+ def refine(self, latents: torch.Tensor, fps: int = 24, **kwargs) -> torch.Tensor:
64
+ """
65
+ Invoca o LTX Pool Manager para refinar um tensor latente existente.
66
+ Esta função foi movida de Deformes4DEngine para centralizar a lógica.
67
+ """
68
+ logger.info(f"[Enhancer] Refinando tensor latente com shape {latents.shape}.")
69
+
70
+ # A lógica de refinamento usa o VAE principal do ltx_manager, não o do upscaler
71
+ main_pipeline_vae = ltx_manager_singleton.workers[0].pipeline.vae
72
+ video_scale_factor = getattr(main_pipeline_vae.config, 'temporal_scale_factor', 8)
73
+
74
+ # O ltx_manager agora lida com o dimensionamento, então não precisamos pré-calcular
75
+ # Apenas garantimos que o número de frames seja passado corretamente
76
+ _, _, num_latent_frames, _, _ = latents.shape
77
+ pixel_frames = num_latent_frames * video_scale_factor
78
+
79
+ final_ltx_params = {
80
+ "video_total_frames": pixel_frames,
81
+ "video_fps": fps,
82
+ "current_fragment_index": int(time.time()),
83
+ **kwargs
84
+ }
85
+
86
+ refined_latents_tensor, _ = ltx_manager_singleton.refine_latents(latents, **final_ltx_params)
87
+
88
+ if refined_latents_tensor is None:
89
+ logger.warning("[Enhancer] O refinamento falhou. Retornando tensor original não refinado.")
90
+ return latents
91
+
92
+ logger.info(f"[Enhancer] Retornando tensor latente refinado com shape: {refined_latents_tensor.shape}")
93
+ return refined_latents_tensor
94
+
95
+ # --- Singleton Global ---
96
+ latent_enhancer_specialist_singleton = LatentEnhancerSpecialist()