|
|
from typing import Any, List, Tuple, Optional, Union, Dict |
|
|
from einops import rearrange |
|
|
from flash_attn import flash_attn_func |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import math |
|
|
from torch.nn.attention.flex_attention import flex_attention |
|
|
|
|
|
try: |
|
|
import flash_attn |
|
|
|
|
|
except: |
|
|
from flash_attn import flash_attn_func |
|
|
|
|
|
FLASH_ATTN_3_AVAILABLE = False |
|
|
|
|
|
|
|
|
DISABLE_COMPILE = False |
|
|
flex_attention = torch.compile( |
|
|
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" |
|
|
) |
|
|
|
|
|
import torch |
|
|
from typing import Union, Tuple, List |
|
|
|
|
|
|
|
|
def _to_tuple(x, dim=2): |
|
|
if isinstance(x, int): |
|
|
return (x,) * dim |
|
|
elif len(x) == dim: |
|
|
return x |
|
|
else: |
|
|
raise ValueError(f"Expected length {dim} or int, but got {x}") |
|
|
|
|
|
|
|
|
def get_meshgrid_nd(start, *args, dim=2): |
|
|
""" |
|
|
Get n-D meshgrid with start, stop and num. |
|
|
|
|
|
Args: |
|
|
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, |
|
|
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num |
|
|
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in |
|
|
n-tuples. |
|
|
*args: See above. |
|
|
dim (int): Dimension of the meshgrid. Defaults to 2. |
|
|
|
|
|
Returns: |
|
|
grid (np.ndarray): [dim, ...] |
|
|
""" |
|
|
if len(args) == 0: |
|
|
|
|
|
num = _to_tuple(start, dim=dim) |
|
|
start = (0,) * dim |
|
|
stop = num |
|
|
elif len(args) == 1: |
|
|
|
|
|
start = _to_tuple(start, dim=dim) |
|
|
stop = _to_tuple(args[0], dim=dim) |
|
|
num = [stop[i] - start[i] for i in range(dim)] |
|
|
elif len(args) == 2: |
|
|
|
|
|
start = _to_tuple(start, dim=dim) |
|
|
stop = _to_tuple(args[0], dim=dim) |
|
|
num = _to_tuple(args[1], dim=dim) |
|
|
else: |
|
|
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") |
|
|
|
|
|
|
|
|
axis_grid = [] |
|
|
for i in range(dim): |
|
|
a, b, n = start[i], stop[i], num[i] |
|
|
g = torch.linspace(a, b, n + 1, dtype=torch.float32, device=torch.cuda.current_device())[:n] |
|
|
axis_grid.append(g) |
|
|
grid = torch.meshgrid(*axis_grid, indexing="ij") |
|
|
grid = torch.stack(grid, dim=0) |
|
|
|
|
|
return grid |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reshape_for_broadcast( |
|
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], |
|
|
x: torch.Tensor, |
|
|
head_first=False, |
|
|
): |
|
|
""" |
|
|
Reshape frequency tensor for broadcasting it with another tensor. |
|
|
|
|
|
This function reshapes the frequency tensor to have the same shape as the target tensor 'x' |
|
|
for the purpose of broadcasting the frequency tensor during element-wise operations. |
|
|
|
|
|
Notes: |
|
|
When using FlashMHAModified, head_first should be False. |
|
|
When using Attention, head_first should be True. |
|
|
|
|
|
Args: |
|
|
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. |
|
|
x (torch.Tensor): Target tensor for broadcasting compatibility. |
|
|
head_first (bool): head dimension first (except batch dim) or not. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Reshaped frequency tensor. |
|
|
|
|
|
Raises: |
|
|
AssertionError: If the frequency tensor doesn't match the expected shape. |
|
|
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. |
|
|
""" |
|
|
ndim = x.ndim |
|
|
assert 0 <= 1 < ndim |
|
|
|
|
|
if isinstance(freqs_cis, tuple): |
|
|
|
|
|
if head_first: |
|
|
assert freqs_cis[0].shape == ( |
|
|
x.shape[-2], |
|
|
x.shape[-1], |
|
|
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" |
|
|
shape = [ |
|
|
d if i == ndim - 2 or i == ndim - 1 else 1 |
|
|
for i, d in enumerate(x.shape) |
|
|
] |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shape = [1, freqs_cis[0].shape[0], 1, freqs_cis[0].shape[1]] |
|
|
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) |
|
|
else: |
|
|
|
|
|
if head_first: |
|
|
assert freqs_cis.shape == ( |
|
|
x.shape[-2], |
|
|
x.shape[-1], |
|
|
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" |
|
|
shape = [ |
|
|
d if i == ndim - 2 or i == ndim - 1 else 1 |
|
|
for i, d in enumerate(x.shape) |
|
|
] |
|
|
else: |
|
|
assert freqs_cis.shape == ( |
|
|
x.shape[1], |
|
|
x.shape[-1], |
|
|
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" |
|
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
|
|
return freqs_cis.view(*shape) |
|
|
|
|
|
|
|
|
def rotate_half(x): |
|
|
x_real, x_imag = ( |
|
|
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) |
|
|
) |
|
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3) |
|
|
|
|
|
|
|
|
def apply_rotary_emb( |
|
|
xq: torch.Tensor, |
|
|
xk: torch.Tensor, |
|
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], |
|
|
head_first: bool = False, |
|
|
start_offset: int = 0, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Apply rotary embeddings to input tensors using the given frequency tensor. |
|
|
|
|
|
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided |
|
|
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor |
|
|
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are |
|
|
returned as real tensors. |
|
|
|
|
|
Args: |
|
|
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] |
|
|
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] |
|
|
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. |
|
|
head_first (bool): head dimension first (except batch dim) or not. |
|
|
|
|
|
Returns: |
|
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. |
|
|
|
|
|
""" |
|
|
|
|
|
xk_out = None |
|
|
assert isinstance(freqs_cis, tuple) |
|
|
if isinstance(freqs_cis, tuple): |
|
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) |
|
|
cos, sin = cos.to(xq.device), sin.to(xq.device) |
|
|
|
|
|
|
|
|
xq_out = (xq.float() * cos[:, start_offset:start_offset + xq.shape[1], :, :] + rotate_half(xq.float()) * sin[:, start_offset:start_offset + xq.shape[1], :, :]).type_as(xq) |
|
|
xk_out = (xk.float() * cos[:, start_offset:start_offset + xk.shape[1], :, :] + rotate_half(xk.float()) * sin[:, start_offset:start_offset + xk.shape[1], :, :]).type_as(xk) |
|
|
else: |
|
|
|
|
|
xq_ = torch.view_as_complex( |
|
|
xq.float().reshape(*xq.shape[:-1], -1, 2) |
|
|
) |
|
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to( |
|
|
xq.device |
|
|
) |
|
|
|
|
|
|
|
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) |
|
|
xk_ = torch.view_as_complex( |
|
|
xk.float().reshape(*xk.shape[:-1], -1, 2) |
|
|
) |
|
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) |
|
|
|
|
|
return xq_out, xk_out |
|
|
|
|
|
|
|
|
def get_nd_rotary_pos_embed( |
|
|
rope_dim_list, |
|
|
start, |
|
|
*args, |
|
|
theta=10000.0, |
|
|
use_real=False, |
|
|
theta_rescale_factor: Union[float, List[float]] = 1.0, |
|
|
interpolation_factor: Union[float, List[float]] = 1.0, |
|
|
): |
|
|
""" |
|
|
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. |
|
|
|
|
|
Args: |
|
|
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. |
|
|
sum(rope_dim_list) should equal to head_dim of attention layer. |
|
|
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, |
|
|
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. |
|
|
*args: See above. |
|
|
theta (float): Scaling factor for frequency computation. Defaults to 10000.0. |
|
|
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. |
|
|
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real |
|
|
part and an imaginary part separately. |
|
|
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. |
|
|
|
|
|
Returns: |
|
|
pos_embed (torch.Tensor): [HW, D/2] |
|
|
""" |
|
|
|
|
|
grid = get_meshgrid_nd( |
|
|
start, *args, dim=len(rope_dim_list) |
|
|
) |
|
|
|
|
|
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): |
|
|
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) |
|
|
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: |
|
|
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) |
|
|
assert len(theta_rescale_factor) == len( |
|
|
rope_dim_list |
|
|
), "len(theta_rescale_factor) should equal to len(rope_dim_list)" |
|
|
|
|
|
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): |
|
|
interpolation_factor = [interpolation_factor] * len(rope_dim_list) |
|
|
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: |
|
|
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) |
|
|
assert len(interpolation_factor) == len( |
|
|
rope_dim_list |
|
|
), "len(interpolation_factor) should equal to len(rope_dim_list)" |
|
|
|
|
|
|
|
|
embs = [] |
|
|
for i in range(len(rope_dim_list)): |
|
|
emb = get_1d_rotary_pos_embed( |
|
|
rope_dim_list[i], |
|
|
grid[i].reshape(-1), |
|
|
theta, |
|
|
use_real=use_real, |
|
|
theta_rescale_factor=theta_rescale_factor[i], |
|
|
interpolation_factor=interpolation_factor[i], |
|
|
) |
|
|
embs.append(emb) |
|
|
|
|
|
if use_real: |
|
|
cos = torch.cat([emb[0] for emb in embs], dim=1) |
|
|
sin = torch.cat([emb[1] for emb in embs], dim=1) |
|
|
return cos, sin |
|
|
else: |
|
|
emb = torch.cat(embs, dim=1) |
|
|
return emb |
|
|
|
|
|
|
|
|
def get_1d_rotary_pos_embed( |
|
|
dim: int, |
|
|
pos: Union[torch.FloatTensor, int], |
|
|
theta: float = 10000.0, |
|
|
use_real: bool = False, |
|
|
theta_rescale_factor: float = 1.0, |
|
|
interpolation_factor: float = 1.0, |
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
|
""" |
|
|
Precompute the frequency tensor for complex exponential (cis) with given dimensions. |
|
|
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) |
|
|
|
|
|
This function calculates a frequency tensor with complex exponential using the given dimension 'dim' |
|
|
and the end index 'end'. The 'theta' parameter scales the frequencies. |
|
|
The returned tensor contains complex values in complex64 data type. |
|
|
|
|
|
Args: |
|
|
dim (int): Dimension of the frequency tensor. |
|
|
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar |
|
|
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. |
|
|
use_real (bool, optional): If True, return real part and imaginary part separately. |
|
|
Otherwise, return complex numbers. |
|
|
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. |
|
|
|
|
|
Returns: |
|
|
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] |
|
|
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] |
|
|
""" |
|
|
if isinstance(pos, int): |
|
|
pos = torch.arange(pos, device=torch.cuda.current_device()).float() |
|
|
|
|
|
|
|
|
|
|
|
if theta_rescale_factor != 1.0: |
|
|
theta *= theta_rescale_factor ** (dim / (dim - 2)) |
|
|
|
|
|
freqs = 1.0 / ( |
|
|
theta ** (torch.arange(0, dim, 2, device=torch.cuda.current_device())[: (dim // 2)].float() / dim) |
|
|
) |
|
|
|
|
|
freqs = torch.outer(pos * interpolation_factor, freqs) |
|
|
if use_real: |
|
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) |
|
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) |
|
|
return freqs_cos, freqs_sin |
|
|
else: |
|
|
freqs_cis = torch.polar( |
|
|
torch.ones_like(freqs), freqs |
|
|
) |
|
|
return freqs_cis |
|
|
|
|
|
|
|
|
class MatrixGameWanRMSNorm(nn.Module): |
|
|
def __init__(self, dim, eps=1e-5): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
r""" |
|
|
Args: |
|
|
x(Tensor): Shape [B, L, C] |
|
|
""" |
|
|
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
|
|
|
|
|
|
|
|
class ActionModule(nn.Module): |
|
|
""" |
|
|
action module from https://arxiv.org/pdf/2501.08325 |
|
|
鼠标控制信号的输入是一个 L*D 的向量 |
|
|
键盘同样 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
mouse_dim_in: int = 2, |
|
|
keyboard_dim_in: int = 6, |
|
|
hidden_size: int = 128, |
|
|
img_hidden_size: int = 1536, |
|
|
keyboard_hidden_dim: int = 1024, |
|
|
mouse_hidden_dim: int = 1024, |
|
|
vae_time_compression_ratio: int = 4, |
|
|
windows_size: int = 3, |
|
|
heads_num: int = 16, |
|
|
patch_size: list = [1, 2, 2], |
|
|
qk_norm: bool = True, |
|
|
qkv_bias: bool = False, |
|
|
rope_dim_list: list = [8, 28, 28], |
|
|
rope_theta=256, |
|
|
mouse_qk_dim_list=[8, 28, 28], |
|
|
enable_mouse=True, |
|
|
enable_keyboard=True, |
|
|
local_attn_size=6, |
|
|
blocks=[], |
|
|
): |
|
|
device = None |
|
|
|
|
|
super().__init__() |
|
|
self.local_attn_size = local_attn_size |
|
|
self.enable_mouse = enable_mouse |
|
|
self.enable_keyboard = enable_keyboard |
|
|
|
|
|
self.rope_dim_list = rope_dim_list |
|
|
self.rope_theta = rope_theta |
|
|
if self.enable_keyboard: |
|
|
self.keyboard_embed = nn.Sequential( |
|
|
nn.Linear(keyboard_dim_in, hidden_size, bias=True), |
|
|
nn.SiLU(), |
|
|
nn.Linear(hidden_size, hidden_size, bias=True), |
|
|
) |
|
|
|
|
|
self.mouse_qk_dim_list = mouse_qk_dim_list |
|
|
self.heads_num = heads_num |
|
|
if self.enable_mouse: |
|
|
c = mouse_hidden_dim |
|
|
self.mouse_mlp = torch.nn.Sequential( |
|
|
torch.nn.Linear( |
|
|
mouse_dim_in * vae_time_compression_ratio * windows_size |
|
|
+ img_hidden_size, |
|
|
c, |
|
|
bias=True, |
|
|
), |
|
|
torch.nn.GELU(approximate="tanh"), |
|
|
torch.nn.Linear(c, c), |
|
|
torch.nn.LayerNorm(c), |
|
|
) |
|
|
|
|
|
head_dim = c // heads_num |
|
|
self.t_qkv = nn.Linear(c, c * 3, bias=qkv_bias) |
|
|
self.img_attn_q_norm = ( |
|
|
MatrixGameWanRMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity() |
|
|
) |
|
|
self.img_attn_k_norm = ( |
|
|
MatrixGameWanRMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity() |
|
|
) |
|
|
self.proj_mouse = nn.Linear(c, img_hidden_size, bias=qkv_bias) |
|
|
|
|
|
if self.enable_keyboard: |
|
|
head_dim_key = keyboard_hidden_dim // heads_num |
|
|
self.key_attn_q_norm = ( |
|
|
MatrixGameWanRMSNorm(head_dim_key, eps=1e-6) if qk_norm else nn.Identity() |
|
|
) |
|
|
self.key_attn_k_norm = ( |
|
|
MatrixGameWanRMSNorm(head_dim_key, eps=1e-6) if qk_norm else nn.Identity() |
|
|
) |
|
|
|
|
|
self.mouse_attn_q = nn.Linear( |
|
|
img_hidden_size, keyboard_hidden_dim, bias=qkv_bias |
|
|
) |
|
|
self.keyboard_attn_kv = nn.Linear( |
|
|
hidden_size * windows_size * vae_time_compression_ratio, |
|
|
keyboard_hidden_dim * 2, |
|
|
bias=qkv_bias, |
|
|
) |
|
|
self.proj_keyboard = nn.Linear( |
|
|
keyboard_hidden_dim, img_hidden_size, bias=qkv_bias |
|
|
) |
|
|
|
|
|
self.vae_time_compression_ratio = vae_time_compression_ratio |
|
|
self.windows_size = windows_size |
|
|
self.patch_size = patch_size |
|
|
self.freqs_cos, self.freqs_sin = self.get_rotary_pos_embed( |
|
|
7500, |
|
|
self.patch_size[1], |
|
|
self.patch_size[2], |
|
|
64, |
|
|
self.mouse_qk_dim_list, |
|
|
start_offset=0, |
|
|
) |
|
|
|
|
|
def patchify(self, x, patch_size): |
|
|
""" |
|
|
x : (N C T H W) |
|
|
""" |
|
|
pt, ph, pw = self.patch_size |
|
|
t, h, w = x.shape[2] // pt, x.shape[3] // ph, x.shape[4] // pw |
|
|
c = x.shape[1] |
|
|
x = x.reshape(shape=(x.shape[0], c, t, pt, h, ph, w, pw)) |
|
|
x = torch.einsum("nctohpwq->nthwcopq", x) |
|
|
x = x.reshape(shape=(x.shape[0], t * h * w, c * pt * ph * pw)) |
|
|
return x |
|
|
|
|
|
def unpatchify(self, x, t, h, w, patch_size): |
|
|
""" |
|
|
x: (N, T, patch_size**2 * C) |
|
|
imgs: (N, H, W, C) |
|
|
""" |
|
|
c = x.shape[2] // patch_size |
|
|
pt, ph, pw = self.patch_size |
|
|
assert t * h * w == x.shape[1] |
|
|
|
|
|
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) |
|
|
x = torch.einsum("nthwcopq->nctohpwq", x) |
|
|
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) |
|
|
|
|
|
return imgs |
|
|
|
|
|
def get_rotary_pos_embed( |
|
|
self, video_length, height, width, head_dim, rope_dim_list=None, start_offset=0 |
|
|
): |
|
|
target_ndim = 3 |
|
|
ndim = 5 - 2 |
|
|
|
|
|
latents_size = [video_length + start_offset, height, width] |
|
|
|
|
|
if isinstance(self.patch_size, int): |
|
|
assert all(s % self.patch_size == 0 for s in latents_size), ( |
|
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.patch_size}), " |
|
|
f"but got {latents_size}." |
|
|
) |
|
|
rope_sizes = [s // self.patch_size for s in latents_size] |
|
|
elif isinstance(self.patch_size, list): |
|
|
assert all( |
|
|
s % self.patch_size[idx] == 0 for idx, s in enumerate(latents_size) |
|
|
), ( |
|
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.patch_size}), " |
|
|
f"but got {latents_size}." |
|
|
) |
|
|
rope_sizes = [ |
|
|
s // self.patch_size[idx] for idx, s in enumerate(latents_size) |
|
|
] |
|
|
|
|
|
if len(rope_sizes) != target_ndim: |
|
|
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes |
|
|
|
|
|
if rope_dim_list is None: |
|
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] |
|
|
assert ( |
|
|
sum(rope_dim_list) == head_dim |
|
|
), "sum(rope_dim_list) should equal to head_dim of attention layer" |
|
|
freqs_cos, freqs_sin = get_nd_rotary_pos_embed( |
|
|
rope_dim_list, |
|
|
rope_sizes, |
|
|
theta=self.rope_theta, |
|
|
use_real=True, |
|
|
theta_rescale_factor=1, |
|
|
) |
|
|
return freqs_cos[ |
|
|
-video_length * rope_sizes[1] * rope_sizes[2] // self.patch_size[0] : |
|
|
], freqs_sin[ |
|
|
-video_length * rope_sizes[1] * rope_sizes[2] // self.patch_size[0] : |
|
|
] |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
tt, |
|
|
th, |
|
|
tw, |
|
|
mouse_condition=None, |
|
|
keyboard_condition=None, |
|
|
block_mask_mouse=None, |
|
|
block_mask_keyboard=None, |
|
|
is_causal=False, |
|
|
kv_cache_mouse=None, |
|
|
kv_cache_keyboard=None, |
|
|
start_frame=0, |
|
|
use_rope_keyboard=True, |
|
|
num_frame_per_block=3, |
|
|
): |
|
|
""" |
|
|
hidden_states: B, tt*th*tw, C |
|
|
mouse_condition: B, N_frames, C1 |
|
|
keyboard_condition: B, N_frames, C2 |
|
|
""" |
|
|
assert use_rope_keyboard == True |
|
|
|
|
|
B, N_frames, C = keyboard_condition.shape |
|
|
|
|
|
assert tt * th * tw == x.shape[1] |
|
|
assert ( |
|
|
(N_frames - 1) + self.vae_time_compression_ratio |
|
|
) % self.vae_time_compression_ratio == 0 |
|
|
N_feats = int((N_frames - 1) / self.vae_time_compression_ratio) + 1 |
|
|
|
|
|
|
|
|
freqs_cis = (self.freqs_cos, self.freqs_sin) |
|
|
|
|
|
assert ( |
|
|
N_feats == tt and ((is_causal and kv_cache_mouse == None) or not is_causal) |
|
|
) or ( |
|
|
(N_frames - 1) // self.vae_time_compression_ratio + 1 == start_frame + num_frame_per_block and is_causal |
|
|
) |
|
|
|
|
|
if self.enable_mouse and mouse_condition is not None: |
|
|
hidden_states = rearrange( |
|
|
x, "B (T S) C -> (B S) T C", T=tt, S=th * tw |
|
|
) |
|
|
B, N_frames, C = mouse_condition.shape |
|
|
else: |
|
|
hidden_states = x |
|
|
|
|
|
|
|
|
pad_t = self.vae_time_compression_ratio * self.windows_size |
|
|
if self.enable_mouse and mouse_condition is not None: |
|
|
pad = mouse_condition[:, 0:1, :].expand(-1, pad_t, -1) |
|
|
mouse_condition = torch.cat([pad, mouse_condition], dim=1) |
|
|
if is_causal and kv_cache_mouse is not None: |
|
|
mouse_condition = mouse_condition[ |
|
|
:, |
|
|
self.vae_time_compression_ratio |
|
|
* (N_feats - num_frame_per_block - self.windows_size) |
|
|
+ pad_t :, |
|
|
:, |
|
|
] |
|
|
group_mouse = [ |
|
|
mouse_condition[ |
|
|
:, |
|
|
self.vae_time_compression_ratio * (i - self.windows_size) |
|
|
+ pad_t : i * self.vae_time_compression_ratio + pad_t, |
|
|
:, |
|
|
] |
|
|
for i in range(num_frame_per_block) |
|
|
] |
|
|
else: |
|
|
group_mouse = [ |
|
|
mouse_condition[ |
|
|
:, |
|
|
self.vae_time_compression_ratio * (i - self.windows_size) |
|
|
+ pad_t : i * self.vae_time_compression_ratio + pad_t, |
|
|
:, |
|
|
] |
|
|
for i in range(N_feats) |
|
|
] |
|
|
|
|
|
group_mouse = torch.stack(group_mouse, dim=1) |
|
|
|
|
|
S = th * tw |
|
|
group_mouse = group_mouse.unsqueeze(-1).expand( |
|
|
B, num_frame_per_block, pad_t, C, S |
|
|
) |
|
|
group_mouse = group_mouse.permute(0, 4, 1, 2, 3).reshape( |
|
|
B * S, num_frame_per_block, pad_t * C |
|
|
) |
|
|
|
|
|
group_mouse = torch.cat([hidden_states, group_mouse], dim=-1) |
|
|
group_mouse = self.mouse_mlp(group_mouse) |
|
|
|
|
|
|
|
|
mouse_qkv = self.t_qkv(group_mouse) |
|
|
q, k, v = rearrange( |
|
|
mouse_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num |
|
|
) |
|
|
q = self.img_attn_q_norm(q).to(v) |
|
|
k = self.img_attn_k_norm(k).to(v) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q, k = apply_rotary_emb( |
|
|
q, k, freqs_cis, start_offset=start_frame, head_first=False |
|
|
) |
|
|
|
|
|
if is_causal: |
|
|
if kv_cache_mouse is None: |
|
|
assert ( |
|
|
q.shape[0] == k.shape[0] and q.shape[0] % 880 == 0 |
|
|
) |
|
|
padded_length = math.ceil(q.shape[1] / 32) * 32 - q.shape[1] |
|
|
padded_q = torch.cat( |
|
|
[ |
|
|
q, |
|
|
torch.zeros( |
|
|
[q.shape[0], padded_length, q.shape[2], q.shape[3]], |
|
|
device=q.device, |
|
|
dtype=v.dtype, |
|
|
), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
padded_k = torch.cat( |
|
|
[ |
|
|
k, |
|
|
torch.zeros( |
|
|
[k.shape[0], padded_length, k.shape[2], k.shape[3]], |
|
|
device=k.device, |
|
|
dtype=v.dtype, |
|
|
), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
padded_v = torch.cat( |
|
|
[ |
|
|
v, |
|
|
torch.zeros( |
|
|
[v.shape[0], padded_length, v.shape[2], v.shape[3]], |
|
|
device=v.device, |
|
|
dtype=v.dtype, |
|
|
), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
attn = flex_attention( |
|
|
query=padded_q.transpose(2, 1), |
|
|
key=padded_k.transpose(2, 1), |
|
|
value=padded_v.transpose(2, 1), |
|
|
block_mask=block_mask_mouse, |
|
|
)[:, :, :-padded_length].transpose(2, 1) |
|
|
else: |
|
|
current_start = start_frame |
|
|
current_end = current_start + q.shape[1] |
|
|
|
|
|
assert q.shape[1] == num_frame_per_block |
|
|
sink_size = 0 |
|
|
max_attention_size = self.local_attn_size |
|
|
sink_tokens = sink_size * 1 |
|
|
kv_cache_size = kv_cache_mouse["k"].shape[1] |
|
|
num_new_tokens = q.shape[1] |
|
|
|
|
|
if (current_end > kv_cache_mouse["global_end_index"].item()) and ( |
|
|
num_new_tokens + kv_cache_mouse["local_end_index"].item() |
|
|
> kv_cache_size |
|
|
): |
|
|
num_evicted_tokens = ( |
|
|
num_new_tokens |
|
|
+ kv_cache_mouse["local_end_index"].item() |
|
|
- kv_cache_size |
|
|
) |
|
|
num_rolled_tokens = ( |
|
|
kv_cache_mouse["local_end_index"].item() |
|
|
- num_evicted_tokens |
|
|
- sink_tokens |
|
|
) |
|
|
kv_cache_mouse["k"][ |
|
|
:, sink_tokens : sink_tokens + num_rolled_tokens |
|
|
] = kv_cache_mouse["k"][ |
|
|
:, |
|
|
sink_tokens + num_evicted_tokens : sink_tokens |
|
|
+ num_evicted_tokens |
|
|
+ num_rolled_tokens, |
|
|
].clone() |
|
|
kv_cache_mouse["v"][ |
|
|
:, sink_tokens : sink_tokens + num_rolled_tokens |
|
|
] = kv_cache_mouse["v"][ |
|
|
:, |
|
|
sink_tokens + num_evicted_tokens : sink_tokens |
|
|
+ num_evicted_tokens |
|
|
+ num_rolled_tokens, |
|
|
].clone() |
|
|
|
|
|
local_end_index = ( |
|
|
kv_cache_mouse["local_end_index"].item() |
|
|
+ current_end |
|
|
- kv_cache_mouse["global_end_index"].item() |
|
|
- num_evicted_tokens |
|
|
) |
|
|
local_start_index = local_end_index - num_new_tokens |
|
|
else: |
|
|
local_end_index = ( |
|
|
kv_cache_mouse["local_end_index"].item() |
|
|
+ current_end |
|
|
- kv_cache_mouse["global_end_index"].item() |
|
|
) |
|
|
local_start_index = local_end_index - num_new_tokens |
|
|
|
|
|
kv_cache_mouse["k"][:, local_start_index:local_end_index] = k |
|
|
kv_cache_mouse["v"][:, local_start_index:local_end_index] = v |
|
|
|
|
|
if FLASH_ATTN_3_AVAILABLE: |
|
|
attn, attn_prob = flash_attn.flash_attn_func( |
|
|
q, |
|
|
kv_cache_mouse["k"][ |
|
|
:, |
|
|
max( |
|
|
0, local_end_index - max_attention_size |
|
|
) : local_end_index, |
|
|
], |
|
|
kv_cache_mouse["v"][ |
|
|
:, |
|
|
max( |
|
|
0, local_end_index - max_attention_size |
|
|
) : local_end_index, |
|
|
], |
|
|
) |
|
|
else: |
|
|
attn = flash_attn_func( |
|
|
q, |
|
|
kv_cache_mouse["k"][ |
|
|
:, |
|
|
max( |
|
|
0, local_end_index - max_attention_size |
|
|
) : local_end_index, |
|
|
], |
|
|
kv_cache_mouse["v"][ |
|
|
:, |
|
|
max( |
|
|
0, local_end_index - max_attention_size |
|
|
) : local_end_index, |
|
|
], |
|
|
) |
|
|
kv_cache_mouse["global_end_index"].fill_(current_end) |
|
|
kv_cache_mouse["local_end_index"].fill_(local_end_index) |
|
|
else: |
|
|
attn = flash_attn_func( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
) |
|
|
|
|
|
|
|
|
attn = rearrange(attn, "(b S) T h d -> b (T S) (h d)", b=B) |
|
|
|
|
|
hidden_states = rearrange(x, "(B S) T C -> B (T S) C", B=B) |
|
|
attn = self.proj_mouse(attn) |
|
|
|
|
|
hidden_states = hidden_states + attn |
|
|
|
|
|
if self.enable_keyboard and keyboard_condition is not None: |
|
|
pad = keyboard_condition[:, 0:1, :].expand(-1, pad_t, -1) |
|
|
keyboard_condition = torch.cat([pad, keyboard_condition], dim=1) |
|
|
if is_causal and kv_cache_keyboard is not None: |
|
|
keyboard_condition = keyboard_condition[ |
|
|
:, |
|
|
self.vae_time_compression_ratio |
|
|
* (N_feats - num_frame_per_block - self.windows_size) |
|
|
+ pad_t :, |
|
|
:, |
|
|
] |
|
|
keyboard_condition = self.keyboard_embed(keyboard_condition) |
|
|
group_keyboard = [ |
|
|
keyboard_condition[ |
|
|
:, |
|
|
self.vae_time_compression_ratio * (i - self.windows_size) |
|
|
+ pad_t : i * self.vae_time_compression_ratio + pad_t, |
|
|
:, |
|
|
] |
|
|
for i in range(num_frame_per_block) |
|
|
] |
|
|
else: |
|
|
keyboard_condition = self.keyboard_embed(keyboard_condition) |
|
|
group_keyboard = [ |
|
|
keyboard_condition[ |
|
|
:, |
|
|
self.vae_time_compression_ratio * (i - self.windows_size) |
|
|
+ pad_t : i * self.vae_time_compression_ratio + pad_t, |
|
|
:, |
|
|
] |
|
|
for i in range(N_feats) |
|
|
] |
|
|
group_keyboard = torch.stack(group_keyboard, dim=1) |
|
|
group_keyboard = group_keyboard.reshape( |
|
|
shape=(group_keyboard.shape[0], group_keyboard.shape[1], -1) |
|
|
) |
|
|
|
|
|
mouse_q = self.mouse_attn_q(hidden_states) |
|
|
keyboard_kv = self.keyboard_attn_kv(group_keyboard) |
|
|
|
|
|
B, L, HD = mouse_q.shape |
|
|
D = HD // self.heads_num |
|
|
q = mouse_q.view(B, L, self.heads_num, D) |
|
|
|
|
|
B, L, KHD = keyboard_kv.shape |
|
|
k, v = keyboard_kv.view(B, L, 2, self.heads_num, D).permute(2, 0, 1, 3, 4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q = self.key_attn_q_norm(q).to(v) |
|
|
k = self.key_attn_k_norm(k).to(v) |
|
|
S = th * tw |
|
|
assert S == 880 |
|
|
|
|
|
if use_rope_keyboard: |
|
|
B, TS, H, D = q.shape |
|
|
T_ = TS // S |
|
|
q = q.view(B, T_, S, H, D).transpose(1, 2).reshape(B * S, T_, H, D) |
|
|
q, k = apply_rotary_emb( |
|
|
q, k, freqs_cis, start_offset=start_frame, head_first=False |
|
|
) |
|
|
|
|
|
k1, k2, k3, k4 = k.shape |
|
|
k = k.expand(S, k2, k3, k4) |
|
|
v = v.expand(S, k2, k3, k4) |
|
|
|
|
|
if is_causal: |
|
|
if kv_cache_keyboard is None: |
|
|
assert q.shape[0] == k.shape[0] and q.shape[0] % 880 == 0 |
|
|
|
|
|
padded_length = math.ceil(q.shape[1] / 32) * 32 - q.shape[1] |
|
|
padded_q = torch.cat( |
|
|
[ |
|
|
q, |
|
|
torch.zeros( |
|
|
[q.shape[0], padded_length, q.shape[2], q.shape[3]], |
|
|
device=q.device, |
|
|
dtype=v.dtype, |
|
|
), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
padded_k = torch.cat( |
|
|
[ |
|
|
k, |
|
|
torch.zeros( |
|
|
[k.shape[0], padded_length, k.shape[2], k.shape[3]], |
|
|
device=k.device, |
|
|
dtype=v.dtype, |
|
|
), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
padded_v = torch.cat( |
|
|
[ |
|
|
v, |
|
|
torch.zeros( |
|
|
[v.shape[0], padded_length, v.shape[2], v.shape[3]], |
|
|
device=v.device, |
|
|
dtype=v.dtype, |
|
|
), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
attn = flex_attention( |
|
|
query=padded_q.transpose(2, 1), |
|
|
key=padded_k.transpose(2, 1), |
|
|
value=padded_v.transpose(2, 1), |
|
|
block_mask=block_mask_keyboard, |
|
|
)[:, :, :-padded_length].transpose(2, 1) |
|
|
else: |
|
|
current_start = start_frame |
|
|
current_end = current_start + k.shape[1] |
|
|
assert k.shape[1] == num_frame_per_block |
|
|
sink_size = 0 |
|
|
max_attention_size = self.local_attn_size |
|
|
sink_tokens = sink_size * 1 |
|
|
kv_cache_size = kv_cache_keyboard["k"].shape[1] |
|
|
num_new_tokens = k.shape[1] |
|
|
|
|
|
if ( |
|
|
current_end > kv_cache_keyboard["global_end_index"].item() |
|
|
) and ( |
|
|
num_new_tokens + kv_cache_keyboard["local_end_index"].item() |
|
|
> kv_cache_size |
|
|
): |
|
|
num_evicted_tokens = ( |
|
|
num_new_tokens |
|
|
+ kv_cache_keyboard["local_end_index"].item() |
|
|
- kv_cache_size |
|
|
) |
|
|
num_rolled_tokens = ( |
|
|
kv_cache_keyboard["local_end_index"].item() |
|
|
- num_evicted_tokens |
|
|
- sink_tokens |
|
|
) |
|
|
kv_cache_keyboard["k"][ |
|
|
:, sink_tokens : sink_tokens + num_rolled_tokens |
|
|
] = kv_cache_keyboard["k"][ |
|
|
:, |
|
|
sink_tokens + num_evicted_tokens : sink_tokens |
|
|
+ num_evicted_tokens |
|
|
+ num_rolled_tokens, |
|
|
].clone() |
|
|
kv_cache_keyboard["v"][ |
|
|
:, sink_tokens : sink_tokens + num_rolled_tokens |
|
|
] = kv_cache_keyboard["v"][ |
|
|
:, |
|
|
sink_tokens + num_evicted_tokens : sink_tokens |
|
|
+ num_evicted_tokens |
|
|
+ num_rolled_tokens, |
|
|
].clone() |
|
|
|
|
|
local_end_index = ( |
|
|
kv_cache_keyboard["local_end_index"].item() |
|
|
+ current_end |
|
|
- kv_cache_keyboard["global_end_index"].item() |
|
|
- num_evicted_tokens |
|
|
) |
|
|
local_start_index = local_end_index - num_new_tokens |
|
|
else: |
|
|
local_end_index = ( |
|
|
kv_cache_keyboard["local_end_index"].item() |
|
|
+ current_end |
|
|
- kv_cache_keyboard["global_end_index"].item() |
|
|
) |
|
|
local_start_index = local_end_index - num_new_tokens |
|
|
assert ( |
|
|
k.shape[0] == 880 |
|
|
) |
|
|
kv_cache_keyboard["k"][:, local_start_index:local_end_index] = ( |
|
|
k[:1] |
|
|
) |
|
|
kv_cache_keyboard["v"][:, local_start_index:local_end_index] = ( |
|
|
v[:1] |
|
|
) |
|
|
|
|
|
if FLASH_ATTN_3_AVAILABLE: |
|
|
attn, attn_prob = flash_attn.flash_attn_func( |
|
|
q, |
|
|
kv_cache_keyboard["k"][ |
|
|
:, |
|
|
max( |
|
|
0, local_end_index - max_attention_size |
|
|
) : local_end_index, |
|
|
].repeat(S, 1, 1, 1), |
|
|
kv_cache_keyboard["v"][ |
|
|
:, |
|
|
max( |
|
|
0, local_end_index - max_attention_size |
|
|
) : local_end_index, |
|
|
].repeat(S, 1, 1, 1), |
|
|
) |
|
|
else: |
|
|
attn = flash_attn_func( |
|
|
q, |
|
|
kv_cache_keyboard["k"][ |
|
|
:, |
|
|
max( |
|
|
0, local_end_index - max_attention_size |
|
|
) : local_end_index, |
|
|
].repeat(S, 1, 1, 1), |
|
|
kv_cache_keyboard["v"][ |
|
|
:, |
|
|
max( |
|
|
0, local_end_index - max_attention_size |
|
|
) : local_end_index, |
|
|
].repeat(S, 1, 1, 1), |
|
|
) |
|
|
|
|
|
kv_cache_keyboard["global_end_index"].fill_(current_end) |
|
|
kv_cache_keyboard["local_end_index"].fill_(local_end_index) |
|
|
else: |
|
|
attn = flash_attn_func( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
causal=False, |
|
|
) |
|
|
attn = rearrange(attn, "(B S) T H D -> B (T S) (H D)", S=S) |
|
|
else: |
|
|
if is_causal: |
|
|
if kv_cache_keyboard is None: |
|
|
padded_length = math.ceil(q.shape[1] / 32) * 32 - q.shape[1] |
|
|
padded_q = torch.cat( |
|
|
[ |
|
|
q, |
|
|
torch.zeros( |
|
|
[q.shape[0], padded_length, q.shape[2], q.shape[3]], |
|
|
device=q.device, |
|
|
dtype=v.dtype, |
|
|
), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
padded_k = torch.cat( |
|
|
[ |
|
|
k, |
|
|
torch.zeros( |
|
|
[k.shape[0], padded_length, k.shape[2], k.shape[3]], |
|
|
device=k.device, |
|
|
dtype=v.dtype, |
|
|
), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
padded_v = torch.cat( |
|
|
[ |
|
|
v, |
|
|
torch.zeros( |
|
|
[v.shape[0], padded_length, v.shape[2], v.shape[3]], |
|
|
device=v.device, |
|
|
dtype=v.dtype, |
|
|
), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
attn = flex_attention( |
|
|
query=padded_q.transpose(2, 1), |
|
|
key=padded_k.transpose(2, 1), |
|
|
value=padded_v.transpose(2, 1), |
|
|
block_mask=block_mask_keyboard, |
|
|
)[:, :, :-padded_length].transpose(2, 1) |
|
|
else: |
|
|
current_start = start_frame |
|
|
current_end = current_start + k.shape[1] |
|
|
assert k.shape[1] == num_frame_per_block |
|
|
sink_size = 0 |
|
|
local_attn_size = self.local_attn_size |
|
|
max_attention_size = self.local_attn_size |
|
|
sink_tokens = sink_size * 1 |
|
|
kv_cache_size = kv_cache_keyboard["k"].shape[1] |
|
|
num_new_tokens = k.shape[1] |
|
|
|
|
|
if ( |
|
|
current_end > kv_cache_keyboard["global_end_index"].item() |
|
|
) and ( |
|
|
num_new_tokens + kv_cache_keyboard["local_end_index"].item() |
|
|
> kv_cache_size |
|
|
): |
|
|
num_evicted_tokens = ( |
|
|
num_new_tokens |
|
|
+ kv_cache_keyboard["local_end_index"].item() |
|
|
- kv_cache_size |
|
|
) |
|
|
num_rolled_tokens = ( |
|
|
kv_cache_keyboard["local_end_index"].item() |
|
|
- num_evicted_tokens |
|
|
- sink_tokens |
|
|
) |
|
|
kv_cache_keyboard["k"][ |
|
|
:, sink_tokens : sink_tokens + num_rolled_tokens |
|
|
] = kv_cache_keyboard["k"][ |
|
|
:, |
|
|
sink_tokens + num_evicted_tokens : sink_tokens |
|
|
+ num_evicted_tokens |
|
|
+ num_rolled_tokens, |
|
|
].clone() |
|
|
kv_cache_keyboard["v"][ |
|
|
:, sink_tokens : sink_tokens + num_rolled_tokens |
|
|
] = kv_cache_keyboard["v"][ |
|
|
:, |
|
|
sink_tokens + num_evicted_tokens : sink_tokens |
|
|
+ num_evicted_tokens |
|
|
+ num_rolled_tokens, |
|
|
].clone() |
|
|
|
|
|
local_end_index = ( |
|
|
kv_cache_keyboard["local_end_index"].item() |
|
|
+ current_end |
|
|
- kv_cache_keyboard["global_end_index"].item() |
|
|
- num_evicted_tokens |
|
|
) |
|
|
local_start_index = local_end_index - num_new_tokens |
|
|
|
|
|
else: |
|
|
local_end_index = ( |
|
|
kv_cache_keyboard["local_end_index"].item() |
|
|
+ current_end |
|
|
- kv_cache_keyboard["global_end_index"].item() |
|
|
) |
|
|
local_start_index = local_end_index - num_new_tokens |
|
|
kv_cache_keyboard["k"][:, local_start_index:local_end_index] = k |
|
|
kv_cache_keyboard["v"][:, local_start_index:local_end_index] = v |
|
|
attn = flash_attn_func( |
|
|
q, |
|
|
kv_cache_keyboard["k"][ |
|
|
:, |
|
|
max( |
|
|
0, local_end_index - max_attention_size |
|
|
) : local_end_index, |
|
|
], |
|
|
kv_cache_keyboard["v"][ |
|
|
:, |
|
|
max( |
|
|
0, local_end_index - max_attention_size |
|
|
) : local_end_index, |
|
|
], |
|
|
|
|
|
) |
|
|
kv_cache_keyboard["global_end_index"].fill_(current_end) |
|
|
kv_cache_keyboard["local_end_index"].fill_(local_end_index) |
|
|
else: |
|
|
attn = flash_attn_func( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
|
|
|
) |
|
|
attn = rearrange(attn, "B L H D -> B L (H D)") |
|
|
attn = self.proj_keyboard(attn) |
|
|
hidden_states = hidden_states + attn |
|
|
return hidden_states |
|
|
|