Spaces:
Paused
Paused
File size: 12,924 Bytes
451b75f ce10698 451b75f ce10698 e8cfb14 607756b e8cfb14 607756b ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f e8cfb14 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f e8cfb14 451b75f e8cfb14 451b75f e8cfb14 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 e8cfb14 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 451b75f ce10698 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 |
# aduc_ltx_latent_patch.py
#
# Este módulo fornece um monkey patch para a classe LTXVideoPipeline da biblioteca ltx_video.
# A principal funcionalidade deste patch é otimizar o processo de condicionamento, permitindo
# que a pipeline aceite tensores de latentes pré-calculados diretamente através de um
# `ConditioningItem` modificado. Isso evita a re-codificação desnecessária de mídias (imagens/vídeos)
# pela VAE, resultando em um ganho de performance significativo quando os latentes já estão disponíveis.
import torch
from torch import Tensor
from typing import Optional, List, Tuple
from pathlib import Path
import os
import sys
from dataclasses import dataclass, replace
# --- CONFIGURAÇÃO DE PATH (Assume que LTXV_DEBUG e _run_setup_script existem no escopo que carrega este módulo) ---
# DEPS_DIR = Path("/data")
# LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
# def add_deps_to_path(repo_path: Path):
# """Adiciona o diretório do repositório ao sys.path para importações locais."""
# resolved_path = str(repo_path.resolve())
# if resolved_path not in sys.path:
# sys.path.insert(0, resolved_path)
# add_deps_to_path(LTX_VIDEO_REPO_DIR)
# Tenta importar as dependências necessárias do módulo original que será modificado.
try:
from ltx_video.pipelines.pipeline_ltx_video import (
LTXVideoPipeline,
ConditioningItem as OriginalConditioningItem
)
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
from diffusers.utils.torch_utils import randn_tensor
except ImportError as e:
print(f"FATAL ERROR: Could not import dependencies from 'ltx_video'. "
f"Please ensure the environment is correctly set up. Error: {e}")
raise
print("[INFO] Patch module 'aduc_ltx_latent_patch' loaded successfully.")
# ==============================================================================
# 1. NOVA DEFINIÇÃO DA DATACLASS `PatchedConditioningItem`
# ==============================================================================
@dataclass
class PatchedConditioningItem:
"""
Versão modificada do `ConditioningItem` que aceita tensores de pixel (`media_item`)
ou tensores de latentes pré-codificados (`latents`).
Attributes:
media_frame_number (int): Quadro inicial do item de condicionamento no vídeo.
conditioning_strength (float): Força do condicionamento (0.0 a 1.0).
media_item (Optional[Tensor]): Tensor de mídia (pixels). Usado se `latents` for None.
media_x (Optional[int]): Coordenada X (esquerda) para posicionamento espacial.
media_y (Optional[int]): Coordenada Y (topo) para posicionamento espacial.
latents (Optional[Tensor]): Tensor de latentes pré-codificado. Terá precedência sobre `media_item`.
"""
media_frame_number: int
conditioning_strength: float
media_item: Optional[Tensor] = None
media_x: Optional[int] = None
media_y: Optional[int] = None
latents: Optional[Tensor] = None
def __post_init__(self):
"""Valida o estado do objeto após a inicialização."""
if self.media_item is None and self.latents is None:
raise ValueError("A `PatchedConditioningItem` must have either 'media_item' or 'latents' defined.")
if self.media_item is not None and self.latents is not None:
print("[WARNING] `PatchedConditioningItem` received both 'media_item' and 'latents'. "
"The 'latents' tensor will take precedence.")
# ==============================================================================
# 2. NOVA IMPLEMENTAÇÃO DA FUNÇÃO `prepare_conditioning`
# ==============================================================================
def prepare_conditioning_with_latents(
self: LTXVideoPipeline,
conditioning_items: Optional[List[PatchedConditioningItem]],
init_latents: Tensor,
num_frames: int,
height: int,
width: int,
vae_per_channel_normalize: bool = False,
generator: Optional[torch.Generator] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor], int]:
"""
Versão modificada de `prepare_conditioning` que prioriza o uso de latentes pré-calculados
dos `conditioning_items`, evitando a re-codificação desnecessária pela VAE.
"""
assert isinstance(self, LTXVideoPipeline), "This function must be called as a method of LTXVideoPipeline."
assert isinstance(self.vae, CausalVideoAutoencoder), "VAE must be of type CausalVideoAutoencoder."
if not conditioning_items:
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
init_pixel_coords = latent_to_pixel_coords(
init_latent_coords, self.vae,
causal_fix=self.transformer.config.causal_temporal_positioning
)
return init_latents, init_pixel_coords, None, 0
init_conditioning_mask = torch.zeros(
init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device
)
extra_conditioning_latents = []
extra_conditioning_pixel_coords = []
extra_conditioning_mask = []
extra_conditioning_num_latents = 0
for item in conditioning_items:
item_latents: Tensor
if item.latents is not None:
item_latents = item.latents.to(dtype=init_latents.dtype, device=init_latents.device)
if item_latents.ndim != 5:
raise ValueError(f"Latents must have 5 dimensions (b, c, f, h, w), but got {item_latents.ndim}")
elif item.media_item is not None:
resized_item = self._resize_conditioning_item(item, height, width)
media_item = resized_item.media_item
assert media_item.ndim == 5, f"media_item must have 5 dims, but got {media_item.ndim}"
item_latents = vae_encode(
media_item.to(dtype=self.vae.dtype, device=self.vae.device),
self.vae,
vae_per_channel_normalize=vae_per_channel_normalize,
).to(dtype=init_latents.dtype)
else:
raise ValueError("ConditioningItem is invalid: it has neither 'latents' nor 'media_item'.")
media_frame_number = item.media_frame_number
strength = item.conditioning_strength
if media_frame_number == 0:
# --- INÍCIO DA MODIFICAÇÃO ---
# Se `item.media_item` for None (nosso caso de uso otimizado), a função original `_get_latent_spatial_position`
# quebraria. Para evitar isso, criamos um item temporário com um tensor de placeholder que contém
# as informações de dimensão corretas, inferidas a partir dos próprios latentes.
item_for_spatial_position = item
if item.media_item is None:
# Infere as dimensões em pixels a partir da forma dos latentes
latent_h, latent_w = item_latents.shape[-2:]
pixel_h = latent_h * self.vae_scale_factor
pixel_w = latent_w * self.vae_scale_factor
# Cria um tensor de placeholder com o shape esperado (o conteúdo não importa)
placeholder_media_item = torch.empty(
(1, 1, 1, pixel_h, pixel_w), device=item_latents.device, dtype=item_latents.dtype
)
# Usa `dataclasses.replace` para criar uma cópia temporária do item com o placeholder
item_for_spatial_position = replace(item, media_item=placeholder_media_item)
# Chama a função original com um item que ela pode processar sem erro
item_latents, l_x, l_y = self._get_latent_spatial_position(
item_latents, item_for_spatial_position, height, width, strip_latent_border=True
)
# --- FIM DA MODIFICAÇÃO ---
_, _, f_l, h_l, w_l = item_latents.shape
init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = torch.lerp(
init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], item_latents, strength
)
init_conditioning_mask[:, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = strength
else:
if item_latents.shape[2] > 1:
(init_latents, init_conditioning_mask, item_latents) = self._handle_non_first_conditioning_sequence(
init_latents, init_conditioning_mask, item_latents, media_frame_number, strength
)
if item_latents is not None:
noise = randn_tensor(
item_latents.shape, generator=generator,
device=item_latents.device, dtype=item_latents.dtype
)
item_latents = torch.lerp(noise, item_latents, strength)
item_latents, latent_coords = self.patchifier.patchify(latents=item_latents)
pixel_coords = latent_to_pixel_coords(
latent_coords, self.vae,
causal_fix=self.transformer.config.causal_temporal_positioning
)
pixel_coords[:, 0] += media_frame_number
extra_conditioning_num_latents += item_latents.shape[1]
conditioning_mask = torch.full(
item_latents.shape[:2], strength,
dtype=torch.float32, device=init_latents.device
)
extra_conditioning_latents.append(item_latents)
extra_conditioning_pixel_coords.append(pixel_coords)
extra_conditioning_mask.append(conditioning_mask)
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
init_pixel_coords = latent_to_pixel_coords(
init_latent_coords, self.vae,
causal_fix=self.transformer.config.causal_temporal_positioning
)
init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
init_conditioning_mask = init_conditioning_mask.squeeze(-1)
if extra_conditioning_latents:
init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
if self.transformer.use_tpu_flash_attention:
init_latents = init_latents[:, :-extra_conditioning_num_latents]
init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents]
init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents]
return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
# ==============================================================================
# 3. CLASSE DO MONKEY PATCHER
# ==============================================================================
class LTXLatentConditioningPatch:
"""
Classe estática para aplicar e reverter o monkey patch na pipeline LTX-Video.
"""
_original_prepare_conditioning = None
_is_patched = False
@staticmethod
def apply():
"""
Aplica o monkey patch à classe `LTXVideoPipeline`.
"""
if LTXLatentConditioningPatch._is_patched:
print("[WARNING] LTXLatentConditioningPatch has already been applied. Ignoring.")
return
print("[INFO] Applying monkey patch for latent-based conditioning...")
LTXLatentConditioningPatch._original_prepare_conditioning = LTXVideoPipeline.prepare_conditioning
LTXVideoPipeline.prepare_conditioning = prepare_conditioning_with_latents
LTXLatentConditioningPatch._is_patched = True
print("[SUCCESS] Monkey patch applied successfully.")
print(" - `LTXVideoPipeline.prepare_conditioning` has been updated.")
print(" - NOTE: Remember to use `aduc_ltx_latent_patch.PatchedConditioningItem` when creating conditioning items.")
@staticmethod
def revert():
"""
Reverte o monkey patch, restaurando a implementação original.
"""
if not LTXLatentConditioningPatch._is_patched:
print("[WARNING] Patch is not currently applied. No action taken.")
return
if LTXLatentConditioningPatch._original_prepare_conditioning:
print("[INFO] Reverting LTXLatentConditioningPatch...")
LTXVideoPipeline.prepare_conditioning = LTXLatentConditioningPatch._original_prepare_conditioning
LTXLatentConditioningPatch._is_patched = False
print("[SUCCESS] Patch reverted successfully. Original functionality restored.")
else:
print("[ERROR] Cannot revert: original implementation was not saved.") |