Spaces:
Paused
Paused
| # aduc_ltx_latent_patch.py | |
| # | |
| # Este módulo fornece um monkey patch para a classe LTXVideoPipeline da biblioteca ltx_video. | |
| # A principal funcionalidade deste patch é otimizar o processo de condicionamento, permitindo | |
| # que a pipeline aceite tensores de latentes pré-calculados diretamente através de um | |
| # `ConditioningItem` modificado. Isso evita a re-codificação desnecessária de mídias (imagens/vídeos) | |
| # pela VAE, resultando em um ganho de performance significativo quando os latentes já estão disponíveis. | |
| import torch | |
| from torch import Tensor | |
| from typing import Optional, List, Tuple | |
| from pathlib import Path | |
| import os | |
| import sys | |
| from dataclasses import dataclass, replace | |
| # --- CONFIGURAÇÃO DE PATH (Assume que LTXV_DEBUG e _run_setup_script existem no escopo que carrega este módulo) --- | |
| # DEPS_DIR = Path("/data") | |
| # LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video" | |
| # def add_deps_to_path(repo_path: Path): | |
| # """Adiciona o diretório do repositório ao sys.path para importações locais.""" | |
| # resolved_path = str(repo_path.resolve()) | |
| # if resolved_path not in sys.path: | |
| # sys.path.insert(0, resolved_path) | |
| # add_deps_to_path(LTX_VIDEO_REPO_DIR) | |
| # Tenta importar as dependências necessárias do módulo original que será modificado. | |
| try: | |
| from ltx_video.pipelines.pipeline_ltx_video import ( | |
| LTXVideoPipeline, | |
| ConditioningItem as OriginalConditioningItem | |
| ) | |
| from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
| from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords | |
| from diffusers.utils.torch_utils import randn_tensor | |
| except ImportError as e: | |
| print(f"FATAL ERROR: Could not import dependencies from 'ltx_video'. " | |
| f"Please ensure the environment is correctly set up. Error: {e}") | |
| raise | |
| print("[INFO] Patch module 'aduc_ltx_latent_patch' loaded successfully.") | |
| # ============================================================================== | |
| # 1. NOVA DEFINIÇÃO DA DATACLASS `PatchedConditioningItem` | |
| # ============================================================================== | |
| class PatchedConditioningItem: | |
| """ | |
| Versão modificada do `ConditioningItem` que aceita tensores de pixel (`media_item`) | |
| ou tensores de latentes pré-codificados (`latents`). | |
| Attributes: | |
| media_frame_number (int): Quadro inicial do item de condicionamento no vídeo. | |
| conditioning_strength (float): Força do condicionamento (0.0 a 1.0). | |
| media_item (Optional[Tensor]): Tensor de mídia (pixels). Usado se `latents` for None. | |
| media_x (Optional[int]): Coordenada X (esquerda) para posicionamento espacial. | |
| media_y (Optional[int]): Coordenada Y (topo) para posicionamento espacial. | |
| latents (Optional[Tensor]): Tensor de latentes pré-codificado. Terá precedência sobre `media_item`. | |
| """ | |
| media_frame_number: int | |
| conditioning_strength: float | |
| media_item: Optional[Tensor] = None | |
| media_x: Optional[int] = None | |
| media_y: Optional[int] = None | |
| latents: Optional[Tensor] = None | |
| def __post_init__(self): | |
| """Valida o estado do objeto após a inicialização.""" | |
| if self.media_item is None and self.latents is None: | |
| raise ValueError("A `PatchedConditioningItem` must have either 'media_item' or 'latents' defined.") | |
| if self.media_item is not None and self.latents is not None: | |
| print("[WARNING] `PatchedConditioningItem` received both 'media_item' and 'latents'. " | |
| "The 'latents' tensor will take precedence.") | |
| # ============================================================================== | |
| # 2. NOVA IMPLEMENTAÇÃO DA FUNÇÃO `prepare_conditioning` | |
| # ============================================================================== | |
| def prepare_conditioning_with_latents( | |
| self: LTXVideoPipeline, | |
| conditioning_items: Optional[List[PatchedConditioningItem]], | |
| init_latents: Tensor, | |
| num_frames: int, | |
| height: int, | |
| width: int, | |
| vae_per_channel_normalize: bool = False, | |
| generator: Optional[torch.Generator] = None, | |
| ) -> Tuple[Tensor, Tensor, Optional[Tensor], int]: | |
| """ | |
| Versão modificada de `prepare_conditioning` que prioriza o uso de latentes pré-calculados | |
| dos `conditioning_items`, evitando a re-codificação desnecessária pela VAE. | |
| """ | |
| assert isinstance(self, LTXVideoPipeline), "This function must be called as a method of LTXVideoPipeline." | |
| assert isinstance(self.vae, CausalVideoAutoencoder), "VAE must be of type CausalVideoAutoencoder." | |
| if not conditioning_items: | |
| init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) | |
| init_pixel_coords = latent_to_pixel_coords( | |
| init_latent_coords, self.vae, | |
| causal_fix=self.transformer.config.causal_temporal_positioning | |
| ) | |
| return init_latents, init_pixel_coords, None, 0 | |
| init_conditioning_mask = torch.zeros( | |
| init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device | |
| ) | |
| extra_conditioning_latents = [] | |
| extra_conditioning_pixel_coords = [] | |
| extra_conditioning_mask = [] | |
| extra_conditioning_num_latents = 0 | |
| for item in conditioning_items: | |
| item_latents: Tensor | |
| if item.latents is not None: | |
| item_latents = item.latents.to(dtype=init_latents.dtype, device=init_latents.device) | |
| if item_latents.ndim != 5: | |
| raise ValueError(f"Latents must have 5 dimensions (b, c, f, h, w), but got {item_latents.ndim}") | |
| elif item.media_item is not None: | |
| resized_item = self._resize_conditioning_item(item, height, width) | |
| media_item = resized_item.media_item | |
| assert media_item.ndim == 5, f"media_item must have 5 dims, but got {media_item.ndim}" | |
| item_latents = vae_encode( | |
| media_item.to(dtype=self.vae.dtype, device=self.vae.device), | |
| self.vae, | |
| vae_per_channel_normalize=vae_per_channel_normalize, | |
| ).to(dtype=init_latents.dtype) | |
| else: | |
| raise ValueError("ConditioningItem is invalid: it has neither 'latents' nor 'media_item'.") | |
| media_frame_number = item.media_frame_number | |
| strength = item.conditioning_strength | |
| if media_frame_number == 0: | |
| # --- INÍCIO DA MODIFICAÇÃO --- | |
| # Se `item.media_item` for None (nosso caso de uso otimizado), a função original `_get_latent_spatial_position` | |
| # quebraria. Para evitar isso, criamos um item temporário com um tensor de placeholder que contém | |
| # as informações de dimensão corretas, inferidas a partir dos próprios latentes. | |
| item_for_spatial_position = item | |
| if item.media_item is None: | |
| # Infere as dimensões em pixels a partir da forma dos latentes | |
| latent_h, latent_w = item_latents.shape[-2:] | |
| pixel_h = latent_h * self.vae_scale_factor | |
| pixel_w = latent_w * self.vae_scale_factor | |
| # Cria um tensor de placeholder com o shape esperado (o conteúdo não importa) | |
| placeholder_media_item = torch.empty( | |
| (1, 1, 1, pixel_h, pixel_w), device=item_latents.device, dtype=item_latents.dtype | |
| ) | |
| # Usa `dataclasses.replace` para criar uma cópia temporária do item com o placeholder | |
| item_for_spatial_position = replace(item, media_item=placeholder_media_item) | |
| # Chama a função original com um item que ela pode processar sem erro | |
| item_latents, l_x, l_y = self._get_latent_spatial_position( | |
| item_latents, item_for_spatial_position, height, width, strip_latent_border=True | |
| ) | |
| # --- FIM DA MODIFICAÇÃO --- | |
| _, _, f_l, h_l, w_l = item_latents.shape | |
| init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = torch.lerp( | |
| init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], item_latents, strength | |
| ) | |
| init_conditioning_mask[:, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = strength | |
| else: | |
| if item_latents.shape[2] > 1: | |
| (init_latents, init_conditioning_mask, item_latents) = self._handle_non_first_conditioning_sequence( | |
| init_latents, init_conditioning_mask, item_latents, media_frame_number, strength | |
| ) | |
| if item_latents is not None: | |
| noise = randn_tensor( | |
| item_latents.shape, generator=generator, | |
| device=item_latents.device, dtype=item_latents.dtype | |
| ) | |
| item_latents = torch.lerp(noise, item_latents, strength) | |
| item_latents, latent_coords = self.patchifier.patchify(latents=item_latents) | |
| pixel_coords = latent_to_pixel_coords( | |
| latent_coords, self.vae, | |
| causal_fix=self.transformer.config.causal_temporal_positioning | |
| ) | |
| pixel_coords[:, 0] += media_frame_number | |
| extra_conditioning_num_latents += item_latents.shape[1] | |
| conditioning_mask = torch.full( | |
| item_latents.shape[:2], strength, | |
| dtype=torch.float32, device=init_latents.device | |
| ) | |
| extra_conditioning_latents.append(item_latents) | |
| extra_conditioning_pixel_coords.append(pixel_coords) | |
| extra_conditioning_mask.append(conditioning_mask) | |
| init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) | |
| init_pixel_coords = latent_to_pixel_coords( | |
| init_latent_coords, self.vae, | |
| causal_fix=self.transformer.config.causal_temporal_positioning | |
| ) | |
| init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1)) | |
| init_conditioning_mask = init_conditioning_mask.squeeze(-1) | |
| if extra_conditioning_latents: | |
| init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1) | |
| init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2) | |
| init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1) | |
| if self.transformer.use_tpu_flash_attention: | |
| init_latents = init_latents[:, :-extra_conditioning_num_latents] | |
| init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents] | |
| init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents] | |
| return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents | |
| # ============================================================================== | |
| # 3. CLASSE DO MONKEY PATCHER | |
| # ============================================================================== | |
| class LTXLatentConditioningPatch: | |
| """ | |
| Classe estática para aplicar e reverter o monkey patch na pipeline LTX-Video. | |
| """ | |
| _original_prepare_conditioning = None | |
| _is_patched = False | |
| def apply(): | |
| """ | |
| Aplica o monkey patch à classe `LTXVideoPipeline`. | |
| """ | |
| if LTXLatentConditioningPatch._is_patched: | |
| print("[WARNING] LTXLatentConditioningPatch has already been applied. Ignoring.") | |
| return | |
| print("[INFO] Applying monkey patch for latent-based conditioning...") | |
| LTXLatentConditioningPatch._original_prepare_conditioning = LTXVideoPipeline.prepare_conditioning | |
| LTXVideoPipeline.prepare_conditioning = prepare_conditioning_with_latents | |
| LTXLatentConditioningPatch._is_patched = True | |
| print("[SUCCESS] Monkey patch applied successfully.") | |
| print(" - `LTXVideoPipeline.prepare_conditioning` has been updated.") | |
| print(" - NOTE: Remember to use `aduc_ltx_latent_patch.PatchedConditioningItem` when creating conditioning items.") | |
| def revert(): | |
| """ | |
| Reverte o monkey patch, restaurando a implementação original. | |
| """ | |
| if not LTXLatentConditioningPatch._is_patched: | |
| print("[WARNING] Patch is not currently applied. No action taken.") | |
| return | |
| if LTXLatentConditioningPatch._original_prepare_conditioning: | |
| print("[INFO] Reverting LTXLatentConditioningPatch...") | |
| LTXVideoPipeline.prepare_conditioning = LTXLatentConditioningPatch._original_prepare_conditioning | |
| LTXLatentConditioningPatch._is_patched = False | |
| print("[SUCCESS] Patch reverted successfully. Original functionality restored.") | |
| else: | |
| print("[ERROR] Cannot revert: original implementation was not saved.") |