Spaces:
Paused
Paused
| # FILE: api/ltx/ltx_aduc_manager.py | |
| # DESCRIPTION: A singleton pool manager for the LTX-Video pipeline. | |
| # This module is the "secret weapon": it handles loading, device placement, | |
| # and applies a runtime monkey patch to the LTX pipeline for full control | |
| # and compatibility with the ADUC-SDR architecture, especially for latent conditioning. | |
| import time | |
| import os | |
| import yaml | |
| import json | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple, Union, Dict | |
| from dataclasses import dataclass | |
| import threading | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from huggingface_hub import hf_hub_download | |
| # --- Importações da nossa arquitetura --- | |
| from managers.gpu_manager import gpu_manager | |
| from api.ltx.ltx_utils import build_ltx_pipeline_on_cpu | |
| LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") | |
| LTX_REPO_ID = "Lightricks/LTX-Video" | |
| CACHE_DIR = os.environ.get("HF_HOME") | |
| # --- Importações da biblioteca LTX-Video --- | |
| repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if repo_path not in sys.path: | |
| sys.path.insert(0, repo_path) | |
| from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline | |
| from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", message=".*") | |
| from huggingface_hub import logging | |
| logging.set_verbosity_error() | |
| logging.set_verbosity_warning() | |
| logging.set_verbosity_info() | |
| logging.set_verbosity_debug() | |
| # ============================================================================== | |
| # --- DEFINIÇÃO DOS DATACLASSES DE CONDICIONAMENTO ADUC-SDR --- | |
| # ============================================================================== | |
| class ConditioningItem: | |
| """Nosso Data Class para condicionamento com TENSORES DE PIXEL (de imagens).""" | |
| pixel_tensor: torch.Tensor | |
| media_frame_number: int | |
| conditioning_strength: float | |
| class LatentConditioningItem: | |
| """Nossa "arma secreta": um Data Class para condicionamento com TENSORES LATENTES (de overlap).""" | |
| latent_tensor: torch.Tensor | |
| media_frame_number: int | |
| conditioning_strength: float | |
| # ============================================================================== | |
| # --- O MONKEY PATCH --- | |
| # Nossa versão customizada de `prepare_conditioning` que entende ambos os Data Classes. | |
| # ============================================================================== | |
| def _aduc_prepare_conditioning_patch( | |
| self: "LTXVideoPipeline", | |
| conditioning_items: Optional[List[Union[ConditioningItem, LatentConditioningItem]]], | |
| init_latents: torch.Tensor, | |
| num_frames: int, height: int, width: int, # Assinatura mantida para compatibilidade | |
| vae_per_channel_normalize: bool = False, | |
| generator=None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: | |
| if not conditioning_items: | |
| latents, latent_coords = self.patchifier.patchify(latents=init_latents) | |
| pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| return latents, pixel_coords, None, 0 | |
| init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device) | |
| extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], [] | |
| extra_conditioning_num_latents = 0 | |
| for item in conditioning_items: | |
| strength = item.conditioning_strength | |
| media_frame_number = item.media_frame_number | |
| if isinstance(item, ConditioningItem): | |
| logging.debug("Patch ADUC: Processando ConditioningItem (pixels).") | |
| pixel_tensor_on_vae_device = item.pixel_tensor.to(device=self.vae.device, dtype=self.vae.dtype) | |
| media_item_latents = vae_encode(pixel_tensor_on_vae_device, self.vae, vae_per_channel_normalize=vae_per_channel_normalize) | |
| media_item_latents = media_item_latents.to(device=init_latents.device, dtype=init_latents.dtype) | |
| elif isinstance(item, LatentConditioningItem): | |
| logging.debug("Patch ADUC: Processando LatentConditioningItem (latentes).") | |
| media_item_latents = item.latent_tensor.to(device=init_latents.device, dtype=init_latents.dtype) | |
| else: | |
| logging.warning(f"Patch ADUC: Item de condicionamento de tipo desconhecido '{type(item)}' será ignorado.") | |
| continue | |
| if media_frame_number == 0: | |
| f_l, h_l, w_l = media_item_latents.shape[-3:] | |
| init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength) | |
| init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength | |
| else: | |
| noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype) | |
| media_item_latents = torch.lerp(noise, media_item_latents, strength) | |
| patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents) | |
| pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| pixel_coords[:, 0] += media_frame_number | |
| extra_conditioning_num_latents += patched_latents.shape[1] | |
| new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device) | |
| extra_conditioning_latents.append(patched_latents) | |
| extra_conditioning_pixel_coords.append(pixel_coords) | |
| extra_conditioning_mask.append(new_mask) | |
| init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) | |
| init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1)) | |
| init_conditioning_mask = init_conditioning_mask.squeeze(-1) | |
| if extra_conditioning_latents: | |
| init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) | |
| init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2) | |
| init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1) | |
| return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents | |
| # ============================================================================== | |
| # --- LTX WORKER E POOL MANAGER --- | |
| # ============================================================================== | |
| class LTXWorker: | |
| """Gerencia uma instância do LTX Pipeline em um par de GPUs (main + vae).""" | |
| def __init__(self, main_device_str: str, vae_device_str: str, config: dict): | |
| self.main_device = torch.device(main_device_str) | |
| self.vae_device = torch.device(vae_device_str) | |
| self.config = config | |
| self.pipeline: LTXVideoPipeline = None | |
| self._load_and_patch_pipeline() | |
| def _load_and_patch_pipeline(self): | |
| logging.info(f"[LTXWorker-{self.main_device}] Carregando pipeline LTX para a CPU...") | |
| self.pipeline, _ = build_ltx_pipeline_on_cpu(self.config) | |
| logging.info(f"[LTXWorker-{self.main_device}] Movendo pipeline para GPUs (Main: {self.main_device}, VAE: {self.vae_device})...") | |
| self.pipeline.to(self.main_device) | |
| self.pipeline.vae.to(self.vae_device) | |
| logging.info(f"[LTXWorker-{self.main_device}] Aplicando patch ADUC-SDR na função 'prepare_conditioning'...") | |
| self.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(self.pipeline, LTXVideoPipeline) | |
| logging.info(f"[LTXWorker-{self.main_device}] ✅ Pipeline 'quente', corrigido e pronto para uso.") | |
| class LtxAducManager: | |
| _instance = None | |
| _lock = threading.Lock() | |
| def __new__(cls, *args, **kwargs): | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: return | |
| with self._lock: | |
| if self._initialized: return | |
| logging.info("⚙️ Inicializando LTXPoolManager Singleton...") | |
| self.config = self._load_config() | |
| main_device_str = str(gpu_manager.get_ltx_device()) | |
| vae_device_str = str(gpu_manager.get_ltx_vae_device()) | |
| self.worker = LTXWorker(main_device_str, vae_device_str, self.config) | |
| self._initialized = True | |
| logging.info("✅ LTXPoolManager pronto.") | |
| def _load_config(self) -> Dict: | |
| """Carrega a configuração YAML principal do LTX.""" | |
| config_path = Path("/data/LTX-Video/configs/ltxv-13b-0.9.8-distilled-fp8.yaml") | |
| with open(config_path, "r") as file: | |
| return yaml.safe_load(file) | |
| def get_pipeline(self) -> LTXVideoPipeline: | |
| """Retorna a instância do pipeline, já carregada e corrigida.""" | |
| return self.worker.pipeline | |
| # --- Instância Singleton Global --- | |
| ltx_aduc_manager = LtxAducManager() | |