Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.nn import Linear | |
| from torch.nn.init import xavier_uniform_ | |
| from torch.nn.init import constant_ | |
| from torch.nn.init import xavier_normal_ | |
| from torch.nn.parameter import Parameter | |
| from torch.nn.modules.module import Module | |
| from .functional import multi_head_attention_forward | |
| class MultiheadAttention(Module): | |
| r"""Allows the model to jointly attend to information | |
| from different representation subspaces. | |
| See reference: Attention Is All You Need | |
| .. math:: | |
| \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O | |
| \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) | |
| Args: | |
| embed_dim: total dimension of the model. | |
| num_heads: parallel attention heads. | |
| dropout: a Dropout layer on attn_output_weights. Default: 0.0. | |
| bias: add bias as module parameter. Default: True. | |
| add_bias_kv: add bias to the key and value sequences at dim=0. | |
| add_zero_attn: add a new batch of zeros to the key and | |
| value sequences at dim=1. | |
| kdim: total number of features in key. Default: None. | |
| vdim: total number of features in key. Default: None. | |
| Note: if kdim and vdim are None, they will be set to embed_dim such that | |
| query, key, and value have the same number of features. | |
| Examples:: | |
| >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) | |
| >>> attn_output, attn_output_weights = multihead_attn(query, key, value) | |
| """ | |
| __annotations__ = { | |
| 'bias_k': torch._jit_internal.Optional[torch.Tensor], | |
| 'bias_v': torch._jit_internal.Optional[torch.Tensor], | |
| } | |
| __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'] | |
| def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): | |
| super(MultiheadAttention, self).__init__() | |
| self.embed_dim = embed_dim | |
| self.kdim = kdim if kdim is not None else embed_dim | |
| self.vdim = vdim if vdim is not None else embed_dim | |
| self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.head_dim = embed_dim // num_heads | |
| assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" | |
| if self._qkv_same_embed_dim is False: | |
| self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) | |
| self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) | |
| self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) | |
| self.register_parameter('in_proj_weight', None) | |
| else: | |
| self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) | |
| self.register_parameter('q_proj_weight', None) | |
| self.register_parameter('k_proj_weight', None) | |
| self.register_parameter('v_proj_weight', None) | |
| if bias: | |
| self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) | |
| else: | |
| self.register_parameter('in_proj_bias', None) | |
| self.out_proj = Linear(embed_dim, embed_dim, bias=bias) | |
| if add_bias_kv: | |
| self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) | |
| self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) | |
| else: | |
| self.bias_k = self.bias_v = None | |
| self.add_zero_attn = add_zero_attn | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| if self._qkv_same_embed_dim: | |
| xavier_uniform_(self.in_proj_weight) | |
| else: | |
| xavier_uniform_(self.q_proj_weight) | |
| xavier_uniform_(self.k_proj_weight) | |
| xavier_uniform_(self.v_proj_weight) | |
| if self.in_proj_bias is not None: | |
| constant_(self.in_proj_bias, 0.) | |
| constant_(self.out_proj.bias, 0.) | |
| if self.bias_k is not None: | |
| xavier_normal_(self.bias_k) | |
| if self.bias_v is not None: | |
| xavier_normal_(self.bias_v) | |
| def __setstate__(self, state): | |
| # Support loading old MultiheadAttention checkpoints generated by v1.1.0 | |
| if '_qkv_same_embed_dim' not in state: | |
| state['_qkv_same_embed_dim'] = True | |
| super(MultiheadAttention, self).__setstate__(state) | |
| def forward(self, query, key, value, key_padding_mask=None, | |
| need_weights=True, attn_mask=None): | |
| # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] | |
| r""" | |
| Args: | |
| query, key, value: map a query and a set of key-value pairs to an output. | |
| See "Attention Is All You Need" for more details. | |
| key_padding_mask: if provided, specified padding elements in the key will | |
| be ignored by the attention. This is an binary mask. When the value is True, | |
| the corresponding value on the attention layer will be filled with -inf. | |
| need_weights: output attn_output_weights. | |
| attn_mask: 2D or 3D mask that prevents attention to certain positions. This is an additive mask | |
| (i.e. the values will be added to the attention layer). A 2D mask will be broadcasted for all | |
| the batches while a 3D mask allows to specify a different mask for the entries of each batch. | |
| Shape: | |
| - Inputs: | |
| - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is | |
| the embedding dimension. | |
| - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is | |
| the embedding dimension. | |
| - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is | |
| the embedding dimension. | |
| - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. | |
| - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. | |
| 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, | |
| S is the source sequence length. | |
| - Outputs: | |
| - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, | |
| E is the embedding dimension. | |
| - attn_output_weights: :math:`(N, L, S)` where N is the batch size, | |
| L is the target sequence length, S is the source sequence length. | |
| """ | |
| if not self._qkv_same_embed_dim: | |
| return multi_head_attention_forward( | |
| query, key, value, self.embed_dim, self.num_heads, | |
| self.in_proj_weight, self.in_proj_bias, | |
| self.bias_k, self.bias_v, self.add_zero_attn, | |
| self.dropout, self.out_proj.weight, self.out_proj.bias, | |
| training=self.training, | |
| key_padding_mask=key_padding_mask, need_weights=need_weights, | |
| attn_mask=attn_mask, use_separate_proj_weight=True, | |
| q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, | |
| v_proj_weight=self.v_proj_weight) | |
| else: | |
| return multi_head_attention_forward( | |
| query, key, value, self.embed_dim, self.num_heads, | |
| self.in_proj_weight, self.in_proj_bias, | |
| self.bias_k, self.bias_v, self.add_zero_attn, | |
| self.dropout, self.out_proj.weight, self.out_proj.bias, | |
| training=self.training, | |
| key_padding_mask=key_padding_mask, need_weights=need_weights, | |
| attn_mask=attn_mask) | |