eeuuia commited on
Commit
db85898
·
verified ·
1 Parent(s): 3a201e7

Create ltx_utils.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_utils.py +203 -0
api/ltx/ltx_utils.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FILE: api/ltx/ltx_utils.py
2
+ # DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline.
3
+ # Handles dependency path injection, model loading, data structures, and helper functions.
4
+
5
+ import os
6
+ import random
7
+ import json
8
+ import logging
9
+ import time
10
+ import sys
11
+ from pathlib import Path
12
+ from typing import Dict, Optional, Tuple, Union
13
+ from dataclasses import dataclass
14
+ from enum import Enum, auto
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torchvision.transforms.functional as TVF
19
+ from PIL import Image
20
+ from safetensors import safe_open
21
+ from transformers import T5EncoderModel, T5Tokenizer
22
+
23
+ # ==============================================================================
24
+ # --- CRITICAL: DEPENDENCY PATH INJECTION ---
25
+ # ==============================================================================
26
+
27
+ # Define o caminho para o repositório clonado
28
+ LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
29
+
30
+ def add_deps_to_path():
31
+ """
32
+ Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
33
+ bibliotecas possam ser importadas.
34
+ """
35
+ repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
36
+ if repo_path not in sys.path:
37
+ sys.path.insert(0, repo_path)
38
+ logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
39
+
40
+ # Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
41
+ add_deps_to_path()
42
+
43
+
44
+ # --- Importações da Biblioteca LTX-Video (Agora devem funcionar) ---
45
+ try:
46
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
47
+ from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
48
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
49
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
50
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
51
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
52
+ from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents
53
+ import ltx_video.pipelines.crf_compressor as crf_compressor
54
+ except ImportError as e:
55
+ raise ImportError(f"Could not import from LTX-Video library even after setting sys.path. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
56
+
57
+
58
+ # ==============================================================================
59
+ # --- ESTRUTURAS DE DADOS E ENUMS (Centralizadas aqui) ---
60
+ # ==============================================================================
61
+
62
+ @dataclass
63
+ class ConditioningItem:
64
+ """Defines a single frame-conditioning item, used to guide the generation pipeline."""
65
+ media_item: torch.Tensor
66
+ media_frame_number: int
67
+ conditioning_strength: float
68
+ media_x: Optional[int] = None
69
+ media_y: Optional[int] = None
70
+
71
+
72
+ class SkipLayerStrategy(Enum):
73
+ """Defines the strategy for how spatio-temporal guidance is applied."""
74
+ AttentionSkip = auto()
75
+ AttentionValues = auto()
76
+ Residual = auto()
77
+ TransformerBlock = auto()
78
+
79
+
80
+ # ==============================================================================
81
+ # --- FUNÇÕES DE CONSTRUÇÃO DE MODELO E PIPELINE ---
82
+ # ==============================================================================
83
+
84
+ def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler:
85
+ """Loads the Latent Upsampler model from a checkpoint path."""
86
+ logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}")
87
+ latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
88
+ latent_upsampler.to(device)
89
+ latent_upsampler.eval()
90
+ return latent_upsampler
91
+
92
+ def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
93
+ """Builds the complete LTX pipeline and upsampler on the CPU."""
94
+ t0 = time.perf_counter()
95
+ logging.info("Building LTX pipeline on CPU...")
96
+
97
+ ckpt_path = Path(config["checkpoint_path"])
98
+ if not ckpt_path.is_file():
99
+ raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}")
100
+
101
+ with safe_open(ckpt_path, framework="pt") as f:
102
+ metadata = f.metadata() or {}
103
+ config_str = metadata.get("config", "{}")
104
+ configs = json.loads(config_str)
105
+ allowed_inference_steps = configs.get("allowed_inference_steps")
106
+
107
+ vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
108
+ transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu")
109
+ scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
110
+
111
+ text_encoder_path = config["text_encoder_model_name_or_path"]
112
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
113
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
114
+ patchifier = SymmetricPatchifier(patch_size=1)
115
+
116
+ precision = config.get("precision", "bfloat16")
117
+ if precision == "bfloat16":
118
+ vae.to(torch.bfloat16)
119
+ transformer.to(torch.bfloat16)
120
+ text_encoder.to(torch.bfloat16)
121
+
122
+ pipeline = LTXVideoPipeline(
123
+ transformer=transformer, patchifier=patchifier, text_encoder=text_encoder,
124
+ tokenizer=tokenizer, scheduler=scheduler, vae=vae,
125
+ allowed_inference_steps=allowed_inference_steps,
126
+ prompt_enhancer_image_caption_model=None, prompt_enhancer_image_caption_processor=None,
127
+ prompt_enhancer_llm_model=None, prompt_enhancer_llm_tokenizer=None,
128
+ )
129
+
130
+ latent_upsampler = None
131
+ if config.get("spatial_upscaler_model_path"):
132
+ spatial_path = config["spatial_upscaler_model_path"]
133
+ latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
134
+ if precision == "bfloat16":
135
+ latent_upsampler.to(torch.bfloat16)
136
+
137
+ logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s")
138
+ return pipeline, latent_upsampler
139
+
140
+
141
+ # ==============================================================================
142
+ # --- FUNÇÕES AUXILIARES (Latent Processing, Seed, Image Prep) ---
143
+ # ==============================================================================
144
+
145
+ def adain_filter_latent(
146
+ latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
147
+ ) -> torch.Tensor:
148
+ """Applies AdaIN to transfer the style from a reference latent to another."""
149
+ result = latents.clone()
150
+ for i in range(latents.size(0)):
151
+ for c in range(latents.size(1)):
152
+ r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None)
153
+ i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
154
+ if i_sd > 1e-6:
155
+ result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
156
+ return torch.lerp(latents, result, factor)
157
+
158
+ def seed_everything(seed: int):
159
+ """Sets the seed for reproducibility."""
160
+ random.seed(seed)
161
+ os.environ['PYTHONHASHSEED'] = str(seed)
162
+ np.random.seed(seed)
163
+ torch.manual_seed(seed)
164
+ torch.cuda.manual_seed_all(seed)
165
+ torch.backends.cudnn.deterministic = True
166
+ torch.backends.cudnn.benchmark = False
167
+
168
+ def load_image_to_tensor_with_resize_and_crop(
169
+ image_input: Union[str, Image.Image],
170
+ target_height: int,
171
+ target_width: int,
172
+ ) -> torch.Tensor:
173
+ """Loads and processes an image into a 5D tensor compatible with the LTX pipeline."""
174
+ if isinstance(image_input, str):
175
+ image = Image.open(image_input).convert("RGB")
176
+ elif isinstance(image_input, Image.Image):
177
+ image = image_input
178
+ else:
179
+ raise ValueError("image_input must be a file path or a PIL Image object")
180
+
181
+ input_width, input_height = image.size
182
+ aspect_ratio_target = target_width / target_height
183
+ aspect_ratio_frame = input_width / input_height
184
+
185
+ if aspect_ratio_frame > aspect_ratio_target:
186
+ new_width, new_height = int(input_height * aspect_ratio_target), input_height
187
+ x_start, y_start = (input_width - new_width) // 2, 0
188
+ else:
189
+ new_width, new_height = input_width, int(input_width / aspect_ratio_target)
190
+ x_start, y_start = 0, (input_height - new_height) // 2
191
+
192
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
193
+ image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
194
+
195
+ frame_tensor = TVF.to_tensor(image)
196
+ frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
197
+
198
+ frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
199
+ frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
200
+ frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
201
+ frame_tensor = (frame_tensor * 2.0) - 1.0
202
+
203
+ return frame_tensor.unsqueeze(0).unsqueeze(2)