|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import logging |
|
|
import time |
|
|
from diffusers import LTXLatentUpsamplePipeline |
|
|
from managers.ltx_manager import ltx_manager_singleton |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LatentEnhancerSpecialist: |
|
|
""" |
|
|
Especialista responsável por melhorar a qualidade de tensores latentes, |
|
|
incluindo upscaling espacial e refinamento por denoise. |
|
|
""" |
|
|
def __init__(self): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.pipe_upsample = None |
|
|
self.base_vae = None |
|
|
|
|
|
def _lazy_init_upscaler(self): |
|
|
"""Inicializa a pipeline de upscaling apenas quando for usada.""" |
|
|
if self.pipe_upsample is not None: |
|
|
return |
|
|
try: |
|
|
from diffusers.models.autoencoders import AutoencoderKLLTXVideo |
|
|
self.base_vae = AutoencoderKLLTXVideo.from_pretrained( |
|
|
"linoyts/LTX-Video-spatial-upscaler-0.9.8", |
|
|
subfolder="vae", |
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
|
|
).to(self.device) |
|
|
|
|
|
self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( |
|
|
"linoyts/LTX-Video-spatial-upscaler-0.9.8", |
|
|
vae=self.base_vae, |
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
|
|
).to(self.device) |
|
|
logger.info("[Enhancer] Pipeline de Upscale carregada com sucesso.") |
|
|
except Exception as e: |
|
|
logger.error(f"[Enhancer] Falha ao carregar pipeline de Upscale: {e}") |
|
|
self.pipe_upsample = None |
|
|
|
|
|
@torch.no_grad() |
|
|
def upscale(self, latents: torch.Tensor) -> torch.Tensor: |
|
|
"""Aplica o upscaling 2x nos tensores latentes fornecidos.""" |
|
|
self._lazy_init_upscaler() |
|
|
if self.pipe_upsample is None: |
|
|
logger.warning("[Enhancer] Pipeline de Upscale indisponível. Retornando latentes originais.") |
|
|
return latents |
|
|
try: |
|
|
logger.info(f"[Enhancer] Recebido shape {latents.shape} para Upscale.") |
|
|
result = self.pipe_upsample(latents=latents, output_type="latent") |
|
|
output_tensor = result.frames |
|
|
logger.info(f"[Enhancer] Upscale concluído. Novo shape: {output_tensor.shape}") |
|
|
return output_tensor |
|
|
except Exception as e: |
|
|
logger.error(f"[Enhancer] Erro durante upscale: {e}", exc_info=True) |
|
|
return latents |
|
|
|
|
|
@torch.no_grad() |
|
|
def refine(self, latents: torch.Tensor, fps: int = 24, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Invoca o LTX Pool Manager para refinar um tensor latente existente. |
|
|
""" |
|
|
logger.info(f"[Enhancer] Refinando tensor latente com shape {latents.shape}.") |
|
|
|
|
|
main_pipeline_vae = ltx_manager_singleton.workers[0].pipeline.vae |
|
|
video_scale_factor = getattr(main_pipeline_vae.config, 'temporal_scale_factor', 8) |
|
|
|
|
|
_, _, num_latent_frames, _, _ = latents.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pixel_frames = (num_latent_frames - 1) * video_scale_factor |
|
|
|
|
|
final_ltx_params = { |
|
|
"video_total_frames": pixel_frames, |
|
|
"video_fps": fps, |
|
|
"current_fragment_index": int(time.time()), |
|
|
**kwargs |
|
|
} |
|
|
|
|
|
refined_latents_tensor, _ = ltx_manager_singleton.refine_latents(latents, **final_ltx_params) |
|
|
|
|
|
if refined_latents_tensor is None: |
|
|
logger.warning("[Enhancer] O refinamento falhou. Retornando tensor original não refinado.") |
|
|
return latents |
|
|
|
|
|
logger.info(f"[Enhancer] Retornando tensor latente refinado com shape: {refined_latents_tensor.shape}") |
|
|
return refined_latents_tensor |
|
|
|
|
|
|
|
|
latent_enhancer_specialist_singleton = LatentEnhancerSpecialist() |