eeuuia commited on
Commit
60431a3
verified
1 Parent(s): eaa00b5

Update api/ltx/ltx_utils.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_utils.py +37 -111
api/ltx/ltx_utils.py CHANGED
@@ -1,6 +1,6 @@
1
  # FILE: api/ltx/ltx_utils.py
2
- # DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline.
3
- # Handles dependency path injection, model loading, pipeline creation, and tensor preparation.
4
 
5
  import os
6
  import random
@@ -10,42 +10,30 @@ import time
10
  import sys
11
  from pathlib import Path
12
  from typing import Dict, Optional, Tuple, Union
13
- from huggingface_hub import hf_hub_download
14
 
 
15
  import numpy as np
16
  import torch
17
  import torchvision.transforms.functional as TVF
18
  from PIL import Image
19
- from safetensors import safe_open
20
  from transformers import T5EncoderModel, T5Tokenizer
21
 
22
  # ==============================================================================
23
- # --- CRITICAL: DEPENDENCY PATH INJECTION ---
24
  # ==============================================================================
25
 
26
- # Define o caminho para o reposit贸rio clonado
27
  LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
28
- LTX_REPO_ID = "Lightricks/LTX-Video"
29
- CACHE_DIR = os.environ.get("HF_HOME")
30
-
31
 
32
  def add_deps_to_path():
33
- """
34
- Adiciona o diret贸rio do reposit贸rio LTX ao sys.path para garantir que suas
35
- bibliotecas possam ser importadas.
36
- """
37
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
38
  if repo_path not in sys.path:
39
  sys.path.insert(0, repo_path)
40
  logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
41
 
42
- # Executa a fun莽茫o imediatamente para configurar o ambiente antes de qualquer importa莽茫o.
43
  add_deps_to_path()
44
 
45
-
46
- # ==============================================================================
47
- # --- IMPORTA脟脮ES DA BIBLIOTECA LTX-VIDEO (Ap贸s configura莽茫o do path) ---
48
- # ==============================================================================
49
  try:
50
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
51
  from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
@@ -55,93 +43,65 @@ try:
55
  from ltx_video.schedulers.rf import RectifiedFlowScheduler
56
  import ltx_video.pipelines.crf_compressor as crf_compressor
57
  except ImportError as e:
58
- raise ImportError(f"Could not import from LTX-Video library even after setting sys.path. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
59
-
60
 
61
  # ==============================================================================
62
- # --- FUN脟脮ES DE CONSTRU脟脙O DE MODELO E PIPELINE ---
 
63
  # ==============================================================================
64
 
65
- def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler:
66
- """Loads the Latent Upsampler model from a checkpoint path."""
67
- logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}")
68
- latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
69
- latent_upsampler.to(device)
70
- latent_upsampler.eval()
71
- return latent_upsampler
72
-
73
- def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
74
- """Builds the complete LTX pipeline and upsampler on the CPU."""
75
- t0 = time.perf_counter()
76
- logging.info("Building LTX pipeline on CPU...")
77
-
78
-
79
 
80
- ckpt_path_str = hf_hub_download(repo_id=LTX_REPO_ID, filename=config["checkpoint_path"], cache_dir=CACHE_DIR)
81
- ckpt_path = Path(ckpt_path_str)
82
- if not ckpt_path.is_file():
83
- raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}")
 
 
 
84
 
85
- logging.info(f"Building LTX pipeline ckpt:{ckpt_path_str}")
86
-
 
87
 
88
- with safe_open(ckpt_path, framework="pt") as f:
89
  metadata = f.metadata() or {}
90
  config_str = metadata.get("config", "{}")
91
  configs = json.loads(config_str)
92
  allowed_inference_steps = configs.get("allowed_inference_steps")
93
-
94
-
95
- vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
96
- transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu")
97
- scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
98
 
 
 
99
  text_encoder_path = config["text_encoder_model_name_or_path"]
100
  text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
101
  tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
102
  patchifier = SymmetricPatchifier(patch_size=1)
103
 
104
- precision = config.get("precision", "bfloat16")
105
  if precision == "bfloat16":
106
- vae.to(torch.bfloat16)
107
  transformer.to(torch.bfloat16)
108
  text_encoder.to(torch.bfloat16)
109
 
110
  pipeline = LTXVideoPipeline(
111
  transformer=transformer, patchifier=patchifier, text_encoder=text_encoder,
112
- tokenizer=tokenizer, scheduler=scheduler, vae=vae,
113
  allowed_inference_steps=allowed_inference_steps,
114
  prompt_enhancer_image_caption_model=None, prompt_enhancer_image_caption_processor=None,
115
  prompt_enhancer_llm_model=None, prompt_enhancer_llm_tokenizer=None,
116
  )
117
-
118
-
119
- vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
120
- if precision == "bfloat16":
121
- vae.to(torch.bfloat16)
122
-
123
- latent_upsampler = None
124
- if config.get("spatial_upscaler_model_path"):
125
- spatial_path = config["spatial_upscaler_model_path"]
126
- spatial_path_str = hf_hub_download(repo_id=LTX_REPO_ID, filename=config["spatial_upscaler_model_path"], cache_dir=CACHE_DIR)
127
- spatial_path = Path(spatial_path_str)
128
- if not spatial_path.is_file():
129
- raise FileNotFoundError(f"Main checkpoint upscaler file not found: {spatial_path_str}")
130
- logging.info(f"Building UPSCALER pipeline ckpt:{spatial_path_str}")
131
- latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
132
- if precision == "bfloat16":
133
- latent_upsampler.to(torch.bfloat16)
134
-
135
- logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s")
136
- return pipeline, latent_upsampler, vae
137
-
138
 
139
  # ==============================================================================
140
- # --- FUN脟脮ES AUXILIARES (Seed, Prepara莽茫o de Imagem) ---
141
  # ==============================================================================
142
 
143
  def seed_everything(seed: int):
144
- """Sets the seed for reproducibility."""
145
  random.seed(seed)
146
  os.environ['PYTHONHASHSEED'] = str(seed)
147
  np.random.seed(seed)
@@ -150,41 +110,7 @@ def seed_everything(seed: int):
150
  torch.backends.cudnn.deterministic = True
151
  torch.backends.cudnn.benchmark = False
152
 
153
- def load_image_to_tensor_with_resize_and_crop(
154
- image_input: Union[str, Image.Image],
155
- target_height: int,
156
- target_width: int,
157
- ) -> torch.Tensor:
158
- """Loads and processes an image into a 5D pixel tensor compatible with the LTX pipeline."""
159
- if isinstance(image_input, str):
160
- image = Image.open(image_input).convert("RGB")
161
- elif isinstance(image_input, Image.Image):
162
- image = image_input
163
- else:
164
- raise ValueError("image_input must be a file path or a PIL Image object")
165
-
166
- input_width, input_height = image.size
167
- aspect_ratio_target = target_width / target_height
168
- aspect_ratio_frame = input_width / input_height
169
-
170
- if aspect_ratio_frame > aspect_ratio_target:
171
- new_width, new_height = int(input_height * aspect_ratio_target), input_height
172
- x_start, y_start = (input_width - new_width) // 2, 0
173
- else:
174
- new_width, new_height = input_width, int(input_width / aspect_ratio_target)
175
- x_start, y_start = 0, (input_height - new_height) // 2
176
-
177
- image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
178
- image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
179
-
180
- frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W) in [0, 1] range
181
- frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
182
-
183
- frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
184
- frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
185
- frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
186
- # Normalize to [-1, 1] range, which the VAE expects for encoding
187
- frame_tensor = (frame_tensor * 2.0) - 1.0
188
-
189
- # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
190
- return frame_tensor.unsqueeze(0).unsqueeze(2)
 
1
  # FILE: api/ltx/ltx_utils.py
2
+ # DESCRIPTION: A pure utility library for the LTX ecosystem.
3
+ # Contains low-level, stateless builders for core components and general-purpose helper functions.
4
 
5
  import os
6
  import random
 
10
  import sys
11
  from pathlib import Path
12
  from typing import Dict, Optional, Tuple, Union
 
13
 
14
+ from huggingface_hub import hf_hub_download
15
  import numpy as np
16
  import torch
17
  import torchvision.transforms.functional as TVF
18
  from PIL import Image
19
+ from safensors import safe_open
20
  from transformers import T5EncoderModel, T5Tokenizer
21
 
22
  # ==============================================================================
23
+ # --- CONFIGURA脟脙O DE PATH E IMPORTS DA BIBLIOTECA LTX ---
24
  # ==============================================================================
25
 
 
26
  LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
 
 
 
27
 
28
  def add_deps_to_path():
29
+ """Adiciona o diret贸rio do reposit贸rio LTX ao sys.path para importa莽茫o de suas bibliotecas."""
 
 
 
30
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
31
  if repo_path not in sys.path:
32
  sys.path.insert(0, repo_path)
33
  logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
34
 
 
35
  add_deps_to_path()
36
 
 
 
 
 
37
  try:
38
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
39
  from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
 
43
  from ltx_video.schedulers.rf import RectifiedFlowScheduler
44
  import ltx_video.pipelines.crf_compressor as crf_compressor
45
  except ImportError as e:
46
+ raise ImportError(f"Could not import from LTX-Video library. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
 
47
 
48
  # ==============================================================================
49
+ # --- CAMADA DE BUILDERS DE BAIXO N脥VEL ---
50
+ # (Respons谩veis por construir um 煤nico componente na CPU)
51
  # ==============================================================================
52
 
53
+ def _build_vae(checkpoint_path: str, precision: str) -> CausalVideoAutoencoder:
54
+ """Constr贸i o CausalVideoAutoencoder a partir de um checkpoint, sempre na CPU."""
55
+ logging.info(f"Building VAE from checkpoint: {Path(checkpoint_path).name}")
56
+ vae = CausalVideoAutoencoder.from_pretrained(checkpoint_path).to("cpu")
57
+ if precision == "bfloat16":
58
+ vae.to(torch.bfloat16)
59
+ return vae
 
 
 
 
 
 
 
60
 
61
+ def _build_latent_upscaler(upscaler_path: str, precision: str) -> LatentUpsampler:
62
+ """Constr贸i o LatentUpsampler a partir de um checkpoint, sempre na CPU."""
63
+ logging.info(f"Building Latent Upscaler from: {Path(upscaler_path).name}")
64
+ upscaler = LatentUpsampler.from_pretrained(upscaler_path).to("cpu")
65
+ if precision == "bfloat16":
66
+ upscaler.to(torch.bfloat16)
67
+ return upscaler
68
 
69
+ def _build_ltx_transformer_pipeline(checkpoint_path: str, config: Dict, precision: str) -> LTXVideoPipeline:
70
+ """Constr贸i o LTXVideoPipeline principal (sem VAE), sempre na CPU."""
71
+ logging.info(f"Building core LTX Transformer Pipeline from checkpoint: {Path(checkpoint_path).name}")
72
 
73
+ with safe_open(checkpoint_path, framework="pt") as f:
74
  metadata = f.metadata() or {}
75
  config_str = metadata.get("config", "{}")
76
  configs = json.loads(config_str)
77
  allowed_inference_steps = configs.get("allowed_inference_steps")
 
 
 
 
 
78
 
79
+ transformer = Transformer3DModel.from_pretrained(checkpoint_path).to("cpu")
80
+ scheduler = RectifiedFlowScheduler.from_pretrained(checkpoint_path)
81
  text_encoder_path = config["text_encoder_model_name_or_path"]
82
  text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
83
  tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
84
  patchifier = SymmetricPatchifier(patch_size=1)
85
 
 
86
  if precision == "bfloat16":
 
87
  transformer.to(torch.bfloat16)
88
  text_encoder.to(torch.bfloat16)
89
 
90
  pipeline = LTXVideoPipeline(
91
  transformer=transformer, patchifier=patchifier, text_encoder=text_encoder,
92
+ tokenizer=tokenizer, scheduler=scheduler, vae=None, # VAE 茅 explicitamente None
93
  allowed_inference_steps=allowed_inference_steps,
94
  prompt_enhancer_image_caption_model=None, prompt_enhancer_image_caption_processor=None,
95
  prompt_enhancer_llm_model=None, prompt_enhancer_llm_tokenizer=None,
96
  )
97
+ return pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # ==============================================================================
100
+ # --- FUN脟脮ES AUXILIARES GEN脡RICAS ---
101
  # ==============================================================================
102
 
103
  def seed_everything(seed: int):
104
+ """Define a semente para PyTorch, NumPy e Python para garantir reprodutibilidade."""
105
  random.seed(seed)
106
  os.environ['PYTHONHASHSEED'] = str(seed)
107
  np.random.seed(seed)
 
110
  torch.backends.cudnn.deterministic = True
111
  torch.backends.cudnn.benchmark = False
112
 
113
+ # NOTA: A fun莽茫o load_image_to_tensor_with_resize_and_crop foi movida para o
114
+ # cliente vae_aduc_pipeline.py, pois 茅 uma depend锚ncia direta de pr茅-processamento
115
+ # para as tarefas de VAE, tornando aquele m贸dulo mais autocontido. Se for
116
+ # necess谩ria em outros locais, este seria o lugar para centraliz谩-la novamente.