Aduc_sdr / upscaler_specialist.py
euiia's picture
Update upscaler_specialist.py
bc96793 verified
raw
history blame
2.8 kB
# upscaler_specialist.py
# Copyright (C) 2025 Carlos Rodrigues
# Especialista ADUC para upscaling espacial de tensores latentes.
import torch
import logging
from diffusers import LTXLatentUpsamplePipeline
from ltx_manager_helpers import ltx_manager_singleton
logger = logging.getLogger(__name__)
class UpscalerSpecialist:
"""
Especialista responsável por aumentar a resolução espacial de tensores latentes
usando o LTX Video Spatial Upscaler.
"""
def __init__(self):
# Força uso de CUDA se disponível
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.base_vae = None
self.pipe_upsample = None
def _lazy_init(self):
"""Inicializa VAE e pipeline apenas quando necessário."""
if self.base_vae is None:
try:
if ltx_manager_singleton.workers:
self.base_vae = ltx_manager_singleton.workers[0].pipeline.vae
logger.info("[Upscaler] VAE base obtido com sucesso.")
else:
logger.warning("[Upscaler] Nenhum worker disponível no ltx_manager_singleton.")
except Exception as e:
logger.error(f"[Upscaler] Falha ao inicializar VAE: {e}")
return
if self.pipe_upsample is None and self.base_vae is not None:
try:
self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
"linoyts/LTX-Video-spatial-upscaler-0.9.8",
vae=self.base_vae,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
logger.info("[Upscaler] Pipeline carregado com sucesso.")
except Exception as e:
logger.error(f"[Upscaler] Falha ao carregar pipeline: {e}")
@torch.no_grad()
def upscale(self, latents: torch.Tensor) -> torch.Tensor:
"""Aplica o upscaling 2x nos tensores latentes fornecidos."""
self._lazy_init()
if self.pipe_upsample is None:
logger.warning("[Upscaler] Pipeline indisponível. Retornando latentes originais.")
return latents
try:
logger.info(f"[Upscaler] Recebido shape {latents.shape}. Executando upscale em {self.device}...")
result = self.pipe_upsample(latents=latents, output_type="latent")
logger.info(f"[Upscaler] Upscale concluído. Novo shape: {result.latents.shape}")
return result.latents
except Exception as e:
logger.error(f"[Upscaler] Erro durante upscale: {e}", exc_info=True)
return latents
# ---------------------------
# Singleton global
# ---------------------------
upscaler_specialist_singleton = UpscalerSpecialist()