Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| # Copyright (c) Microsoft Corporation. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # DeepSpeed Team | |
| # This file is modified from https://github.com/feifeibear/long-context-attention | |
| # Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719 | |
| # This file is also partly modified from https://github.com/microsoft/DeepSpeed | |
| # Implementation refers to Ulysses Paper: https://arxiv.org/abs/2309.14509 | |
| import copy | |
| from typing import Any, Tuple | |
| import deepspeed.comm as dist | |
| import torch | |
| import torch.distributed as torch_dist | |
| from flash_attn import flash_attn_func | |
| from torch import Tensor | |
| from torch.nn import Module | |
| from llava.train.sequence_parallel.globals import get_ulysses_seq_len, get_ulysses_sp_rank, get_ulysses_sp_size | |
| from .all_to_all import SeqAllGather, SeqAllToAll4D, SeqAllToAll5D | |
| class _ExpandKVFunction(torch.autograd.Function): | |
| """ | |
| Copy the KV head repeat times to extend sequence parallel support for Ulysses. | |
| Args: | |
| kv: input kv. | |
| repeat_times: the repeat number of each head. | |
| num_head_dim: the dimension of head number. | |
| """ | |
| def forward(ctx, k, v, repeat_times, num_head_dim): | |
| kv_shape = k.shape | |
| num_heads_kv = kv_shape[num_head_dim] | |
| ctx.num_head_dim = num_head_dim | |
| ctx.num_heads_kv = num_heads_kv | |
| # here we construct a repeat index to indicate which dim should copy | |
| repeat_index = [1] * k.ndim | |
| repeat_index[num_head_dim] = repeat_times | |
| # split the kv into head num splits | |
| k_splits = torch.chunk(k, chunks=num_heads_kv, dim=num_head_dim) | |
| v_splits = torch.chunk(v, chunks=num_heads_kv, dim=num_head_dim) | |
| k_repeats, v_repeats = [], [] | |
| # for each split, we copy it to repeat_times copys. | |
| for split in k_splits: | |
| k_split_repeat = split.repeat(repeat_index) | |
| k_repeats.append(k_split_repeat) | |
| for split in v_splits: | |
| v_split_repeat = split.repeat(repeat_index) | |
| v_repeats.append(v_split_repeat) | |
| return torch.cat(k_repeats, dim=num_head_dim), torch.cat(v_repeats, dim=num_head_dim) | |
| def backward(ctx, grad_output_k, grad_output_v): | |
| """ | |
| For backward, we sum the copy head inside a query group. | |
| """ | |
| num_head_dim = ctx.num_head_dim | |
| num_heads_kv = ctx.num_heads_kv | |
| # we split the grad into query groups splits. | |
| grad_output_k_splits = torch.chunk(grad_output_k, chunks=num_heads_kv, dim=num_head_dim) | |
| grad_output_v_splits = torch.chunk(grad_output_v, chunks=num_heads_kv, dim=num_head_dim) | |
| grad_output_k_sums, grad_output_v_sums = [], [] | |
| # for each split, we sum the head | |
| for grad_output_k_split in grad_output_k_splits: | |
| grad_output_k_sum = grad_output_k_split.sum(dim=num_head_dim, keepdim=True) | |
| grad_output_k_sums.append(grad_output_k_sum) | |
| for grad_output_v_split in grad_output_v_splits: | |
| grad_output_v_sum = grad_output_v_split.sum(dim=num_head_dim, keepdim=True) | |
| grad_output_v_sums.append(grad_output_v_sum) | |
| # then we concat the split sums on the num_head_dim dimension. | |
| grad_k = torch.cat(grad_output_k_sums, dim=num_head_dim) | |
| grad_v = torch.cat(grad_output_v_sums, dim=num_head_dim) | |
| return grad_k, grad_v, None, None | |
| expandKV = _ExpandKVFunction.apply | |
| class UlyssesAttention(torch.nn.Module): | |
| """Initialization. | |
| Arguments: | |
| local_attention (Module): local attention with q,k,v | |
| sequence_process_group (ProcessGroup): sequence parallel process group | |
| scatter_idx (int): scatter_idx for all2all comm | |
| gather_idx (int): gather_idx for all2all comm | |
| """ | |
| def __init__( | |
| self, | |
| local_attention: Module, | |
| sequence_process_group: dist.ProcessGroup = None, | |
| scatter_idx: int = 2, | |
| gather_idx: int = 1, | |
| ) -> None: | |
| super().__init__() | |
| self.local_attn = local_attention | |
| self.spg = sequence_process_group | |
| self.scatter_idx = scatter_idx | |
| self.gather_idx = gather_idx | |
| self.ulysses_degree = get_ulysses_sp_size() | |
| def forward( | |
| self, | |
| query: Tensor, | |
| key: Tensor, | |
| value: Tensor, | |
| *args: Any, | |
| attention_mask=None, | |
| dropout_p=0.0, | |
| softmax_scale=None, | |
| seqlens_in_batch=None, | |
| causal=False, | |
| window_size=(-1, -1), | |
| alibi_slopes=None, | |
| deterministic=False, | |
| return_attn_probs=False, | |
| ) -> Tensor: | |
| """forward | |
| Arguments: | |
| query (Tensor): query input to the layer | |
| key (Tensor): key input to the layer | |
| value (Tensor): value input to the layer | |
| args: other args | |
| Returns: | |
| * output (Tensor): context output | |
| """ | |
| # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) | |
| # KV Replication for GQA | |
| head_dim = 2 | |
| num_head_kv = key.shape[head_dim] | |
| if self.ulysses_degree > num_head_kv: | |
| assert self.ulysses_degree % num_head_kv == 0, "Ulysses require num_head_kv to be dividable by sp degree." | |
| key, value = expandKV(key, value, self.ulysses_degree // num_head_kv, head_dim) | |
| # scatter 2, gather 1 | |
| q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx) | |
| k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx) | |
| v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx) | |
| if attention_mask is not None: | |
| local_attention_mask = copy.deepcopy(attention_mask) | |
| shard_seqlen = local_attention_mask.size(1) | |
| ulysses_seq_len = get_ulysses_seq_len() | |
| max_global_length = max(ulysses_seq_len) | |
| global_attention_mask_list = [] | |
| for i in range(get_ulysses_sp_size()): | |
| if i == get_ulysses_sp_rank(): | |
| global_attention_mask_list.append( | |
| torch.cat( | |
| [ | |
| local_attention_mask, | |
| torch.zeros( | |
| (local_attention_mask.size(0), max_global_length - shard_seqlen), | |
| dtype=local_attention_mask.dtype, | |
| device=local_attention_mask.device, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| ) | |
| else: | |
| global_attention_mask_list.append( | |
| torch.zeros( | |
| (local_attention_mask.size(0), max_global_length), | |
| dtype=local_attention_mask.dtype, | |
| device=local_attention_mask.device, | |
| ) | |
| ) | |
| global_attention_mask = torch.stack(global_attention_mask_list, dim=0) | |
| torch_dist.all_reduce(global_attention_mask, group=self.spg) | |
| torch_dist.barrier(group=self.spg) | |
| new_global_attention_mask_list = list(torch.unbind(global_attention_mask, dim=0)) | |
| # Unpad the global attention mask list and concatenate them | |
| for i in range(len(new_global_attention_mask_list)): | |
| new_global_attention_mask_list[i] = new_global_attention_mask_list[i][:, : ulysses_seq_len[i]] | |
| global_attention_mask = torch.cat(new_global_attention_mask_list, dim=1) | |
| context_layer = self.local_attn( | |
| q, | |
| k, | |
| v, | |
| *args, | |
| attention_mask=global_attention_mask, | |
| dropout_p=dropout_p, | |
| softmax_scale=softmax_scale, | |
| seqlens_in_batch=seqlens_in_batch, | |
| causal=causal, | |
| ) | |
| else: | |
| context_layer = self.local_attn( | |
| q, | |
| k, | |
| v, | |
| *args, | |
| dropout_p=dropout_p, | |
| softmax_scale=softmax_scale, | |
| causal=causal, | |
| ) | |
| if isinstance(context_layer, tuple): | |
| context_layer = context_layer[0] | |
| # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) | |
| # scatter 1, gather 2 | |
| output = SeqAllToAll4D.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) | |
| # out e.g., [s/p::h] | |
| return output | |