eeuuia commited on
Commit
c2d3ac4
verified
1 Parent(s): fb6f7a3

Update api/ltx/ltx_utils.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_utils.py +74 -47
api/ltx/ltx_utils.py CHANGED
@@ -1,16 +1,18 @@
1
  # FILE: api/ltx/ltx_utils.py
2
  # DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline.
3
- # Handles dependency path injection, model loading, pipeline creation, and tensor preparation.
4
 
5
  import os
6
  import random
7
  import json
8
-
9
  import time
10
  import sys
11
  from pathlib import Path
12
  from typing import Dict, Optional, Tuple, Union
13
- from huggingface_hub import hf_hub_download
 
 
14
  import numpy as np
15
  import torch
16
  import torchvision.transforms.functional as TVF
@@ -18,50 +20,69 @@ from PIL import Image
18
  from safetensors import safe_open
19
  from transformers import T5EncoderModel, T5Tokenizer
20
 
21
- import logging
22
- import warnings
23
- warnings.filterwarnings("ignore", category=UserWarning)
24
- warnings.filterwarnings("ignore", category=FutureWarning)
25
- warnings.filterwarnings("ignore", message=".*")
26
- from huggingface_hub import logging as ll
27
- ll.set_verbosity_error()
28
- ll.set_verbosity_warning()
29
- ll.set_verbosity_info()
30
- from utils.debug_utils import log_function_io
31
-
32
- ll.set_verbosity_debug()
33
-
34
-
35
  # ==============================================================================
36
  # --- CRITICAL: DEPENDENCY PATH INJECTION ---
37
  # ==============================================================================
38
 
39
  # Define o caminho para o reposit贸rio clonado
40
  LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
41
- LTX_REPO_ID = "Lightricks/LTX-Video"
42
- CACHE_DIR = os.environ.get("HF_HOME")
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # ==============================================================================
45
  # --- IMPORTA脟脮ES DA BIBLIOTECA LTX-VIDEO (Ap贸s configura莽茫o do path) ---
46
  # ==============================================================================
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
49
- if repo_path not in sys.path:
50
- sys.path.insert(0, repo_path)
51
 
52
- from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
53
- from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
54
- from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
55
- from ltx_video.models.transformers.transformer3d import Transformer3DModel
56
- from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
57
- from ltx_video.schedulers.rf import RectifiedFlowScheduler
58
- import ltx_video.pipelines.crf_compressor as crf_compressor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # ==============================================================================
61
  # --- FUN脟脮ES DE CONSTRU脟脙O DE MODELO E PIPELINE ---
62
  # ==============================================================================
63
 
64
- @log_function_io
65
  def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler:
66
  """Loads the Latent Upsampler model from a checkpoint path."""
67
  logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}")
@@ -70,17 +91,15 @@ def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> La
70
  latent_upsampler.eval()
71
  return latent_upsampler
72
 
73
- @log_function_io
74
  def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
75
  """Builds the complete LTX pipeline and upsampler on the CPU."""
76
  t0 = time.perf_counter()
77
  logging.info("Building LTX pipeline on CPU...")
78
 
79
- ckpt_path_str = hf_hub_download(repo_id=LTX_REPO_ID, filename=config["checkpoint_path"], cache_dir=CACHE_DIR)
80
- ckpt_path = Path(ckpt_path_str)
81
  if not ckpt_path.is_file():
82
  raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}")
83
-
84
  with safe_open(ckpt_path, framework="pt") as f:
85
  metadata = f.metadata() or {}
86
  config_str = metadata.get("config", "{}")
@@ -111,11 +130,8 @@ def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[
111
  )
112
 
113
  latent_upsampler = None
114
- if config.get("spatial_upscaler_model_path"):
115
- spatial_path_str = hf_hub_download(repo_id=LTX_REPO_ID, filename=config["spatial_upscaler_model_path"], cache_dir=CACHE_DIR)
116
- spatial_path = Path(spatial_path_str)
117
- if not ckpt_path.is_file():
118
- raise FileNotFoundError(f"Main checkpoint file not found: {spatial_path}")
119
  latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
120
  if precision == "bfloat16":
121
  latent_upsampler.to(torch.bfloat16)
@@ -123,11 +139,24 @@ def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[
123
  logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s")
124
  return pipeline, latent_upsampler
125
 
 
126
  # ==============================================================================
127
- # --- FUN脟脮ES AUXILIARES (Seed, Prepara莽茫o de Imagem) ---
128
  # ==============================================================================
129
 
130
- @log_function_io
 
 
 
 
 
 
 
 
 
 
 
 
131
  def seed_everything(seed: int):
132
  """Sets the seed for reproducibility."""
133
  random.seed(seed)
@@ -138,13 +167,12 @@ def seed_everything(seed: int):
138
  torch.backends.cudnn.deterministic = True
139
  torch.backends.cudnn.benchmark = False
140
 
141
- @log_function_io
142
  def load_image_to_tensor_with_resize_and_crop(
143
  image_input: Union[str, Image.Image],
144
  target_height: int,
145
  target_width: int,
146
  ) -> torch.Tensor:
147
- """Loads and processes an image into a 5D pixel tensor compatible with the LTX pipeline."""
148
  if isinstance(image_input, str):
149
  image = Image.open(image_input).convert("RGB")
150
  elif isinstance(image_input, Image.Image):
@@ -166,15 +194,14 @@ def load_image_to_tensor_with_resize_and_crop(
166
  image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
167
  image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
168
 
169
- frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W) in [0, 1] range
170
  frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
171
 
172
  frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
173
  frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
174
  frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
175
- # Normalize to [-1, 1] range, which the VAE expects for encoding
176
  frame_tensor = (frame_tensor * 2.0) - 1.0
177
 
178
  # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
179
- return frame_tensor.unsqueeze(0).unsqueeze(2)
180
-
 
1
  # FILE: api/ltx/ltx_utils.py
2
  # DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline.
3
+ # Handles dependency path injection, model loading, data structures, and helper functions.
4
 
5
  import os
6
  import random
7
  import json
8
+ import logging
9
  import time
10
  import sys
11
  from pathlib import Path
12
  from typing import Dict, Optional, Tuple, Union
13
+ from dataclasses import dataclass
14
+ from enum import Enum, auto
15
+
16
  import numpy as np
17
  import torch
18
  import torchvision.transforms.functional as TVF
 
20
  from safetensors import safe_open
21
  from transformers import T5EncoderModel, T5Tokenizer
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # ==============================================================================
24
  # --- CRITICAL: DEPENDENCY PATH INJECTION ---
25
  # ==============================================================================
26
 
27
  # Define o caminho para o reposit贸rio clonado
28
  LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
29
+
30
+ def add_deps_to_path():
31
+ """
32
+ Adiciona o diret贸rio do reposit贸rio LTX ao sys.path para garantir que suas
33
+ bibliotecas possam ser importadas.
34
+ """
35
+ repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
36
+ if repo_path not in sys.path:
37
+ sys.path.insert(0, repo_path)
38
+ logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
39
+
40
+ # Executa a fun莽茫o imediatamente para configurar o ambiente antes de qualquer importa莽茫o.
41
+ add_deps_to_path()
42
+
43
 
44
  # ==============================================================================
45
  # --- IMPORTA脟脮ES DA BIBLIOTECA LTX-VIDEO (Ap贸s configura莽茫o do path) ---
46
  # ==============================================================================
47
+ try:
48
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
49
+ from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
50
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
51
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
52
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
53
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
54
+ from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents
55
+ import ltx_video.pipelines.crf_compressor as crf_compressor
56
+ except ImportError as e:
57
+ raise ImportError(f"Could not import from LTX-Video library even after setting sys.path. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
58
 
 
 
 
59
 
60
+ # ==============================================================================
61
+ # --- ESTRUTURAS DE DADOS E ENUMS (Centralizadas aqui) ---
62
+ # ==============================================================================
63
+
64
+ @dataclass
65
+ class ConditioningItem:
66
+ """Define a single frame-conditioning item, used to guide the generation pipeline."""
67
+ media_item: torch.Tensor
68
+ media_frame_number: int
69
+ conditioning_strength: float
70
+ media_x: Optional[int] = None
71
+ media_y: Optional[int] = None
72
+
73
+
74
+ class SkipLayerStrategy(Enum):
75
+ """Defines the strategy for how spatio-temporal guidance is applied across transformer blocks."""
76
+ AttentionSkip = auto()
77
+ AttentionValues = auto()
78
+ Residual = auto()
79
+ TransformerBlock = auto()
80
+
81
 
82
  # ==============================================================================
83
  # --- FUN脟脮ES DE CONSTRU脟脙O DE MODELO E PIPELINE ---
84
  # ==============================================================================
85
 
 
86
  def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler:
87
  """Loads the Latent Upsampler model from a checkpoint path."""
88
  logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}")
 
91
  latent_upsampler.eval()
92
  return latent_upsampler
93
 
 
94
  def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
95
  """Builds the complete LTX pipeline and upsampler on the CPU."""
96
  t0 = time.perf_counter()
97
  logging.info("Building LTX pipeline on CPU...")
98
 
99
+ ckpt_path = Path(config["checkpoint_path"])
 
100
  if not ckpt_path.is_file():
101
  raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}")
102
+
103
  with safe_open(ckpt_path, framework="pt") as f:
104
  metadata = f.metadata() or {}
105
  config_str = metadata.get("config", "{}")
 
130
  )
131
 
132
  latent_upsampler = None
133
+ if config.get("spatial_upscaler_model_path"):
134
+ spatial_path = config["spatial_upscaler_model_path"]
 
 
 
135
  latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
136
  if precision == "bfloat16":
137
  latent_upsampler.to(torch.bfloat16)
 
139
  logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s")
140
  return pipeline, latent_upsampler
141
 
142
+
143
  # ==============================================================================
144
+ # --- FUN脟脮ES AUXILIARES (Latent Processing, Seed, Image Prep) ---
145
  # ==============================================================================
146
 
147
+ def adain_filter_latent(
148
+ latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
149
+ ) -> torch.Tensor:
150
+ """Applies AdaIN to transfer the style from a reference latent to another."""
151
+ result = latents.clone()
152
+ for i in range(latents.size(0)):
153
+ for c in range(latents.size(1)):
154
+ r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None)
155
+ i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
156
+ if i_sd > 1e-6:
157
+ result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
158
+ return torch.lerp(latents, result, factor)
159
+
160
  def seed_everything(seed: int):
161
  """Sets the seed for reproducibility."""
162
  random.seed(seed)
 
167
  torch.backends.cudnn.deterministic = True
168
  torch.backends.cudnn.benchmark = False
169
 
 
170
  def load_image_to_tensor_with_resize_and_crop(
171
  image_input: Union[str, Image.Image],
172
  target_height: int,
173
  target_width: int,
174
  ) -> torch.Tensor:
175
+ """Loads and processes an image into a 5D tensor compatible with the LTX pipeline."""
176
  if isinstance(image_input, str):
177
  image = Image.open(image_input).convert("RGB")
178
  elif isinstance(image_input, Image.Image):
 
194
  image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
195
  image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
196
 
197
+ frame_tensor = TVF.to_tensor(image)
198
  frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
199
 
200
  frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
201
  frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
202
  frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
203
+ # Normalize to [-1, 1] range
204
  frame_tensor = (frame_tensor * 2.0) - 1.0
205
 
206
  # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
207
+ return frame_tensor.unsqueeze(0).unsqueeze(2)