Spaces:
Paused
Paused
File size: 7,841 Bytes
9185209 b42e494 db85898 9185209 db85898 206971d b42e494 9185209 db85898 b42e494 1c4fd91 b42e494 9a6b3d7 b42e494 2f3144c 0d5ddf4 db85898 f18c91b 206971d f18c91b febdd12 f18c91b 206971d db85898 b42e494 db85898 b42e494 9185209 0d8f151 9185209 a5cbce5 b42e494 9a6b3d7 a7e6912 9a6b3d7 b42e494 9a6b3d7 b42e494 9a6b3d7 b42e494 9a6b3d7 febdd12 b42e494 9a6b3d7 febdd12 b42e494 9a6b3d7 80792cf b42e494 80792cf b42e494 9a6b3d7 b42e494 9a6b3d7 b42e494 9a6b3d7 b42e494 9a6b3d7 b42e494 9a6b3d7 b42e494 9a6b3d7 b42e494 9a6b3d7 b42e494 9a6b3d7 b42e494 77294b1 5ba240b 77294b1 b42e494 77294b1 b42e494 9a6b3d7 b42e494 a5cbce5 febdd12 a5cbce5 b42e494 a5cbce5 b42e494 febdd12 a5cbce5 b42e494 a5cbce5 b42e494 a5cbce5 b42e494 a5cbce5 b42e494 a5cbce5 b42e494 a5cbce5 b42e494 9a6b3d7 b42e494 d6edb59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
# FILE: api/ltx/ltx_utils.py
# DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline.
# Handles dependency path injection, model loading, pipeline creation, and tensor preparation.
import os
import random
import json
import time
import sys
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from huggingface_hub import hf_hub_download
import numpy as np
import torch
import torchvision.transforms.functional as TVF
from PIL import Image
from safetensors import safe_open
from transformers import T5EncoderModel, T5Tokenizer
import logging
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*")
from huggingface_hub import logging as ll
ll.set_verbosity_error()
ll.set_verbosity_warning()
ll.set_verbosity_info()
from utils.debug_utils import log_function_io
ll.set_verbosity_debug()
# ==============================================================================
# --- CRITICAL: DEPENDENCY PATH INJECTION ---
# ==============================================================================
# Define o caminho para o repositório clonado
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
LTX_REPO_ID = "Lightricks/LTX-Video"
CACHE_DIR = os.environ.get("HF_HOME")
# ==============================================================================
# --- IMPORTAÇÕES DA BIBLIOTECA LTX-VIDEO (Após configuração do path) ---
# ==============================================================================
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
if repo_path not in sys.path:
sys.path.insert(0, repo_path)
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from ltx_video.models.transformers.transformer3d import Transformer3DModel
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
from ltx_video.schedulers.rf import RectifiedFlowScheduler
import ltx_video.pipelines.crf_compressor as crf_compressor
# ==============================================================================
# --- FUNÇÕES DE CONSTRUÇÃO DE MODELO E PIPELINE ---
# ==============================================================================
@log_function_io
def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler:
"""Loads the Latent Upsampler model from a checkpoint path."""
logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}")
latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
latent_upsampler.to(device)
latent_upsampler.eval()
return latent_upsampler
@log_function_io
def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
"""Builds the complete LTX pipeline and upsampler on the CPU."""
t0 = time.perf_counter()
logging.info("Building LTX pipeline on CPU...")
ckpt_path_str = hf_hub_download(repo_id=LTX_REPO_ID, filename=config["checkpoint_path"], cache_dir=CACHE_DIR)
ckpt_path = Path(ckpt_path_str)
if not ckpt_path.is_file():
raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}")
with safe_open(ckpt_path, framework="pt") as f:
metadata = f.metadata() or {}
config_str = metadata.get("config", "{}")
configs = json.loads(config_str)
allowed_inference_steps = configs.get("allowed_inference_steps")
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu")
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
text_encoder_path = config["text_encoder_model_name_or_path"]
text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
patchifier = SymmetricPatchifier(patch_size=1)
precision = config.get("precision", "bfloat16")
if precision == "bfloat16":
vae.to(torch.bfloat16)
transformer.to(torch.bfloat16)
text_encoder.to(torch.bfloat16)
pipeline = LTXVideoPipeline(
transformer=transformer, patchifier=patchifier, text_encoder=text_encoder,
tokenizer=tokenizer, scheduler=scheduler, vae=vae,
allowed_inference_steps=allowed_inference_steps,
prompt_enhancer_image_caption_model=None, prompt_enhancer_image_caption_processor=None,
prompt_enhancer_llm_model=None, prompt_enhancer_llm_tokenizer=None,
)
latent_upsampler = None
if config.get("spatial_upscaler_model_path"):
spatial_path_str = hf_hub_download(repo_id=LTX_REPO_ID, filename=config["spatial_upscaler_model_path"], cache_dir=CACHE_DIR)
spatial_path = Path(spatial_path_str)
if not ckpt_path.is_file():
raise FileNotFoundError(f"Main checkpoint file not found: {spatial_path}")
latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
if precision == "bfloat16":
latent_upsampler.to(torch.bfloat16)
logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s")
return pipeline, latent_upsampler
# ==============================================================================
# --- FUNÇÕES AUXILIARES (Seed, Preparação de Imagem) ---
# ==============================================================================
@log_function_io
def seed_everything(seed: int):
"""Sets the seed for reproducibility."""
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
@log_function_io
def load_image_to_tensor_with_resize_and_crop(
image_input: Union[str, Image.Image],
target_height: int,
target_width: int,
) -> torch.Tensor:
"""Loads and processes an image into a 5D pixel tensor compatible with the LTX pipeline."""
if isinstance(image_input, str):
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input
else:
raise ValueError("image_input must be a file path or a PIL Image object")
input_width, input_height = image.size
aspect_ratio_target = target_width / target_height
aspect_ratio_frame = input_width / input_height
if aspect_ratio_frame > aspect_ratio_target:
new_width, new_height = int(input_height * aspect_ratio_target), input_height
x_start, y_start = (input_width - new_width) // 2, 0
else:
new_width, new_height = input_width, int(input_width / aspect_ratio_target)
x_start, y_start = 0, (input_height - new_height) // 2
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W) in [0, 1] range
frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
# Normalize to [-1, 1] range, which the VAE expects for encoding
frame_tensor = (frame_tensor * 2.0) - 1.0
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
return frame_tensor.unsqueeze(0).unsqueeze(2)
|