File size: 8,924 Bytes
9185209
 
 
db85898
9185209
 
db85898
 
975f84e
9185209
db85898
9185209
 
 
db85898
975f84e
9185209
 
 
 
 
db85898
 
9185209
db85898
 
9185209
 
 
db85898
9185209
 
 
 
db85898
 
 
9185209
db85898
9185209
db85898
 
9185209
 
 
 
db85898
9185209
 
 
 
 
 
 
 
db85898
9185209
 
db85898
 
9185209
db85898
 
9185209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db85898
 
9185209
db85898
 
9185209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db85898
9185209
 
 
 
 
 
 
 
 
 
db85898
9185209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db85898
 
 
9185209
db85898
9185209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# FILE: api/ltx/ltx_utils.py
# DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline.
# Handles dependency path injection, model loading, data structures, and helper functions.

import os
import random
import json
import logging
import time
import sys
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
from enum import Enum, auto

import numpy as np
import torch
import torchvision.transforms.functional as TVF
from PIL import Image
from safetensors import safe_open
from transformers import T5EncoderModel, T5Tokenizer

# ==============================================================================
# --- CRITICAL: DEPENDENCY PATH INJECTION ---
# ==============================================================================

# Define o caminho para o repositório clonado
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")

def add_deps_to_path():
    """
    Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
    bibliotecas possam ser importadas.
    """
    repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
    if repo_path not in sys.path:
        sys.path.insert(0, repo_path)
        logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")

# Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
add_deps_to_path()


# ==============================================================================
# --- IMPORTAÇÕES DA BIBLIOTECA LTX-VIDEO (Após configuração do path) ---
# ==============================================================================
try:
    from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
    from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
    from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
    from ltx_video.models.transformers.transformer3d import Transformer3DModel
    from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
    from ltx_video.schedulers.rf import RectifiedFlowScheduler
    from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents
    import ltx_video.pipelines.crf_compressor as crf_compressor
except ImportError as e:
    raise ImportError(f"Could not import from LTX-Video library even after setting sys.path. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")


# ==============================================================================
# --- ESTRUTURAS DE DADOS E ENUMS (Centralizadas aqui) ---
# ==============================================================================

@dataclass
class ConditioningItem:
    """Define a single frame-conditioning item, used to guide the generation pipeline."""
    media_item: torch.Tensor
    media_frame_number: int
    conditioning_strength: float
    media_x: Optional[int] = None
    media_y: Optional[int] = None


class SkipLayerStrategy(Enum):
    """Defines the strategy for how spatio-temporal guidance is applied across transformer blocks."""
    AttentionSkip = auto()
    AttentionValues = auto()
    Residual = auto()
    TransformerBlock = auto()


# ==============================================================================
# --- FUNÇÕES DE CONSTRUÇÃO DE MODELO E PIPELINE ---
# ==============================================================================

def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler:
    """Loads the Latent Upsampler model from a checkpoint path."""
    logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}")
    latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
    latent_upsampler.to(device)
    latent_upsampler.eval()
    return latent_upsampler

def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
    """Builds the complete LTX pipeline and upsampler on the CPU."""
    t0 = time.perf_counter()
    logging.info("Building LTX pipeline on CPU...")

    ckpt_path = Path(config["checkpoint_path"])
    if not ckpt_path.is_file():
        raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}")

    with safe_open(ckpt_path, framework="pt") as f:
        metadata = f.metadata() or {}
        config_str = metadata.get("config", "{}")
        configs = json.loads(config_str)
        allowed_inference_steps = configs.get("allowed_inference_steps")

    vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
    transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu")
    scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
    
    text_encoder_path = config["text_encoder_model_name_or_path"]
    text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
    tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
    patchifier = SymmetricPatchifier(patch_size=1)

    precision = config.get("precision", "bfloat16")
    if precision == "bfloat16":
        vae.to(torch.bfloat16)
        transformer.to(torch.bfloat16)
        text_encoder.to(torch.bfloat16)
    
    pipeline = LTXVideoPipeline(
        transformer=transformer, patchifier=patchifier, text_encoder=text_encoder,
        tokenizer=tokenizer, scheduler=scheduler, vae=vae,
        allowed_inference_steps=allowed_inference_steps,
        prompt_enhancer_image_caption_model=None, prompt_enhancer_image_caption_processor=None,
        prompt_enhancer_llm_model=None, prompt_enhancer_llm_tokenizer=None,
    )

    latent_upsampler = None
    if config.get("spatial_upscaler_model_path"):
        spatial_path = config["spatial_upscaler_model_path"]
        latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
        if precision == "bfloat16":
            latent_upsampler.to(torch.bfloat16)

    logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s")
    return pipeline, latent_upsampler


# ==============================================================================
# --- FUNÇÕES AUXILIARES (Latent Processing, Seed, Image Prep) ---
# ==============================================================================

def adain_filter_latent(
    latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
) -> torch.Tensor:
    """Applies AdaIN to transfer the style from a reference latent to another."""
    result = latents.clone()
    for i in range(latents.size(0)):
        for c in range(latents.size(1)):
            r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None)
            i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
            if i_sd > 1e-6:
                result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
    return torch.lerp(latents, result, factor)

def seed_everything(seed: int):
    """Sets the seed for reproducibility."""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def load_image_to_tensor_with_resize_and_crop(
    image_input: Union[str, Image.Image],
    target_height: int,
    target_width: int,
) -> torch.Tensor:
    """Loads and processes an image into a 5D tensor compatible with the LTX pipeline."""
    if isinstance(image_input, str):
        image = Image.open(image_input).convert("RGB")
    elif isinstance(image_input, Image.Image):
        image = image_input
    else:
        raise ValueError("image_input must be a file path or a PIL Image object")

    input_width, input_height = image.size
    aspect_ratio_target = target_width / target_height
    aspect_ratio_frame = input_width / input_height

    if aspect_ratio_frame > aspect_ratio_target:
        new_width, new_height = int(input_height * aspect_ratio_target), input_height
        x_start, y_start = (input_width - new_width) // 2, 0
    else:
        new_width, new_height = input_width, int(input_width / aspect_ratio_target)
        x_start, y_start = 0, (input_height - new_height) // 2

    image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
    image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)

    frame_tensor = TVF.to_tensor(image)
    frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
    
    frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
    frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
    frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
    # Normalize to [-1, 1] range
    frame_tensor = (frame_tensor * 2.0) - 1.0
    
    # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
    return frame_tensor.unsqueeze(0).unsqueeze(2)