Spaces:
Paused
Paused
Update LTX-Video/ltx_video/pipelines/pipeline_ltx_video (1).py
Browse files
LTX-Video/ltx_video/pipelines/pipeline_ltx_video (1).py
CHANGED
|
@@ -190,32 +190,72 @@ def retrieve_timesteps(
|
|
| 190 |
return timesteps, num_inference_steps
|
| 191 |
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
@dataclass
|
| 195 |
-
class
|
| 196 |
-
|
| 197 |
-
latent_tensor: torch.Tensor
|
| 198 |
media_frame_number: int
|
| 199 |
conditioning_strength: float
|
|
|
|
|
|
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
"""
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
| 212 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
|
| 221 |
class LTXVideoPipeline(DiffusionPipeline):
|
|
@@ -1390,242 +1430,106 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1390 |
|
| 1391 |
|
| 1392 |
def prepare_conditioning(
|
| 1393 |
-
self: "LTXVideoPipeline",
|
| 1394 |
-
conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]],
|
| 1395 |
-
init_latents: torch.Tensor,
|
| 1396 |
-
num_frames: int,
|
| 1397 |
-
height: int,
|
| 1398 |
-
width: int,
|
| 1399 |
-
vae_per_channel_normalize: bool = False,
|
| 1400 |
-
generator=None,
|
| 1401 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
| 1402 |
-
if not conditioning_items:
|
| 1403 |
-
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
|
| 1404 |
-
init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
|
| 1405 |
-
return init_latents, init_pixel_coords, None, 0
|
| 1406 |
-
|
| 1407 |
-
init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
|
| 1408 |
-
extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
|
| 1409 |
-
extra_conditioning_num_latents = 0
|
| 1410 |
-
|
| 1411 |
-
for item in conditioning_items:
|
| 1412 |
-
if not isinstance(item, LatentConditioningItem):
|
| 1413 |
-
logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.")
|
| 1414 |
-
continue
|
| 1415 |
-
|
| 1416 |
-
media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
|
| 1417 |
-
media_frame_number, strength = item.media_frame_number, item.conditioning_strength
|
| 1418 |
-
|
| 1419 |
-
if media_frame_number == 0:
|
| 1420 |
-
f_l, h_l, w_l = media_item_latents.shape[-3:]
|
| 1421 |
-
init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
|
| 1422 |
-
init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
|
| 1423 |
-
else:
|
| 1424 |
-
noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
|
| 1425 |
-
media_item_latents = torch.lerp(noise, media_item_latents, strength)
|
| 1426 |
-
patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
|
| 1427 |
-
pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
|
| 1428 |
-
pixel_coords[:, 0] += media_frame_number
|
| 1429 |
-
extra_conditioning_num_latents += patched_latents.shape[1]
|
| 1430 |
-
new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
|
| 1431 |
-
extra_conditioning_latents.append(patched_latents)
|
| 1432 |
-
extra_conditioning_pixel_coords.append(pixel_coords)
|
| 1433 |
-
extra_conditioning_mask.append(new_mask)
|
| 1434 |
-
|
| 1435 |
-
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
|
| 1436 |
-
init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
|
| 1437 |
-
init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
|
| 1438 |
-
init_conditioning_mask = init_conditioning_mask.squeeze(-1)
|
| 1439 |
-
|
| 1440 |
-
if extra_conditioning_latents:
|
| 1441 |
-
init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
|
| 1442 |
-
init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
|
| 1443 |
-
init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
|
| 1444 |
-
|
| 1445 |
-
return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
|
| 1446 |
-
|
| 1447 |
-
|
| 1448 |
-
|
| 1449 |
-
def prepare_conditioning1(
|
| 1450 |
self,
|
| 1451 |
conditioning_items: Optional[List[ConditioningItem]],
|
| 1452 |
init_latents: torch.Tensor,
|
| 1453 |
num_frames: int,
|
| 1454 |
height: int,
|
| 1455 |
width: int,
|
| 1456 |
-
vae_per_channel_normalize: bool = False,
|
| 1457 |
generator=None,
|
| 1458 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
| 1459 |
-
"""
|
| 1460 |
-
Prepare conditioning tokens based on the provided conditioning items.
|
| 1461 |
-
|
| 1462 |
-
This method encodes provided conditioning items (video frames or single frames) into latents
|
| 1463 |
-
and integrates them with the initial latent tensor. It also calculates corresponding pixel
|
| 1464 |
-
coordinates, a mask indicating the influence of conditioning latents, and the total number of
|
| 1465 |
-
conditioning latents.
|
| 1466 |
-
|
| 1467 |
-
Args:
|
| 1468 |
-
conditioning_items (Optional[List[ConditioningItem]]): A list of ConditioningItem objects.
|
| 1469 |
-
init_latents (torch.Tensor): The initial latent tensor of shape (b, c, f_l, h_l, w_l), where
|
| 1470 |
-
`f_l` is the number of latent frames, and `h_l` and `w_l` are latent spatial dimensions.
|
| 1471 |
-
num_frames, height, width: The dimensions of the generated video.
|
| 1472 |
-
vae_per_channel_normalize (bool, optional): Whether to normalize channels during VAE encoding.
|
| 1473 |
-
Defaults to `False`.
|
| 1474 |
-
generator: The random generator
|
| 1475 |
-
|
| 1476 |
-
Returns:
|
| 1477 |
-
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
| 1478 |
-
- `init_latents` (torch.Tensor): The updated latent tensor including conditioning latents,
|
| 1479 |
-
patchified into (b, n, c) shape.
|
| 1480 |
-
- `init_pixel_coords` (torch.Tensor): The pixel coordinates corresponding to the updated
|
| 1481 |
-
latent tensor.
|
| 1482 |
-
- `conditioning_mask` (torch.Tensor): A mask indicating the conditioning-strength of each
|
| 1483 |
-
latent token.
|
| 1484 |
-
- `num_cond_latents` (int): The total number of latent tokens added from conditioning items.
|
| 1485 |
-
|
| 1486 |
-
Raises:
|
| 1487 |
-
AssertionError: If input shapes, dimensions, or conditions for applying conditioning are invalid.
|
| 1488 |
-
"""
|
| 1489 |
assert isinstance(self.vae, CausalVideoAutoencoder)
|
| 1490 |
-
|
| 1491 |
if conditioning_items:
|
| 1492 |
batch_size, _, num_latent_frames = init_latents.shape[:3]
|
| 1493 |
-
|
| 1494 |
init_conditioning_mask = torch.zeros(
|
| 1495 |
init_latents[:, 0, :, :, :].shape,
|
| 1496 |
dtype=torch.float32,
|
| 1497 |
device=init_latents.device,
|
| 1498 |
)
|
| 1499 |
-
|
| 1500 |
extra_conditioning_latents = []
|
| 1501 |
extra_conditioning_pixel_coords = []
|
| 1502 |
extra_conditioning_mask = []
|
| 1503 |
-
extra_conditioning_num_latents = 0
|
| 1504 |
-
|
| 1505 |
-
|
| 1506 |
-
|
| 1507 |
-
|
| 1508 |
-
conditioning_item, height, width
|
| 1509 |
-
)
|
| 1510 |
-
media_item = conditioning_item.media_item
|
| 1511 |
-
media_frame_number = conditioning_item.media_frame_number
|
| 1512 |
-
strength = conditioning_item.conditioning_strength
|
| 1513 |
-
assert media_item.ndim == 5 # (b, c, f, h, w)
|
| 1514 |
-
b, c, n_frames, h, w = media_item.shape
|
| 1515 |
-
assert (
|
| 1516 |
-
height == h and width == w
|
| 1517 |
-
) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
|
| 1518 |
-
assert n_frames % 8 == 1
|
| 1519 |
-
assert (
|
| 1520 |
-
media_frame_number >= 0
|
| 1521 |
-
and media_frame_number + n_frames <= num_frames
|
| 1522 |
)
|
| 1523 |
-
|
| 1524 |
-
|
| 1525 |
-
|
| 1526 |
-
|
| 1527 |
-
|
| 1528 |
-
|
| 1529 |
-
|
| 1530 |
-
|
| 1531 |
-
# Handle the different conditioning cases
|
| 1532 |
-
if media_frame_number == 0:
|
| 1533 |
-
# Get the target spatial position of the latent conditioning item
|
| 1534 |
-
media_item_latents, l_x, l_y = self._get_latent_spatial_position(
|
| 1535 |
-
media_item_latents,
|
| 1536 |
-
conditioning_item,
|
| 1537 |
-
height,
|
| 1538 |
-
width,
|
| 1539 |
-
strip_latent_border=True,
|
| 1540 |
)
|
| 1541 |
-
b, c_l, f_l, h_l, w_l =
|
| 1542 |
-
|
| 1543 |
-
|
| 1544 |
-
|
| 1545 |
-
|
| 1546 |
-
|
| 1547 |
-
media_item_latents,
|
| 1548 |
-
strength,
|
| 1549 |
-
)
|
| 1550 |
)
|
| 1551 |
-
init_conditioning_mask[
|
| 1552 |
-
:, :f_l, l_y : l_y + h_l, l_x : l_x + w_l
|
| 1553 |
-
] = strength
|
| 1554 |
else:
|
| 1555 |
-
#
|
| 1556 |
-
if
|
| 1557 |
-
|
| 1558 |
-
|
| 1559 |
-
|
| 1560 |
-
|
| 1561 |
-
|
| 1562 |
-
|
| 1563 |
-
|
| 1564 |
-
|
| 1565 |
-
init_conditioning_mask,
|
| 1566 |
-
media_item_latents,
|
| 1567 |
-
media_frame_number,
|
| 1568 |
-
strength,
|
| 1569 |
)
|
| 1570 |
-
|
| 1571 |
-
# Single frame or sequence-prefix latents
|
| 1572 |
-
if media_item_latents is not None:
|
| 1573 |
noise = randn_tensor(
|
| 1574 |
-
|
| 1575 |
generator=generator,
|
| 1576 |
-
device=
|
| 1577 |
-
dtype=
|
| 1578 |
)
|
| 1579 |
-
|
| 1580 |
-
|
| 1581 |
-
|
| 1582 |
-
|
| 1583 |
-
|
| 1584 |
-
# Patchify the extra conditioning latents and calculate their pixel coordinates
|
| 1585 |
-
media_item_latents, latent_coords = self.patchifier.patchify(
|
| 1586 |
-
latents=media_item_latents
|
| 1587 |
)
|
| 1588 |
pixel_coords = latent_to_pixel_coords(
|
| 1589 |
latent_coords,
|
| 1590 |
self.vae,
|
| 1591 |
causal_fix=self.transformer.config.causal_temporal_positioning,
|
| 1592 |
)
|
| 1593 |
-
|
| 1594 |
-
|
| 1595 |
-
|
| 1596 |
-
|
| 1597 |
-
|
| 1598 |
-
conditioning_mask = torch.full(
|
| 1599 |
-
media_item_latents.shape[:2],
|
| 1600 |
strength,
|
| 1601 |
dtype=torch.float32,
|
| 1602 |
device=init_latents.device,
|
| 1603 |
)
|
| 1604 |
-
|
| 1605 |
-
extra_conditioning_latents.append(media_item_latents)
|
| 1606 |
extra_conditioning_pixel_coords.append(pixel_coords)
|
| 1607 |
-
extra_conditioning_mask.append(
|
| 1608 |
-
|
| 1609 |
-
#
|
| 1610 |
-
init_latents, init_latent_coords = self.patchifier.patchify(
|
| 1611 |
-
latents=init_latents
|
| 1612 |
-
)
|
| 1613 |
init_pixel_coords = latent_to_pixel_coords(
|
| 1614 |
init_latent_coords,
|
| 1615 |
self.vae,
|
| 1616 |
causal_fix=self.transformer.config.causal_temporal_positioning,
|
| 1617 |
)
|
| 1618 |
-
|
| 1619 |
if not conditioning_items:
|
| 1620 |
return init_latents, init_pixel_coords, None, 0
|
| 1621 |
-
|
|
|
|
| 1622 |
init_conditioning_mask, _ = self.patchifier.patchify(
|
| 1623 |
latents=init_conditioning_mask.unsqueeze(1)
|
| 1624 |
)
|
| 1625 |
init_conditioning_mask = init_conditioning_mask.squeeze(-1)
|
| 1626 |
-
|
| 1627 |
if extra_conditioning_latents:
|
| 1628 |
-
# Stack the extra conditioning latents, pixel coordinates and mask
|
| 1629 |
init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
|
| 1630 |
init_pixel_coords = torch.cat(
|
| 1631 |
[*extra_conditioning_pixel_coords, init_pixel_coords], dim=2
|
|
@@ -1633,25 +1537,15 @@ class LTXVideoPipeline(DiffusionPipeline):
|
|
| 1633 |
init_conditioning_mask = torch.cat(
|
| 1634 |
[*extra_conditioning_mask, init_conditioning_mask], dim=1
|
| 1635 |
)
|
| 1636 |
-
|
| 1637 |
if self.transformer.use_tpu_flash_attention:
|
| 1638 |
-
# When flash attention is used, keep the original number of tokens by removing
|
| 1639 |
-
# tokens from the end.
|
| 1640 |
init_latents = init_latents[:, :-extra_conditioning_num_latents]
|
| 1641 |
-
init_pixel_coords = init_pixel_coords[
|
| 1642 |
-
|
| 1643 |
-
|
| 1644 |
-
|
| 1645 |
-
|
| 1646 |
-
]
|
| 1647 |
-
|
| 1648 |
-
return (
|
| 1649 |
-
init_latents,
|
| 1650 |
-
init_pixel_coords,
|
| 1651 |
-
init_conditioning_mask,
|
| 1652 |
-
extra_conditioning_num_latents,
|
| 1653 |
-
)
|
| 1654 |
|
|
|
|
| 1655 |
@staticmethod
|
| 1656 |
def _resize_conditioning_item(
|
| 1657 |
conditioning_item: ConditioningItem,
|
|
|
|
| 190 |
return timesteps, num_inference_steps
|
| 191 |
|
| 192 |
|
| 193 |
+
from typing import Union, Optional
|
| 194 |
+
from PIL import Image, ImageOps
|
| 195 |
+
import torch
|
| 196 |
+
from dataclasses import dataclass
|
| 197 |
|
| 198 |
@dataclass
|
| 199 |
+
class ConditioningItem:
|
| 200 |
+
media_item_latents: torch.Tensor
|
|
|
|
| 201 |
media_frame_number: int
|
| 202 |
conditioning_strength: float
|
| 203 |
+
media_x: Optional[int] = None
|
| 204 |
+
media_y: Optional[int] = None
|
| 205 |
|
| 206 |
+
def encode_conditioning_item(
|
| 207 |
+
self,
|
| 208 |
+
raw_item: Union[Image.Image, torch.Tensor],
|
| 209 |
+
frame_number: int,
|
| 210 |
+
strength: float,
|
| 211 |
+
height: int,
|
| 212 |
+
width: int,
|
| 213 |
+
vae_per_channel_normalize: bool = False,
|
| 214 |
+
) -> ConditioningItem:
|
| 215 |
"""
|
| 216 |
+
Converte PIL Image ou tensor latente em ConditioningItem com latentes codificados.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
raw_item: PIL.Image.Image ou torch.Tensor ([B, C, f, H, W] ou [B, C, H, W]).
|
| 220 |
+
frame_number: índice inicial no vídeo.
|
| 221 |
+
strength: peso de condicionamento (0.0–1.0).
|
| 222 |
+
height, width: resolução alvo do vídeo.
|
| 223 |
+
vae_per_channel_normalize: normalize nos canais do VAE.
|
| 224 |
+
|
| 225 |
+
Retorna:
|
| 226 |
+
ConditioningItem com media_item_latents corretamente formatados.
|
| 227 |
"""
|
| 228 |
+
# 1) Se for PIL, redimensiona e converte em latentes de pixel
|
| 229 |
+
if isinstance(raw_item, Image.Image):
|
| 230 |
+
pil = ImageOps.fit(raw_item, (width, height), Image.LANCZOS)
|
| 231 |
+
# image_to_latents: converte PIL→tensor [B, C, H_lat, W_lat]
|
| 232 |
+
pixel_latents = image_to_latents(pil) # fornecido pelo seu utilitário
|
| 233 |
+
# adiciona dimensão de frame se necessário
|
| 234 |
+
if pixel_latents.ndim == 4:
|
| 235 |
+
pixel_latents = pixel_latents.unsqueeze(2) # [B, C, 1, H_lat, W_lat]
|
| 236 |
+
latents = pixel_latents.to(dtype=self.vae.dtype, device=self.vae.device)
|
| 237 |
+
|
| 238 |
+
# codifica via VAE de vídeo
|
| 239 |
+
latents = vae_encode(
|
| 240 |
+
latents,
|
| 241 |
+
self.vae,
|
| 242 |
+
vae_per_channel_normalize=vae_per_channel_normalize,
|
| 243 |
+
).to(dtype=latents.dtype, device=latents.device)
|
| 244 |
|
| 245 |
+
# 2) Se já for tensor de latentes
|
| 246 |
+
elif isinstance(raw_item, torch.Tensor):
|
| 247 |
+
latents = raw_item
|
| 248 |
+
# opcional: validar shape == (B, C, f, H_lat, W_lat)
|
| 249 |
+
|
| 250 |
+
else:
|
| 251 |
+
raise TypeError(f"Tipo não suportado: {type(raw_item)}")
|
| 252 |
+
|
| 253 |
+
return ConditioningItem(
|
| 254 |
+
media_item_latents=latents,
|
| 255 |
+
media_frame_number=frame_number,
|
| 256 |
+
conditioning_strength=strength,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
|
| 260 |
|
| 261 |
class LTXVideoPipeline(DiffusionPipeline):
|
|
|
|
| 1430 |
|
| 1431 |
|
| 1432 |
def prepare_conditioning(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1433 |
self,
|
| 1434 |
conditioning_items: Optional[List[ConditioningItem]],
|
| 1435 |
init_latents: torch.Tensor,
|
| 1436 |
num_frames: int,
|
| 1437 |
height: int,
|
| 1438 |
width: int,
|
|
|
|
| 1439 |
generator=None,
|
| 1440 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1441 |
assert isinstance(self.vae, CausalVideoAutoencoder)
|
|
|
|
| 1442 |
if conditioning_items:
|
| 1443 |
batch_size, _, num_latent_frames = init_latents.shape[:3]
|
|
|
|
| 1444 |
init_conditioning_mask = torch.zeros(
|
| 1445 |
init_latents[:, 0, :, :, :].shape,
|
| 1446 |
dtype=torch.float32,
|
| 1447 |
device=init_latents.device,
|
| 1448 |
)
|
|
|
|
| 1449 |
extra_conditioning_latents = []
|
| 1450 |
extra_conditioning_pixel_coords = []
|
| 1451 |
extra_conditioning_mask = []
|
| 1452 |
+
extra_conditioning_num_latents = 0
|
| 1453 |
+
|
| 1454 |
+
for item in conditioning_items:
|
| 1455 |
+
media_latents = item.media_item_latents.to(
|
| 1456 |
+
dtype=init_latents.dtype, device=init_latents.device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1457 |
)
|
| 1458 |
+
strength = item.conditioning_strength
|
| 1459 |
+
frame_idx = item.media_frame_number
|
| 1460 |
+
|
| 1461 |
+
if frame_idx == 0:
|
| 1462 |
+
# posicionamento espacial
|
| 1463 |
+
media_latents, l_x, l_y = self._get_latent_spatial_position(
|
| 1464 |
+
media_latents, item, height, width, strip_latent_border=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1465 |
)
|
| 1466 |
+
b, c_l, f_l, h_l, w_l = media_latents.shape
|
| 1467 |
+
|
| 1468 |
+
init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = torch.lerp(
|
| 1469 |
+
init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l],
|
| 1470 |
+
media_latents,
|
| 1471 |
+
strength,
|
|
|
|
|
|
|
|
|
|
| 1472 |
)
|
| 1473 |
+
init_conditioning_mask[:, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = strength
|
|
|
|
|
|
|
| 1474 |
else:
|
| 1475 |
+
# sequências não iniciais
|
| 1476 |
+
if media_latents.shape[2] > 1:
|
| 1477 |
+
init_latents, init_conditioning_mask, media_latents = (
|
| 1478 |
+
self._handle_non_first_conditioning_sequence(
|
| 1479 |
+
init_latents,
|
| 1480 |
+
init_conditioning_mask,
|
| 1481 |
+
media_latents,
|
| 1482 |
+
frame_idx,
|
| 1483 |
+
strength,
|
| 1484 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1485 |
)
|
| 1486 |
+
if media_latents is not None:
|
|
|
|
|
|
|
| 1487 |
noise = randn_tensor(
|
| 1488 |
+
media_latents.shape,
|
| 1489 |
generator=generator,
|
| 1490 |
+
device=media_latents.device,
|
| 1491 |
+
dtype=media_latents.dtype,
|
| 1492 |
)
|
| 1493 |
+
media_latents = torch.lerp(noise, media_latents, strength)
|
| 1494 |
+
# patchify
|
| 1495 |
+
media_latents, latent_coords = self.patchifier.patchify(
|
| 1496 |
+
latents=media_latents
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1497 |
)
|
| 1498 |
pixel_coords = latent_to_pixel_coords(
|
| 1499 |
latent_coords,
|
| 1500 |
self.vae,
|
| 1501 |
causal_fix=self.transformer.config.causal_temporal_positioning,
|
| 1502 |
)
|
| 1503 |
+
pixel_coords[:, 0] += frame_idx
|
| 1504 |
+
extra_conditioning_num_latents += media_latents.shape[1]
|
| 1505 |
+
mask = torch.full(
|
| 1506 |
+
media_latents.shape[:2],
|
|
|
|
|
|
|
|
|
|
| 1507 |
strength,
|
| 1508 |
dtype=torch.float32,
|
| 1509 |
device=init_latents.device,
|
| 1510 |
)
|
| 1511 |
+
extra_conditioning_latents.append(media_latents)
|
|
|
|
| 1512 |
extra_conditioning_pixel_coords.append(pixel_coords)
|
| 1513 |
+
extra_conditioning_mask.append(mask)
|
| 1514 |
+
|
| 1515 |
+
# patchify init_latents
|
| 1516 |
+
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
|
|
|
|
|
|
|
| 1517 |
init_pixel_coords = latent_to_pixel_coords(
|
| 1518 |
init_latent_coords,
|
| 1519 |
self.vae,
|
| 1520 |
causal_fix=self.transformer.config.causal_temporal_positioning,
|
| 1521 |
)
|
| 1522 |
+
|
| 1523 |
if not conditioning_items:
|
| 1524 |
return init_latents, init_pixel_coords, None, 0
|
| 1525 |
+
|
| 1526 |
+
# patchify mask
|
| 1527 |
init_conditioning_mask, _ = self.patchifier.patchify(
|
| 1528 |
latents=init_conditioning_mask.unsqueeze(1)
|
| 1529 |
)
|
| 1530 |
init_conditioning_mask = init_conditioning_mask.squeeze(-1)
|
| 1531 |
+
|
| 1532 |
if extra_conditioning_latents:
|
|
|
|
| 1533 |
init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
|
| 1534 |
init_pixel_coords = torch.cat(
|
| 1535 |
[*extra_conditioning_pixel_coords, init_pixel_coords], dim=2
|
|
|
|
| 1537 |
init_conditioning_mask = torch.cat(
|
| 1538 |
[*extra_conditioning_mask, init_conditioning_mask], dim=1
|
| 1539 |
)
|
|
|
|
| 1540 |
if self.transformer.use_tpu_flash_attention:
|
|
|
|
|
|
|
| 1541 |
init_latents = init_latents[:, :-extra_conditioning_num_latents]
|
| 1542 |
+
init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents]
|
| 1543 |
+
init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents]
|
| 1544 |
+
|
| 1545 |
+
return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
|
| 1546 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1547 |
|
| 1548 |
+
|
| 1549 |
@staticmethod
|
| 1550 |
def _resize_conditioning_item(
|
| 1551 |
conditioning_item: ConditioningItem,
|