eeuuia commited on
Commit
b42e494
·
verified ·
1 Parent(s): 655068e

Update api/ltx/ltx_utils.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_utils.py +84 -219
api/ltx/ltx_utils.py CHANGED
@@ -1,263 +1,144 @@
1
  # FILE: api/ltx/ltx_utils.py
2
- # DESCRIPTION: A pure utility library for the LTX ecosystem.
3
- # Contains the official low-level builder function for the complete pipeline
4
- # and other stateless helper functions.
5
 
6
  import os
7
  import random
8
  import json
9
  import logging
 
10
  import sys
11
  from pathlib import Path
12
- from typing import Dict, Tuple, Union
13
- import torchvision.transforms.functional as TVF
14
- from PIL import Image
15
 
 
16
  import torch
 
 
17
  from safetensors import safe_open
18
  from transformers import T5EncoderModel, T5Tokenizer
19
 
20
  # ==============================================================================
21
- # --- CONFIGURAÇÃO DE PATH E IMPORTS DA BIBLIOTECA LTX ---
22
  # ==============================================================================
23
 
 
24
  LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
25
 
26
  def add_deps_to_path():
27
- """Adiciona o diretório do repositório LTX ao sys.path para importação de suas bibliotecas."""
 
 
 
28
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
29
  if repo_path not in sys.path:
30
  sys.path.insert(0, repo_path)
31
  logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
32
 
 
33
  add_deps_to_path()
34
 
35
- try:
36
- from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
37
- from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
38
- from ltx_video.models.transformers.transformer3d import Transformer3DModel
39
- from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
40
- from ltx_video.schedulers.rf import RectifiedFlowScheduler
41
- except ImportError as e:
42
- logging.critical("Failed to import a core LTX-Video library component.", exc_info=True)
43
- raise ImportError(f"Could not import from LTX-Video library. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
44
-
45
- # ==============================================================================
46
- # --- FUNÇÃO HELPER 'create_transformer' (Essencial) ---
47
- # ==============================================================================
48
-
49
- def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel:
50
- """
51
- Cria e carrega o modelo Transformer3D com a lógica de precisão correta,
52
- incluindo suporte para a otimização float8_e4m3fn.
53
- """
54
- if precision == "float8_e4m3fn":
55
- try:
56
- from q8_kernels.integration.patch_transformer import patch_diffusers_transformer as patch_transformer_for_q8_kernels
57
- transformer = Transformer3DModel.from_pretrained(ckpt_path, dtype=torch.float8_e4m3fn)
58
- patch_transformer_for_q8_kernels(transformer)
59
- return transformer
60
- except ImportError:
61
- raise ValueError("Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from the project's wheels.")
62
- elif precision == "bfloat16":
63
- return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16)
64
- else:
65
- return Transformer3DModel.from_pretrained(ckpt_path)
66
-
67
- # ==============================================================================
68
- # --- BUILDER DE BAIXO NÍVEL OFICIAL ---
69
- # ==============================================================================
70
-
71
- def build_complete_pipeline_on_cpu(checkpoint_path: str, config: Dict) -> LTXVideoPipeline:
72
- """
73
- Constrói o pipeline LTX COMPLETO, incluindo o VAE, e o mantém na CPU.
74
- Esta é a função de construção fundamental usada pelo LTXAducManager.
75
- """
76
- logging.info(f"Building complete LTX pipeline from checkpoint: {Path(checkpoint_path).name}")
77
-
78
- with safe_open(checkpoint_path, framework="pt") as f:
79
- metadata = f.metadata() or {}
80
- config_str = metadata.get("config", "{}")
81
- allowed_inference_steps = json.loads(config_str).get("allowed_inference_steps")
82
-
83
- precision = config.get("precision", "bfloat16")
84
-
85
- # Usa a função helper correta para criar o transformer
86
- transformer = create_transformer(checkpoint_path, precision).to("cpu")
87
-
88
- scheduler = RectifiedFlowScheduler.from_pretrained(checkpoint_path)
89
- text_encoder = T5EncoderModel.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="text_encoder").to("cpu")
90
- tokenizer = T5Tokenizer.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="tokenizer")
91
- patchifier = SymmetricPatchifier(patch_size=1)
92
- vae = CausalVideoAutoencoder.from_pretrained(checkpoint_path).to("cpu")
93
-
94
- if precision == "bfloat16":
95
- text_encoder.to(torch.bfloat16)
96
- vae.to(torch.bfloat16)
97
- # O transformer já foi convertido para bfloat16 dentro de create_transformer, se aplicável
98
-
99
- pipeline = LTXVideoPipeline(
100
- transformer=transformer,
101
- patchifier=patchifier,
102
- text_encoder=text_encoder,
103
- tokenizer=tokenizer,
104
- scheduler=scheduler,
105
- vae=vae, # VAE é incluído para que o pipeline possa ser auto-suficiente
106
- allowed_inference_steps=allowed_inference_steps,
107
- prompt_enhancer_image_caption_model=None,
108
- prompt_enhancer_image_caption_processor=None,
109
- prompt_enhancer_llm_model=None,
110
- prompt_enhancer_llm_tokenizer=None,
111
- )
112
-
113
- return pipeline
114
-
115
- # ==============================================================================
116
- # --- FUNÇÕES AUXILIARES GENÉRICAS ---
117
- # ==============================================================================
118
-
119
-
120
- # # FILE: api/ltx/ltx_utils.py
121
- # DESCRIPTION: A pure utility library for the LTX ecosystem.
122
- # Contains the official low-level builder function for the complete pipeline
123
- # and other stateless helper functions.
124
-
125
- import os
126
- import random
127
- import json
128
- import logging
129
- import sys
130
- from pathlib import Path
131
- from typing import Dict, Tuple
132
-
133
- import torch
134
- from safetensors import safe_open
135
- from transformers import T5EncoderModel, T5Tokenizer
136
 
137
  # ==============================================================================
138
- # --- CONFIGURAÇÃO DE PATH E IMPORTS DA BIBLIOTECA LTX ---
139
  # ==============================================================================
140
-
141
- LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
142
-
143
- def add_deps_to_path():
144
- """Adiciona o diretório do repositório LTX ao sys.path para importação de suas bibliotecas."""
145
- repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
146
- if repo_path not in sys.path:
147
- sys.path.insert(0, repo_path)
148
- logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
149
-
150
- add_deps_to_path()
151
-
152
  try:
153
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
 
154
  from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
155
  from ltx_video.models.transformers.transformer3d import Transformer3DModel
156
  from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
157
  from ltx_video.schedulers.rf import RectifiedFlowScheduler
 
158
  except ImportError as e:
159
- logging.critical("Failed to import a core LTX-Video library component.", exc_info=True)
160
- raise ImportError(f"Could not import from LTX-Video library. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
161
 
162
  # ==============================================================================
163
- # --- FUNÇÃO HELPER 'create_transformer' (Essencial) ---
164
  # ==============================================================================
165
 
166
- def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel:
167
- """
168
- Cria e carrega o modelo Transformer3D com a lógica de precisão correta,
169
- incluindo suporte para a otimização float8_e4m3fn.
170
- """
171
- if precision == "float8_e4m3fn":
172
- try:
173
- from q8_kernels.integration.patch_transformer import patch_diffusers_transformer as patch_transformer_for_q8_kernels
174
- transformer = Transformer3DModel.from_pretrained(ckpt_path, dtype=torch.float8_e4m3fn)
175
- patch_transformer_for_q8_kernels(transformer)
176
- return transformer
177
- except ImportError:
178
- raise ValueError("Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from the project's wheels.")
179
- elif precision == "bfloat16":
180
- return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16)
181
- else:
182
- return Transformer3DModel.from_pretrained(ckpt_path)
183
 
184
- # ==============================================================================
185
- # --- BUILDER DE BAIXO NÍVEL OFICIAL ---
186
- # ==============================================================================
 
187
 
188
- def build_complete_pipeline_on_cpu(checkpoint_path: str, config: Dict) -> LTXVideoPipeline:
189
- """
190
- Constrói o pipeline LTX COMPLETO, incluindo o VAE, e o mantém na CPU.
191
- Esta é a função de construção fundamental usada pelo LTXAducManager.
192
- """
193
- logging.info(f"Building complete LTX pipeline from checkpoint: {Path(checkpoint_path).name}")
194
-
195
- with safe_open(checkpoint_path, framework="pt") as f:
196
  metadata = f.metadata() or {}
197
  config_str = metadata.get("config", "{}")
198
- allowed_inference_steps = json.loads(config_str).get("allowed_inference_steps")
 
199
 
200
- precision = config.get("precision", "bfloat16")
 
 
201
 
202
- # Usa a função helper correta para criar o transformer
203
- transformer = create_transformer(checkpoint_path, precision).to("cpu")
204
-
205
- scheduler = RectifiedFlowScheduler.from_pretrained(checkpoint_path)
206
- text_encoder = T5EncoderModel.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="text_encoder").to("cpu")
207
- tokenizer = T5Tokenizer.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="tokenizer")
208
  patchifier = SymmetricPatchifier(patch_size=1)
209
- vae = CausalVideoAutoencoder.from_pretrained(checkpoint_path).to("cpu")
210
 
 
211
  if precision == "bfloat16":
212
- text_encoder.to(torch.bfloat16)
213
  vae.to(torch.bfloat16)
214
- # O transformer já foi convertido para bfloat16 dentro de create_transformer, se aplicável
215
-
 
216
  pipeline = LTXVideoPipeline(
217
- transformer=transformer,
218
- patchifier=patchifier,
219
- text_encoder=text_encoder,
220
- tokenizer=tokenizer,
221
- scheduler=scheduler,
222
- vae=vae, # VAE é incluído para que o pipeline possa ser auto-suficiente
223
  allowed_inference_steps=allowed_inference_steps,
224
- prompt_enhancer_image_caption_model=None,
225
- prompt_enhancer_image_caption_processor=None,
226
- prompt_enhancer_llm_model=None,
227
- prompt_enhancer_llm_tokenizer=None,
228
  )
229
-
230
- return pipeline
 
 
 
 
 
 
 
 
 
231
 
232
  # ==============================================================================
233
- # --- FUNÇÕES AUXILIARES GENÉRICAS ---
234
  # ==============================================================================
235
 
236
  def seed_everything(seed: int):
237
- """
238
- Define a semente para PyTorch, NumPy e Python para garantir reprodutibilidade.
239
- """
240
  random.seed(seed)
241
  os.environ['PYTHONHASHSEED'] = str(seed)
242
  np.random.seed(seed)
243
  torch.manual_seed(seed)
244
  torch.cuda.manual_seed_all(seed)
245
  torch.backends.cudnn.deterministic = True
246
- torch.backends.cudnn.benchmark = Fals
247
-
248
  def load_image_to_tensor_with_resize_and_crop(
249
  image_input: Union[str, Image.Image],
250
  target_height: int,
251
  target_width: int,
252
  ) -> torch.Tensor:
253
- """
254
- Carrega, redimensiona, corta e processa uma imagem para um tensor de pixel 5D,
255
- normalizado para [-1, 1], pronto para ser enviado ao VAE para encoding.
256
- """
257
  if isinstance(image_input, str):
258
  image = Image.open(image_input).convert("RGB")
259
  elif isinstance(image_input, Image.Image):
260
- image = image_input.convert("RGB")
261
  else:
262
  raise ValueError("image_input must be a file path or a PIL Image object")
263
 
@@ -267,38 +148,22 @@ def load_image_to_tensor_with_resize_and_crop(
267
 
268
  if aspect_ratio_frame > aspect_ratio_target:
269
  new_width, new_height = int(input_height * aspect_ratio_target), input_height
270
- x_start = (input_width - new_width) // 2
271
- image = image.crop((x_start, 0, x_start + new_width, new_height))
272
  else:
273
- new_height = int(input_width / aspect_ratio_target)
274
- y_start = (input_height - new_height) // 2
275
- image = image.crop((0, y_start, input_width, y_start + new_height))
276
-
277
  image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
 
 
 
278
 
279
- frame_tensor = TVF.to_tensor(image)
280
-
281
- # Esta parte depende de 'crf_compressor', então precisamos importá-lo aqui também
282
- try:
283
- from ltx_video.pipelines import crf_compressor
284
- frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
285
- frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
286
- frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
287
- except ImportError:
288
- logging.warning("CRF Compressor not found. Skipping compression step.")
289
-
290
  frame_tensor = (frame_tensor * 2.0) - 1.0
291
- return frame_tensor.unsqueeze(0).unsqueeze(2)
292
-
293
-
294
- def seed_everything(seed: int):
295
- """
296
- Define a semente para PyTorch, NumPy e Python para garantir reprodutibilidade.
297
- """
298
- random.seed(seed)
299
- os.environ['PYTHONHASHSEED'] = str(seed)
300
- np.random.seed(seed)
301
- torch.manual_seed(seed)
302
- torch.cuda.manual_seed_all(seed)
303
- torch.backends.cudnn.deterministic = True
304
- torch.backends.cudnn.benchmark = False
 
1
  # FILE: api/ltx/ltx_utils.py
2
+ # DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline.
3
+ # Handles dependency path injection, model loading, pipeline creation, and tensor preparation.
 
4
 
5
  import os
6
  import random
7
  import json
8
  import logging
9
+ import time
10
  import sys
11
  from pathlib import Path
12
+ from typing import Dict, Optional, Tuple, Union
 
 
13
 
14
+ import numpy as np
15
  import torch
16
+ import torchvision.transforms.functional as TVF
17
+ from PIL import Image
18
  from safetensors import safe_open
19
  from transformers import T5EncoderModel, T5Tokenizer
20
 
21
  # ==============================================================================
22
+ # --- CRITICAL: DEPENDENCY PATH INJECTION ---
23
  # ==============================================================================
24
 
25
+ # Define o caminho para o repositório clonado
26
  LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
27
 
28
  def add_deps_to_path():
29
+ """
30
+ Adiciona o diretório do repositório LTX ao sys.path para garantir que suas
31
+ bibliotecas possam ser importadas.
32
+ """
33
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
34
  if repo_path not in sys.path:
35
  sys.path.insert(0, repo_path)
36
  logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
37
 
38
+ # Executa a função imediatamente para configurar o ambiente antes de qualquer importação.
39
  add_deps_to_path()
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # ==============================================================================
43
+ # --- IMPORTAÇÕES DA BIBLIOTECA LTX-VIDEO (Após configuração do path) ---
44
  # ==============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
45
  try:
46
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
47
+ from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
48
  from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
49
  from ltx_video.models.transformers.transformer3d import Transformer3DModel
50
  from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
51
  from ltx_video.schedulers.rf import RectifiedFlowScheduler
52
+ import ltx_video.pipelines.crf_compressor as crf_compressor
53
  except ImportError as e:
54
+ raise ImportError(f"Could not import from LTX-Video library even after setting sys.path. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
55
+
56
 
57
  # ==============================================================================
58
+ # --- FUNÇÕES DE CONSTRUÇÃO DE MODELO E PIPELINE ---
59
  # ==============================================================================
60
 
61
+ def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler:
62
+ """Loads the Latent Upsampler model from a checkpoint path."""
63
+ logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}")
64
+ latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
65
+ latent_upsampler.to(device)
66
+ latent_upsampler.eval()
67
+ return latent_upsampler
 
 
 
 
 
 
 
 
 
 
68
 
69
+ def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
70
+ """Builds the complete LTX pipeline and upsampler on the CPU."""
71
+ t0 = time.perf_counter()
72
+ logging.info("Building LTX pipeline on CPU...")
73
 
74
+ ckpt_path = Path(config["checkpoint_path"])
75
+ if not ckpt_path.is_file():
76
+ raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}")
77
+
78
+ with safe_open(ckpt_path, framework="pt") as f:
 
 
 
79
  metadata = f.metadata() or {}
80
  config_str = metadata.get("config", "{}")
81
+ configs = json.loads(config_str)
82
+ allowed_inference_steps = configs.get("allowed_inference_steps")
83
 
84
+ vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
85
+ transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu")
86
+ scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
87
 
88
+ text_encoder_path = config["text_encoder_model_name_or_path"]
89
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
90
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
 
 
 
91
  patchifier = SymmetricPatchifier(patch_size=1)
 
92
 
93
+ precision = config.get("precision", "bfloat16")
94
  if precision == "bfloat16":
 
95
  vae.to(torch.bfloat16)
96
+ transformer.to(torch.bfloat16)
97
+ text_encoder.to(torch.bfloat16)
98
+
99
  pipeline = LTXVideoPipeline(
100
+ transformer=transformer, patchifier=patchifier, text_encoder=text_encoder,
101
+ tokenizer=tokenizer, scheduler=scheduler, vae=vae,
 
 
 
 
102
  allowed_inference_steps=allowed_inference_steps,
103
+ prompt_enhancer_image_caption_model=None, prompt_enhancer_image_caption_processor=None,
104
+ prompt_enhancer_llm_model=None, prompt_enhancer_llm_tokenizer=None,
 
 
105
  )
106
+
107
+ latent_upsampler = None
108
+ if config.get("spatial_upscaler_model_path"):
109
+ spatial_path = config["spatial_upscaler_model_path"]
110
+ latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
111
+ if precision == "bfloat16":
112
+ latent_upsampler.to(torch.bfloat16)
113
+
114
+ logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s")
115
+ return pipeline, latent_upsampler
116
+
117
 
118
  # ==============================================================================
119
+ # --- FUNÇÕES AUXILIARES (Seed, Preparação de Imagem) ---
120
  # ==============================================================================
121
 
122
  def seed_everything(seed: int):
123
+ """Sets the seed for reproducibility."""
 
 
124
  random.seed(seed)
125
  os.environ['PYTHONHASHSEED'] = str(seed)
126
  np.random.seed(seed)
127
  torch.manual_seed(seed)
128
  torch.cuda.manual_seed_all(seed)
129
  torch.backends.cudnn.deterministic = True
130
+ torch.backends.cudnn.benchmark = False
131
+
132
  def load_image_to_tensor_with_resize_and_crop(
133
  image_input: Union[str, Image.Image],
134
  target_height: int,
135
  target_width: int,
136
  ) -> torch.Tensor:
137
+ """Loads and processes an image into a 5D pixel tensor compatible with the LTX pipeline."""
 
 
 
138
  if isinstance(image_input, str):
139
  image = Image.open(image_input).convert("RGB")
140
  elif isinstance(image_input, Image.Image):
141
+ image = image_input
142
  else:
143
  raise ValueError("image_input must be a file path or a PIL Image object")
144
 
 
148
 
149
  if aspect_ratio_frame > aspect_ratio_target:
150
  new_width, new_height = int(input_height * aspect_ratio_target), input_height
151
+ x_start, y_start = (input_width - new_width) // 2, 0
 
152
  else:
153
+ new_width, new_height = input_width, int(input_width / aspect_ratio_target)
154
+ x_start, y_start = 0, (input_height - new_height) // 2
155
+
156
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
157
  image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
158
+
159
+ frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W) in [0, 1] range
160
+ frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
161
 
162
+ frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
163
+ frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
164
+ frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
165
+ # Normalize to [-1, 1] range, which the VAE expects for encoding
 
 
 
 
 
 
 
166
  frame_tensor = (frame_tensor * 2.0) - 1.0
167
+
168
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
169
+ return frame_tensor.unsqueeze(0).unsqueeze(2)