Spaces:
Paused
Paused
| # FILE: api/ltx/ltx_aduc_manager.py | |
| # DESCRIPTION: The "secret weapon". A pool manager for LTX that applies | |
| # runtime patches to the pipeline for full control and ADUC-SDR compatibility. | |
| import logging | |
| from typing import Dict, List, Optional, Tuple, Union | |
| from dataclasses import dataclass | |
| import torch | |
| from diffusers.utils.torch_utils import randn_tensor | |
| import sys | |
| from pathlib import Path | |
| import os | |
| import random | |
| import yaml | |
| LTX_REPO_ID = "Lightricks/LTX-Video" | |
| DEPS_DIR = Path("/data") | |
| LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video" | |
| RESULTS_DIR = Path("/app/output") | |
| # --- Importações da nossa arquitetura --- | |
| from managers.gpu_manager import gpu_manager | |
| from api.ltx.ltx_utils import build_ltx_pipeline_on_cpu | |
| def add_deps_to_path(): | |
| """ | |
| Adiciona o diretório do repositório LTX ao sys.path para garantir que suas | |
| bibliotecas possam ser importadas. | |
| """ | |
| repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if repo_path not in sys.path: | |
| sys.path.insert(0, repo_path) | |
| logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}") | |
| # Executa a função imediatamente para configurar o ambiente antes de qualquer importação. | |
| add_deps_to_path() | |
| from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline | |
| # --- Definição dos nossos Data Classes --- | |
| class ConditioningItem: | |
| pixel_tensor: torch.Tensor # Sempre um tensor de pixel | |
| media_frame_number: int | |
| conditioning_strength: float | |
| class LatentConditioningItem: | |
| latent_tensor: torch.Tensor # Sempre um tensor latente | |
| media_frame_number: int | |
| conditioning_strength: float | |
| # ============================================================================== | |
| # --- O MONKEY PATCH --- | |
| # Esta é a nossa versão customizada de `prepare_conditioning` | |
| # ============================================================================== | |
| def _aduc_prepare_conditioning_patch( | |
| self: "LTXVideoPipeline", | |
| conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]], | |
| init_latents: torch.Tensor, | |
| num_frames: int, | |
| height: int, | |
| width: int, | |
| vae_per_channel_normalize: bool = False, | |
| generator=None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: | |
| if not conditioning_items: | |
| init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) | |
| init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| return init_latents, init_pixel_coords, None, 0 | |
| init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device) | |
| extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], [] | |
| extra_conditioning_num_latents = 0 | |
| for item in conditioning_items: | |
| if not isinstance(item, LatentConditioningItem): | |
| logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.") | |
| continue | |
| media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device) | |
| media_frame_number, strength = item.media_frame_number, item.conditioning_strength | |
| if media_frame_number == 0: | |
| f_l, h_l, w_l = media_item_latents.shape[-3:] | |
| init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength) | |
| init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength | |
| else: | |
| noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype) | |
| media_item_latents = torch.lerp(noise, media_item_latents, strength) | |
| patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents) | |
| pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| pixel_coords[:, 0] += media_frame_number | |
| extra_conditioning_num_latents += patched_latents.shape[1] | |
| new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device) | |
| extra_conditioning_latents.append(patched_latents) | |
| extra_conditioning_pixel_coords.append(pixel_coords) | |
| extra_conditioning_mask.append(new_mask) | |
| init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) | |
| init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning) | |
| init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1)) | |
| init_conditioning_mask = init_conditioning_mask.squeeze(-1) | |
| if extra_conditioning_latents: | |
| init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) | |
| init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2) | |
| init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1) | |
| return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents | |
| # ============================================================================== | |
| # --- LTX Worker e Pool Manager --- | |
| # ============================================================================== | |
| class LTXWorker: | |
| """Gerencia uma instância do LTX Pipeline em um par de GPUs (main + vae).""" | |
| def __init__(self, main_device: str, vae_device: str, config: dict): | |
| self.main_device = torch.device(main_device) | |
| self.vae_device = torch.device(vae_device) | |
| self.config = config | |
| self.pipeline: LTXVideoPipeline = None | |
| self._load_and_patch_pipeline() | |
| def _load_and_patch_pipeline(self): | |
| logging.info(f"[LTXWorker-{self.main_device}] Carregando pipeline LTX para a CPU...") | |
| self.pipeline, _ = build_ltx_pipeline_on_cpu(self.config) | |
| logging.info(f"[LTXWorker-{self.main_device}] Movendo pipeline para GPUs (Main: {self.main_device}, VAE: {self.vae_device})...") | |
| self.pipeline.to(self.main_device) | |
| self.pipeline.vae.to(self.vae_device) | |
| logging.info(f"[LTXWorker-{self.main_device}] Aplicando patch ADUC-SDR na função 'prepare_conditioning'...") | |
| # A "mágica" do monkey patching acontece aqui | |
| self.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(self.pipeline, LTXVideoPipeline) | |
| logging.info(f"[LTXWorker-{self.main_device}] ✅ Pipeline 'quente', corrigido e pronto.") | |
| class LTXAducManager: | |
| def __init__(self): | |
| main_device = gpu_manager.get_ltx_device() | |
| vae_device = gpu_manager.get_ltx_vae_device() | |
| # Em uma arquitetura futura, poderíamos ter múltiplos workers. Por enquanto, temos um. | |
| self.worker = LTXWorker(str(main_device), str(vae_device), load_config()) | |
| def load_config(self) -> Dict: | |
| """Loads the YAML configuration file.""" | |
| config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml" | |
| with open(config_path, "r") as file: | |
| return yaml.safe_load(file) | |
| def get_pipeline(self) -> LTXVideoPipeline: | |
| return self.worker.pipeline | |
| # Instância Singleton | |
| ltx_aduc_manager = LTXAducManager() |