Test / api /ltx /ltx_aduc_manager.py
eeuuia's picture
Update api/ltx/ltx_aduc_manager.py
b54d196 verified
raw
history blame
7.87 kB
# FILE: api/ltx/ltx_aduc_manager.py
# DESCRIPTION: A singleton manager for the LTX-Video pipeline.
# This module loads the pipeline, places it on the correct devices, and applies a
# targeted runtime monkey patch to delegate conditioning tasks to the specialized
# VaeAducPipeline service, enabling full control for the ADUC-SDR architecture.
import time
import yaml
from pathlib import Path
from typing import List, Optional, Tuple, Union, Dict
import threading
import sys
import torch
from dataclasses import dataclass
# --- Importações da arquitetura ADUC-SDR ---
from managers.gpu_manager import gpu_manager
from api.ltx.ltx_utils import build_ltx_pipeline_on_cpu
from utils.debug_utils import log_function_io
# Importa o serviço VAE que fará o trabalho real
# --- Importações da biblioteca LTX-Video ---
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
if repo_path not in sys.path:
sys.path.insert(0, repo_path)
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
# Importa o tipo original de conditioning item para type hinting
from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem as PipelineConditioningItem
import logging
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*")
try:
from huggingface_hub import logging as hf_logging
hf_logging.set_verbosity_error()
except ImportError:
pass
logger = logging.getLogger("AducDebug")
logging.basicConfig(level=logging.DEBUG)
logger.setLevel(logging.DEBUG)
@dataclass
class LatentConditioningItem:
latent_tensor: torch.Tensor
media_frame_number: int
conditioning_strength: float
from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem
# ==============================================================================
# --- O MONKEY PATCH DIRECIONADO E SIMPLES ---
# ==============================================================================
@log_function_io
def _aduc_prepare_conditioning_patch(
self: "LTXVideoPipeline",
conditioning_items: Optional[List[Union[ConditioningItem, LatentConditioningItem]]],
init_latents: torch.Tensor,
num_frames: int,
height: int,
width: int,
vae_per_channel_normalize: bool = False,
generator=None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
"""
[PATCH] Substitui o método `prepare_conditioning` original da LTXVideoPipeline.
Esta função atua como um proxy (intermediário). Ela não contém lógica de processamento.
Em vez disso, ela delega 100% do trabalho para o `vae_aduc_pipeline`, que é o nosso
serviço especializado e otimizado para essa tarefa.
"""
logging.debug(f"Patch ADUC: Interceptado 'prepare_conditioning'. Delegando para o serviço VaeAducPipeline.")
from api.ltx.vae_aduc_pipeline import vae_aduc_pipeline
# 1. Chama o serviço especializado para fazer todo o trabalho pesado.
# O serviço VAE processa na sua própria GPU dedicada e retorna os tensores na CPU.
latents_cpu, coords_cpu, mask_cpu, num_latents = vae_aduc_pipeline.prepare_conditioning(
conditioning_items=conditioning_items,
init_latents=init_latents,
num_frames=num_frames,
height=height,
width=width,
vae_per_channel_normalize=vae_per_channel_normalize,
generator=generator,
)
# 2. Move os resultados da CPU para o dispositivo correto que a pipeline principal espera.
# O `init_latents.device` garante que estamos usando o dispositivo principal da pipeline (ex: 'cuda:0').
device = init_latents.device
latents = latents_cpu.to(device)
pixel_coords = coords_cpu.to(device)
conditioning_mask = mask_cpu.to(device) if mask_cpu is not None else None
# 3. Retorna os tensores prontos. A pipeline principal continua sua execução normalmente,
# sem saber que a lógica de condicionamento foi executada por um serviço externo.
return latents, pixel_coords, conditioning_mask, num_latents
# ==============================================================================
# --- LTX WORKER E POOL MANAGER ---
# ==============================================================================
class LTXWorker:
"""
Gerencia uma instância única da LTXVideoPipeline, aplicando o patch
necessário durante a inicialização.
"""
def __init__(self, main_device_str: str, vae_device_str: str, config: dict):
self.main_device = torch.device(main_device_str)
self.vae_device = torch.device(vae_device_str)
self.config = config
self.pipeline: LTXVideoPipeline = None
self._load_and_patch_pipeline()
@log_function_io
def _load_and_patch_pipeline(self):
"""
Orquestra o carregamento da pipeline e a aplicação do monkey patch.
"""
logging.info(f"[LTXWorker-{self.main_device}] Carregando pipeline LTX para a CPU...")
self.pipeline, _ = build_ltx_pipeline_on_cpu(self.config)
logging.info(f"[LTXWorker-{self.main_device}] Movendo pipeline para GPUs (Main: {self.main_device}, VAE: {self.vae_device})...")
self.pipeline.to(self.main_device) # Move a maioria dos componentes
self.pipeline.vae.to(self.vae_device) # Move o VAE para sua GPU dedicada
logging.info(f"[LTXWorker-{self.main_device}] Aplicando patch ADUC-SDR em 'prepare_conditioning'...")
# A "mágica" simples e eficaz acontece aqui:
self.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(self.pipeline, LTXVideoPipeline)
logging.info(f"[LTXWorker-{self.main_device}] ✅ Pipeline 'quente', corrigida e pronta para uso.")
class LtxAducManager:
"""
Implementa o padrão Singleton para garantir que a pipeline LTX seja
carregada e corrigida apenas uma vez durante a vida útil da aplicação.
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if hasattr(self, '_initialized') and self._initialized:
return
with self._lock:
if hasattr(self, '_initialized') and self._initialized:
return
logging.info("⚙️ Inicializando LtxAducManager Singleton...")
self.config = self._load_config()
main_device_str = str(gpu_manager.get_ltx_device())
vae_device_str = str(gpu_manager.get_ltx_vae_device())
# Cria o worker que irá carregar e patchear a pipeline
self.worker = LTXWorker(main_device_str, vae_device_str, self.config)
self._initialized = True
logging.info("✅ LtxAducManager pronto.")
def _load_config(self) -> Dict:
"""Carrega a configuração YAML principal do LTX."""
# TODO: Considerar mover o path da configuração para uma variável de ambiente ou config central
config_path = Path("/data/LTX-Video/configs/ltxv-13b-0.9.8-dev-fp8.yaml")
with open(config_path, "r") as file:
return yaml.safe_load(file)
def get_pipeline(self) -> LTXVideoPipeline:
"""
Ponto de acesso principal para obter a instância da pipeline.
Returns:
LTXVideoPipeline: A instância única, carregada e já corrigida.
"""
return self.worker.pipeline
# --- Instância Singleton Global ---
# Outras partes do código importarão esta instância para interagir com a pipeline.
ltx_aduc_manager = LtxAducManager()