Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import math | |
| def _gen_bias_mask(max_length): | |
| """ | |
| Generates bias values (-Inf) to mask future timesteps during attention | |
| """ | |
| np_mask = np.triu(np.full([max_length, max_length], -np.inf), 1) | |
| torch_mask = torch.from_numpy(np_mask).type(torch.FloatTensor) | |
| return torch_mask.unsqueeze(0).unsqueeze(1) | |
| def _gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4): | |
| """ | |
| Generates a [1, length, channels] timing signal consisting of sinusoids | |
| Adapted from: | |
| https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py | |
| """ | |
| position = np.arange(length) | |
| num_timescales = channels // 2 | |
| log_timescale_increment = ( | |
| math.log(float(max_timescale) / float(min_timescale)) / | |
| (float(num_timescales) - 1)) | |
| inv_timescales = min_timescale * np.exp( | |
| np.arange(num_timescales).astype(np.float64) * -log_timescale_increment) | |
| scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0) | |
| signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) | |
| signal = np.pad(signal, [[0, 0], [0, channels % 2]], | |
| 'constant', constant_values=[0.0, 0.0]) | |
| signal = signal.reshape([1, length, channels]) | |
| return torch.from_numpy(signal).type(torch.FloatTensor) | |
| class LayerNorm(nn.Module): | |
| # Borrowed from jekbradbury | |
| # https://github.com/pytorch/pytorch/issues/1959 | |
| def __init__(self, features, eps=1e-6): | |
| super(LayerNorm, self).__init__() | |
| self.gamma = nn.Parameter(torch.ones(features)) | |
| self.beta = nn.Parameter(torch.zeros(features)) | |
| self.eps = eps | |
| def forward(self, x): | |
| mean = x.mean(-1, keepdim=True) | |
| std = x.std(-1, keepdim=True) | |
| return self.gamma * (x - mean) / (std + self.eps) + self.beta | |
| class OutputLayer(nn.Module): | |
| """ | |
| Abstract base class for output layer. | |
| Handles projection to output labels | |
| """ | |
| def __init__(self, hidden_size, output_size, probs_out=False): | |
| super(OutputLayer, self).__init__() | |
| self.output_size = output_size | |
| self.output_projection = nn.Linear(hidden_size, output_size) | |
| self.probs_out = probs_out | |
| self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=int(hidden_size/2), batch_first=True, bidirectional=True) | |
| self.hidden_size = hidden_size | |
| def loss(self, hidden, labels): | |
| raise NotImplementedError('Must implement {}.loss'.format(self.__class__.__name__)) | |
| class SoftmaxOutputLayer(OutputLayer): | |
| """ | |
| Implements a softmax based output layer | |
| """ | |
| def forward(self, hidden): | |
| logits = self.output_projection(hidden) | |
| probs = F.softmax(logits, -1) | |
| # _, predictions = torch.max(probs, dim=-1) | |
| topk, indices = torch.topk(probs, 2) | |
| predictions = indices[:,:,0] | |
| second = indices[:,:,1] | |
| if self.probs_out is True: | |
| return logits | |
| # return probs | |
| return predictions, second | |
| def loss(self, hidden, labels): | |
| logits = self.output_projection(hidden) | |
| log_probs = F.log_softmax(logits, -1) | |
| return F.nll_loss(log_probs.view(-1, self.output_size), labels.view(-1)) | |
| class MultiHeadAttention(nn.Module): | |
| """ | |
| Multi-head attention as per https://arxiv.org/pdf/1706.03762.pdf | |
| Refer Figure 2 | |
| """ | |
| def __init__(self, input_depth, total_key_depth, total_value_depth, output_depth, | |
| num_heads, bias_mask=None, dropout=0.0, attention_map=False): | |
| """ | |
| Parameters: | |
| input_depth: Size of last dimension of input | |
| total_key_depth: Size of last dimension of keys. Must be divisible by num_head | |
| total_value_depth: Size of last dimension of values. Must be divisible by num_head | |
| output_depth: Size last dimension of the final output | |
| num_heads: Number of attention heads | |
| bias_mask: Masking tensor to prevent connections to future elements | |
| dropout: Dropout probability (Should be non-zero only during training) | |
| """ | |
| super(MultiHeadAttention, self).__init__() | |
| # Checks borrowed from | |
| # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py | |
| if total_key_depth % num_heads != 0: | |
| raise ValueError("Key depth (%d) must be divisible by the number of " | |
| "attention heads (%d)." % (total_key_depth, num_heads)) | |
| if total_value_depth % num_heads != 0: | |
| raise ValueError("Value depth (%d) must be divisible by the number of " | |
| "attention heads (%d)." % (total_value_depth, num_heads)) | |
| self.attention_map = attention_map | |
| self.num_heads = num_heads | |
| self.query_scale = (total_key_depth // num_heads) ** -0.5 | |
| self.bias_mask = bias_mask | |
| # Key and query depth will be same | |
| self.query_linear = nn.Linear(input_depth, total_key_depth, bias=False) | |
| self.key_linear = nn.Linear(input_depth, total_key_depth, bias=False) | |
| self.value_linear = nn.Linear(input_depth, total_value_depth, bias=False) | |
| self.output_linear = nn.Linear(total_value_depth, output_depth, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def _split_heads(self, x): | |
| """ | |
| Split x such to add an extra num_heads dimension | |
| Input: | |
| x: a Tensor with shape [batch_size, seq_length, depth] | |
| Returns: | |
| A Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads] | |
| """ | |
| if len(x.shape) != 3: | |
| raise ValueError("x must have rank 3") | |
| shape = x.shape | |
| return x.view(shape[0], shape[1], self.num_heads, shape[2] // self.num_heads).permute(0, 2, 1, 3) | |
| def _merge_heads(self, x): | |
| """ | |
| Merge the extra num_heads into the last dimension | |
| Input: | |
| x: a Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads] | |
| Returns: | |
| A Tensor with shape [batch_size, seq_length, depth] | |
| """ | |
| if len(x.shape) != 4: | |
| raise ValueError("x must have rank 4") | |
| shape = x.shape | |
| return x.permute(0, 2, 1, 3).contiguous().view(shape[0], shape[2], shape[3] * self.num_heads) | |
| def forward(self, queries, keys, values): | |
| # Do a linear for each component | |
| queries = self.query_linear(queries) | |
| keys = self.key_linear(keys) | |
| values = self.value_linear(values) | |
| # Split into multiple heads | |
| queries = self._split_heads(queries) | |
| keys = self._split_heads(keys) | |
| values = self._split_heads(values) | |
| # Scale queries | |
| queries *= self.query_scale | |
| # Combine queries and keys | |
| logits = torch.matmul(queries, keys.permute(0, 1, 3, 2)) | |
| # Add bias to mask future values | |
| if self.bias_mask is not None: | |
| logits += self.bias_mask[:, :, :logits.shape[-2], :logits.shape[-1]].type_as(logits.data) | |
| # Convert to probabilites | |
| weights = nn.functional.softmax(logits, dim=-1) | |
| # Dropout | |
| weights = self.dropout(weights) | |
| # Combine with values to get context | |
| contexts = torch.matmul(weights, values) | |
| # Merge heads | |
| contexts = self._merge_heads(contexts) | |
| # contexts = torch.tanh(contexts) | |
| # Linear to get output | |
| outputs = self.output_linear(contexts) | |
| if self.attention_map is True: | |
| return outputs, weights | |
| return outputs | |
| class Conv(nn.Module): | |
| """ | |
| Convenience class that does padding and convolution for inputs in the format | |
| [batch_size, sequence length, hidden size] | |
| """ | |
| def __init__(self, input_size, output_size, kernel_size, pad_type): | |
| """ | |
| Parameters: | |
| input_size: Input feature size | |
| output_size: Output feature size | |
| kernel_size: Kernel width | |
| pad_type: left -> pad on the left side (to mask future data_loader), | |
| both -> pad on both sides | |
| """ | |
| super(Conv, self).__init__() | |
| padding = (kernel_size - 1, 0) if pad_type == 'left' else (kernel_size // 2, (kernel_size - 1) // 2) | |
| self.pad = nn.ConstantPad1d(padding, 0) | |
| self.conv = nn.Conv1d(input_size, output_size, kernel_size=kernel_size, padding=0) | |
| def forward(self, inputs): | |
| inputs = self.pad(inputs.permute(0, 2, 1)) | |
| outputs = self.conv(inputs).permute(0, 2, 1) | |
| return outputs | |
| class PositionwiseFeedForward(nn.Module): | |
| """ | |
| Does a Linear + RELU + Linear on each of the timesteps | |
| """ | |
| def __init__(self, input_depth, filter_size, output_depth, layer_config='ll', padding='left', dropout=0.0): | |
| """ | |
| Parameters: | |
| input_depth: Size of last dimension of input | |
| filter_size: Hidden size of the middle layer | |
| output_depth: Size last dimension of the final output | |
| layer_config: ll -> linear + ReLU + linear | |
| cc -> conv + ReLU + conv etc. | |
| padding: left -> pad on the left side (to mask future data_loader), | |
| both -> pad on both sides | |
| dropout: Dropout probability (Should be non-zero only during training) | |
| """ | |
| super(PositionwiseFeedForward, self).__init__() | |
| layers = [] | |
| sizes = ([(input_depth, filter_size)] + | |
| [(filter_size, filter_size)] * (len(layer_config) - 2) + | |
| [(filter_size, output_depth)]) | |
| for lc, s in zip(list(layer_config), sizes): | |
| if lc == 'l': | |
| layers.append(nn.Linear(*s)) | |
| elif lc == 'c': | |
| layers.append(Conv(*s, kernel_size=3, pad_type=padding)) | |
| else: | |
| raise ValueError("Unknown layer type {}".format(lc)) | |
| self.layers = nn.ModuleList(layers) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, inputs): | |
| x = inputs | |
| for i, layer in enumerate(self.layers): | |
| x = layer(x) | |
| if i < len(self.layers): | |
| x = self.relu(x) | |
| x = self.dropout(x) | |
| return x |