# FILE: api/ltx/ltx_utils.py # DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline. # Handles dependency path injection, model loading, data structures, and helper functions. import os import random import json import logging import time import sys from pathlib import Path from typing import Dict, Optional, Tuple, Union from dataclasses import dataclass from enum import Enum, auto import numpy as np import torch import torchvision.transforms.functional as TVF from PIL import Image from safetensors import safe_open from transformers import T5EncoderModel, T5Tokenizer # ============================================================================== # --- CRITICAL: DEPENDENCY PATH INJECTION --- # ============================================================================== # Define o caminho para o repositório clonado LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") def add_deps_to_path(): """ Adiciona o diretório do repositório LTX ao sys.path para garantir que suas bibliotecas possam ser importadas. """ repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) if repo_path not in sys.path: sys.path.insert(0, repo_path) logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}") # Executa a função imediatamente para configurar o ambiente antes de qualquer importação. add_deps_to_path() # ============================================================================== # --- IMPORTAÇÕES DA BIBLIOTECA LTX-VIDEO (Após configuração do path) --- # ============================================================================== try: from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder from ltx_video.models.transformers.transformer3d import Transformer3DModel from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier from ltx_video.schedulers.rf import RectifiedFlowScheduler from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents import ltx_video.pipelines.crf_compressor as crf_compressor except ImportError as e: raise ImportError(f"Could not import from LTX-Video library even after setting sys.path. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}") # ============================================================================== # --- ESTRUTURAS DE DADOS E ENUMS (Centralizadas aqui) --- # ============================================================================== @dataclass class ConditioningItem: """Define a single frame-conditioning item, used to guide the generation pipeline.""" media_item: torch.Tensor media_frame_number: int conditioning_strength: float media_x: Optional[int] = None media_y: Optional[int] = None class SkipLayerStrategy(Enum): """Defines the strategy for how spatio-temporal guidance is applied across transformer blocks.""" AttentionSkip = auto() AttentionValues = auto() Residual = auto() TransformerBlock = auto() # ============================================================================== # --- FUNÇÕES DE CONSTRUÇÃO DE MODELO E PIPELINE --- # ============================================================================== def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler: """Loads the Latent Upsampler model from a checkpoint path.""" logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}") latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path) latent_upsampler.to(device) latent_upsampler.eval() return latent_upsampler def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]: """Builds the complete LTX pipeline and upsampler on the CPU.""" t0 = time.perf_counter() logging.info("Building LTX pipeline on CPU...") ckpt_path = Path(config["checkpoint_path"]) if not ckpt_path.is_file(): raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}") with safe_open(ckpt_path, framework="pt") as f: metadata = f.metadata() or {} config_str = metadata.get("config", "{}") configs = json.loads(config_str) allowed_inference_steps = configs.get("allowed_inference_steps") vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu") transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu") scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) text_encoder_path = config["text_encoder_model_name_or_path"] text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu") tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer") patchifier = SymmetricPatchifier(patch_size=1) precision = config.get("precision", "bfloat16") if precision == "bfloat16": vae.to(torch.bfloat16) transformer.to(torch.bfloat16) text_encoder.to(torch.bfloat16) pipeline = LTXVideoPipeline( transformer=transformer, patchifier=patchifier, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, vae=vae, allowed_inference_steps=allowed_inference_steps, prompt_enhancer_image_caption_model=None, prompt_enhancer_image_caption_processor=None, prompt_enhancer_llm_model=None, prompt_enhancer_llm_tokenizer=None, ) latent_upsampler = None if config.get("spatial_upscaler_model_path"): spatial_path = config["spatial_upscaler_model_path"] latent_upsampler = create_latent_upsampler(spatial_path, device="cpu") if precision == "bfloat16": latent_upsampler.to(torch.bfloat16) logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s") return pipeline, latent_upsampler # ============================================================================== # --- FUNÇÕES AUXILIARES (Latent Processing, Seed, Image Prep) --- # ============================================================================== def adain_filter_latent( latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0 ) -> torch.Tensor: """Applies AdaIN to transfer the style from a reference latent to another.""" result = latents.clone() for i in range(latents.size(0)): for c in range(latents.size(1)): r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) i_sd, i_mean = torch.std_mean(result[i, c], dim=None) if i_sd > 1e-6: result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean return torch.lerp(latents, result, factor) def seed_everything(seed: int): """Sets the seed for reproducibility.""" random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def load_image_to_tensor_with_resize_and_crop( image_input: Union[str, Image.Image], target_height: int, target_width: int, ) -> torch.Tensor: """Loads and processes an image into a 5D tensor compatible with the LTX pipeline.""" if isinstance(image_input, str): image = Image.open(image_input).convert("RGB") elif isinstance(image_input, Image.Image): image = image_input else: raise ValueError("image_input must be a file path or a PIL Image object") input_width, input_height = image.size aspect_ratio_target = target_width / target_height aspect_ratio_frame = input_width / input_height if aspect_ratio_frame > aspect_ratio_target: new_width, new_height = int(input_height * aspect_ratio_target), input_height x_start, y_start = (input_width - new_width) // 2, 0 else: new_width, new_height = input_width, int(input_width / aspect_ratio_target) x_start, y_start = 0, (input_height - new_height) // 2 image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) image = image.resize((target_width, target_height), Image.Resampling.LANCZOS) frame_tensor = TVF.to_tensor(image) frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3)) frame_tensor_hwc = frame_tensor.permute(1, 2, 0) frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc) frame_tensor = frame_tensor_hwc.permute(2, 0, 1) # Normalize to [-1, 1] range frame_tensor = (frame_tensor * 2.0) - 1.0 # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width) return frame_tensor.unsqueeze(0).unsqueeze(2)