Aduc_sdr / upscaler_specialist.py
euiia's picture
Create upscaler_specialist.py
795e89c verified
raw
history blame
2.47 kB
# upscaler_specialist.py
# Copyright (C) 2025 Carlos Rodrigues dos Santos
# Especialista ADUC para upscaling espacial de tensores latentes.
import torch
import logging
from diffusers import LTXLatentUpsamplePipeline
from ltx_manager_helpers import ltx_manager_singleton
logger = logging.getLogger(__name__)
class UpscalerSpecialist:
"""
Especialista responsável por aumentar a resolução espacial de tensores latentes
usando o LTX Video Spatial Upscaler.
"""
def __init__(self, base_vae):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.pipe_upsample = None
if base_vae is not None:
logger.info("Inicializando Especialista de Upscale Latente...")
try:
self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
"linoyts/LTX-Video-spatial-upscaler-0.9.8",
vae=base_vae,
torch_dtype=torch.bfloat16,
).to(self.device)
logger.info("Especialista de Upscale Latente pronto.")
except Exception as e:
logger.error(f"Falha ao carregar o modelo de upscale: {e}", exc_info=True)
else:
logger.warning("VAE base não fornecido. Especialista de Upscale desativado.")
@torch.no_grad()
def upscale(self, latents: torch.Tensor) -> torch.Tensor:
"""
Aplica o upscaling 2x nos tensores latentes fornecidos.
"""
if self.pipe_upsample is None:
logger.warning("Upscaler não está disponível. Retornando latentes originais.")
return latents
logger.info(f"Upscaler: Recebeu latentes com shape {latents.shape}. Aplicando upscale 2x...")
# O upscaler opera em um batch de latentes.
upscaled_latents = self.pipe_upsample(
latents=latents,
output_type="latent"
).frames
logger.info(f"Upscaler: Latentes redimensionados para {upscaled_latents.shape}.")
return upscaled_latents
# Instanciação Singleton
# Depende do VAE do ltx_manager, então o obtemos de lá.
try:
base_vae_for_upscaler = ltx_manager_singleton.workers[0].pipeline.vae
upscaler_specialist_singleton = UpscalerSpecialist(base_vae=base_vae_for_upscaler)
except Exception as e:
logger.error(f"Não foi possível inicializar o UpscalerSpecialist Singleton: {e}")
upscaler_specialist_singleton = None