Spaces:
Configuration error
Configuration error
| import torch | |
| # import flashinfer | |
| import matplotlib.pyplot as plt | |
| # from sparse_sageattn import sparse_sageattn | |
| from einops import rearrange, repeat | |
| from sageattention import sageattn | |
| from spas_sage_attn import block_sparse_sage2_attn_cuda | |
| def get_cuda_arch_versions(): | |
| cuda_archs = [] | |
| for i in range(torch.cuda.device_count()): | |
| major, minor = torch.cuda.get_device_capability(i) | |
| cuda_archs.append(f"sm{major}{minor}") | |
| return cuda_archs | |
| from spas_sage_attn import block_sparse_sage2_attn_cuda | |
| def sparge_mask_convert(mask: torch.Tensor, block_size: int = 128, arch="sm") -> torch.Tensor: | |
| assert block_size in [128, 64], "Radial Attention only supports block size of 128 or 64" | |
| assert mask.shape[0] == mask.shape[1], "Input mask must be square." | |
| if block_size == 128: | |
| if arch == "sm90": | |
| new_mask = torch.repeat_interleave(mask, 2, dim=0) | |
| else: | |
| new_mask = torch.repeat_interleave(mask, 2, dim=1) | |
| elif block_size == 64: | |
| if arch == "sm90": | |
| num_row, num_col = mask.shape | |
| reshaped_mask = mask.view(num_row, num_col // 2, 2) | |
| new_mask = torch.max(reshaped_mask, dim=2).values | |
| else: | |
| num_row, num_col = mask.shape | |
| reshaped_mask = mask.view(num_row // 2, 2, num_col) | |
| new_mask = torch.max(reshaped_mask, dim=1).values | |
| return new_mask | |
| def get_indptr_from_mask(mask, query): | |
| # query shows the device of the indptr | |
| # indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension, | |
| # shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension. | |
| # The first element is always 0, and the last element is the number of blocks in the row dimension. | |
| # The rest of the elements are the number of blocks in each row. | |
| # the mask is already a block sparse mask | |
| indptr = torch.zeros(mask.shape[0] + 1, device=query.device, dtype=torch.int32) | |
| indptr[0] = 0 | |
| row_counts = mask.sum(dim=1).flatten() # Ensure 1D output [num_blocks_row] | |
| indptr[1:] = torch.cumsum(row_counts, dim=0) | |
| return indptr | |
| def get_indices_from_mask(mask, query): | |
| # indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension, | |
| # shape `(nnz,),` where `nnz` is the number of non-zero blocks. | |
| # The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension. | |
| nonzero_indices = torch.nonzero(mask) | |
| indices = nonzero_indices[:, 1].to(dtype=torch.int32, device=query.device) | |
| return indices | |
| def shrinkMaskStrict(mask, block_size=128): | |
| seqlen = mask.shape[0] | |
| block_num = seqlen // block_size | |
| mask = mask[:block_num * block_size, :block_num * block_size].view(block_num, block_size, block_num, block_size) | |
| col_densities = mask.sum(dim = 1) / block_size | |
| # we want the minimum non-zero column density in the block | |
| non_zero_densities = col_densities > 0 | |
| high_density_cols = col_densities > 1/3 | |
| frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9) | |
| block_mask = frac_high_density_cols > 0.6 | |
| block_mask[0:0] = True | |
| block_mask[-1:-1] = True | |
| return block_mask | |
| def pad_qkv(input_tensor, block_size=128): | |
| """ | |
| Pad the input tensor to be a multiple of the block size. | |
| input shape: (seqlen, num_heads, hidden_dim) | |
| """ | |
| seqlen, num_heads, hidden_dim = input_tensor.shape | |
| # Calculate the necessary padding | |
| padding_length = (block_size - (seqlen % block_size)) % block_size | |
| # Create a padded tensor with zeros | |
| padded_tensor = torch.zeros((seqlen + padding_length, num_heads, hidden_dim), device=input_tensor.device, dtype=input_tensor.dtype) | |
| # Copy the original tensor into the padded tensor | |
| padded_tensor[:seqlen, :, :] = input_tensor | |
| return padded_tensor | |
| def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query): | |
| assert(sparse_type in ["radial"]) | |
| dist = abs(i - j) | |
| group = dist.bit_length() | |
| threshold = 128 # hardcoded threshold for now, which is equal to block-size | |
| decay_length = 2 ** token_per_frame.bit_length() / 2 ** group | |
| if decay_length >= threshold: | |
| return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) | |
| split_factor = int(threshold / decay_length) | |
| modular = dist % split_factor | |
| if modular == 0: | |
| return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) | |
| else: | |
| return torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) | |
| def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None): | |
| assert(sparse_type in ["radial"]) | |
| dist = abs(i - j) | |
| if model_type == "wan": | |
| if dist < 1: | |
| return token_per_frame | |
| if dist == 1: | |
| return token_per_frame // 2 | |
| elif model_type == "hunyuan": | |
| if dist <= 1: | |
| return token_per_frame | |
| else: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| group = dist.bit_length() | |
| decay_length = 2 ** token_per_frame.bit_length() / 2 ** group * decay_factor | |
| threshold = block_size | |
| if decay_length >= threshold: | |
| return decay_length | |
| else: | |
| return threshold | |
| def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None): | |
| """ | |
| A more memory friendly version, we generate the attention mask of each frame pair at a time, | |
| shrinks it, and stores it into the final result | |
| """ | |
| final_log_mask = torch.zeros((s // block_size, s // block_size), device=query.device, dtype=torch.bool) | |
| token_per_frame = video_token_num // num_frame | |
| video_text_border = video_token_num // block_size | |
| col_indices = torch.arange(0, token_per_frame, device=query.device).view(1, -1) | |
| row_indices = torch.arange(0, token_per_frame, device=query.device).view(-1, 1) | |
| final_log_mask[video_text_border:] = True | |
| final_log_mask[:, video_text_border:] = True | |
| for i in range(num_frame): | |
| for j in range(num_frame): | |
| local_mask = torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) | |
| if j == 0 and model_type == "wan": # this is attention sink | |
| local_mask = torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) | |
| else: | |
| window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type) | |
| local_mask = torch.abs(col_indices - row_indices) <= window_width | |
| split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query) | |
| local_mask = torch.logical_and(local_mask, split_mask) | |
| remainder_row = (i * token_per_frame) % block_size | |
| remainder_col = (j * token_per_frame) % block_size | |
| # get the padded size | |
| all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size | |
| all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size | |
| padded_local_mask = torch.zeros((all_length_row, all_length_col), device=query.device, dtype=torch.bool) | |
| padded_local_mask[remainder_row:remainder_row + token_per_frame, remainder_col:remainder_col + token_per_frame] = local_mask | |
| # shrink the mask | |
| block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size) | |
| # set the block mask to the final log mask | |
| block_row_start = (i * token_per_frame) // block_size | |
| block_col_start = (j * token_per_frame) // block_size | |
| block_row_end = block_row_start + block_mask.shape[0] | |
| block_col_end = block_col_start + block_mask.shape[1] | |
| final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or( | |
| final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask) | |
| print(f"mask sparsity: {1 - final_log_mask.sum() / final_log_mask.numel()}") | |
| return final_log_mask | |
| class MaskMap: | |
| _log_mask = None | |
| def __init__(self, video_token_num=25440, num_frame=16): | |
| self.video_token_num = video_token_num | |
| self.num_frame = num_frame | |
| def queryLogMask(self, query, sparse_type, block_size=128, decay_factor=0.5, model_type=None): | |
| if MaskMap._log_mask is None: | |
| MaskMap._log_mask = torch.ones((query.shape[0] // block_size, query.shape[0] // block_size), device=query.device, dtype=torch.bool) | |
| MaskMap._log_mask = gen_log_mask_shrinked(query, query.shape[0], self.video_token_num, self.num_frame, sparse_type=sparse_type, decay_factor=decay_factor, model_type=model_type, block_size=block_size) | |
| return MaskMap._log_mask | |
| def SpargeSageAttnBackend(query, key, value, mask_map=None, video_mask=None, pre_defined_mask=None, block_size=128): | |
| if video_mask.all(): | |
| # dense case | |
| kv_border = pre_defined_mask[0].sum() if pre_defined_mask is not None else key.shape[0] | |
| output_video = sageattn( | |
| query[:mask_map.video_token_num, :, :].unsqueeze(0), | |
| key[:kv_border, :, :].unsqueeze(0), | |
| value[:kv_border, :, :].unsqueeze(0), | |
| tensor_layout="NHD", | |
| )[0] | |
| if pre_defined_mask is not None: | |
| output_text = flashinfer.single_prefill_with_kv_cache( | |
| q=query[mask_map.video_token_num:, :, :], | |
| k=key[:pre_defined_mask[0].sum(), :, :], | |
| v=value[:pre_defined_mask[0].sum(), :, :], | |
| causal=False, | |
| return_lse=False, | |
| ) | |
| return torch.cat([output_video, output_text], dim=0) | |
| else: | |
| return output_video | |
| # sparse-sageattention only supports (b, h, s, d) layout, need rearrange first | |
| query_hnd = rearrange(query.unsqueeze(0), "b s h d -> b h s d") | |
| key_hnd = rearrange(key.unsqueeze(0), "b s h d -> b h s d") | |
| value_hnd = rearrange(value.unsqueeze(0), "b s h d -> b h s d") | |
| arch = get_cuda_arch_versions()[query.device.index] | |
| converted_mask = repeat(sparge_mask_convert(mask=video_mask, block_size=block_size, arch=arch), "s t -> b h s t", b=query_hnd.shape[0], h=query_hnd.shape[1]) | |
| converted_mask = converted_mask.to(torch.int8) | |
| if pre_defined_mask is None: | |
| # wan case | |
| output = block_sparse_sage2_attn_cuda( | |
| query_hnd[:, :, :mask_map.video_token_num, :], | |
| key_hnd[:, :, :mask_map.video_token_num, :], | |
| value_hnd[:, :, :mask_map.video_token_num, :], | |
| mask_id=converted_mask, | |
| tensor_layout="HND", | |
| ) | |
| # rearrange back to (s, h, d), we know that b = 1 | |
| output = rearrange(output, "b h s d -> s (b h) d", b=1) | |
| return output | |
| query_video = query_hnd[:, :, :mask_map.video_token_num, :] | |
| key_video = key_hnd | |
| value_video = value_hnd | |
| kv_border = (pre_defined_mask[0].sum() + 63) // 64 | |
| converted_mask[:, :, :, kv_border:] = False | |
| output_video = block_sparse_sage2_attn_cuda( | |
| query_video, | |
| key_video, | |
| value_video, | |
| mask_id=converted_mask[:, :, :mask_map.video_token_num // block_size, :].contiguous(), | |
| tensor_layout="HND", | |
| ) | |
| # rearrange back to (s, h, d), we know that b = 1 | |
| output_video = rearrange(output_video, "b h s d -> s (b h) d", b=1) | |
| # gt = sparse_sageattn( | |
| # query_video, | |
| # key_video, | |
| # value_video, | |
| # mask_id=None, | |
| # is_causal=False, | |
| # tensor_layout="HND", | |
| # )[0] | |
| # import pdb; pdb.set_trace() | |
| output_text = flashinfer.single_prefill_with_kv_cache( | |
| q=query[mask_map.video_token_num:, :, :], | |
| k=key[:pre_defined_mask[0].sum(), :, :], | |
| v=value[:pre_defined_mask[0].sum(), :, :], | |
| causal=False, | |
| return_lse=False, | |
| ) | |
| return torch.cat([output_video, output_text], dim=0) | |
| def FlashInferBackend(query, key, value, mask_map=None, pre_defined_mask=None, bsr_wrapper=None): | |
| if pre_defined_mask is not None: | |
| video_video_o, video_video_o_lse = bsr_wrapper.run( | |
| query[:mask_map.video_token_num, :, :], | |
| key[:mask_map.video_token_num, :, :], | |
| value[:mask_map.video_token_num, :, :], | |
| return_lse=True | |
| ) | |
| # perform non-causal flashinfer on the text tokens | |
| video_text_o, video_text_o_lse = flashinfer.single_prefill_with_kv_cache( | |
| q=query[:mask_map.video_token_num, :, :], | |
| k=key[mask_map.video_token_num:, :, :], | |
| v=value[mask_map.video_token_num:, :, :], | |
| causal=False, | |
| return_lse=True, | |
| custom_mask=pre_defined_mask[:mask_map.video_token_num, mask_map.video_token_num:] | |
| ) | |
| # merge the two results | |
| o_video, _ = flashinfer.merge_state(v_a=video_video_o, s_a=video_video_o_lse, v_b=video_text_o, s_b=video_text_o_lse) | |
| o_text = flashinfer.single_prefill_with_kv_cache( | |
| q=query[mask_map.video_token_num:, :, :], | |
| k=key, | |
| v=value, | |
| causal=False, | |
| return_lse=False, | |
| custom_mask=pre_defined_mask[mask_map.video_token_num:, :] | |
| ) | |
| return torch.cat([o_video, o_text], dim=0) | |
| else: | |
| o = bsr_wrapper.run( | |
| query[:mask_map.video_token_num, :, :], | |
| key[:mask_map.video_token_num, :, :], | |
| value[:mask_map.video_token_num, :, :] | |
| ) | |
| return o | |
| def RadialAttention(query, key, value, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_type=None, pre_defined_mask=None, use_sage_attention=False): | |
| orig_seqlen, num_head, hidden_dim = query.shape | |
| if sparsity_type == "dense": | |
| video_mask = torch.ones((mask_map.video_token_num // block_size, mask_map.video_token_num // block_size), device=query.device, dtype=torch.bool) | |
| else: | |
| video_mask = mask_map.queryLogMask(query, sparsity_type, block_size=block_size, decay_factor=decay_factor, model_type=model_type) if mask_map else None | |
| backend = "sparse_sageattn" if use_sage_attention else "flashinfer" | |
| if backend == "flashinfer": | |
| video_mask = video_mask[:mask_map.video_token_num // block_size, :mask_map.video_token_num // block_size] | |
| # perform block-sparse attention on the video tokens | |
| workspace_buffer = torch.empty(128 * 1024 * 1024, device=query.device, dtype=torch.uint8) | |
| bsr_wrapper = flashinfer.BlockSparseAttentionWrapper( | |
| workspace_buffer, | |
| backend="fa2", | |
| ) | |
| indptr = get_indptr_from_mask(video_mask, query) | |
| indices = get_indices_from_mask(video_mask, query) | |
| bsr_wrapper.plan( | |
| indptr=indptr, | |
| indices=indices, | |
| M=mask_map.video_token_num, | |
| N=mask_map.video_token_num, | |
| R=block_size, | |
| C=block_size, | |
| num_qo_heads=num_head, | |
| num_kv_heads=num_head, | |
| head_dim=hidden_dim, | |
| q_data_type=query.dtype, | |
| kv_data_type=key.dtype, | |
| o_data_type=query.dtype, | |
| ) | |
| return FlashInferBackend(query, key, value, mask_map, pre_defined_mask, bsr_wrapper) | |
| elif backend == "sparse_sageattn": | |
| return SpargeSageAttnBackend(query, key, value, mask_map, video_mask, pre_defined_mask, block_size=block_size) | |
| if __name__ == "__main__": | |
| query = torch.randn(1, 2, 4, 64).cuda() | |
| # mask = torch.tensor([ | |
| # [True, False, True, False], | |
| # [False, True, False, True], | |
| # [True, False, False, True], | |
| # [False, True, True, False] | |
| # ], dtype=torch.bool) | |
| # indices = get_indices_from_mask(mask, query) | |
| # indptr = get_indptr_from_mask(mask, query) | |
| # print("Indices: ", indices) | |
| # print("Indptr: ", indptr) | |
| video_token_num = 3840 * 30 | |
| num_frame = 30 | |
| token_per_frame = video_token_num / num_frame | |
| padded_video_token_num = ((video_token_num + 1) // 128 + 1) * 128 | |
| print("padded: ", padded_video_token_num) | |
| temporal_mask = gen_log_mask_shrinked(query, padded_video_token_num, video_token_num, num_frame, sparse_type="radial", decay_factor=1, model_type="hunyuan") | |
| plt.figure(figsize=(10, 8), dpi=500) | |
| plt.imshow(temporal_mask.cpu().numpy()[:, :], cmap='hot') | |
| plt.colorbar() | |
| plt.title("Temporal Mask") | |
| plt.savefig("temporal_mask.png", | |
| dpi=300, | |
| bbox_inches='tight', | |
| pad_inches=0.1) | |
| plt.close() | |
| # save the mask tensor | |
| torch.save(temporal_mask, "temporal_mask.pt") | |