Spaces:
Paused
Paused
File size: 7,490 Bytes
5105909 42998d3 9e82695 42998d3 6e680c2 25d9c99 511f633 42998d3 61082a6 42998d3 5105909 42998d3 8504d9f 42998d3 5105909 42998d3 5105909 42998d3 5105909 42998d3 5105909 42998d3 5105909 42998d3 5105909 42998d3 bf5e057 42998d3 24af70a 81d098e 42998d3 bf5e057 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# FILE: api/ltx/ltx_aduc_manager.py
# DESCRIPTION: The "secret weapon". A pool manager for LTX that applies
# runtime patches to the pipeline for full control and ADUC-SDR compatibility.
import logging
from typing import Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
import torch
from diffusers.utils.torch_utils import randn_tensor
import sys
from pathlib import Path
import os
import random
import yaml
LTX_REPO_ID = "Lightricks/LTX-Video"
DEPS_DIR = Path("/data")
LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
RESULTS_DIR = Path("/app/output")
# --- Importações da nossa arquitetura ---
from managers.gpu_manager import gpu_manager
from api.ltx.ltx_utils import build_ltx_pipeline_on_cpu
def add_deps_to_path():
"""
Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
bibliotecas possam ser importadas.
"""
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
if repo_path not in sys.path:
sys.path.insert(0, repo_path)
logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
# Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
add_deps_to_path()
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
# --- Definição dos nossos Data Classes ---
@dataclass
class ConditioningItem:
pixel_tensor: torch.Tensor # Sempre um tensor de pixel
media_frame_number: int
conditioning_strength: float
@dataclass
class LatentConditioningItem:
latent_tensor: torch.Tensor # Sempre um tensor latente
media_frame_number: int
conditioning_strength: float
# ==============================================================================
# --- O MONKEY PATCH ---
# Esta é a nossa versão customizada de `prepare_conditioning`
# ==============================================================================
def _aduc_prepare_conditioning_patch(
self: "LTXVideoPipeline",
conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]],
init_latents: torch.Tensor,
num_frames: int,
height: int,
width: int,
vae_per_channel_normalize: bool = False,
generator=None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
if not conditioning_items:
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
return init_latents, init_pixel_coords, None, 0
init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
extra_conditioning_num_latents = 0
for item in conditioning_items:
if not isinstance(item, LatentConditioningItem):
logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.")
continue
media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
media_frame_number, strength = item.media_frame_number, item.conditioning_strength
if media_frame_number == 0:
f_l, h_l, w_l = media_item_latents.shape[-3:]
init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
else:
noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
media_item_latents = torch.lerp(noise, media_item_latents, strength)
patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
pixel_coords[:, 0] += media_frame_number
extra_conditioning_num_latents += patched_latents.shape[1]
new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
extra_conditioning_latents.append(patched_latents)
extra_conditioning_pixel_coords.append(pixel_coords)
extra_conditioning_mask.append(new_mask)
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
init_conditioning_mask = init_conditioning_mask.squeeze(-1)
if extra_conditioning_latents:
init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
# ==============================================================================
# --- LTX Worker e Pool Manager ---
# ==============================================================================
class LTXWorker:
"""Gerencia uma instância do LTX Pipeline em um par de GPUs (main + vae)."""
def __init__(self, main_device: str, vae_device: str, config: dict):
self.main_device = torch.device(main_device)
self.vae_device = torch.device(vae_device)
self.config = config
self.pipeline: LTXVideoPipeline = None
self._load_and_patch_pipeline()
def _load_and_patch_pipeline(self):
logging.info(f"[LTXWorker-{self.main_device}] Carregando pipeline LTX para a CPU...")
self.pipeline, _ = build_ltx_pipeline_on_cpu(self.config)
logging.info(f"[LTXWorker-{self.main_device}] Movendo pipeline para GPUs (Main: {self.main_device}, VAE: {self.vae_device})...")
self.pipeline.to(self.main_device)
self.pipeline.vae.to(self.vae_device)
logging.info(f"[LTXWorker-{self.main_device}] Aplicando patch ADUC-SDR na função 'prepare_conditioning'...")
# A "mágica" do monkey patching acontece aqui
self.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(self.pipeline, LTXVideoPipeline)
logging.info(f"[LTXWorker-{self.main_device}] ✅ Pipeline 'quente', corrigido e pronto.")
class LTXAducManager:
def __init__(self):
main_device = gpu_manager.get_ltx_device()
vae_device = gpu_manager.get_ltx_vae_device()
# Em uma arquitetura futura, poderíamos ter múltiplos workers. Por enquanto, temos um.
self.worker = LTXWorker(str(main_device), str(vae_device), load_config())
def load_config(self) -> Dict:
"""Loads the YAML configuration file."""
config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
with open(config_path, "r") as file:
return yaml.safe_load(file)
def get_pipeline(self) -> LTXVideoPipeline:
return self.worker.pipeline
# Instância Singleton
ltx_aduc_manager = LTXAducManager() |