eeuuia commited on
Commit
a5cbce5
·
verified ·
1 Parent(s): fb84d49

Update api/ltx/ltx_utils.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_utils.py +177 -0
api/ltx/ltx_utils.py CHANGED
@@ -10,6 +10,8 @@ import logging
10
  import sys
11
  from pathlib import Path
12
  from typing import Dict, Tuple
 
 
13
 
14
  import torch
15
  from safetensors import safe_open
@@ -114,6 +116,181 @@ def build_complete_pipeline_on_cpu(checkpoint_path: str, config: Dict) -> LTXVid
114
  # --- FUNÇÕES AUXILIARES GENÉRICAS ---
115
  # ==============================================================================
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def seed_everything(seed: int):
118
  """
119
  Define a semente para PyTorch, NumPy e Python para garantir reprodutibilidade.
 
10
  import sys
11
  from pathlib import Path
12
  from typing import Dict, Tuple
13
+ import torchvision.transforms.functional as TVF
14
+ from PIL import Image
15
 
16
  import torch
17
  from safetensors import safe_open
 
116
  # --- FUNÇÕES AUXILIARES GENÉRICAS ---
117
  # ==============================================================================
118
 
119
+
120
+ # # FILE: api/ltx/ltx_utils.py
121
+ # DESCRIPTION: A pure utility library for the LTX ecosystem.
122
+ # Contains the official low-level builder function for the complete pipeline
123
+ # and other stateless helper functions.
124
+
125
+ import os
126
+ import random
127
+ import json
128
+ import logging
129
+ import sys
130
+ from pathlib import Path
131
+ from typing import Dict, Tuple
132
+
133
+ import torch
134
+ from safetensors import safe_open
135
+ from transformers import T5EncoderModel, T5Tokenizer
136
+
137
+ # ==============================================================================
138
+ # --- CONFIGURAÇÃO DE PATH E IMPORTS DA BIBLIOTECA LTX ---
139
+ # ==============================================================================
140
+
141
+ LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
142
+
143
+ def add_deps_to_path():
144
+ """Adiciona o diretório do repositório LTX ao sys.path para importação de suas bibliotecas."""
145
+ repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
146
+ if repo_path not in sys.path:
147
+ sys.path.insert(0, repo_path)
148
+ logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
149
+
150
+ add_deps_to_path()
151
+
152
+ try:
153
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
154
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
155
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
156
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
157
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
158
+ except ImportError as e:
159
+ logging.critical("Failed to import a core LTX-Video library component.", exc_info=True)
160
+ raise ImportError(f"Could not import from LTX-Video library. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
161
+
162
+ # ==============================================================================
163
+ # --- FUNÇÃO HELPER 'create_transformer' (Essencial) ---
164
+ # ==============================================================================
165
+
166
+ def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel:
167
+ """
168
+ Cria e carrega o modelo Transformer3D com a lógica de precisão correta,
169
+ incluindo suporte para a otimização float8_e4m3fn.
170
+ """
171
+ if precision == "float8_e4m3fn":
172
+ try:
173
+ from q8_kernels.integration.patch_transformer import patch_diffusers_transformer as patch_transformer_for_q8_kernels
174
+ transformer = Transformer3DModel.from_pretrained(ckpt_path, dtype=torch.float8_e4m3fn)
175
+ patch_transformer_for_q8_kernels(transformer)
176
+ return transformer
177
+ except ImportError:
178
+ raise ValueError("Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from the project's wheels.")
179
+ elif precision == "bfloat16":
180
+ return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16)
181
+ else:
182
+ return Transformer3DModel.from_pretrained(ckpt_path)
183
+
184
+ # ==============================================================================
185
+ # --- BUILDER DE BAIXO NÍVEL OFICIAL ---
186
+ # ==============================================================================
187
+
188
+ def build_complete_pipeline_on_cpu(checkpoint_path: str, config: Dict) -> LTXVideoPipeline:
189
+ """
190
+ Constrói o pipeline LTX COMPLETO, incluindo o VAE, e o mantém na CPU.
191
+ Esta é a função de construção fundamental usada pelo LTXAducManager.
192
+ """
193
+ logging.info(f"Building complete LTX pipeline from checkpoint: {Path(checkpoint_path).name}")
194
+
195
+ with safe_open(checkpoint_path, framework="pt") as f:
196
+ metadata = f.metadata() or {}
197
+ config_str = metadata.get("config", "{}")
198
+ allowed_inference_steps = json.loads(config_str).get("allowed_inference_steps")
199
+
200
+ precision = config.get("precision", "bfloat16")
201
+
202
+ # Usa a função helper correta para criar o transformer
203
+ transformer = create_transformer(checkpoint_path, precision).to("cpu")
204
+
205
+ scheduler = RectifiedFlowScheduler.from_pretrained(checkpoint_path)
206
+ text_encoder = T5EncoderModel.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="text_encoder").to("cpu")
207
+ tokenizer = T5Tokenizer.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="tokenizer")
208
+ patchifier = SymmetricPatchifier(patch_size=1)
209
+ vae = CausalVideoAutoencoder.from_pretrained(checkpoint_path).to("cpu")
210
+
211
+ if precision == "bfloat16":
212
+ text_encoder.to(torch.bfloat16)
213
+ vae.to(torch.bfloat16)
214
+ # O transformer já foi convertido para bfloat16 dentro de create_transformer, se aplicável
215
+
216
+ pipeline = LTXVideoPipeline(
217
+ transformer=transformer,
218
+ patchifier=patchifier,
219
+ text_encoder=text_encoder,
220
+ tokenizer=tokenizer,
221
+ scheduler=scheduler,
222
+ vae=vae, # VAE é incluído para que o pipeline possa ser auto-suficiente
223
+ allowed_inference_steps=allowed_inference_steps,
224
+ prompt_enhancer_image_caption_model=None,
225
+ prompt_enhancer_image_caption_processor=None,
226
+ prompt_enhancer_llm_model=None,
227
+ prompt_enhancer_llm_tokenizer=None,
228
+ )
229
+
230
+ return pipeline
231
+
232
+ # ==============================================================================
233
+ # --- FUNÇÕES AUXILIARES GENÉRICAS ---
234
+ # ==============================================================================
235
+
236
+ def seed_everything(seed: int):
237
+ """
238
+ Define a semente para PyTorch, NumPy e Python para garantir reprodutibilidade.
239
+ """
240
+ random.seed(seed)
241
+ os.environ['PYTHONHASHSEED'] = str(seed)
242
+ np.random.seed(seed)
243
+ torch.manual_seed(seed)
244
+ torch.cuda.manual_seed_all(seed)
245
+ torch.backends.cudnn.deterministic = True
246
+ torch.backends.cudnn.benchmark = Fals
247
+
248
+ def load_image_to_tensor_with_resize_and_crop(
249
+ image_input: Union[str, Image.Image],
250
+ target_height: int,
251
+ target_width: int,
252
+ ) -> torch.Tensor:
253
+ """
254
+ Carrega, redimensiona, corta e processa uma imagem para um tensor de pixel 5D,
255
+ normalizado para [-1, 1], pronto para ser enviado ao VAE para encoding.
256
+ """
257
+ if isinstance(image_input, str):
258
+ image = Image.open(image_input).convert("RGB")
259
+ elif isinstance(image_input, Image.Image):
260
+ image = image_input.convert("RGB")
261
+ else:
262
+ raise ValueError("image_input must be a file path or a PIL Image object")
263
+
264
+ input_width, input_height = image.size
265
+ aspect_ratio_target = target_width / target_height
266
+ aspect_ratio_frame = input_width / input_height
267
+
268
+ if aspect_ratio_frame > aspect_ratio_target:
269
+ new_width, new_height = int(input_height * aspect_ratio_target), input_height
270
+ x_start = (input_width - new_width) // 2
271
+ image = image.crop((x_start, 0, x_start + new_width, new_height))
272
+ else:
273
+ new_height = int(input_width / aspect_ratio_target)
274
+ y_start = (input_height - new_height) // 2
275
+ image = image.crop((0, y_start, input_width, y_start + new_height))
276
+
277
+ image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
278
+
279
+ frame_tensor = TVF.to_tensor(image)
280
+
281
+ # Esta parte depende de 'crf_compressor', então precisamos importá-lo aqui também
282
+ try:
283
+ from ltx_video.pipelines import crf_compressor
284
+ frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
285
+ frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
286
+ frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
287
+ except ImportError:
288
+ logging.warning("CRF Compressor not found. Skipping compression step.")
289
+
290
+ frame_tensor = (frame_tensor * 2.0) - 1.0
291
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
292
+
293
+
294
  def seed_everything(seed: int):
295
  """
296
  Define a semente para PyTorch, NumPy e Python para garantir reprodutibilidade.