eeuuia commited on
Commit
2f53bb4
·
verified ·
1 Parent(s): 376f545

Upload ltx_utils.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_utils.py +5 -43
api/ltx/ltx_utils.py CHANGED
@@ -1,6 +1,6 @@
1
  # FILE: api/ltx/ltx_utils.py
2
  # DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline.
3
- # Handles dependency path injection, model loading, data structures, and helper functions.
4
 
5
  import os
6
  import random
@@ -10,8 +10,6 @@ import time
10
  import sys
11
  from pathlib import Path
12
  from typing import Dict, Optional, Tuple, Union
13
- from dataclasses import dataclass
14
- from enum import Enum, auto
15
 
16
  import numpy as np
17
  import torch
@@ -51,34 +49,11 @@ try:
51
  from ltx_video.models.transformers.transformer3d import Transformer3DModel
52
  from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
53
  from ltx_video.schedulers.rf import RectifiedFlowScheduler
54
- from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents
55
  import ltx_video.pipelines.crf_compressor as crf_compressor
56
  except ImportError as e:
57
  raise ImportError(f"Could not import from LTX-Video library even after setting sys.path. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
58
 
59
 
60
- # ==============================================================================
61
- # --- ESTRUTURAS DE DADOS E ENUMS (Centralizadas aqui) ---
62
- # ==============================================================================
63
-
64
- @dataclass
65
- class ConditioningItem:
66
- """Define a single frame-conditioning item, used to guide the generation pipeline."""
67
- media_item: torch.Tensor
68
- media_frame_number: int
69
- conditioning_strength: float
70
- media_x: Optional[int] = None
71
- media_y: Optional[int] = None
72
-
73
-
74
- class SkipLayerStrategy(Enum):
75
- """Defines the strategy for how spatio-temporal guidance is applied across transformer blocks."""
76
- AttentionSkip = auto()
77
- AttentionValues = auto()
78
- Residual = auto()
79
- TransformerBlock = auto()
80
-
81
-
82
  # ==============================================================================
83
  # --- FUNÇÕES DE CONSTRUÇÃO DE MODELO E PIPELINE ---
84
  # ==============================================================================
@@ -141,22 +116,9 @@ def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[
141
 
142
 
143
  # ==============================================================================
144
- # --- FUNÇÕES AUXILIARES (Latent Processing, Seed, Image Prep) ---
145
  # ==============================================================================
146
 
147
- def adain_filter_latent(
148
- latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
149
- ) -> torch.Tensor:
150
- """Applies AdaIN to transfer the style from a reference latent to another."""
151
- result = latents.clone()
152
- for i in range(latents.size(0)):
153
- for c in range(latents.size(1)):
154
- r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None)
155
- i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
156
- if i_sd > 1e-6:
157
- result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
158
- return torch.lerp(latents, result, factor)
159
-
160
  def seed_everything(seed: int):
161
  """Sets the seed for reproducibility."""
162
  random.seed(seed)
@@ -172,7 +134,7 @@ def load_image_to_tensor_with_resize_and_crop(
172
  target_height: int,
173
  target_width: int,
174
  ) -> torch.Tensor:
175
- """Loads and processes an image into a 5D tensor compatible with the LTX pipeline."""
176
  if isinstance(image_input, str):
177
  image = Image.open(image_input).convert("RGB")
178
  elif isinstance(image_input, Image.Image):
@@ -194,13 +156,13 @@ def load_image_to_tensor_with_resize_and_crop(
194
  image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
195
  image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
196
 
197
- frame_tensor = TVF.to_tensor(image)
198
  frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
199
 
200
  frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
201
  frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
202
  frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
203
- # Normalize to [-1, 1] range
204
  frame_tensor = (frame_tensor * 2.0) - 1.0
205
 
206
  # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
 
1
  # FILE: api/ltx/ltx_utils.py
2
  # DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline.
3
+ # Handles dependency path injection, model loading, pipeline creation, and tensor preparation.
4
 
5
  import os
6
  import random
 
10
  import sys
11
  from pathlib import Path
12
  from typing import Dict, Optional, Tuple, Union
 
 
13
 
14
  import numpy as np
15
  import torch
 
49
  from ltx_video.models.transformers.transformer3d import Transformer3DModel
50
  from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
51
  from ltx_video.schedulers.rf import RectifiedFlowScheduler
 
52
  import ltx_video.pipelines.crf_compressor as crf_compressor
53
  except ImportError as e:
54
  raise ImportError(f"Could not import from LTX-Video library even after setting sys.path. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # ==============================================================================
58
  # --- FUNÇÕES DE CONSTRUÇÃO DE MODELO E PIPELINE ---
59
  # ==============================================================================
 
116
 
117
 
118
  # ==============================================================================
119
+ # --- FUNÇÕES AUXILIARES (Seed, Preparação de Imagem) ---
120
  # ==============================================================================
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def seed_everything(seed: int):
123
  """Sets the seed for reproducibility."""
124
  random.seed(seed)
 
134
  target_height: int,
135
  target_width: int,
136
  ) -> torch.Tensor:
137
+ """Loads and processes an image into a 5D pixel tensor compatible with the LTX pipeline."""
138
  if isinstance(image_input, str):
139
  image = Image.open(image_input).convert("RGB")
140
  elif isinstance(image_input, Image.Image):
 
156
  image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
157
  image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
158
 
159
+ frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W) in [0, 1] range
160
  frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
161
 
162
  frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
163
  frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
164
  frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
165
+ # Normalize to [-1, 1] range, which the VAE expects for encoding
166
  frame_tensor = (frame_tensor * 2.0) - 1.0
167
 
168
  # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)