from typing import Any, List, Tuple, Optional, Union, Dict from einops import rearrange from flash_attn import flash_attn_func import torch import torch.nn as nn import math from torch.nn.attention.flex_attention import flex_attention try: import flash_attn except: from flash_attn import flash_attn_func FLASH_ATTN_3_AVAILABLE = False DISABLE_COMPILE = False # get os env flex_attention = torch.compile( flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" ) import torch from typing import Union, Tuple, List def _to_tuple(x, dim=2): if isinstance(x, int): return (x,) * dim elif len(x) == dim: return x else: raise ValueError(f"Expected length {dim} or int, but got {x}") def get_meshgrid_nd(start, *args, dim=2): """ Get n-D meshgrid with start, stop and num. Args: start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in n-tuples. *args: See above. dim (int): Dimension of the meshgrid. Defaults to 2. Returns: grid (np.ndarray): [dim, ...] """ if len(args) == 0: # start is grid_size num = _to_tuple(start, dim=dim) start = (0,) * dim stop = num elif len(args) == 1: # start is start, args[0] is stop, step is 1 start = _to_tuple(start, dim=dim) stop = _to_tuple(args[0], dim=dim) num = [stop[i] - start[i] for i in range(dim)] elif len(args) == 2: # start is start, args[0] is stop, args[1] is num start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 else: raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) axis_grid = [] for i in range(dim): a, b, n = start[i], stop[i], num[i] g = torch.linspace(a, b, n + 1, dtype=torch.float32, device=torch.cuda.current_device())[:n] axis_grid.append(g) grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] grid = torch.stack(grid, dim=0) # [dim, W, H, D] return grid ################################################################################# # Rotary Positional Embedding Functions # ################################################################################# # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 def reshape_for_broadcast( freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False, ): """ Reshape frequency tensor for broadcasting it with another tensor. This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. Notes: When using FlashMHAModified, head_first should be False. When using Attention, head_first should be True. Args: freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. head_first (bool): head dimension first (except batch dim) or not. Returns: torch.Tensor: Reshaped frequency tensor. Raises: AssertionError: If the frequency tensor doesn't match the expected shape. AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. """ ndim = x.ndim assert 0 <= 1 < ndim if isinstance(freqs_cis, tuple): # freqs_cis: (cos, sin) in real space if head_first: assert freqs_cis[0].shape == ( x.shape[-2], x.shape[-1], ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" shape = [ d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape) ] else: # assert freqs_cis[0].shape == ( # x.shape[1], # x.shape[-1], # ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" # shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] shape = [1, freqs_cis[0].shape[0], 1, freqs_cis[0].shape[1]] return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) else: # freqs_cis: values in complex space if head_first: assert freqs_cis.shape == ( x.shape[-2], x.shape[-1], ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" shape = [ d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape) ] else: assert freqs_cis.shape == ( x.shape[1], x.shape[-1], ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def rotate_half(x): x_real, x_imag = ( x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) ) # [B, S, H, D//2] return torch.stack([-x_imag, x_real], dim=-1).flatten(3) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], head_first: bool = False, start_offset: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. Args: xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. head_first (bool): head dimension first (except batch dim) or not. Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ # print(freqs_cis[0].shape, xq.shape, xk.shape) xk_out = None assert isinstance(freqs_cis, tuple) if isinstance(freqs_cis, tuple): cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] cos, sin = cos.to(xq.device), sin.to(xq.device) # real * cos - imag * sin # imag * cos + real * sin xq_out = (xq.float() * cos[:, start_offset:start_offset + xq.shape[1], :, :] + rotate_half(xq.float()) * sin[:, start_offset:start_offset + xq.shape[1], :, :]).type_as(xq) xk_out = (xk.float() * cos[:, start_offset:start_offset + xk.shape[1], :, :] + rotate_half(xk.float()) * sin[:, start_offset:start_offset + xk.shape[1], :, :]).type_as(xk) else: # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) xq_ = torch.view_as_complex( xq.float().reshape(*xq.shape[:-1], -1, 2) ) # [B, S, H, D//2] freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to( xq.device ) # [S, D//2] --> [1, S, 1, D//2] # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) xk_ = torch.view_as_complex( xk.float().reshape(*xk.shape[:-1], -1, 2) ) # [B, S, H, D//2] xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) return xq_out, xk_out def get_nd_rotary_pos_embed( rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor: Union[float, List[float]] = 1.0, interpolation_factor: Union[float, List[float]] = 1.0, ): """ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. Args: rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. sum(rope_dim_list) should equal to head_dim of attention layer. start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. *args: See above. theta (float): Scaling factor for frequency computation. Defaults to 10000.0. use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real part and an imaginary part separately. theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. Returns: pos_embed (torch.Tensor): [HW, D/2] """ grid = get_meshgrid_nd( start, *args, dim=len(rope_dim_list) ) # [3, W, H, D] / [2, W, H] if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) assert len(theta_rescale_factor) == len( rope_dim_list ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): interpolation_factor = [interpolation_factor] * len(rope_dim_list) elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) assert len(interpolation_factor) == len( rope_dim_list ), "len(interpolation_factor) should equal to len(rope_dim_list)" # use 1/ndim of dimensions to encode grid_axis embs = [] for i in range(len(rope_dim_list)): emb = get_1d_rotary_pos_embed( rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real, theta_rescale_factor=theta_rescale_factor[i], interpolation_factor=interpolation_factor[i], ) # 2 x [WHD, rope_dim_list[i]] embs.append(emb) if use_real: cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) return cos, sin else: emb = torch.cat(embs, dim=1) # (WHD, D/2) return emb def get_1d_rotary_pos_embed( dim: int, pos: Union[torch.FloatTensor, int], theta: float = 10000.0, use_real: bool = False, theta_rescale_factor: float = 1.0, interpolation_factor: float = 1.0, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Precompute the frequency tensor for complex exponential (cis) with given dimensions. (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) This function calculates a frequency tensor with complex exponential using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (int): Dimension of the frequency tensor. pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. use_real (bool, optional): If True, return real part and imaginary part separately. Otherwise, return complex numbers. theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. Returns: freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] """ if isinstance(pos, int): pos = torch.arange(pos, device=torch.cuda.current_device()).float() # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature if theta_rescale_factor != 1.0: theta *= theta_rescale_factor ** (dim / (dim - 2)) freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2, device=torch.cuda.current_device())[: (dim // 2)].float() / dim) ) # [D/2] # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] if use_real: freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] return freqs_cos, freqs_sin else: freqs_cis = torch.polar( torch.ones_like(freqs), freqs ) # complex64 # [S, D/2] return freqs_cis class MatrixGameWanRMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): r""" Args: x(Tensor): Shape [B, L, C] """ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) class ActionModule(nn.Module): """ action module from https://arxiv.org/pdf/2501.08325 鼠标控制信号的输入是一个 L*D 的向量 键盘同样 """ def __init__( self, mouse_dim_in: int = 2, keyboard_dim_in: int = 6, hidden_size: int = 128, img_hidden_size: int = 1536, keyboard_hidden_dim: int = 1024, mouse_hidden_dim: int = 1024, vae_time_compression_ratio: int = 4, windows_size: int = 3, heads_num: int = 16, patch_size: list = [1, 2, 2], qk_norm: bool = True, qkv_bias: bool = False, rope_dim_list: list = [8, 28, 28], rope_theta=256, mouse_qk_dim_list=[8, 28, 28], enable_mouse=True, enable_keyboard=True, local_attn_size=6, blocks=[], ): device = None super().__init__() self.local_attn_size = local_attn_size self.enable_mouse = enable_mouse self.enable_keyboard = enable_keyboard self.rope_dim_list = rope_dim_list self.rope_theta = rope_theta if self.enable_keyboard: self.keyboard_embed = nn.Sequential( nn.Linear(keyboard_dim_in, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.mouse_qk_dim_list = mouse_qk_dim_list self.heads_num = heads_num if self.enable_mouse: c = mouse_hidden_dim self.mouse_mlp = torch.nn.Sequential( torch.nn.Linear( mouse_dim_in * vae_time_compression_ratio * windows_size + img_hidden_size, c, bias=True, ), torch.nn.GELU(approximate="tanh"), torch.nn.Linear(c, c), torch.nn.LayerNorm(c), ) head_dim = c // heads_num self.t_qkv = nn.Linear(c, c * 3, bias=qkv_bias) self.img_attn_q_norm = ( MatrixGameWanRMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity() ) self.img_attn_k_norm = ( MatrixGameWanRMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity() ) self.proj_mouse = nn.Linear(c, img_hidden_size, bias=qkv_bias) if self.enable_keyboard: head_dim_key = keyboard_hidden_dim // heads_num self.key_attn_q_norm = ( MatrixGameWanRMSNorm(head_dim_key, eps=1e-6) if qk_norm else nn.Identity() ) self.key_attn_k_norm = ( MatrixGameWanRMSNorm(head_dim_key, eps=1e-6) if qk_norm else nn.Identity() ) self.mouse_attn_q = nn.Linear( img_hidden_size, keyboard_hidden_dim, bias=qkv_bias ) self.keyboard_attn_kv = nn.Linear( hidden_size * windows_size * vae_time_compression_ratio, keyboard_hidden_dim * 2, bias=qkv_bias, ) self.proj_keyboard = nn.Linear( keyboard_hidden_dim, img_hidden_size, bias=qkv_bias ) self.vae_time_compression_ratio = vae_time_compression_ratio self.windows_size = windows_size self.patch_size = patch_size self.freqs_cos, self.freqs_sin = self.get_rotary_pos_embed( 7500, self.patch_size[1], self.patch_size[2], 64, self.mouse_qk_dim_list, start_offset=0, ) def patchify(self, x, patch_size): """ x : (N C T H W) """ pt, ph, pw = self.patch_size t, h, w = x.shape[2] // pt, x.shape[3] // ph, x.shape[4] // pw c = x.shape[1] x = x.reshape(shape=(x.shape[0], c, t, pt, h, ph, w, pw)) x = torch.einsum("nctohpwq->nthwcopq", x) x = x.reshape(shape=(x.shape[0], t * h * w, c * pt * ph * pw)) return x def unpatchify(self, x, t, h, w, patch_size): """ x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ c = x.shape[2] // patch_size # self.unpatchify_channels pt, ph, pw = self.patch_size assert t * h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) x = torch.einsum("nthwcopq->nctohpwq", x) imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) return imgs def get_rotary_pos_embed( self, video_length, height, width, head_dim, rope_dim_list=None, start_offset=0 ): target_ndim = 3 ndim = 5 - 2 latents_size = [video_length + start_offset, height, width] if isinstance(self.patch_size, int): assert all(s % self.patch_size == 0 for s in latents_size), ( f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.patch_size}), " f"but got {latents_size}." ) rope_sizes = [s // self.patch_size for s in latents_size] elif isinstance(self.patch_size, list): assert all( s % self.patch_size[idx] == 0 for idx, s in enumerate(latents_size) ), ( f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.patch_size}), " f"but got {latents_size}." ) rope_sizes = [ s // self.patch_size[idx] for idx, s in enumerate(latents_size) ] if len(rope_sizes) != target_ndim: rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis if rope_dim_list is None: rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] assert ( sum(rope_dim_list) == head_dim ), "sum(rope_dim_list) should equal to head_dim of attention layer" freqs_cos, freqs_sin = get_nd_rotary_pos_embed( rope_dim_list, rope_sizes, theta=self.rope_theta, use_real=True, theta_rescale_factor=1, ) return freqs_cos[ -video_length * rope_sizes[1] * rope_sizes[2] // self.patch_size[0] : ], freqs_sin[ -video_length * rope_sizes[1] * rope_sizes[2] // self.patch_size[0] : ] def forward( self, x, tt, th, tw, mouse_condition=None, keyboard_condition=None, block_mask_mouse=None, block_mask_keyboard=None, is_causal=False, kv_cache_mouse=None, kv_cache_keyboard=None, start_frame=0, use_rope_keyboard=True, num_frame_per_block=3, ): """ hidden_states: B, tt*th*tw, C mouse_condition: B, N_frames, C1 keyboard_condition: B, N_frames, C2 """ assert use_rope_keyboard == True B, N_frames, C = keyboard_condition.shape assert tt * th * tw == x.shape[1] assert ( (N_frames - 1) + self.vae_time_compression_ratio ) % self.vae_time_compression_ratio == 0 N_feats = int((N_frames - 1) / self.vae_time_compression_ratio) + 1 # Defined freqs_cis early so it's available for both mouse and keyboard freqs_cis = (self.freqs_cos, self.freqs_sin) assert ( N_feats == tt and ((is_causal and kv_cache_mouse == None) or not is_causal) ) or ( (N_frames - 1) // self.vae_time_compression_ratio + 1 == start_frame + num_frame_per_block and is_causal ) if self.enable_mouse and mouse_condition is not None: hidden_states = rearrange( x, "B (T S) C -> (B S) T C", T=tt, S=th * tw ) # 65*272*480 -> 17*(272//16)*(480//16) -> 8670 B, N_frames, C = mouse_condition.shape else: hidden_states = x # padding pad_t = self.vae_time_compression_ratio * self.windows_size if self.enable_mouse and mouse_condition is not None: pad = mouse_condition[:, 0:1, :].expand(-1, pad_t, -1) mouse_condition = torch.cat([pad, mouse_condition], dim=1) if is_causal and kv_cache_mouse is not None: mouse_condition = mouse_condition[ :, self.vae_time_compression_ratio * (N_feats - num_frame_per_block - self.windows_size) + pad_t :, :, ] group_mouse = [ mouse_condition[ :, self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, :, ] for i in range(num_frame_per_block) ] else: group_mouse = [ mouse_condition[ :, self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, :, ] for i in range(N_feats) ] group_mouse = torch.stack(group_mouse, dim=1) S = th * tw group_mouse = group_mouse.unsqueeze(-1).expand( B, num_frame_per_block, pad_t, C, S ) group_mouse = group_mouse.permute(0, 4, 1, 2, 3).reshape( B * S, num_frame_per_block, pad_t * C ) group_mouse = torch.cat([hidden_states, group_mouse], dim=-1) group_mouse = self.mouse_mlp(group_mouse) # qkv mouse_qkv = self.t_qkv(group_mouse) q, k, v = rearrange( mouse_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num ) # BHW F H C q = self.img_attn_q_norm(q).to(v) k = self.img_attn_k_norm(k).to(v) # rope embd # freqs_cis = (self.freqs_cos, self.freqs_sin) q, k = apply_rotary_emb( q, k, freqs_cis, start_offset=start_frame, head_first=False ) ## TODO: adding cache here if is_causal: if kv_cache_mouse is None: assert ( q.shape[0] == k.shape[0] and q.shape[0] % 880 == 0 ) # == 880, f"{q.shape[0]},{k.shape[0]}" padded_length = math.ceil(q.shape[1] / 32) * 32 - q.shape[1] padded_q = torch.cat( [ q, torch.zeros( [q.shape[0], padded_length, q.shape[2], q.shape[3]], device=q.device, dtype=v.dtype, ), ], dim=1, ) padded_k = torch.cat( [ k, torch.zeros( [k.shape[0], padded_length, k.shape[2], k.shape[3]], device=k.device, dtype=v.dtype, ), ], dim=1, ) padded_v = torch.cat( [ v, torch.zeros( [v.shape[0], padded_length, v.shape[2], v.shape[3]], device=v.device, dtype=v.dtype, ), ], dim=1, ) attn = flex_attention( query=padded_q.transpose(2, 1), # after: B, HW, F, C key=padded_k.transpose(2, 1), value=padded_v.transpose(2, 1), block_mask=block_mask_mouse, )[:, :, :-padded_length].transpose(2, 1) else: current_start = start_frame current_end = current_start + q.shape[1] assert q.shape[1] == num_frame_per_block sink_size = 0 max_attention_size = self.local_attn_size sink_tokens = sink_size * 1 kv_cache_size = kv_cache_mouse["k"].shape[1] num_new_tokens = q.shape[1] if (current_end > kv_cache_mouse["global_end_index"].item()) and ( num_new_tokens + kv_cache_mouse["local_end_index"].item() > kv_cache_size ): num_evicted_tokens = ( num_new_tokens + kv_cache_mouse["local_end_index"].item() - kv_cache_size ) num_rolled_tokens = ( kv_cache_mouse["local_end_index"].item() - num_evicted_tokens - sink_tokens ) kv_cache_mouse["k"][ :, sink_tokens : sink_tokens + num_rolled_tokens ] = kv_cache_mouse["k"][ :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens, ].clone() kv_cache_mouse["v"][ :, sink_tokens : sink_tokens + num_rolled_tokens ] = kv_cache_mouse["v"][ :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens, ].clone() # Insert the new keys/values at the end local_end_index = ( kv_cache_mouse["local_end_index"].item() + current_end - kv_cache_mouse["global_end_index"].item() - num_evicted_tokens ) local_start_index = local_end_index - num_new_tokens else: local_end_index = ( kv_cache_mouse["local_end_index"].item() + current_end - kv_cache_mouse["global_end_index"].item() ) local_start_index = local_end_index - num_new_tokens kv_cache_mouse["k"][:, local_start_index:local_end_index] = k kv_cache_mouse["v"][:, local_start_index:local_end_index] = v if FLASH_ATTN_3_AVAILABLE: attn, attn_prob = flash_attn.flash_attn_func( q, kv_cache_mouse["k"][ :, max( 0, local_end_index - max_attention_size ) : local_end_index, ], kv_cache_mouse["v"][ :, max( 0, local_end_index - max_attention_size ) : local_end_index, ], ) else: attn = flash_attn_func( q, kv_cache_mouse["k"][ :, max( 0, local_end_index - max_attention_size ) : local_end_index, ], kv_cache_mouse["v"][ :, max( 0, local_end_index - max_attention_size ) : local_end_index, ], ) kv_cache_mouse["global_end_index"].fill_(current_end) kv_cache_mouse["local_end_index"].fill_(local_end_index) else: attn = flash_attn_func( q, # 880, f, 16, 64 k, # 880, f, 16, 64 v, # 880, f, 16, 64 ) # Compute cu_squlens and max_seqlen for flash attention # qk norm attn = rearrange(attn, "(b S) T h d -> b (T S) (h d)", b=B) hidden_states = rearrange(x, "(B S) T C -> B (T S) C", B=B) attn = self.proj_mouse(attn) hidden_states = hidden_states + attn if self.enable_keyboard and keyboard_condition is not None: pad = keyboard_condition[:, 0:1, :].expand(-1, pad_t, -1) keyboard_condition = torch.cat([pad, keyboard_condition], dim=1) if is_causal and kv_cache_keyboard is not None: keyboard_condition = keyboard_condition[ :, self.vae_time_compression_ratio * (N_feats - num_frame_per_block - self.windows_size) + pad_t :, :, ] # keyboard_condition[:, self.vae_time_compression_ratio*(start_frame - self.windows_size) + pad_t:start_frame * self.vae_time_compression_ratio + pad_t,:] keyboard_condition = self.keyboard_embed(keyboard_condition) group_keyboard = [ keyboard_condition[ :, self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, :, ] for i in range(num_frame_per_block) ] else: keyboard_condition = self.keyboard_embed(keyboard_condition) group_keyboard = [ keyboard_condition[ :, self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, :, ] for i in range(N_feats) ] group_keyboard = torch.stack(group_keyboard, dim=1) # B F RW C group_keyboard = group_keyboard.reshape( shape=(group_keyboard.shape[0], group_keyboard.shape[1], -1) ) # apply cross attn mouse_q = self.mouse_attn_q(hidden_states) keyboard_kv = self.keyboard_attn_kv(group_keyboard) B, L, HD = mouse_q.shape D = HD // self.heads_num q = mouse_q.view(B, L, self.heads_num, D) B, L, KHD = keyboard_kv.shape k, v = keyboard_kv.view(B, L, 2, self.heads_num, D).permute(2, 0, 1, 3, 4) # Compute cu_squlens and max_seqlen for flash attention # qk norm q = self.key_attn_q_norm(q).to(v) k = self.key_attn_k_norm(k).to(v) S = th * tw assert S == 880 # position embed if use_rope_keyboard: B, TS, H, D = q.shape T_ = TS // S q = q.view(B, T_, S, H, D).transpose(1, 2).reshape(B * S, T_, H, D) q, k = apply_rotary_emb( q, k, freqs_cis, start_offset=start_frame, head_first=False ) k1, k2, k3, k4 = k.shape k = k.expand(S, k2, k3, k4) v = v.expand(S, k2, k3, k4) if is_causal: if kv_cache_keyboard is None: assert q.shape[0] == k.shape[0] and q.shape[0] % 880 == 0 padded_length = math.ceil(q.shape[1] / 32) * 32 - q.shape[1] padded_q = torch.cat( [ q, torch.zeros( [q.shape[0], padded_length, q.shape[2], q.shape[3]], device=q.device, dtype=v.dtype, ), ], dim=1, ) padded_k = torch.cat( [ k, torch.zeros( [k.shape[0], padded_length, k.shape[2], k.shape[3]], device=k.device, dtype=v.dtype, ), ], dim=1, ) padded_v = torch.cat( [ v, torch.zeros( [v.shape[0], padded_length, v.shape[2], v.shape[3]], device=v.device, dtype=v.dtype, ), ], dim=1, ) attn = flex_attention( query=padded_q.transpose(2, 1), # after: B, HW, F, C key=padded_k.transpose(2, 1), value=padded_v.transpose(2, 1), block_mask=block_mask_keyboard, )[:, :, :-padded_length].transpose(2, 1) else: current_start = start_frame current_end = current_start + k.shape[1] assert k.shape[1] == num_frame_per_block sink_size = 0 max_attention_size = self.local_attn_size sink_tokens = sink_size * 1 kv_cache_size = kv_cache_keyboard["k"].shape[1] num_new_tokens = k.shape[1] if ( current_end > kv_cache_keyboard["global_end_index"].item() ) and ( num_new_tokens + kv_cache_keyboard["local_end_index"].item() > kv_cache_size ): num_evicted_tokens = ( num_new_tokens + kv_cache_keyboard["local_end_index"].item() - kv_cache_size ) num_rolled_tokens = ( kv_cache_keyboard["local_end_index"].item() - num_evicted_tokens - sink_tokens ) kv_cache_keyboard["k"][ :, sink_tokens : sink_tokens + num_rolled_tokens ] = kv_cache_keyboard["k"][ :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens, ].clone() kv_cache_keyboard["v"][ :, sink_tokens : sink_tokens + num_rolled_tokens ] = kv_cache_keyboard["v"][ :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens, ].clone() # Insert the new keys/values at the end local_end_index = ( kv_cache_keyboard["local_end_index"].item() + current_end - kv_cache_keyboard["global_end_index"].item() - num_evicted_tokens ) local_start_index = local_end_index - num_new_tokens else: local_end_index = ( kv_cache_keyboard["local_end_index"].item() + current_end - kv_cache_keyboard["global_end_index"].item() ) local_start_index = local_end_index - num_new_tokens assert ( k.shape[0] == 880 ) # BS == 1 or the cache should not be saved/ load method should be modified kv_cache_keyboard["k"][:, local_start_index:local_end_index] = ( k[:1] ) kv_cache_keyboard["v"][:, local_start_index:local_end_index] = ( v[:1] ) if FLASH_ATTN_3_AVAILABLE: attn, attn_prob = flash_attn.flash_attn_func( q, kv_cache_keyboard["k"][ :, max( 0, local_end_index - max_attention_size ) : local_end_index, ].repeat(S, 1, 1, 1), kv_cache_keyboard["v"][ :, max( 0, local_end_index - max_attention_size ) : local_end_index, ].repeat(S, 1, 1, 1), ) else: attn = flash_attn_func( q, kv_cache_keyboard["k"][ :, max( 0, local_end_index - max_attention_size ) : local_end_index, ].repeat(S, 1, 1, 1), kv_cache_keyboard["v"][ :, max( 0, local_end_index - max_attention_size ) : local_end_index, ].repeat(S, 1, 1, 1), ) kv_cache_keyboard["global_end_index"].fill_(current_end) kv_cache_keyboard["local_end_index"].fill_(local_end_index) else: attn = flash_attn_func( q, # 1, f*880, 16, 64 k, # 1, f, 16, 64 v, # 1, f, 16, 64 causal=False, ) attn = rearrange(attn, "(B S) T H D -> B (T S) (H D)", S=S) else: if is_causal: if kv_cache_keyboard is None: padded_length = math.ceil(q.shape[1] / 32) * 32 - q.shape[1] padded_q = torch.cat( [ q, torch.zeros( [q.shape[0], padded_length, q.shape[2], q.shape[3]], device=q.device, dtype=v.dtype, ), ], dim=1, ) padded_k = torch.cat( [ k, torch.zeros( [k.shape[0], padded_length, k.shape[2], k.shape[3]], device=k.device, dtype=v.dtype, ), ], dim=1, ) padded_v = torch.cat( [ v, torch.zeros( [v.shape[0], padded_length, v.shape[2], v.shape[3]], device=v.device, dtype=v.dtype, ), ], dim=1, ) attn = flex_attention( query=padded_q.transpose(2, 1), # after: B, HW, F, C key=padded_k.transpose(2, 1), value=padded_v.transpose(2, 1), block_mask=block_mask_keyboard, )[:, :, :-padded_length].transpose(2, 1) else: current_start = start_frame current_end = current_start + k.shape[1] assert k.shape[1] == num_frame_per_block sink_size = 0 local_attn_size = self.local_attn_size max_attention_size = self.local_attn_size sink_tokens = sink_size * 1 kv_cache_size = kv_cache_keyboard["k"].shape[1] num_new_tokens = k.shape[1] if ( current_end > kv_cache_keyboard["global_end_index"].item() ) and ( num_new_tokens + kv_cache_keyboard["local_end_index"].item() > kv_cache_size ): num_evicted_tokens = ( num_new_tokens + kv_cache_keyboard["local_end_index"].item() - kv_cache_size ) num_rolled_tokens = ( kv_cache_keyboard["local_end_index"].item() - num_evicted_tokens - sink_tokens ) kv_cache_keyboard["k"][ :, sink_tokens : sink_tokens + num_rolled_tokens ] = kv_cache_keyboard["k"][ :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens, ].clone() kv_cache_keyboard["v"][ :, sink_tokens : sink_tokens + num_rolled_tokens ] = kv_cache_keyboard["v"][ :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens, ].clone() # Insert the new keys/values at the end local_end_index = ( kv_cache_keyboard["local_end_index"].item() + current_end - kv_cache_keyboard["global_end_index"].item() - num_evicted_tokens ) local_start_index = local_end_index - num_new_tokens else: local_end_index = ( kv_cache_keyboard["local_end_index"].item() + current_end - kv_cache_keyboard["global_end_index"].item() ) local_start_index = local_end_index - num_new_tokens kv_cache_keyboard["k"][:, local_start_index:local_end_index] = k kv_cache_keyboard["v"][:, local_start_index:local_end_index] = v attn = flash_attn_func( q, kv_cache_keyboard["k"][ :, max( 0, local_end_index - max_attention_size ) : local_end_index, ], kv_cache_keyboard["v"][ :, max( 0, local_end_index - max_attention_size ) : local_end_index, ], # causal=is_causal ) kv_cache_keyboard["global_end_index"].fill_(current_end) kv_cache_keyboard["local_end_index"].fill_(local_end_index) else: attn = flash_attn_func( q, # 1, f*880, 16, 64 k, # 1, f, 16, 64 v, # 1, f, 16, 64 # causal=is_causal, ) attn = rearrange(attn, "B L H D -> B L (H D)") attn = self.proj_keyboard(attn) hidden_states = hidden_states + attn return hidden_states