File size: 7,490 Bytes
5105909
42998d3
 
 
 
9e82695
42998d3
 
 
6e680c2
 
 
25d9c99
511f633
42998d3
61082a6
 
 
 
 
42998d3
5105909
42998d3
8504d9f
 
 
 
 
 
 
 
 
 
 
 
 
42998d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5105909
42998d3
 
 
 
 
 
 
5105909
 
 
 
42998d3
5105909
 
 
42998d3
5105909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42998d3
5105909
 
 
 
 
 
 
 
 
 
42998d3
5105909
 
42998d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf5e057
42998d3
 
 
 
24af70a
81d098e
 
 
 
 
 
42998d3
 
 
 
 
bf5e057
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# FILE: api/ltx/ltx_aduc_manager.py
# DESCRIPTION: The "secret weapon". A pool manager for LTX that applies
# runtime patches to the pipeline for full control and ADUC-SDR compatibility.

import logging
from typing import Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
import torch
from diffusers.utils.torch_utils import randn_tensor
import sys
from pathlib import Path
import os
import random
import yaml

LTX_REPO_ID = "Lightricks/LTX-Video"
DEPS_DIR = Path("/data")
LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
RESULTS_DIR = Path("/app/output")

# --- Importações da nossa arquitetura ---
from managers.gpu_manager import gpu_manager
from api.ltx.ltx_utils import build_ltx_pipeline_on_cpu

def add_deps_to_path():
    """
    Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
    bibliotecas possam ser importadas.
    """
    repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
    if repo_path not in sys.path:
        sys.path.insert(0, repo_path)
        logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")

# Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
add_deps_to_path()
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline

# --- Definição dos nossos Data Classes ---
@dataclass
class ConditioningItem:
    pixel_tensor: torch.Tensor # Sempre um tensor de pixel
    media_frame_number: int
    conditioning_strength: float

@dataclass
class LatentConditioningItem:
    latent_tensor: torch.Tensor # Sempre um tensor latente
    media_frame_number: int
    conditioning_strength: float

# ==============================================================================
# --- O MONKEY PATCH ---
# Esta é a nossa versão customizada de `prepare_conditioning`
# ==============================================================================
def _aduc_prepare_conditioning_patch(
    self: "LTXVideoPipeline",
    conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]],
    init_latents: torch.Tensor,
    num_frames: int,
    height: int,
    width: int,
    vae_per_channel_normalize: bool = False,
    generator=None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
    if not conditioning_items:
        init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
        init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
        return init_latents, init_pixel_coords, None, 0
    
    init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
    extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
    extra_conditioning_num_latents = 0
    
    for item in conditioning_items:
        if not isinstance(item, LatentConditioningItem):
            logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.")
            continue
        
        media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
        media_frame_number, strength = item.media_frame_number, item.conditioning_strength
        
        if media_frame_number == 0:
            f_l, h_l, w_l = media_item_latents.shape[-3:]
            init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
            init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
        else:
            noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
            media_item_latents = torch.lerp(noise, media_item_latents, strength)
            patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
            pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
            pixel_coords[:, 0] += media_frame_number
            extra_conditioning_num_latents += patched_latents.shape[1]
            new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
            extra_conditioning_latents.append(patched_latents)
            extra_conditioning_pixel_coords.append(pixel_coords)
            extra_conditioning_mask.append(new_mask)

    init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
    init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
    init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
    init_conditioning_mask = init_conditioning_mask.squeeze(-1)
    
    if extra_conditioning_latents:
        init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
        init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
        init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
    
    return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
    



# ==============================================================================
# --- LTX Worker e Pool Manager ---
# ==============================================================================

class LTXWorker:
    """Gerencia uma instância do LTX Pipeline em um par de GPUs (main + vae)."""
    def __init__(self, main_device: str, vae_device: str, config: dict):
        self.main_device = torch.device(main_device)
        self.vae_device = torch.device(vae_device)
        self.config = config
        self.pipeline: LTXVideoPipeline = None
        
        self._load_and_patch_pipeline()

    def _load_and_patch_pipeline(self):
        logging.info(f"[LTXWorker-{self.main_device}] Carregando pipeline LTX para a CPU...")
        self.pipeline, _ = build_ltx_pipeline_on_cpu(self.config)
        
        logging.info(f"[LTXWorker-{self.main_device}] Movendo pipeline para GPUs (Main: {self.main_device}, VAE: {self.vae_device})...")
        self.pipeline.to(self.main_device)
        self.pipeline.vae.to(self.vae_device)

        logging.info(f"[LTXWorker-{self.main_device}] Aplicando patch ADUC-SDR na função 'prepare_conditioning'...")
        # A "mágica" do monkey patching acontece aqui
        self.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(self.pipeline, LTXVideoPipeline)
        logging.info(f"[LTXWorker-{self.main_device}] ✅ Pipeline 'quente', corrigido e pronto.")


class LTXAducManager:
    def __init__(self):
        main_device = gpu_manager.get_ltx_device()
        vae_device = gpu_manager.get_ltx_vae_device()
        # Em uma arquitetura futura, poderíamos ter múltiplos workers. Por enquanto, temos um.
        self.worker = LTXWorker(str(main_device), str(vae_device), load_config())

    def load_config(self) -> Dict:
        """Loads the YAML configuration file."""
        config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
        with open(config_path, "r") as file:
            return yaml.safe_load(file)

    def get_pipeline(self) -> LTXVideoPipeline:
        return self.worker.pipeline

# Instância Singleton
ltx_aduc_manager = LTXAducManager()