Test / api /ltx /ltx_utils.py
eeuuia's picture
Update api/ltx/ltx_utils.py
9185209 verified
raw
history blame
8.92 kB
# FILE: api/ltx/ltx_utils.py
# DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline.
# Handles dependency path injection, model loading, data structures, and helper functions.
import os
import random
import json
import logging
import time
import sys
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
from enum import Enum, auto
import numpy as np
import torch
import torchvision.transforms.functional as TVF
from PIL import Image
from safetensors import safe_open
from transformers import T5EncoderModel, T5Tokenizer
# ==============================================================================
# --- CRITICAL: DEPENDENCY PATH INJECTION ---
# ==============================================================================
# Define o caminho para o reposit贸rio clonado
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
def add_deps_to_path():
"""
Adiciona o diret贸rio do reposit贸rio LTX ao sys.path para garantir que suas
bibliotecas possam ser importadas.
"""
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
if repo_path not in sys.path:
sys.path.insert(0, repo_path)
logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
# Executa a fun莽茫o imediatamente para configurar o ambiente antes de qualquer importa莽茫o.
add_deps_to_path()
# ==============================================================================
# --- IMPORTA脟脮ES DA BIBLIOTECA LTX-VIDEO (Ap贸s configura莽茫o do path) ---
# ==============================================================================
try:
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from ltx_video.models.transformers.transformer3d import Transformer3DModel
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
from ltx_video.schedulers.rf import RectifiedFlowScheduler
from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents
import ltx_video.pipelines.crf_compressor as crf_compressor
except ImportError as e:
raise ImportError(f"Could not import from LTX-Video library even after setting sys.path. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
# ==============================================================================
# --- ESTRUTURAS DE DADOS E ENUMS (Centralizadas aqui) ---
# ==============================================================================
@dataclass
class ConditioningItem:
"""Define a single frame-conditioning item, used to guide the generation pipeline."""
media_item: torch.Tensor
media_frame_number: int
conditioning_strength: float
media_x: Optional[int] = None
media_y: Optional[int] = None
class SkipLayerStrategy(Enum):
"""Defines the strategy for how spatio-temporal guidance is applied across transformer blocks."""
AttentionSkip = auto()
AttentionValues = auto()
Residual = auto()
TransformerBlock = auto()
# ==============================================================================
# --- FUN脟脮ES DE CONSTRU脟脙O DE MODELO E PIPELINE ---
# ==============================================================================
def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler:
"""Loads the Latent Upsampler model from a checkpoint path."""
logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}")
latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
latent_upsampler.to(device)
latent_upsampler.eval()
return latent_upsampler
def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
"""Builds the complete LTX pipeline and upsampler on the CPU."""
t0 = time.perf_counter()
logging.info("Building LTX pipeline on CPU...")
ckpt_path = Path(config["checkpoint_path"])
if not ckpt_path.is_file():
raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}")
with safe_open(ckpt_path, framework="pt") as f:
metadata = f.metadata() or {}
config_str = metadata.get("config", "{}")
configs = json.loads(config_str)
allowed_inference_steps = configs.get("allowed_inference_steps")
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu")
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
text_encoder_path = config["text_encoder_model_name_or_path"]
text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
patchifier = SymmetricPatchifier(patch_size=1)
precision = config.get("precision", "bfloat16")
if precision == "bfloat16":
vae.to(torch.bfloat16)
transformer.to(torch.bfloat16)
text_encoder.to(torch.bfloat16)
pipeline = LTXVideoPipeline(
transformer=transformer, patchifier=patchifier, text_encoder=text_encoder,
tokenizer=tokenizer, scheduler=scheduler, vae=vae,
allowed_inference_steps=allowed_inference_steps,
prompt_enhancer_image_caption_model=None, prompt_enhancer_image_caption_processor=None,
prompt_enhancer_llm_model=None, prompt_enhancer_llm_tokenizer=None,
)
latent_upsampler = None
if config.get("spatial_upscaler_model_path"):
spatial_path = config["spatial_upscaler_model_path"]
latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
if precision == "bfloat16":
latent_upsampler.to(torch.bfloat16)
logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s")
return pipeline, latent_upsampler
# ==============================================================================
# --- FUN脟脮ES AUXILIARES (Latent Processing, Seed, Image Prep) ---
# ==============================================================================
def adain_filter_latent(
latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
) -> torch.Tensor:
"""Applies AdaIN to transfer the style from a reference latent to another."""
result = latents.clone()
for i in range(latents.size(0)):
for c in range(latents.size(1)):
r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None)
i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
if i_sd > 1e-6:
result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
return torch.lerp(latents, result, factor)
def seed_everything(seed: int):
"""Sets the seed for reproducibility."""
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def load_image_to_tensor_with_resize_and_crop(
image_input: Union[str, Image.Image],
target_height: int,
target_width: int,
) -> torch.Tensor:
"""Loads and processes an image into a 5D tensor compatible with the LTX pipeline."""
if isinstance(image_input, str):
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input
else:
raise ValueError("image_input must be a file path or a PIL Image object")
input_width, input_height = image.size
aspect_ratio_target = target_width / target_height
aspect_ratio_frame = input_width / input_height
if aspect_ratio_frame > aspect_ratio_target:
new_width, new_height = int(input_height * aspect_ratio_target), input_height
x_start, y_start = (input_width - new_width) // 2, 0
else:
new_width, new_height = input_width, int(input_width / aspect_ratio_target)
x_start, y_start = 0, (input_height - new_height) // 2
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
frame_tensor = TVF.to_tensor(image)
frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
# Normalize to [-1, 1] range
frame_tensor = (frame_tensor * 2.0) - 1.0
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
return frame_tensor.unsqueeze(0).unsqueeze(2)