Spaces:
Configuration error
Configuration error
File size: 2,202 Bytes
2b67076 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from .attn_mask import RadialAttention, MaskMap
def fill_radial_cache(radial_cache, nb_layers, lat_t, lat_h, lat_w):
MaskMap._log_mask = None
for i in range(nb_layers):
radial_cache[i] = WanSparseAttnProcessor2_0(i, lat_t, lat_h, lat_w)
class WanSparseAttnProcessor2_0:
mask_map = None
dense_timestep = 0
dense_block = 0
decay_factor = 1.0
sparse_type = "radial" # default to radial attention, can be changed to "dense" for dense attention
use_sage_attention = True
def __init__(self, layer_idx, lat_t, lat_h, lat_w):
self.layer_idx = layer_idx
self.mask_map = MaskMap(video_token_num=lat_t * lat_h * lat_w // 4 , num_frame=lat_t)
def __call__(
self,
qkv_list,
timestep_no = 0,
) -> torch.Tensor:
query, key, value = qkv_list
batch_size = query.shape[0]
# transform (batch_size, seq_len, num_heads, head_dim) to (seq_len * batch_size, num_heads, head_dim)
query = rearrange(query, "b s h d -> (b s) h d")
key = rearrange(key, "b s h d -> (b s) h d")
value = rearrange(value, "b s h d -> (b s) h d")
if timestep_no < self.dense_timestep or self.layer_idx < self.dense_block or self.sparse_type == "dense":
hidden_states = RadialAttention(
query=query, key=key, value=value, mask_map=self.mask_map, sparsity_type="dense", block_size=128, decay_factor=self.decay_factor, model_type="wan", pre_defined_mask=None, use_sage_attention=self.use_sage_attention
)
else:
# apply radial attention
hidden_states = RadialAttention(
query=query, key=key, value=value, mask_map=self.mask_map, sparsity_type="radial", block_size=128, decay_factor=self.decay_factor, model_type="wan", pre_defined_mask=None, use_sage_attention=self.use_sage_attention
)
# transform back to (batch_size, num_heads, seq_len, head_dim)
hidden_states = rearrange(hidden_states, "(b s) h d -> b s h d", b=batch_size)
return hidden_states
|