Spaces:
Paused
Paused
| # FILE: api/ltx/ltx_utils.py | |
| # DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline. | |
| # Handles dependency path injection, model loading, pipeline creation, and tensor preparation. | |
| import os | |
| import random | |
| import json | |
| import time | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, Optional, Tuple, Union | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms.functional as TVF | |
| from PIL import Image | |
| from safetensors import safe_open | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| import logging | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", message=".*") | |
| from huggingface_hub import logging as ll | |
| ll.set_verbosity_error() | |
| ll.set_verbosity_warning() | |
| ll.set_verbosity_info() | |
| from utils.debug_utils import log_function_io | |
| ll.set_verbosity_debug() | |
| # ============================================================================== | |
| # --- CRITICAL: DEPENDENCY PATH INJECTION --- | |
| # ============================================================================== | |
| # Define o caminho para o repositório clonado | |
| LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") | |
| LTX_REPO_ID = "Lightricks/LTX-Video" | |
| CACHE_DIR = os.environ.get("HF_HOME") | |
| # ============================================================================== | |
| # --- IMPORTAÇÕES DA BIBLIOTECA LTX-VIDEO (Após configuração do path) --- | |
| # ============================================================================== | |
| repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if repo_path not in sys.path: | |
| sys.path.insert(0, repo_path) | |
| from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline | |
| from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler | |
| from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
| from ltx_video.models.transformers.transformer3d import Transformer3DModel | |
| from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier | |
| from ltx_video.schedulers.rf import RectifiedFlowScheduler | |
| import ltx_video.pipelines.crf_compressor as crf_compressor | |
| # ============================================================================== | |
| # --- FUNÇÕES DE CONSTRUÇÃO DE MODELO E PIPELINE --- | |
| # ============================================================================== | |
| def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler: | |
| """Loads the Latent Upsampler model from a checkpoint path.""" | |
| logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}") | |
| latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path) | |
| latent_upsampler.to(device) | |
| latent_upsampler.eval() | |
| return latent_upsampler | |
| def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]: | |
| """Builds the complete LTX pipeline and upsampler on the CPU.""" | |
| t0 = time.perf_counter() | |
| logging.info("Building LTX pipeline on CPU...") | |
| ckpt_path_str = hf_hub_download(repo_id=LTX_REPO_ID, filename=config["checkpoint_path"], cache_dir=CACHE_DIR) | |
| ckpt_path = Path(ckpt_path_str) | |
| if not ckpt_path.is_file(): | |
| raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}") | |
| with safe_open(ckpt_path, framework="pt") as f: | |
| metadata = f.metadata() or {} | |
| config_str = metadata.get("config", "{}") | |
| configs = json.loads(config_str) | |
| allowed_inference_steps = configs.get("allowed_inference_steps") | |
| vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu") | |
| transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu") | |
| scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) | |
| text_encoder_path = config["text_encoder_model_name_or_path"] | |
| text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu") | |
| tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer") | |
| patchifier = SymmetricPatchifier(patch_size=1) | |
| precision = config.get("precision", "bfloat16") | |
| if precision == "bfloat16": | |
| vae.to(torch.bfloat16) | |
| transformer.to(torch.bfloat16) | |
| text_encoder.to(torch.bfloat16) | |
| pipeline = LTXVideoPipeline( | |
| transformer=transformer, patchifier=patchifier, text_encoder=text_encoder, | |
| tokenizer=tokenizer, scheduler=scheduler, vae=vae, | |
| allowed_inference_steps=allowed_inference_steps, | |
| prompt_enhancer_image_caption_model=None, prompt_enhancer_image_caption_processor=None, | |
| prompt_enhancer_llm_model=None, prompt_enhancer_llm_tokenizer=None, | |
| ) | |
| latent_upsampler = None | |
| if config.get("spatial_upscaler_model_path"): | |
| spatial_path_str = hf_hub_download(repo_id=LTX_REPO_ID, filename=config["spatial_upscaler_model_path"], cache_dir=CACHE_DIR) | |
| spatial_path = Path(spatial_path_str) | |
| if not ckpt_path.is_file(): | |
| raise FileNotFoundError(f"Main checkpoint file not found: {spatial_path}") | |
| latent_upsampler = create_latent_upsampler(spatial_path, device="cpu") | |
| if precision == "bfloat16": | |
| latent_upsampler.to(torch.bfloat16) | |
| logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s") | |
| return pipeline, latent_upsampler | |
| # ============================================================================== | |
| # --- FUNÇÕES AUXILIARES (Seed, Preparação de Imagem) --- | |
| # ============================================================================== | |
| def seed_everything(seed: int): | |
| """Sets the seed for reproducibility.""" | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def load_image_to_tensor_with_resize_and_crop( | |
| image_input: Union[str, Image.Image], | |
| target_height: int, | |
| target_width: int, | |
| ) -> torch.Tensor: | |
| """Loads and processes an image into a 5D pixel tensor compatible with the LTX pipeline.""" | |
| if isinstance(image_input, str): | |
| image = Image.open(image_input).convert("RGB") | |
| elif isinstance(image_input, Image.Image): | |
| image = image_input | |
| else: | |
| raise ValueError("image_input must be a file path or a PIL Image object") | |
| input_width, input_height = image.size | |
| aspect_ratio_target = target_width / target_height | |
| aspect_ratio_frame = input_width / input_height | |
| if aspect_ratio_frame > aspect_ratio_target: | |
| new_width, new_height = int(input_height * aspect_ratio_target), input_height | |
| x_start, y_start = (input_width - new_width) // 2, 0 | |
| else: | |
| new_width, new_height = input_width, int(input_width / aspect_ratio_target) | |
| x_start, y_start = 0, (input_height - new_height) // 2 | |
| image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) | |
| image = image.resize((target_width, target_height), Image.Resampling.LANCZOS) | |
| frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W) in [0, 1] range | |
| frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3)) | |
| frame_tensor_hwc = frame_tensor.permute(1, 2, 0) | |
| frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc) | |
| frame_tensor = frame_tensor_hwc.permute(2, 0, 1) | |
| # Normalize to [-1, 1] range, which the VAE expects for encoding | |
| frame_tensor = (frame_tensor * 2.0) - 1.0 | |
| # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width) | |
| return frame_tensor.unsqueeze(0).unsqueeze(2) | |