euiiiia commited on
Commit
4aa7f1b
·
verified ·
1 Parent(s): e6bfe26

Update LTX-Video/ltx_video/pipelines/pipeline_ltx_video (1).py

Browse files
LTX-Video/ltx_video/pipelines/pipeline_ltx_video (1).py CHANGED
@@ -190,32 +190,72 @@ def retrieve_timesteps(
190
  return timesteps, num_inference_steps
191
 
192
 
 
 
 
 
193
 
194
  @dataclass
195
- class LatentConditioningItem:
196
- """Item de dados para condicionamento da pipeline LTX."""
197
- latent_tensor: torch.Tensor
198
  media_frame_number: int
199
  conditioning_strength: float
 
 
200
 
201
- @dataclass
202
- class ConditioningItem:
 
 
 
 
 
 
 
203
  """
204
- Defines a single frame-conditioning item - a single frame or a sequence of frames.
205
-
206
- Attributes:
207
- media_item (torch.Tensor): shape=(b, 3, f, h, w). The media item to condition on.
208
- media_frame_number (int): The start-frame number of the media item in the generated video.
209
- conditioning_strength (float): The strength of the conditioning (1.0 = full conditioning).
210
- media_x (Optional[int]): Optional left x coordinate of the media item in the generated frame.
211
- media_y (Optional[int]): Optional top y coordinate of the media item in the generated frame.
 
 
 
212
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
- media_item: torch.Tensor
215
- media_frame_number: int
216
- conditioning_strength: float
217
- media_x: Optional[int] = None
218
- media_y: Optional[int] = None
 
 
 
 
 
 
 
 
 
219
 
220
 
221
  class LTXVideoPipeline(DiffusionPipeline):
@@ -1390,242 +1430,106 @@ class LTXVideoPipeline(DiffusionPipeline):
1390
 
1391
 
1392
  def prepare_conditioning(
1393
- self: "LTXVideoPipeline",
1394
- conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]],
1395
- init_latents: torch.Tensor,
1396
- num_frames: int,
1397
- height: int,
1398
- width: int,
1399
- vae_per_channel_normalize: bool = False,
1400
- generator=None,
1401
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
1402
- if not conditioning_items:
1403
- init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
1404
- init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
1405
- return init_latents, init_pixel_coords, None, 0
1406
-
1407
- init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
1408
- extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
1409
- extra_conditioning_num_latents = 0
1410
-
1411
- for item in conditioning_items:
1412
- if not isinstance(item, LatentConditioningItem):
1413
- logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.")
1414
- continue
1415
-
1416
- media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
1417
- media_frame_number, strength = item.media_frame_number, item.conditioning_strength
1418
-
1419
- if media_frame_number == 0:
1420
- f_l, h_l, w_l = media_item_latents.shape[-3:]
1421
- init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
1422
- init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
1423
- else:
1424
- noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
1425
- media_item_latents = torch.lerp(noise, media_item_latents, strength)
1426
- patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
1427
- pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
1428
- pixel_coords[:, 0] += media_frame_number
1429
- extra_conditioning_num_latents += patched_latents.shape[1]
1430
- new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
1431
- extra_conditioning_latents.append(patched_latents)
1432
- extra_conditioning_pixel_coords.append(pixel_coords)
1433
- extra_conditioning_mask.append(new_mask)
1434
-
1435
- init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
1436
- init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
1437
- init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
1438
- init_conditioning_mask = init_conditioning_mask.squeeze(-1)
1439
-
1440
- if extra_conditioning_latents:
1441
- init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
1442
- init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
1443
- init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
1444
-
1445
- return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
1446
-
1447
-
1448
-
1449
- def prepare_conditioning1(
1450
  self,
1451
  conditioning_items: Optional[List[ConditioningItem]],
1452
  init_latents: torch.Tensor,
1453
  num_frames: int,
1454
  height: int,
1455
  width: int,
1456
- vae_per_channel_normalize: bool = False,
1457
  generator=None,
1458
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
1459
- """
1460
- Prepare conditioning tokens based on the provided conditioning items.
1461
-
1462
- This method encodes provided conditioning items (video frames or single frames) into latents
1463
- and integrates them with the initial latent tensor. It also calculates corresponding pixel
1464
- coordinates, a mask indicating the influence of conditioning latents, and the total number of
1465
- conditioning latents.
1466
-
1467
- Args:
1468
- conditioning_items (Optional[List[ConditioningItem]]): A list of ConditioningItem objects.
1469
- init_latents (torch.Tensor): The initial latent tensor of shape (b, c, f_l, h_l, w_l), where
1470
- `f_l` is the number of latent frames, and `h_l` and `w_l` are latent spatial dimensions.
1471
- num_frames, height, width: The dimensions of the generated video.
1472
- vae_per_channel_normalize (bool, optional): Whether to normalize channels during VAE encoding.
1473
- Defaults to `False`.
1474
- generator: The random generator
1475
-
1476
- Returns:
1477
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
1478
- - `init_latents` (torch.Tensor): The updated latent tensor including conditioning latents,
1479
- patchified into (b, n, c) shape.
1480
- - `init_pixel_coords` (torch.Tensor): The pixel coordinates corresponding to the updated
1481
- latent tensor.
1482
- - `conditioning_mask` (torch.Tensor): A mask indicating the conditioning-strength of each
1483
- latent token.
1484
- - `num_cond_latents` (int): The total number of latent tokens added from conditioning items.
1485
-
1486
- Raises:
1487
- AssertionError: If input shapes, dimensions, or conditions for applying conditioning are invalid.
1488
- """
1489
  assert isinstance(self.vae, CausalVideoAutoencoder)
1490
-
1491
  if conditioning_items:
1492
  batch_size, _, num_latent_frames = init_latents.shape[:3]
1493
-
1494
  init_conditioning_mask = torch.zeros(
1495
  init_latents[:, 0, :, :, :].shape,
1496
  dtype=torch.float32,
1497
  device=init_latents.device,
1498
  )
1499
-
1500
  extra_conditioning_latents = []
1501
  extra_conditioning_pixel_coords = []
1502
  extra_conditioning_mask = []
1503
- extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding)
1504
-
1505
- # Process each conditioning item
1506
- for conditioning_item in conditioning_items:
1507
- conditioning_item = self._resize_conditioning_item(
1508
- conditioning_item, height, width
1509
- )
1510
- media_item = conditioning_item.media_item
1511
- media_frame_number = conditioning_item.media_frame_number
1512
- strength = conditioning_item.conditioning_strength
1513
- assert media_item.ndim == 5 # (b, c, f, h, w)
1514
- b, c, n_frames, h, w = media_item.shape
1515
- assert (
1516
- height == h and width == w
1517
- ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
1518
- assert n_frames % 8 == 1
1519
- assert (
1520
- media_frame_number >= 0
1521
- and media_frame_number + n_frames <= num_frames
1522
  )
1523
-
1524
- # Encode the provided conditioning media item
1525
- media_item_latents = vae_encode(
1526
- media_item.to(dtype=self.vae.dtype, device=self.vae.device),
1527
- self.vae,
1528
- vae_per_channel_normalize=vae_per_channel_normalize,
1529
- ).to(dtype=init_latents.dtype)
1530
-
1531
- # Handle the different conditioning cases
1532
- if media_frame_number == 0:
1533
- # Get the target spatial position of the latent conditioning item
1534
- media_item_latents, l_x, l_y = self._get_latent_spatial_position(
1535
- media_item_latents,
1536
- conditioning_item,
1537
- height,
1538
- width,
1539
- strip_latent_border=True,
1540
  )
1541
- b, c_l, f_l, h_l, w_l = media_item_latents.shape
1542
-
1543
- # First frame or sequence - just update the initial noise latents and the mask
1544
- init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = (
1545
- torch.lerp(
1546
- init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l],
1547
- media_item_latents,
1548
- strength,
1549
- )
1550
  )
1551
- init_conditioning_mask[
1552
- :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l
1553
- ] = strength
1554
  else:
1555
- # Non-first frame or sequence
1556
- if n_frames > 1:
1557
- # Handle non-first sequence.
1558
- # Encoded latents are either fully consumed, or the prefix is handled separately below.
1559
- (
1560
- init_latents,
1561
- init_conditioning_mask,
1562
- media_item_latents,
1563
- ) = self._handle_non_first_conditioning_sequence(
1564
- init_latents,
1565
- init_conditioning_mask,
1566
- media_item_latents,
1567
- media_frame_number,
1568
- strength,
1569
  )
1570
-
1571
- # Single frame or sequence-prefix latents
1572
- if media_item_latents is not None:
1573
  noise = randn_tensor(
1574
- media_item_latents.shape,
1575
  generator=generator,
1576
- device=media_item_latents.device,
1577
- dtype=media_item_latents.dtype,
1578
  )
1579
-
1580
- media_item_latents = torch.lerp(
1581
- noise, media_item_latents, strength
1582
- )
1583
-
1584
- # Patchify the extra conditioning latents and calculate their pixel coordinates
1585
- media_item_latents, latent_coords = self.patchifier.patchify(
1586
- latents=media_item_latents
1587
  )
1588
  pixel_coords = latent_to_pixel_coords(
1589
  latent_coords,
1590
  self.vae,
1591
  causal_fix=self.transformer.config.causal_temporal_positioning,
1592
  )
1593
-
1594
- # Update the frame numbers to match the target frame number
1595
- pixel_coords[:, 0] += media_frame_number
1596
- extra_conditioning_num_latents += media_item_latents.shape[1]
1597
-
1598
- conditioning_mask = torch.full(
1599
- media_item_latents.shape[:2],
1600
  strength,
1601
  dtype=torch.float32,
1602
  device=init_latents.device,
1603
  )
1604
-
1605
- extra_conditioning_latents.append(media_item_latents)
1606
  extra_conditioning_pixel_coords.append(pixel_coords)
1607
- extra_conditioning_mask.append(conditioning_mask)
1608
-
1609
- # Patchify the updated latents and calculate their pixel coordinates
1610
- init_latents, init_latent_coords = self.patchifier.patchify(
1611
- latents=init_latents
1612
- )
1613
  init_pixel_coords = latent_to_pixel_coords(
1614
  init_latent_coords,
1615
  self.vae,
1616
  causal_fix=self.transformer.config.causal_temporal_positioning,
1617
  )
1618
-
1619
  if not conditioning_items:
1620
  return init_latents, init_pixel_coords, None, 0
1621
-
 
1622
  init_conditioning_mask, _ = self.patchifier.patchify(
1623
  latents=init_conditioning_mask.unsqueeze(1)
1624
  )
1625
  init_conditioning_mask = init_conditioning_mask.squeeze(-1)
1626
-
1627
  if extra_conditioning_latents:
1628
- # Stack the extra conditioning latents, pixel coordinates and mask
1629
  init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
1630
  init_pixel_coords = torch.cat(
1631
  [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2
@@ -1633,25 +1537,15 @@ class LTXVideoPipeline(DiffusionPipeline):
1633
  init_conditioning_mask = torch.cat(
1634
  [*extra_conditioning_mask, init_conditioning_mask], dim=1
1635
  )
1636
-
1637
  if self.transformer.use_tpu_flash_attention:
1638
- # When flash attention is used, keep the original number of tokens by removing
1639
- # tokens from the end.
1640
  init_latents = init_latents[:, :-extra_conditioning_num_latents]
1641
- init_pixel_coords = init_pixel_coords[
1642
- :, :, :-extra_conditioning_num_latents
1643
- ]
1644
- init_conditioning_mask = init_conditioning_mask[
1645
- :, :-extra_conditioning_num_latents
1646
- ]
1647
-
1648
- return (
1649
- init_latents,
1650
- init_pixel_coords,
1651
- init_conditioning_mask,
1652
- extra_conditioning_num_latents,
1653
- )
1654
 
 
1655
  @staticmethod
1656
  def _resize_conditioning_item(
1657
  conditioning_item: ConditioningItem,
 
190
  return timesteps, num_inference_steps
191
 
192
 
193
+ from typing import Union, Optional
194
+ from PIL import Image, ImageOps
195
+ import torch
196
+ from dataclasses import dataclass
197
 
198
  @dataclass
199
+ class ConditioningItem:
200
+ media_item_latents: torch.Tensor
 
201
  media_frame_number: int
202
  conditioning_strength: float
203
+ media_x: Optional[int] = None
204
+ media_y: Optional[int] = None
205
 
206
+ def encode_conditioning_item(
207
+ self,
208
+ raw_item: Union[Image.Image, torch.Tensor],
209
+ frame_number: int,
210
+ strength: float,
211
+ height: int,
212
+ width: int,
213
+ vae_per_channel_normalize: bool = False,
214
+ ) -> ConditioningItem:
215
  """
216
+ Converte PIL Image ou tensor latente em ConditioningItem com latentes codificados.
217
+
218
+ Args:
219
+ raw_item: PIL.Image.Image ou torch.Tensor ([B, C, f, H, W] ou [B, C, H, W]).
220
+ frame_number: índice inicial no vídeo.
221
+ strength: peso de condicionamento (0.0–1.0).
222
+ height, width: resolução alvo do vídeo.
223
+ vae_per_channel_normalize: normalize nos canais do VAE.
224
+
225
+ Retorna:
226
+ ConditioningItem com media_item_latents corretamente formatados.
227
  """
228
+ # 1) Se for PIL, redimensiona e converte em latentes de pixel
229
+ if isinstance(raw_item, Image.Image):
230
+ pil = ImageOps.fit(raw_item, (width, height), Image.LANCZOS)
231
+ # image_to_latents: converte PIL→tensor [B, C, H_lat, W_lat]
232
+ pixel_latents = image_to_latents(pil) # fornecido pelo seu utilitário
233
+ # adiciona dimensão de frame se necessário
234
+ if pixel_latents.ndim == 4:
235
+ pixel_latents = pixel_latents.unsqueeze(2) # [B, C, 1, H_lat, W_lat]
236
+ latents = pixel_latents.to(dtype=self.vae.dtype, device=self.vae.device)
237
+
238
+ # codifica via VAE de vídeo
239
+ latents = vae_encode(
240
+ latents,
241
+ self.vae,
242
+ vae_per_channel_normalize=vae_per_channel_normalize,
243
+ ).to(dtype=latents.dtype, device=latents.device)
244
 
245
+ # 2) Se já for tensor de latentes
246
+ elif isinstance(raw_item, torch.Tensor):
247
+ latents = raw_item
248
+ # opcional: validar shape == (B, C, f, H_lat, W_lat)
249
+
250
+ else:
251
+ raise TypeError(f"Tipo não suportado: {type(raw_item)}")
252
+
253
+ return ConditioningItem(
254
+ media_item_latents=latents,
255
+ media_frame_number=frame_number,
256
+ conditioning_strength=strength,
257
+ )
258
+
259
 
260
 
261
  class LTXVideoPipeline(DiffusionPipeline):
 
1430
 
1431
 
1432
  def prepare_conditioning(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1433
  self,
1434
  conditioning_items: Optional[List[ConditioningItem]],
1435
  init_latents: torch.Tensor,
1436
  num_frames: int,
1437
  height: int,
1438
  width: int,
 
1439
  generator=None,
1440
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1441
  assert isinstance(self.vae, CausalVideoAutoencoder)
 
1442
  if conditioning_items:
1443
  batch_size, _, num_latent_frames = init_latents.shape[:3]
 
1444
  init_conditioning_mask = torch.zeros(
1445
  init_latents[:, 0, :, :, :].shape,
1446
  dtype=torch.float32,
1447
  device=init_latents.device,
1448
  )
 
1449
  extra_conditioning_latents = []
1450
  extra_conditioning_pixel_coords = []
1451
  extra_conditioning_mask = []
1452
+ extra_conditioning_num_latents = 0
1453
+
1454
+ for item in conditioning_items:
1455
+ media_latents = item.media_item_latents.to(
1456
+ dtype=init_latents.dtype, device=init_latents.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1457
  )
1458
+ strength = item.conditioning_strength
1459
+ frame_idx = item.media_frame_number
1460
+
1461
+ if frame_idx == 0:
1462
+ # posicionamento espacial
1463
+ media_latents, l_x, l_y = self._get_latent_spatial_position(
1464
+ media_latents, item, height, width, strip_latent_border=True
 
 
 
 
 
 
 
 
 
 
1465
  )
1466
+ b, c_l, f_l, h_l, w_l = media_latents.shape
1467
+
1468
+ init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = torch.lerp(
1469
+ init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l],
1470
+ media_latents,
1471
+ strength,
 
 
 
1472
  )
1473
+ init_conditioning_mask[:, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = strength
 
 
1474
  else:
1475
+ # sequências não iniciais
1476
+ if media_latents.shape[2] > 1:
1477
+ init_latents, init_conditioning_mask, media_latents = (
1478
+ self._handle_non_first_conditioning_sequence(
1479
+ init_latents,
1480
+ init_conditioning_mask,
1481
+ media_latents,
1482
+ frame_idx,
1483
+ strength,
1484
+ )
 
 
 
 
1485
  )
1486
+ if media_latents is not None:
 
 
1487
  noise = randn_tensor(
1488
+ media_latents.shape,
1489
  generator=generator,
1490
+ device=media_latents.device,
1491
+ dtype=media_latents.dtype,
1492
  )
1493
+ media_latents = torch.lerp(noise, media_latents, strength)
1494
+ # patchify
1495
+ media_latents, latent_coords = self.patchifier.patchify(
1496
+ latents=media_latents
 
 
 
 
1497
  )
1498
  pixel_coords = latent_to_pixel_coords(
1499
  latent_coords,
1500
  self.vae,
1501
  causal_fix=self.transformer.config.causal_temporal_positioning,
1502
  )
1503
+ pixel_coords[:, 0] += frame_idx
1504
+ extra_conditioning_num_latents += media_latents.shape[1]
1505
+ mask = torch.full(
1506
+ media_latents.shape[:2],
 
 
 
1507
  strength,
1508
  dtype=torch.float32,
1509
  device=init_latents.device,
1510
  )
1511
+ extra_conditioning_latents.append(media_latents)
 
1512
  extra_conditioning_pixel_coords.append(pixel_coords)
1513
+ extra_conditioning_mask.append(mask)
1514
+
1515
+ # patchify init_latents
1516
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
 
 
1517
  init_pixel_coords = latent_to_pixel_coords(
1518
  init_latent_coords,
1519
  self.vae,
1520
  causal_fix=self.transformer.config.causal_temporal_positioning,
1521
  )
1522
+
1523
  if not conditioning_items:
1524
  return init_latents, init_pixel_coords, None, 0
1525
+
1526
+ # patchify mask
1527
  init_conditioning_mask, _ = self.patchifier.patchify(
1528
  latents=init_conditioning_mask.unsqueeze(1)
1529
  )
1530
  init_conditioning_mask = init_conditioning_mask.squeeze(-1)
1531
+
1532
  if extra_conditioning_latents:
 
1533
  init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
1534
  init_pixel_coords = torch.cat(
1535
  [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2
 
1537
  init_conditioning_mask = torch.cat(
1538
  [*extra_conditioning_mask, init_conditioning_mask], dim=1
1539
  )
 
1540
  if self.transformer.use_tpu_flash_attention:
 
 
1541
  init_latents = init_latents[:, :-extra_conditioning_num_latents]
1542
+ init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents]
1543
+ init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents]
1544
+
1545
+ return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
1546
+
 
 
 
 
 
 
 
 
1547
 
1548
+
1549
  @staticmethod
1550
  def _resize_conditioning_item(
1551
  conditioning_item: ConditioningItem,