Spaces:
Running
on
Zero
Running
on
Zero
| """Fastformer attention definition. | |
| Reference: | |
| Wu et al., "Fastformer: Additive Attention Can Be All You Need" | |
| https://arxiv.org/abs/2108.09084 | |
| https://github.com/wuch15/Fastformer | |
| """ | |
| import numpy | |
| import torch | |
| class FastSelfAttention(torch.nn.Module): | |
| """Fast self-attention used in Fastformer.""" | |
| def __init__( | |
| self, | |
| size, | |
| attention_heads, | |
| dropout_rate, | |
| ): | |
| super().__init__() | |
| if size % attention_heads != 0: | |
| raise ValueError( | |
| f"Hidden size ({size}) is not an integer multiple " | |
| f"of attention heads ({attention_heads})" | |
| ) | |
| self.attention_head_size = size // attention_heads | |
| self.num_attention_heads = attention_heads | |
| self.query = torch.nn.Linear(size, size) | |
| self.query_att = torch.nn.Linear(size, attention_heads) | |
| self.key = torch.nn.Linear(size, size) | |
| self.key_att = torch.nn.Linear(size, attention_heads) | |
| self.transform = torch.nn.Linear(size, size) | |
| self.dropout = torch.nn.Dropout(dropout_rate) | |
| def espnet_initialization_fn(self): | |
| self.apply(self.init_weights) | |
| def init_weights(self, module): | |
| if isinstance(module, torch.nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if isinstance(module, torch.nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| def transpose_for_scores(self, x): | |
| """Reshape and transpose to compute scores. | |
| Args: | |
| x: (batch, time, size = n_heads * attn_dim) | |
| Returns: | |
| (batch, n_heads, time, attn_dim) | |
| """ | |
| new_x_shape = x.shape[:-1] + ( | |
| self.num_attention_heads, | |
| self.attention_head_size, | |
| ) | |
| return x.reshape(*new_x_shape).transpose(1, 2) | |
| def forward(self, xs_pad, mask): | |
| """Forward method. | |
| Args: | |
| xs_pad: (batch, time, size = n_heads * attn_dim) | |
| mask: (batch, 1, time), nonpadding is 1, padding is 0 | |
| Returns: | |
| torch.Tensor: (batch, time, size) | |
| """ | |
| batch_size, seq_len, _ = xs_pad.shape | |
| mixed_query_layer = self.query(xs_pad) # (batch, time, size) | |
| mixed_key_layer = self.key(xs_pad) # (batch, time, size) | |
| if mask is not None: | |
| mask = mask.eq(0) # padding is 1, nonpadding is 0 | |
| # (batch, n_heads, time) | |
| query_for_score = ( | |
| self.query_att(mixed_query_layer).transpose(1, 2) | |
| / self.attention_head_size**0.5 | |
| ) | |
| if mask is not None: | |
| min_value = float( | |
| numpy.finfo( | |
| torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype | |
| ).min | |
| ) | |
| query_for_score = query_for_score.masked_fill(mask, min_value) | |
| query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0) | |
| else: | |
| query_weight = torch.softmax(query_for_score, dim=-1) | |
| query_weight = query_weight.unsqueeze(2) # (batch, n_heads, 1, time) | |
| query_layer = self.transpose_for_scores( | |
| mixed_query_layer | |
| ) # (batch, n_heads, time, attn_dim) | |
| pooled_query = ( | |
| torch.matmul(query_weight, query_layer) | |
| .transpose(1, 2) | |
| .reshape(-1, 1, self.num_attention_heads * self.attention_head_size) | |
| ) # (batch, 1, size = n_heads * attn_dim) | |
| pooled_query = self.dropout(pooled_query) | |
| pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size) | |
| mixed_query_key_layer = ( | |
| mixed_key_layer * pooled_query_repeat | |
| ) # (batch, time, size) | |
| # (batch, n_heads, time) | |
| query_key_score = ( | |
| self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5 | |
| ).transpose(1, 2) | |
| if mask is not None: | |
| min_value = float( | |
| numpy.finfo( | |
| torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype | |
| ).min | |
| ) | |
| query_key_score = query_key_score.masked_fill(mask, min_value) | |
| query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill( | |
| mask, 0.0 | |
| ) | |
| else: | |
| query_key_weight = torch.softmax(query_key_score, dim=-1) | |
| query_key_weight = query_key_weight.unsqueeze(2) # (batch, n_heads, 1, time) | |
| key_layer = self.transpose_for_scores( | |
| mixed_query_key_layer | |
| ) # (batch, n_heads, time, attn_dim) | |
| pooled_key = torch.matmul( | |
| query_key_weight, key_layer | |
| ) # (batch, n_heads, 1, attn_dim) | |
| pooled_key = self.dropout(pooled_key) | |
| # NOTE: value = query, due to param sharing | |
| weighted_value = (pooled_key * query_layer).transpose( | |
| 1, 2 | |
| ) # (batch, time, n_heads, attn_dim) | |
| weighted_value = weighted_value.reshape( | |
| weighted_value.shape[:-2] | |
| + (self.num_attention_heads * self.attention_head_size,) | |
| ) # (batch, time, size) | |
| weighted_value = ( | |
| self.dropout(self.transform(weighted_value)) + mixed_query_layer | |
| ) | |
| return weighted_value | |