Spaces:
Paused
Paused
| # FILE: api/ltx/ltx_utils.py | |
| # DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline. | |
| # Handles dependency path injection, model loading, data structures, and helper functions. | |
| import os | |
| import random | |
| import json | |
| import logging | |
| import time | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, Optional, Tuple, Union | |
| from dataclasses import dataclass | |
| from enum import Enum, auto | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms.functional as TVF | |
| from PIL import Image | |
| from safetensors import safe_open | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| # ============================================================================== | |
| # --- CRITICAL: DEPENDENCY PATH INJECTION --- | |
| # ============================================================================== | |
| # Define o caminho para o reposit贸rio clonado | |
| LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") | |
| def add_deps_to_path(): | |
| """ | |
| Adiciona o diret贸rio do reposit贸rio LTX ao sys.path para garantir que suas | |
| bibliotecas possam ser importadas. | |
| """ | |
| repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if repo_path not in sys.path: | |
| sys.path.insert(0, repo_path) | |
| logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}") | |
| # Executa a fun莽茫o imediatamente para configurar o ambiente antes de qualquer importa莽茫o. | |
| add_deps_to_path() | |
| # ============================================================================== | |
| # --- IMPORTA脟脮ES DA BIBLIOTECA LTX-VIDEO (Ap贸s configura莽茫o do path) --- | |
| # ============================================================================== | |
| try: | |
| from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline | |
| from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler | |
| from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
| from ltx_video.models.transformers.transformer3d import Transformer3DModel | |
| from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier | |
| from ltx_video.schedulers.rf import RectifiedFlowScheduler | |
| from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents | |
| import ltx_video.pipelines.crf_compressor as crf_compressor | |
| except ImportError as e: | |
| raise ImportError(f"Could not import from LTX-Video library even after setting sys.path. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}") | |
| # ============================================================================== | |
| # --- ESTRUTURAS DE DADOS E ENUMS (Centralizadas aqui) --- | |
| # ============================================================================== | |
| class ConditioningItem: | |
| """Define a single frame-conditioning item, used to guide the generation pipeline.""" | |
| media_item: torch.Tensor | |
| media_frame_number: int | |
| conditioning_strength: float | |
| media_x: Optional[int] = None | |
| media_y: Optional[int] = None | |
| class SkipLayerStrategy(Enum): | |
| """Defines the strategy for how spatio-temporal guidance is applied across transformer blocks.""" | |
| AttentionSkip = auto() | |
| AttentionValues = auto() | |
| Residual = auto() | |
| TransformerBlock = auto() | |
| # ============================================================================== | |
| # --- FUN脟脮ES DE CONSTRU脟脙O DE MODELO E PIPELINE --- | |
| # ============================================================================== | |
| def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler: | |
| """Loads the Latent Upsampler model from a checkpoint path.""" | |
| logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}") | |
| latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path) | |
| latent_upsampler.to(device) | |
| latent_upsampler.eval() | |
| return latent_upsampler | |
| def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]: | |
| """Builds the complete LTX pipeline and upsampler on the CPU.""" | |
| t0 = time.perf_counter() | |
| logging.info("Building LTX pipeline on CPU...") | |
| ckpt_path = Path(config["checkpoint_path"]) | |
| if not ckpt_path.is_file(): | |
| raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}") | |
| with safe_open(ckpt_path, framework="pt") as f: | |
| metadata = f.metadata() or {} | |
| config_str = metadata.get("config", "{}") | |
| configs = json.loads(config_str) | |
| allowed_inference_steps = configs.get("allowed_inference_steps") | |
| vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu") | |
| transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu") | |
| scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) | |
| text_encoder_path = config["text_encoder_model_name_or_path"] | |
| text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu") | |
| tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer") | |
| patchifier = SymmetricPatchifier(patch_size=1) | |
| precision = config.get("precision", "bfloat16") | |
| if precision == "bfloat16": | |
| vae.to(torch.bfloat16) | |
| transformer.to(torch.bfloat16) | |
| text_encoder.to(torch.bfloat16) | |
| pipeline = LTXVideoPipeline( | |
| transformer=transformer, patchifier=patchifier, text_encoder=text_encoder, | |
| tokenizer=tokenizer, scheduler=scheduler, vae=vae, | |
| allowed_inference_steps=allowed_inference_steps, | |
| prompt_enhancer_image_caption_model=None, prompt_enhancer_image_caption_processor=None, | |
| prompt_enhancer_llm_model=None, prompt_enhancer_llm_tokenizer=None, | |
| ) | |
| latent_upsampler = None | |
| if config.get("spatial_upscaler_model_path"): | |
| spatial_path = config["spatial_upscaler_model_path"] | |
| latent_upsampler = create_latent_upsampler(spatial_path, device="cpu") | |
| if precision == "bfloat16": | |
| latent_upsampler.to(torch.bfloat16) | |
| logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s") | |
| return pipeline, latent_upsampler | |
| # ============================================================================== | |
| # --- FUN脟脮ES AUXILIARES (Latent Processing, Seed, Image Prep) --- | |
| # ============================================================================== | |
| def adain_filter_latent( | |
| latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0 | |
| ) -> torch.Tensor: | |
| """Applies AdaIN to transfer the style from a reference latent to another.""" | |
| result = latents.clone() | |
| for i in range(latents.size(0)): | |
| for c in range(latents.size(1)): | |
| r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) | |
| i_sd, i_mean = torch.std_mean(result[i, c], dim=None) | |
| if i_sd > 1e-6: | |
| result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean | |
| return torch.lerp(latents, result, factor) | |
| def seed_everything(seed: int): | |
| """Sets the seed for reproducibility.""" | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def load_image_to_tensor_with_resize_and_crop( | |
| image_input: Union[str, Image.Image], | |
| target_height: int, | |
| target_width: int, | |
| ) -> torch.Tensor: | |
| """Loads and processes an image into a 5D tensor compatible with the LTX pipeline.""" | |
| if isinstance(image_input, str): | |
| image = Image.open(image_input).convert("RGB") | |
| elif isinstance(image_input, Image.Image): | |
| image = image_input | |
| else: | |
| raise ValueError("image_input must be a file path or a PIL Image object") | |
| input_width, input_height = image.size | |
| aspect_ratio_target = target_width / target_height | |
| aspect_ratio_frame = input_width / input_height | |
| if aspect_ratio_frame > aspect_ratio_target: | |
| new_width, new_height = int(input_height * aspect_ratio_target), input_height | |
| x_start, y_start = (input_width - new_width) // 2, 0 | |
| else: | |
| new_width, new_height = input_width, int(input_width / aspect_ratio_target) | |
| x_start, y_start = 0, (input_height - new_height) // 2 | |
| image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) | |
| image = image.resize((target_width, target_height), Image.Resampling.LANCZOS) | |
| frame_tensor = TVF.to_tensor(image) | |
| frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3)) | |
| frame_tensor_hwc = frame_tensor.permute(1, 2, 0) | |
| frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc) | |
| frame_tensor = frame_tensor_hwc.permute(2, 0, 1) | |
| # Normalize to [-1, 1] range | |
| frame_tensor = (frame_tensor * 2.0) - 1.0 | |
| # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width) | |
| return frame_tensor.unsqueeze(0).unsqueeze(2) |