|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import logging |
|
|
from diffusers import LTXLatentUpsamplePipeline |
|
|
from managers.ltx_manager import ltx_manager_singleton |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class UpscalerSpecialist: |
|
|
""" |
|
|
Especialista responsável por aumentar a resolução espacial de tensores latentes |
|
|
usando o LTX Video Spatial Upscaler. |
|
|
""" |
|
|
def __init__(self): |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.base_vae = None |
|
|
self.pipe_upsample = None |
|
|
|
|
|
|
|
|
def _lazy_init(self): |
|
|
try: |
|
|
|
|
|
if ltx_manager_singleton.workers: |
|
|
candidate_vae = ltx_manager_singleton.workers[0].pipeline.vae |
|
|
if candidate_vae.__class__.__name__ == "AutoencoderKLLTXVideo": |
|
|
self.base_vae = candidate_vae |
|
|
logger.info("[Upscaler] Usando VAE do ltx_manager (AutoencoderKLLTXVideo).") |
|
|
else: |
|
|
logger.warning(f"[Upscaler] VAE incompatível: {type(candidate_vae)}. " |
|
|
"Carregando AutoencoderKLLTXVideo manualmente...") |
|
|
from diffusers.models.autoencoders import AutoencoderKLLTXVideo |
|
|
self.base_vae = AutoencoderKLLTXVideo.from_pretrained( |
|
|
"linoyts/LTX-Video-spatial-upscaler-0.9.8", |
|
|
subfolder="vae", |
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
|
|
).to(self.device) |
|
|
else: |
|
|
logger.warning("[Upscaler] Nenhum worker disponível, carregando VAE manualmente...") |
|
|
from diffusers.models.autoencoders import AutoencoderKLLTXVideo |
|
|
self.base_vae = AutoencoderKLLTXVideo.from_pretrained( |
|
|
"linoyts/LTX-Video-spatial-upscaler-0.9.8", |
|
|
subfolder="vae", |
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( |
|
|
"linoyts/LTX-Video-spatial-upscaler-0.9.8", |
|
|
vae=self.base_vae, |
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
|
|
).to(self.device) |
|
|
|
|
|
logger.info("[Upscaler] Pipeline carregado com sucesso.") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[Upscaler] Falha ao carregar pipeline: {e}") |
|
|
self.pipe_upsample = None |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def upscale(self, latents: torch.Tensor) -> torch.Tensor: |
|
|
"""Aplica o upscaling 2x nos tensores latentes fornecidos.""" |
|
|
self._lazy_init() |
|
|
if self.pipe_upsample is None: |
|
|
logger.warning("[Upscaler] Pipeline indisponível. Retornando latentes originais.") |
|
|
return latents |
|
|
|
|
|
try: |
|
|
logger.info(f"[Upscaler] Recebido shape {latents.shape}. Executando upscale em {self.device}...") |
|
|
|
|
|
|
|
|
result = self.pipe_upsample(latents=latents, output_type="latent") |
|
|
output_tensor = result.frames |
|
|
|
|
|
logger.info(f"[Upscaler] Upscale concluído. Novo shape: {output_tensor.shape}") |
|
|
return output_tensor |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[Upscaler] Erro durante upscale: {e}", exc_info=True) |
|
|
return latents |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
upscaler_specialist_singleton = UpscalerSpecialist() |