Eueuiaa commited on
Commit
8e36bc5
·
verified ·
1 Parent(s): daf606d

Update LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py

Browse files
LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py CHANGED
@@ -426,12 +426,12 @@ class LTXVideoPipeline(DiffusionPipeline):
426
 
427
  self.allowed_inference_steps = allowed_inference_steps
428
 
429
- self.spy = SpyLatent(vae=self.pipeline.vae)
430
 
431
  def mask_text_embeddings(self, emb, mask):
432
  if emb.shape[0] == 1:
433
  keep_index = mask.sum().item()
434
- return emb[:, :, :keep_index, :], keep_index
435
  else:
436
  masked_feature = emb * mask[:, None, :, None]
437
  return masked_feature, emb.shape[2]
 
426
 
427
  self.allowed_inference_steps = allowed_inference_steps
428
 
429
+ self.spy = SpyLatent(vae=vae)
430
 
431
  def mask_text_embeddings(self, emb, mask):
432
  if emb.shape[0] == 1:
433
  keep_index = mask.sum().item()
434
+ return emb[, :, :keep_index, :], keep_index
435
  else:
436
  masked_feature = emb * mask[:, None, :, None]
437
  return masked_feature, emb.shape[2]