eeuuia commited on
Commit
9f023b7
·
verified ·
1 Parent(s): b54d196

Update api/ltx/vae_aduc_pipeline.py

Browse files
Files changed (1) hide show
  1. api/ltx/vae_aduc_pipeline.py +40 -80
api/ltx/vae_aduc_pipeline.py CHANGED
@@ -46,8 +46,8 @@ if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
46
 
47
  from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode, latent_to_pixel_coords
48
  from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
49
- from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem
50
-
51
 
52
  @dataclass
53
  class LatentConditioningItem:
@@ -59,7 +59,6 @@ class LatentConditioningItem:
59
  # --- CLASSE PRINCIPAL DO SERVIÇO VAE ---
60
  # ==============================================================================
61
 
62
- @log_function_io
63
  class VaeAducPipeline:
64
  _instance = None
65
  _lock = threading.Lock()
@@ -71,7 +70,6 @@ class VaeAducPipeline:
71
  cls._instance._initialized = False
72
  return cls._instance
73
 
74
- @log_function_io
75
  def __init__(self):
76
  if hasattr(self, '_initialized') and self._initialized: return
77
  with self._lock:
@@ -101,68 +99,24 @@ class VaeAducPipeline:
101
  # --- MÉTODOS PÚBLICOS DE SERVIÇO ---
102
 
103
  @log_function_io
104
- def encode_video(
105
- self,
106
- video_tensor: torch.Tensor,
107
- vae_per_channel_normalize: bool = True
108
- ) -> torch.Tensor:
109
- """
110
- [NOVO] Codifica um tensor de vídeo (pixels) para o espaço latente.
111
-
112
- Args:
113
- video_tensor (torch.Tensor): Tensor de vídeo no formato (B, C, F, H, W) e range [0, 1].
114
- vae_per_channel_normalize (bool): Se deve normalizar os latentes por canal.
115
-
116
- Returns:
117
- torch.Tensor: O tensor latente resultante na CPU.
118
- """
119
  logging.info(f"VaeAducPipeline: Encoding video with shape {video_tensor.shape}")
120
  if not (video_tensor.ndim == 5):
121
  raise ValueError(f"Input video tensor must be 5D (B, C, F, H, W), but got shape {video_tensor.shape}")
122
-
123
- # Normaliza o tensor de [0, 1] para [-1, 1]
124
  video_tensor_normalized = (video_tensor * 2.0) - 1.0
125
-
126
  try:
127
  video_gpu = video_tensor_normalized.to(self.device, dtype=self.dtype)
128
  with torch.no_grad():
129
- latents = vae_encode(
130
- video_gpu,
131
- self.vae,
132
- vae_per_channel_normalize=vae_per_channel_normalize
133
- )
134
  logging.info(f"VaeAducPipeline: Successfully encoded video to latents of shape {latents.shape}")
135
  return latents.cpu()
136
  finally:
137
  self._cleanup_gpu()
138
 
139
  @log_function_io
140
- def decode_and_resize_video(
141
- self,
142
- latent_tensor: torch.Tensor,
143
- target_height: int,
144
- target_width: int,
145
- decode_timestep: float = 0.05
146
- ) -> torch.Tensor:
147
- """
148
- [NOVO] Decodifica um tensor latente para pixels e o redimensiona para a resolução final.
149
-
150
- Args:
151
- latent_tensor (torch.Tensor): O tensor latente a ser decodificado.
152
- target_height (int): A altura final do vídeo.
153
- target_width (int): A largura final do vídeo.
154
- decode_timestep (float): Timestep para o decoder do VAE, se aplicável.
155
-
156
- Returns:
157
- torch.Tensor: O tensor de vídeo em pixels, redimensionado e na CPU.
158
- """
159
  logging.info(f"VaeAducPipeline: Decoding latents {latent_tensor.shape} and resizing to {target_height}x{target_width}")
160
-
161
- # 1. Decodificar para pixels (usando a função já existente)
162
- # O resultado já virá para a CPU
163
  pixel_video = self.decode_to_pixels(latent_tensor, decode_timestep)
164
-
165
- # 2. Redimensionar para o tamanho final
166
  num_frames = pixel_video.shape[2]
167
  current_height, current_width = pixel_video.shape[3:]
168
 
@@ -170,33 +124,21 @@ class VaeAducPipeline:
170
  logging.info("VaeAducPipeline: Resizing skipped, already at target resolution.")
171
  return pixel_video
172
 
173
- # Aplica a interpolação para redimensionar
174
  videos_flat = rearrange(pixel_video, "b c f h w -> (b f) c h w")
175
- videos_resized = F.interpolate(
176
- videos_flat,
177
- size=(target_height, target_width),
178
- mode="bilinear",
179
- align_corners=False,
180
- )
181
  final_video = rearrange(videos_resized, "(b f) c h w -> b c f h w", f=num_frames)
182
-
183
  logging.info(f"VaeAducPipeline: Resized video to final shape {final_video.shape}")
184
  return final_video
185
 
186
  @log_function_io
187
  def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
188
- """Decodifica um tensor latente para um tensor de pixels, retornando na CPU."""
189
  t0 = time.time()
190
  try:
191
  latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
192
  num_items = latent_tensor_gpu.shape[0]
193
  timestep_tensor = torch.tensor([decode_timestep] * num_items, device=self.device, dtype=self.dtype)
194
-
195
  with torch.no_grad():
196
- pixels = vae_decode(
197
- latent_tensor_gpu, self.vae, is_video=True,
198
- timestep=timestep_tensor, vae_per_channel_normalize=True
199
- )
200
  logging.info(f"VaeAducPipeline: Decoded latents {latent_tensor.shape} in {time.time() - t0:.2f}s.")
201
  return pixels.cpu()
202
  finally:
@@ -213,7 +155,6 @@ class VaeAducPipeline:
213
  vae_per_channel_normalize: bool = True,
214
  generator: Optional[torch.Generator] = None,
215
  ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
216
- """Prepara tensores de condicionamento a partir de uma lista de itens de pixels ou latentes."""
217
  init_latents = init_latents.to(self.device, dtype=self.dtype)
218
 
219
  if not conditioning_items:
@@ -236,9 +177,14 @@ class VaeAducPipeline:
236
  init_latents[..., :f, :h, :w] = torch.lerp(init_latents[..., :f, :h, :w], latents, item.conditioning_strength)
237
  mask[..., :f, :h, :w] = item.conditioning_strength
238
  else:
239
- latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator)
240
- extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask)
241
- num_extra_latents += num_new
 
 
 
 
 
242
  else:
243
  for item in conditioning_items:
244
  item_resized = self._resize_conditioning_item(item, height, width)
@@ -252,7 +198,9 @@ class VaeAducPipeline:
252
  mask[..., :f, ly:ly+h, lx:lx+w] = item.conditioning_strength
253
  else:
254
  if media_item.shape[2] > 1:
255
- init_latents, mask, latents = self._handle_non_first_sequence(init_latents, mask, latents, item)
 
 
256
  if latents is not None:
257
  latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator)
258
  extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask)
@@ -282,7 +230,7 @@ class VaeAducPipeline:
282
  if torch.cuda.is_available():
283
  with torch.cuda.device(self.device): torch.cuda.empty_cache()
284
 
285
- def _latent_to_pixel_coords(self, c): return latent_to_pixel_coords(c, self.vae, self.transformer.config.causal_temporal_positioning)
286
 
287
  @staticmethod
288
  def _resize_tensor(m, h, w):
@@ -307,15 +255,27 @@ class VaeAducPipeline:
307
  if (ys + hi) < h: l = l[..., :-1, :]
308
  return l, xs // s, ys // s
309
 
310
- def _handle_non_first_sequence(self, il, m, l, i, np=2, mode="concat"):
311
- fl, flp = l.shape[2], np
 
 
 
 
 
 
 
 
 
312
  if fl > flp:
313
- s, e = i.media_frame_number // 8 + flp, i.media_frame_number // 8 + fl
314
- il[..., s:e, :, :] = torch.lerp(il[..., s:e, :, :], l[..., flp:, :, :], i.conditioning_strength)
315
- m[..., s:e, :, :] = i.conditioning_strength
316
- if mode == "concat": l = l[..., :flp, :, :]
317
- else: l = None
318
- return il, m, l
 
 
 
319
 
320
  def _process_extra_item(self, l, i, g):
321
  n = randn_tensor(l.shape, generator=g, device=self.device, dtype=self.dtype)
@@ -326,5 +286,5 @@ class VaeAducPipeline:
326
  nm = torch.full(lp.shape[:2], i.conditioning_strength, dtype=torch.float32, device=self.device)
327
  return lp, cp, nm, nl
328
 
329
- # --- Instanciação do Singleton ---
330
  vae_aduc_pipeline = VaeAducPipeline()
 
46
 
47
  from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode, latent_to_pixel_coords
48
  from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
49
+ from pipeline_ltx_video import ConditioningItem as PipelineConditioningItem
50
+
51
 
52
  @dataclass
53
  class LatentConditioningItem:
 
59
  # --- CLASSE PRINCIPAL DO SERVIÇO VAE ---
60
  # ==============================================================================
61
 
 
62
  class VaeAducPipeline:
63
  _instance = None
64
  _lock = threading.Lock()
 
70
  cls._instance._initialized = False
71
  return cls._instance
72
 
 
73
  def __init__(self):
74
  if hasattr(self, '_initialized') and self._initialized: return
75
  with self._lock:
 
99
  # --- MÉTODOS PÚBLICOS DE SERVIÇO ---
100
 
101
  @log_function_io
102
+ def encode_video(self, video_tensor: torch.Tensor, vae_per_channel_normalize: bool = True) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  logging.info(f"VaeAducPipeline: Encoding video with shape {video_tensor.shape}")
104
  if not (video_tensor.ndim == 5):
105
  raise ValueError(f"Input video tensor must be 5D (B, C, F, H, W), but got shape {video_tensor.shape}")
 
 
106
  video_tensor_normalized = (video_tensor * 2.0) - 1.0
 
107
  try:
108
  video_gpu = video_tensor_normalized.to(self.device, dtype=self.dtype)
109
  with torch.no_grad():
110
+ latents = vae_encode(video_gpu, self.vae, vae_per_channel_normalize=vae_per_channel_normalize)
 
 
 
 
111
  logging.info(f"VaeAducPipeline: Successfully encoded video to latents of shape {latents.shape}")
112
  return latents.cpu()
113
  finally:
114
  self._cleanup_gpu()
115
 
116
  @log_function_io
117
+ def decode_and_resize_video(self, latent_tensor: torch.Tensor, target_height: int, target_width: int, decode_timestep: float = 0.05) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  logging.info(f"VaeAducPipeline: Decoding latents {latent_tensor.shape} and resizing to {target_height}x{target_width}")
 
 
 
119
  pixel_video = self.decode_to_pixels(latent_tensor, decode_timestep)
 
 
120
  num_frames = pixel_video.shape[2]
121
  current_height, current_width = pixel_video.shape[3:]
122
 
 
124
  logging.info("VaeAducPipeline: Resizing skipped, already at target resolution.")
125
  return pixel_video
126
 
 
127
  videos_flat = rearrange(pixel_video, "b c f h w -> (b f) c h w")
128
+ videos_resized = F.interpolate(videos_flat, size=(target_height, target_width), mode="bilinear", align_corners=False)
 
 
 
 
 
129
  final_video = rearrange(videos_resized, "(b f) c h w -> b c f h w", f=num_frames)
 
130
  logging.info(f"VaeAducPipeline: Resized video to final shape {final_video.shape}")
131
  return final_video
132
 
133
  @log_function_io
134
  def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
 
135
  t0 = time.time()
136
  try:
137
  latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
138
  num_items = latent_tensor_gpu.shape[0]
139
  timestep_tensor = torch.tensor([decode_timestep] * num_items, device=self.device, dtype=self.dtype)
 
140
  with torch.no_grad():
141
+ pixels = vae_decode(latent_tensor_gpu, self.vae, is_video=True, timestep=timestep_tensor, vae_per_channel_normalize=True)
 
 
 
142
  logging.info(f"VaeAducPipeline: Decoded latents {latent_tensor.shape} in {time.time() - t0:.2f}s.")
143
  return pixels.cpu()
144
  finally:
 
155
  vae_per_channel_normalize: bool = True,
156
  generator: Optional[torch.Generator] = None,
157
  ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
 
158
  init_latents = init_latents.to(self.device, dtype=self.dtype)
159
 
160
  if not conditioning_items:
 
177
  init_latents[..., :f, :h, :w] = torch.lerp(init_latents[..., :f, :h, :w], latents, item.conditioning_strength)
178
  mask[..., :f, :h, :w] = item.conditioning_strength
179
  else:
180
+ if latents.shape[2] > 1:
181
+ init_latents, mask, latents = self._handle_non_first_sequence(
182
+ init_latents, mask, latents, item.media_frame_number, item.conditioning_strength
183
+ )
184
+ if latents is not None:
185
+ latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator)
186
+ extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask)
187
+ num_extra_latents += num_new
188
  else:
189
  for item in conditioning_items:
190
  item_resized = self._resize_conditioning_item(item, height, width)
 
198
  mask[..., :f, ly:ly+h, lx:lx+w] = item.conditioning_strength
199
  else:
200
  if media_item.shape[2] > 1:
201
+ init_latents, mask, latents = self._handle_non_first_sequence(
202
+ init_latents, mask, latents, item.media_frame_number, item.conditioning_strength
203
+ )
204
  if latents is not None:
205
  latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator)
206
  extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask)
 
230
  if torch.cuda.is_available():
231
  with torch.cuda.device(self.device): torch.cuda.empty_cache()
232
 
233
+ def _latent_to_pixel_coords(self, c): return latent_to_pixel_coords(c, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
234
 
235
  @staticmethod
236
  def _resize_tensor(m, h, w):
 
255
  if (ys + hi) < h: l = l[..., :-1, :]
256
  return l, xs // s, ys // s
257
 
258
+ def _handle_non_first_sequence(
259
+ self,
260
+ init_latents: torch.Tensor,
261
+ mask: torch.Tensor,
262
+ latents: torch.Tensor,
263
+ media_frame_number: int,
264
+ conditioning_strength: float,
265
+ num_prefix=2,
266
+ mode="concat"
267
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
268
+ fl, flp = latents.shape[2], num_prefix
269
  if fl > flp:
270
+ start = media_frame_number // 8 + flp
271
+ end = start + fl - flp
272
+ init_latents[..., start:end, :, :] = torch.lerp(init_latents[..., start:end, :, :], latents[..., flp:, :, :], conditioning_strength)
273
+ mask[..., start:end, :, :] = conditioning_strength
274
+ if mode == "concat":
275
+ latents = latents[..., :flp, :, :]
276
+ else:
277
+ latents = None
278
+ return init_latents, mask, latents
279
 
280
  def _process_extra_item(self, l, i, g):
281
  n = randn_tensor(l.shape, generator=g, device=self.device, dtype=self.dtype)
 
286
  nm = torch.full(lp.shape[:2], i.conditioning_strength, dtype=torch.float32, device=self.device)
287
  return lp, cp, nm, nl
288
 
289
+ # --- Instânciação do Singleton ---
290
  vae_aduc_pipeline = VaeAducPipeline()