Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ----------------------------------------------------------------------------- | |
| Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | |
| NVIDIA CORPORATION and its licensors retain all intellectual property | |
| and proprietary rights in and to this software, related documentation | |
| and any modifications thereto. Any use, reproduction, disclosure or | |
| distribution of this software and related documentation without an express | |
| license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| ----------------------------------------------------------------------------- | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| try: | |
| from flash_attn import flash_attn_func, flash_attn_varlen_func | |
| from flash_attn.bert_padding import ( # , unpad_input # noqa | |
| index_first_axis, | |
| pad_input, | |
| ) | |
| FLASH_ATTN_AVAILABLE = True | |
| except Exception as e: | |
| print("[WARN] flash_attn not available, using torch/naive implementation") | |
| FLASH_ATTN_AVAILABLE = False | |
| # Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py#L98 | |
| # flashattn 2.7.0 changes the API, we are overriding it here | |
| def unpad_input(hidden_states, attention_mask): | |
| """ | |
| Arguments: | |
| hidden_states: (batch, seqlen, ...) | |
| attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. | |
| Return: | |
| hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. | |
| indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. | |
| cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. | |
| max_seqlen_in_batch: int | |
| """ | |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
| max_seqlen_in_batch = seqlens_in_batch.max().item() | |
| cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) | |
| # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the | |
| # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim | |
| # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to | |
| # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, | |
| # so we write custom forward and backward to make it a bit faster. | |
| return ( | |
| index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), | |
| indices, | |
| cu_seqlens, | |
| max_seqlen_in_batch, | |
| ) | |
| def attention(q, k, v, mask_q=None, mask_kv=None, dropout=0, causal=False, window_size=(-1, -1), backend="torch"): | |
| # q: (B, N, H, D) | |
| # k: (B, M, H, D) | |
| # v: (B, M, H, D) | |
| # mask_q: (B, N) | |
| # mask_kv: (B, M) | |
| # return: (B, N, H, D) | |
| B, N, H, D = q.shape | |
| M = k.shape[1] | |
| if causal: | |
| assert N == 1 or N == M, "Causal mask only supports self-attention" | |
| # unmasked case (usually inference) | |
| # will ignore window_size except flash-attn impl. Only provide the effective window! | |
| if mask_q is None and mask_kv is None: | |
| if backend == "flash-attn" and FLASH_ATTN_AVAILABLE: | |
| return flash_attn_func(q, k, v, dropout, causal=causal, window_size=window_size) # [B, N, H, D] | |
| elif backend == "torch": # torch implementation | |
| q = q.permute(0, 2, 1, 3) | |
| k = k.permute(0, 2, 1, 3) | |
| v = v.permute(0, 2, 1, 3) | |
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=dropout, is_causal=causal) | |
| out = out.permute(0, 2, 1, 3).contiguous() | |
| return out | |
| else: # naive implementation | |
| q = q.transpose(1, 2).reshape(B * H, N, D) | |
| k = k.transpose(1, 2).reshape(B * H, M, D) | |
| v = v.transpose(1, 2).reshape(B * H, M, D) | |
| w = torch.bmm(q, k.transpose(1, 2)) / (D**0.5) # [B*H, N, M] | |
| if causal and N > 1: | |
| causal_mask = torch.full((N, M), float("-inf"), device=w.device, dtype=w.dtype) | |
| causal_mask = torch.triu(causal_mask, diagonal=1) | |
| w = w + causal_mask.unsqueeze(0) | |
| w = F.softmax(w, dim=-1) | |
| if dropout > 0: | |
| w = F.dropout(w, p=dropout) | |
| out = torch.bmm(w, v) # [B*H, N, D] | |
| out = out.reshape(B, H, N, D).transpose(1, 2).contiguous() # [B, N, H, D] | |
| return out | |
| # at least one of q or kv is masked (training) | |
| # only support flash-attn for now... | |
| if mask_q is None: | |
| mask_q = torch.ones(B, N, dtype=torch.bool, device=q.device) | |
| elif mask_kv is None: | |
| mask_kv = torch.ones(B, M, dtype=torch.bool, device=q.device) | |
| if FLASH_ATTN_AVAILABLE: | |
| # unpad (gather) input | |
| # mask_q: [B, N], first row has N1 1s, second row has N2 1s, ... | |
| # indices: [Ns,], Ns = N1 + N2 + ... | |
| # cu_seqlens_q: [B+1,], (0, N1, N1+N2, ...), cu=cumulative | |
| # max_len_q: scalar, max(N1, N2, ...) | |
| q, indices_q, cu_seqlens_q, max_len_q = unpad_input(q, mask_q) | |
| k, indices_kv, cu_seqlens_kv, max_len_kv = unpad_input(k, mask_kv) | |
| v = index_first_axis(v.reshape(-1, H, D), indices_kv) # same indice as k | |
| # call varlen_func | |
| out = flash_attn_varlen_func( | |
| q, | |
| k, | |
| v, | |
| cu_seqlens_q=cu_seqlens_q, | |
| cu_seqlens_k=cu_seqlens_kv, | |
| max_seqlen_q=max_len_q, | |
| max_seqlen_k=max_len_kv, | |
| dropout_p=dropout, | |
| causal=causal, | |
| window_size=window_size, | |
| ) | |
| # pad (put back) output | |
| out = pad_input(out, indices_q, B, N) | |
| return out | |
| else: | |
| raise NotImplementedError("masked attention requires flash_attn!") | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self.eps = eps | |
| def forward(self, x): | |
| rnorm = torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) | |
| return (x * rnorm).to(dtype=self.weight.dtype) * self.weight | |
| class SelfAttention(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_dim, | |
| num_heads, | |
| input_dim=None, | |
| output_dim=None, | |
| dropout=0, | |
| causal=False, | |
| qknorm=False, | |
| qknorm_type="LayerNorm", | |
| ): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.input_dim = input_dim if input_dim is not None else hidden_dim | |
| self.output_dim = output_dim if output_dim is not None else hidden_dim | |
| self.num_heads = num_heads | |
| assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" | |
| self.head_dim = hidden_dim // num_heads | |
| self.causal = causal | |
| self.dropout = dropout | |
| self.qknorm = qknorm | |
| self.qkv_proj = nn.Linear(self.input_dim, 3 * self.hidden_dim) | |
| self.out_proj = nn.Linear(self.hidden_dim, self.output_dim) | |
| if self.qknorm: | |
| if qknorm_type == "RMSNorm": | |
| self.q_norm = RMSNorm(self.hidden_dim, eps=1e-6) | |
| self.k_norm = RMSNorm(self.hidden_dim, eps=1e-6) | |
| else: | |
| self.q_norm = nn.LayerNorm(self.hidden_dim, eps=1e-6, elementwise_affine=False) | |
| self.k_norm = nn.LayerNorm(self.hidden_dim, eps=1e-6, elementwise_affine=False) | |
| def forward(self, x, mask=None): | |
| # x: [B, N, C] | |
| # mask: [B, N] | |
| B, N, C = x.shape | |
| qkv = self.qkv_proj(x) # [B, N, C] -> [B, N, 3 * D] | |
| qkv = qkv.reshape(B, N, 3, -1).permute(2, 0, 1, 3) # [3, B, N, D] | |
| q, k, v = qkv.chunk(3, dim=0) # [3, B, N, D] -> 3 * [1, B, N, D] | |
| q = q.squeeze(0) | |
| k = k.squeeze(0) | |
| v = v.squeeze(0) | |
| if self.qknorm: | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| q = q.reshape(B, N, self.num_heads, self.head_dim) | |
| k = k.reshape(B, N, self.num_heads, self.head_dim) | |
| v = v.reshape(B, N, self.num_heads, self.head_dim) | |
| x = attention(q, k, v, mask_q=mask, mask_kv=mask, dropout=self.dropout, causal=self.causal) # [B, N, H, D] | |
| x = self.out_proj(x.reshape(B, N, -1)) | |
| return x | |
| class CrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_dim, | |
| num_heads, | |
| input_dim=None, | |
| context_dim=None, | |
| output_dim=None, | |
| dropout=0, | |
| qknorm=False, | |
| qknorm_type="LayerNorm", | |
| ): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.input_dim = input_dim if input_dim is not None else hidden_dim | |
| self.context_dim = context_dim if context_dim is not None else hidden_dim | |
| self.output_dim = output_dim if output_dim is not None else hidden_dim | |
| self.num_heads = num_heads | |
| assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" | |
| self.head_dim = hidden_dim // num_heads | |
| self.dropout = dropout | |
| self.qknorm = qknorm | |
| self.q_proj = nn.Linear(self.input_dim, self.hidden_dim) | |
| self.k_proj = nn.Linear(self.context_dim, self.hidden_dim) | |
| self.v_proj = nn.Linear(self.context_dim, self.hidden_dim) | |
| self.out_proj = nn.Linear(self.hidden_dim, self.output_dim) | |
| if self.qknorm: | |
| if qknorm_type == "RMSNorm": | |
| self.q_norm = RMSNorm(self.hidden_dim, eps=1e-6) | |
| self.k_norm = RMSNorm(self.hidden_dim, eps=1e-6) | |
| else: | |
| self.q_norm = nn.LayerNorm(self.hidden_dim, eps=1e-6, elementwise_affine=False) | |
| self.k_norm = nn.LayerNorm(self.hidden_dim, eps=1e-6, elementwise_affine=False) | |
| def forward(self, x, context, mask_q=None, mask_kv=None): | |
| # x: [B, N, C] | |
| # context: [B, M, C'] | |
| # mask_q: [B, N] | |
| # mask_kv: [B, M] | |
| B, N, C = x.shape | |
| M = context.shape[1] | |
| q = self.q_proj(x) | |
| k = self.k_proj(context) | |
| v = self.v_proj(context) | |
| if self.qknorm: | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| q = q.reshape(B, N, self.num_heads, self.head_dim) | |
| k = k.reshape(B, M, self.num_heads, self.head_dim) | |
| v = v.reshape(B, M, self.num_heads, self.head_dim) | |
| x = attention(q, k, v, mask_q=mask_q, mask_kv=mask_kv, dropout=self.dropout, causal=False) # [B, N, H, D] | |
| x = self.out_proj(x.reshape(B, N, -1)) | |
| return x | |