Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """ Implement a pyTorch LSTM with hard sigmoid reccurent activation functions. | |
| Adapted from the non-cuda variant of pyTorch LSTM at | |
| https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py | |
| """ | |
| from __future__ import print_function, division | |
| import math | |
| import torch | |
| from torch.nn import Module | |
| from torch.nn.parameter import Parameter | |
| from torch.nn.utils.rnn import PackedSequence | |
| import torch.nn.functional as F | |
| class LSTMHardSigmoid(Module): | |
| def __init__(self, input_size, hidden_size, | |
| num_layers=1, bias=True, batch_first=False, | |
| dropout=0, bidirectional=False): | |
| super(LSTMHardSigmoid, self).__init__() | |
| self.input_size = input_size | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| self.bias = bias | |
| self.batch_first = batch_first | |
| self.dropout = dropout | |
| self.dropout_state = {} | |
| self.bidirectional = bidirectional | |
| num_directions = 2 if bidirectional else 1 | |
| gate_size = 4 * hidden_size | |
| self._all_weights = [] | |
| for layer in range(num_layers): | |
| for direction in range(num_directions): | |
| layer_input_size = input_size if layer == 0 else hidden_size * num_directions | |
| w_ih = Parameter(torch.Tensor(gate_size, layer_input_size)) | |
| w_hh = Parameter(torch.Tensor(gate_size, hidden_size)) | |
| b_ih = Parameter(torch.Tensor(gate_size)) | |
| b_hh = Parameter(torch.Tensor(gate_size)) | |
| layer_params = (w_ih, w_hh, b_ih, b_hh) | |
| suffix = '_reverse' if direction == 1 else '' | |
| param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] | |
| if bias: | |
| param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] | |
| param_names = [x.format(layer, suffix) for x in param_names] | |
| for name, param in zip(param_names, layer_params): | |
| setattr(self, name, param) | |
| self._all_weights.append(param_names) | |
| self.flatten_parameters() | |
| self.reset_parameters() | |
| def flatten_parameters(self): | |
| """Resets parameter data pointer so that they can use faster code paths. | |
| Right now, this is a no-op wince we don't use CUDA acceleration. | |
| """ | |
| self._data_ptrs = [] | |
| def _apply(self, fn): | |
| ret = super(LSTMHardSigmoid, self)._apply(fn) | |
| self.flatten_parameters() | |
| return ret | |
| def reset_parameters(self): | |
| stdv = 1.0 / math.sqrt(self.hidden_size) | |
| for weight in self.parameters(): | |
| weight.data.uniform_(-stdv, stdv) | |
| def forward(self, input, hx=None): | |
| is_packed = isinstance(input, PackedSequence) | |
| if is_packed: | |
| input, batch_sizes ,_ ,_ = input | |
| max_batch_size = batch_sizes[0] | |
| else: | |
| batch_sizes = None | |
| max_batch_size = input.size(0) if self.batch_first else input.size(1) | |
| if hx is None: | |
| num_directions = 2 if self.bidirectional else 1 | |
| hx = torch.autograd.Variable(input.data.new(self.num_layers * | |
| num_directions, | |
| max_batch_size, | |
| self.hidden_size).zero_(), requires_grad=False) | |
| hx = (hx, hx) | |
| has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs | |
| if has_flat_weights: | |
| first_data = next(self.parameters()).data | |
| assert first_data.storage().size() == self._param_buf_size | |
| flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size])) | |
| else: | |
| flat_weight = None | |
| func = AutogradRNN( | |
| self.input_size, | |
| self.hidden_size, | |
| num_layers=self.num_layers, | |
| batch_first=self.batch_first, | |
| dropout=self.dropout, | |
| train=self.training, | |
| bidirectional=self.bidirectional, | |
| batch_sizes=batch_sizes, | |
| dropout_state=self.dropout_state, | |
| flat_weight=flat_weight | |
| ) | |
| output, hidden = func(input, self.all_weights, hx) | |
| if is_packed: | |
| output = PackedSequence(output, batch_sizes) | |
| return output, hidden | |
| def __repr__(self): | |
| s = '{name}({input_size}, {hidden_size}' | |
| if self.num_layers != 1: | |
| s += ', num_layers={num_layers}' | |
| if self.bias is not True: | |
| s += ', bias={bias}' | |
| if self.batch_first is not False: | |
| s += ', batch_first={batch_first}' | |
| if self.dropout != 0: | |
| s += ', dropout={dropout}' | |
| if self.bidirectional is not False: | |
| s += ', bidirectional={bidirectional}' | |
| s += ')' | |
| return s.format(name=self.__class__.__name__, **self.__dict__) | |
| def __setstate__(self, d): | |
| super(LSTMHardSigmoid, self).__setstate__(d) | |
| self.__dict__.setdefault('_data_ptrs', []) | |
| if 'all_weights' in d: | |
| self._all_weights = d['all_weights'] | |
| if isinstance(self._all_weights[0][0], str): | |
| return | |
| num_layers = self.num_layers | |
| num_directions = 2 if self.bidirectional else 1 | |
| self._all_weights = [] | |
| for layer in range(num_layers): | |
| for direction in range(num_directions): | |
| suffix = '_reverse' if direction == 1 else '' | |
| weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}'] | |
| weights = [x.format(layer, suffix) for x in weights] | |
| if self.bias: | |
| self._all_weights += [weights] | |
| else: | |
| self._all_weights += [weights[:2]] | |
| def all_weights(self): | |
| return [[getattr(self, weight) for weight in weights] for weights in self._all_weights] | |
| def AutogradRNN(input_size, hidden_size, num_layers=1, batch_first=False, | |
| dropout=0, train=True, bidirectional=False, batch_sizes=None, | |
| dropout_state=None, flat_weight=None): | |
| cell = LSTMCell | |
| if batch_sizes is None: | |
| rec_factory = Recurrent | |
| else: | |
| rec_factory = variable_recurrent_factory(batch_sizes) | |
| if bidirectional: | |
| layer = (rec_factory(cell), rec_factory(cell, reverse=True)) | |
| else: | |
| layer = (rec_factory(cell),) | |
| func = StackedRNN(layer, | |
| num_layers, | |
| True, | |
| dropout=dropout, | |
| train=train) | |
| def forward(input, weight, hidden): | |
| if batch_first and batch_sizes is None: | |
| input = input.transpose(0, 1) | |
| nexth, output = func(input, hidden, weight) | |
| if batch_first and batch_sizes is None: | |
| output = output.transpose(0, 1) | |
| return output, nexth | |
| return forward | |
| def Recurrent(inner, reverse=False): | |
| def forward(input, hidden, weight): | |
| output = [] | |
| steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) | |
| for i in steps: | |
| hidden = inner(input[i], hidden, *weight) | |
| # hack to handle LSTM | |
| output.append(hidden[0] if isinstance(hidden, tuple) else hidden) | |
| if reverse: | |
| output.reverse() | |
| output = torch.cat(output, 0).view(input.size(0), *output[0].size()) | |
| return hidden, output | |
| return forward | |
| def variable_recurrent_factory(batch_sizes): | |
| def fac(inner, reverse=False): | |
| if reverse: | |
| return VariableRecurrentReverse(batch_sizes, inner) | |
| else: | |
| return VariableRecurrent(batch_sizes, inner) | |
| return fac | |
| def VariableRecurrent(batch_sizes, inner): | |
| def forward(input, hidden, weight): | |
| output = [] | |
| input_offset = 0 | |
| last_batch_size = batch_sizes[0] | |
| hiddens = [] | |
| flat_hidden = not isinstance(hidden, tuple) | |
| if flat_hidden: | |
| hidden = (hidden,) | |
| for batch_size in batch_sizes: | |
| step_input = input[input_offset:input_offset + batch_size] | |
| input_offset += batch_size | |
| dec = last_batch_size - batch_size | |
| if dec > 0: | |
| hiddens.append(tuple(h[-dec:] for h in hidden)) | |
| hidden = tuple(h[:-dec] for h in hidden) | |
| last_batch_size = batch_size | |
| if flat_hidden: | |
| hidden = (inner(step_input, hidden[0], *weight),) | |
| else: | |
| hidden = inner(step_input, hidden, *weight) | |
| output.append(hidden[0]) | |
| hiddens.append(hidden) | |
| hiddens.reverse() | |
| hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens)) | |
| assert hidden[0].size(0) == batch_sizes[0] | |
| if flat_hidden: | |
| hidden = hidden[0] | |
| output = torch.cat(output, 0) | |
| return hidden, output | |
| return forward | |
| def VariableRecurrentReverse(batch_sizes, inner): | |
| def forward(input, hidden, weight): | |
| output = [] | |
| input_offset = input.size(0) | |
| last_batch_size = batch_sizes[-1] | |
| initial_hidden = hidden | |
| flat_hidden = not isinstance(hidden, tuple) | |
| if flat_hidden: | |
| hidden = (hidden,) | |
| initial_hidden = (initial_hidden,) | |
| hidden = tuple(h[:batch_sizes[-1]] for h in hidden) | |
| for batch_size in reversed(batch_sizes): | |
| inc = batch_size - last_batch_size | |
| if inc > 0: | |
| hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0) | |
| for h, ih in zip(hidden, initial_hidden)) | |
| last_batch_size = batch_size | |
| step_input = input[input_offset - batch_size:input_offset] | |
| input_offset -= batch_size | |
| if flat_hidden: | |
| hidden = (inner(step_input, hidden[0], *weight),) | |
| else: | |
| hidden = inner(step_input, hidden, *weight) | |
| output.append(hidden[0]) | |
| output.reverse() | |
| output = torch.cat(output, 0) | |
| if flat_hidden: | |
| hidden = hidden[0] | |
| return hidden, output | |
| return forward | |
| def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True): | |
| num_directions = len(inners) | |
| total_layers = num_layers * num_directions | |
| def forward(input, hidden, weight): | |
| assert(len(weight) == total_layers) | |
| next_hidden = [] | |
| if lstm: | |
| hidden = list(zip(*hidden)) | |
| for i in range(num_layers): | |
| all_output = [] | |
| for j, inner in enumerate(inners): | |
| l = i * num_directions + j | |
| hy, output = inner(input, hidden[l], weight[l]) | |
| next_hidden.append(hy) | |
| all_output.append(output) | |
| input = torch.cat(all_output, input.dim() - 1) | |
| if dropout != 0 and i < num_layers - 1: | |
| input = F.dropout(input, p=dropout, training=train, inplace=False) | |
| if lstm: | |
| next_h, next_c = zip(*next_hidden) | |
| next_hidden = ( | |
| torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), | |
| torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) | |
| ) | |
| else: | |
| next_hidden = torch.cat(next_hidden, 0).view( | |
| total_layers, *next_hidden[0].size()) | |
| return next_hidden, input | |
| return forward | |
| def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): | |
| """ | |
| A modified LSTM cell with hard sigmoid activation on the input, forget and output gates. | |
| """ | |
| hx, cx = hidden | |
| gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) | |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | |
| ingate = hard_sigmoid(ingate) | |
| forgetgate = hard_sigmoid(forgetgate) | |
| cellgate = F.tanh(cellgate) | |
| outgate = hard_sigmoid(outgate) | |
| cy = (forgetgate * cx) + (ingate * cellgate) | |
| hy = outgate * F.tanh(cy) | |
| return hy, cy | |
| def hard_sigmoid(x): | |
| """ | |
| Computes element-wise hard sigmoid of x. | |
| See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279 | |
| """ | |
| x = (0.2 * x) + 0.5 | |
| x = F.threshold(-x, -1, -1) | |
| x = F.threshold(-x, 0, 0) | |
| return x | |