File size: 2,337 Bytes
795e89c c80b8f1 795e89c c80b8f1 795e89c c80b8f1 795e89c c80b8f1 795e89c c80b8f1 795e89c c80b8f1 795e89c c80b8f1 795e89c c80b8f1 795e89c c80b8f1 795e89c c80b8f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# upscaler_specialist.py
# Copyright (C) 2025 Carlos Rodrigues dos Santos
# Especialista ADUC para upscaling espacial de tensores latentes.
import torch
import logging
from diffusers import LTXLatentUpsamplePipeline
from ltx_manager_helpers import ltx_manager_singleton
logger = logging.getLogger(__name__)
class UpscalerSpecialist:
def __init__(self, device="cuda"):
self.device = device if torch.cuda.is_available() else "cpu"
self.pipe_upsample = None
self.base_vae = None
def _lazy_init(self):
"""Inicializa o VAE e o pipeline somente quando for chamado."""
if self.base_vae is None:
try:
from ltx_manager_helpers import ltx_manager_singleton
if ltx_manager_singleton.workers:
self.base_vae = ltx_manager_singleton.workers[0].pipeline.vae
else:
logger.warning("[Upscaler] Nenhum worker disponível no ltx_manager_singleton.")
except Exception as e:
logger.error(f"[Upscaler] Falha ao inicializar VAE: {e}")
return
if self.pipe_upsample is None and self.base_vae is not None:
try:
from ltx_video.pipelines.latent_upscale import LTXLatentUpsamplePipeline
self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
"linoyts/LTX-Video-spatial-upscaler-0.9.8",
vae=self.base_vae,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
logger.info("[Upscaler] Pipeline carregado com sucesso.")
except Exception as e:
logger.error(f"[Upscaler] Falha ao carregar pipeline: {e}")
def upscale(self, latents: torch.Tensor) -> torch.Tensor:
self._lazy_init()
if self.pipe_upsample is None:
logger.warning("[Upscaler] Pipeline indisponível. Retornando latentes originais.")
return latents
try:
with torch.no_grad():
result = self.pipe_upsample(latents=latents, output_type="latent")
return result.latents
except Exception as e:
logger.error(f"[Upscaler] Erro durante upscale: {e}")
return latents
|