File size: 13,016 Bytes
451b75f
ce10698
 
 
 
 
 
451b75f
 
 
ce10698
 
 
 
607756b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce10698
 
451b75f
ce10698
 
 
 
451b75f
 
 
 
ce10698
 
 
451b75f
 
ce10698
451b75f
 
ce10698
451b75f
 
ce10698
 
451b75f
 
 
ce10698
 
 
 
 
 
 
 
 
 
451b75f
 
 
 
 
 
 
 
 
ce10698
451b75f
ce10698
451b75f
ce10698
 
451b75f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce10698
 
451b75f
ce10698
451b75f
 
 
 
 
 
 
 
ce10698
451b75f
 
 
 
 
 
 
 
 
 
 
 
 
ce10698
451b75f
 
ce10698
 
 
451b75f
 
ce10698
451b75f
 
 
 
 
 
ce10698
 
 
451b75f
 
 
 
ce10698
 
451b75f
 
 
 
 
 
 
 
 
 
ce10698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451b75f
 
 
 
 
ce10698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451b75f
 
 
ce10698
451b75f
 
 
 
ce10698
 
 
 
 
451b75f
 
 
 
 
 
 
ce10698
 
 
 
451b75f
 
ce10698
451b75f
 
ce10698
451b75f
ce10698
451b75f
 
ce10698
 
451b75f
 
 
ce10698
 
 
451b75f
 
 
 
ce10698
451b75f
 
ce10698
451b75f
 
 
ce10698
451b75f
 
ce10698
451b75f
ce10698
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# aduc_ltx_latent_patch.py
#
# Este módulo fornece um monkey patch para a classe LTXVideoPipeline da biblioteca ltx_video.
# A principal funcionalidade deste patch é otimizar o processo de condicionamento, permitindo
# que a pipeline aceite tensores de latentes pré-calculados diretamente através de um
# `ConditioningItem` modificado. Isso evita a re-codificação desnecessária de mídias (imagens/vídeos)
# pela VAE, resultando em um ganho de performance significativo quando os latentes já estão disponíveis.

import torch
from torch import Tensor
from typing import Optional, List, Tuple
from pathlib import Path
import os
import sys

DEPS_DIR = Path("/data")
LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
def add_deps_to_path(repo_path: Path):
    """Adiciona o diretório do repositório ao sys.path para importações locais."""
    resolved_path = str(repo_path.resolve())
    if resolved_path not in sys.path:
        sys.path.insert(0, resolved_path)
        if LTXV_DEBUG:
            print(f"[DEBUG] Adicionado ao sys.path: {resolved_path}")

# --- Execução da configuração inicial ---
if not LTX_VIDEO_REPO_DIR.exists():
    _run_setup_script()
add_deps_to_path(LTX_VIDEO_REPO_DIR)


# Tenta importar as dependências necessárias do módulo original que será modificado.
# Isso requer que o ambiente Python tenha o pacote `ltx_video` acessível em seu sys.path.
try:
    from ltx_video.pipelines.pipeline_ltx_video import (
        LTXVideoPipeline,
        ConditioningItem as OriginalConditioningItem
    )
    from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
    from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
    from diffusers.utils.torch_utils import randn_tensor
except ImportError as e:
    print(f"FATAL ERROR: Could not import dependencies from 'ltx_video'. "
          f"Please ensure the environment is correctly set up. Error: {e}")
    # Interrompe a execução se as dependências essenciais não puderem ser encontradas.
    raise

print("[INFO] Patch module 'aduc_ltx_latent_patch' loaded successfully.")

# ==============================================================================
# 1. NOVA DEFINIÇÃO DA DATACLASS `ConditioningItem`
# ==============================================================================

from dataclasses import dataclass

@dataclass
class PatchedConditioningItem:
    """
    Versão modificada do `ConditioningItem` que aceita tensores de pixel (`media_item`)
    ou tensores de latentes pré-codificados (`latents`).

    Attributes:
        media_frame_number (int): Quadro inicial do item de condicionamento no vídeo.
        conditioning_strength (float): Força do condicionamento (0.0 a 1.0).
        media_item (Optional[Tensor]): Tensor de mídia (pixels). Usado se `latents` for None.
        media_x (Optional[int]): Coordenada X (esquerda) para posicionamento espacial.
        media_y (Optional[int]): Coordenada Y (topo) para posicionamento espacial.
        latents (Optional[Tensor]): Tensor de latentes pré-codificado. Terá precedência sobre `media_item`.
    """
    media_frame_number: int
    conditioning_strength: float
    media_item: Optional[Tensor] = None
    media_x: Optional[int] = None
    media_y: Optional[int] = None
    latents: Optional[Tensor] = None

    def __post_init__(self):
        """Valida o estado do objeto após a inicialização."""
        if self.media_item is None and self.latents is None:
            raise ValueError("A `PatchedConditioningItem` must have either 'media_item' or 'latents' defined.")
        if self.media_item is not None and self.latents is not None:
            print("[WARNING] `PatchedConditioningItem` received both 'media_item' and 'latents'. "
                  "The 'latents' tensor will take precedence.")

# ==============================================================================
# 2. NOVA IMPLEMENTAÇÃO DA FUNÇÃO `prepare_conditioning`
# ==============================================================================

def prepare_conditioning_with_latents(
    self: LTXVideoPipeline,
    conditioning_items: Optional[List[PatchedConditioningItem]],
    init_latents: Tensor,
    num_frames: int,
    height: int,
    width: int,
    vae_per_channel_normalize: bool = False,
    generator: Optional[torch.Generator] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor], int]:
    """
    Versão modificada de `prepare_conditioning` que prioriza o uso de latentes pré-calculados
    dos `conditioning_items`, evitando a re-codificação desnecessária pela VAE.
    """
    assert isinstance(self, LTXVideoPipeline), "This function must be called as a method of LTXVideoPipeline."
    assert isinstance(self.vae, CausalVideoAutoencoder), "VAE must be of type CausalVideoAutoencoder."

    # Se não há itens de condicionamento, apenas patchifica os latentes e retorna.
    if not conditioning_items:
        init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
        init_pixel_coords = latent_to_pixel_coords(
            init_latent_coords, self.vae,
            causal_fix=self.transformer.config.causal_temporal_positioning
        )
        return init_latents, init_pixel_coords, None, 0

    # Inicializa tensores para acumular resultados
    init_conditioning_mask = torch.zeros(
        init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device
    )
    extra_conditioning_latents = []
    extra_conditioning_pixel_coords = []
    extra_conditioning_mask = []
    extra_conditioning_num_latents = 0

    for item in conditioning_items:
        item_latents: Tensor

        # --- LÓGICA CENTRAL DO PATCH ---
        if item.latents is not None:
            # 1. Se latentes pré-calculados existem, use-os diretamente.
            item_latents = item.latents.to(dtype=init_latents.dtype, device=init_latents.device)
            if item_latents.ndim != 5:
                raise ValueError(f"Latents must have 5 dimensions (b, c, f, h, w), but got {item_latents.ndim}")
        elif item.media_item is not None:
            # 2. Caso contrário, volte para o fluxo original de codificação da VAE.
            resized_item = self._resize_conditioning_item(item, height, width)
            media_item = resized_item.media_item
            assert media_item.ndim == 5, f"media_item must have 5 dims, but got {media_item.ndim}"

            item_latents = vae_encode(
                media_item.to(dtype=self.vae.dtype, device=self.vae.device),
                self.vae,
                vae_per_channel_normalize=vae_per_channel_normalize,
            ).to(dtype=init_latents.dtype)
        else:
            # Este caso é prevenido pelo __post_init__ do dataclass, mas é bom ter uma checagem.
            raise ValueError("ConditioningItem is invalid: it has neither 'latents' nor 'media_item'.")
        # --- FIM DA LÓGICA DO PATCH ---

        media_frame_number = item.media_frame_number
        strength = item.conditioning_strength
        
        # O resto da lógica da função original é aplicado sobre `item_latents`.
        if media_frame_number == 0:
            item_latents, l_x, l_y = self._get_latent_spatial_position(
                item_latents, item, height, width, strip_latent_border=True
            )
            _, _, f_l, h_l, w_l = item_latents.shape
            init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = torch.lerp(
                init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], item_latents, strength
            )
            init_conditioning_mask[:, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = strength
        else:
            if item_latents.shape[2] > 1:
                (init_latents, init_conditioning_mask, item_latents) = self._handle_non_first_conditioning_sequence(
                    init_latents, init_conditioning_mask, item_latents, media_frame_number, strength
                )

            if item_latents is not None:
                noise = randn_tensor(
                    item_latents.shape, generator=generator,
                    device=item_latents.device, dtype=item_latents.dtype
                )
                item_latents = torch.lerp(noise, item_latents, strength)
                item_latents, latent_coords = self.patchifier.patchify(latents=item_latents)
                pixel_coords = latent_to_pixel_coords(
                    latent_coords, self.vae,
                    causal_fix=self.transformer.config.causal_temporal_positioning
                )
                pixel_coords[:, 0] += media_frame_number
                extra_conditioning_num_latents += item_latents.shape[1]
                conditioning_mask = torch.full(
                    item_latents.shape[:2], strength,
                    dtype=torch.float32, device=init_latents.device
                )
                extra_conditioning_latents.append(item_latents)
                extra_conditioning_pixel_coords.append(pixel_coords)
                extra_conditioning_mask.append(conditioning_mask)

    # Patchifica os latentes principais e a máscara de condicionamento
    init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
    init_pixel_coords = latent_to_pixel_coords(
        init_latent_coords, self.vae,
        causal_fix=self.transformer.config.causal_temporal_positioning
    )
    init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
    init_conditioning_mask = init_conditioning_mask.squeeze(-1)

    # Concatena os latentes extras (se houver)
    if extra_conditioning_latents:
        init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
        init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
        init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)

        if self.transformer.use_tpu_flash_attention:
            init_latents = init_latents[:, :-extra_conditioning_num_latents]
            init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents]
            init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents]

    return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents


# ==============================================================================
# 3. CLASSE DO MONKEY PATCHER
# ==============================================================================

class LTXLatentConditioningPatch:
    """
    Classe estática para aplicar e reverter o monkey patch na pipeline LTX-Video.

    Esta classe substitui o método `prepare_conditioning` da `LTXVideoPipeline`
    pela versão otimizada que suporta latentes pré-calculados, e implicitamente
    requer o uso da `PatchedConditioningItem`.
    """
    _original_prepare_conditioning = None
    _is_patched = False

    @staticmethod
    def apply():
        """
        Aplica o monkey patch à classe `LTXVideoPipeline`.

        Guarda o método original e o substitui pela nova implementação.
        É idempotente; aplicar múltiplas vezes não causa efeito adicional.
        """
        if LTXLatentConditioningPatch._is_patched:
            print("[WARNING] LTXLatentConditioningPatch has already been applied. Ignoring.")
            return

        print("[INFO] Applying monkey patch for latent-based conditioning...")

        # Guarda a implementação original para permitir a reversão.
        LTXLatentConditioningPatch._original_prepare_conditioning = LTXVideoPipeline.prepare_conditioning

        # Substitui o método na classe LTXVideoPipeline.
        # Todas as instâncias futuras e existentes da classe usarão este novo método.
        LTXVideoPipeline.prepare_conditioning = prepare_conditioning_with_latents

        LTXLatentConditioningPatch._is_patched = True
        print("[SUCCESS] Monkey patch applied successfully.")
        print("  - `LTXVideoPipeline.prepare_conditioning` has been updated.")
        print("  - NOTE: Remember to use `aduc_ltx_latent_patch.PatchedConditioningItem` when creating conditioning items.")

    @staticmethod
    def revert():
        """
        Reverte o monkey patch, restaurando a implementação original.
        """
        if not LTXLatentConditioningPatch._is_patched:
            print("[WARNING] Patch is not currently applied. No action taken.")
            return

        if LTXLatentConditioningPatch._original_prepare_conditioning:
            print("[INFO] Reverting LTXLatentConditioningPatch...")
            LTXVideoPipeline.prepare_conditioning = LTXLatentConditioningPatch._original_prepare_conditioning
            LTXLatentConditioningPatch._is_patched = False
            print("[SUCCESS] Patch reverted successfully. Original functionality restored.")
        else:
            print("[ERROR] Cannot revert: original implementation was not saved.")