Spaces:
Runtime error
Runtime error
| # Copyright 2017 Google Inc. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Virtual adversarial text models.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import csv | |
| import os | |
| # Dependency imports | |
| import tensorflow as tf | |
| import adversarial_losses as adv_lib | |
| import inputs as inputs_lib | |
| import layers as layers_lib | |
| flags = tf.app.flags | |
| FLAGS = flags.FLAGS | |
| # Flags governing adversarial training are defined in adversarial_losses.py. | |
| # Classifier | |
| flags.DEFINE_integer('num_classes', 2, 'Number of classes for classification') | |
| # Data path | |
| flags.DEFINE_string('data_dir', '/tmp/IMDB', | |
| 'Directory path to preprocessed text dataset.') | |
| flags.DEFINE_string('vocab_freq_path', None, | |
| 'Path to pre-calculated vocab frequency data. If ' | |
| 'None, use FLAGS.data_dir/vocab_freq.txt.') | |
| flags.DEFINE_integer('batch_size', 64, 'Size of the batch.') | |
| flags.DEFINE_integer('num_timesteps', 100, 'Number of timesteps for BPTT') | |
| # Model architechture | |
| flags.DEFINE_bool('bidir_lstm', False, 'Whether to build a bidirectional LSTM.') | |
| flags.DEFINE_bool('single_label', True, 'Whether the sequence has a single ' | |
| 'label, for optimization.') | |
| flags.DEFINE_integer('rnn_num_layers', 1, 'Number of LSTM layers.') | |
| flags.DEFINE_integer('rnn_cell_size', 512, | |
| 'Number of hidden units in the LSTM.') | |
| flags.DEFINE_integer('cl_num_layers', 1, | |
| 'Number of hidden layers of classification model.') | |
| flags.DEFINE_integer('cl_hidden_size', 30, | |
| 'Number of hidden units in classification layer.') | |
| flags.DEFINE_integer('num_candidate_samples', -1, | |
| 'Num samples used in the sampled output layer.') | |
| flags.DEFINE_bool('use_seq2seq_autoencoder', False, | |
| 'If True, seq2seq auto-encoder is used to pretrain. ' | |
| 'If False, standard language model is used.') | |
| # Vocabulary and embeddings | |
| flags.DEFINE_integer('embedding_dims', 256, 'Dimensions of embedded vector.') | |
| flags.DEFINE_integer('vocab_size', 86934, | |
| 'The size of the vocaburary. This value ' | |
| 'should be exactly same as the number of the ' | |
| 'vocabulary used in dataset. Because the last ' | |
| 'indexed vocabulary of the dataset preprocessed by ' | |
| 'my preprocessed code, is always <eos> and here we ' | |
| 'specify the <eos> with the the index.') | |
| flags.DEFINE_bool('normalize_embeddings', True, | |
| 'Normalize word embeddings by vocab frequency') | |
| # Optimization | |
| flags.DEFINE_float('learning_rate', 0.001, 'Learning rate while fine-tuning.') | |
| flags.DEFINE_float('learning_rate_decay_factor', 1.0, | |
| 'Learning rate decay factor') | |
| flags.DEFINE_boolean('sync_replicas', False, 'sync_replica or not') | |
| flags.DEFINE_integer('replicas_to_aggregate', 1, | |
| 'The number of replicas to aggregate') | |
| # Regularization | |
| flags.DEFINE_float('max_grad_norm', 1.0, | |
| 'Clip the global gradient norm to this value.') | |
| flags.DEFINE_float('keep_prob_emb', 1.0, 'keep probability on embedding layer. ' | |
| '0.5 is optimal on IMDB with virtual adversarial training.') | |
| flags.DEFINE_float('keep_prob_lstm_out', 1.0, | |
| 'keep probability on lstm output.') | |
| flags.DEFINE_float('keep_prob_cl_hidden', 1.0, | |
| 'keep probability on classification hidden layer') | |
| def get_model(): | |
| if FLAGS.bidir_lstm: | |
| return VatxtBidirModel() | |
| else: | |
| return VatxtModel() | |
| class VatxtModel(object): | |
| """Constructs training and evaluation graphs. | |
| Main methods: `classifier_training()`, `language_model_training()`, | |
| and `eval_graph()`. | |
| Variable reuse is a critical part of the model, both for sharing variables | |
| between the language model and the classifier, and for reusing variables for | |
| the adversarial loss calculation. To ensure correct variable reuse, all | |
| variables are created in Keras-style layers, wherein stateful layers (i.e. | |
| layers with variables) are represented as callable instances of the Layer | |
| class. Each time the Layer instance is called, it is using the same variables. | |
| All Layers are constructed in the __init__ method and reused in the various | |
| graph-building functions. | |
| """ | |
| def __init__(self, cl_logits_input_dim=None): | |
| self.global_step = tf.train.get_or_create_global_step() | |
| self.vocab_freqs = _get_vocab_freqs() | |
| # Cache VatxtInput objects | |
| self.cl_inputs = None | |
| self.lm_inputs = None | |
| # Cache intermediate Tensors that are reused | |
| self.tensors = {} | |
| # Construct layers which are reused in constructing the LM and | |
| # Classification graphs. Instantiating them all once here ensures that | |
| # variable reuse works correctly. | |
| self.layers = {} | |
| self.layers['embedding'] = layers_lib.Embedding( | |
| FLAGS.vocab_size, FLAGS.embedding_dims, FLAGS.normalize_embeddings, | |
| self.vocab_freqs, FLAGS.keep_prob_emb) | |
| self.layers['lstm'] = layers_lib.LSTM( | |
| FLAGS.rnn_cell_size, FLAGS.rnn_num_layers, FLAGS.keep_prob_lstm_out) | |
| self.layers['lm_loss'] = layers_lib.SoftmaxLoss( | |
| FLAGS.vocab_size, | |
| FLAGS.num_candidate_samples, | |
| self.vocab_freqs, | |
| name='LM_loss') | |
| cl_logits_input_dim = cl_logits_input_dim or FLAGS.rnn_cell_size | |
| self.layers['cl_logits'] = layers_lib.cl_logits_subgraph( | |
| [FLAGS.cl_hidden_size] * FLAGS.cl_num_layers, cl_logits_input_dim, | |
| FLAGS.num_classes, FLAGS.keep_prob_cl_hidden) | |
| def pretrained_variables(self): | |
| return (self.layers['embedding'].trainable_weights + | |
| self.layers['lstm'].trainable_weights) | |
| def classifier_training(self): | |
| loss = self.classifier_graph() | |
| train_op = optimize(loss, self.global_step) | |
| return train_op, loss, self.global_step | |
| def language_model_training(self): | |
| loss = self.language_model_graph() | |
| train_op = optimize(loss, self.global_step) | |
| return train_op, loss, self.global_step | |
| def classifier_graph(self): | |
| """Constructs classifier graph from inputs to classifier loss. | |
| * Caches the VatxtInput object in `self.cl_inputs` | |
| * Caches tensors: `cl_embedded`, `cl_logits`, `cl_loss` | |
| Returns: | |
| loss: scalar float. | |
| """ | |
| inputs = _inputs('train', pretrain=False) | |
| self.cl_inputs = inputs | |
| embedded = self.layers['embedding'](inputs.tokens) | |
| self.tensors['cl_embedded'] = embedded | |
| _, next_state, logits, loss = self.cl_loss_from_embedding( | |
| embedded, return_intermediates=True) | |
| tf.summary.scalar('classification_loss', loss) | |
| self.tensors['cl_logits'] = logits | |
| self.tensors['cl_loss'] = loss | |
| if FLAGS.single_label: | |
| indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1) | |
| labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1) | |
| weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1) | |
| else: | |
| labels = inputs.labels | |
| weights = inputs.weights | |
| acc = layers_lib.accuracy(logits, labels, weights) | |
| tf.summary.scalar('accuracy', acc) | |
| adv_loss = (self.adversarial_loss() * tf.constant( | |
| FLAGS.adv_reg_coeff, name='adv_reg_coeff')) | |
| tf.summary.scalar('adversarial_loss', adv_loss) | |
| total_loss = loss + adv_loss | |
| with tf.control_dependencies([inputs.save_state(next_state)]): | |
| total_loss = tf.identity(total_loss) | |
| tf.summary.scalar('total_classification_loss', total_loss) | |
| return total_loss | |
| def language_model_graph(self, compute_loss=True): | |
| """Constructs LM graph from inputs to LM loss. | |
| * Caches the VatxtInput object in `self.lm_inputs` | |
| * Caches tensors: `lm_embedded` | |
| Args: | |
| compute_loss: bool, whether to compute and return the loss or stop after | |
| the LSTM computation. | |
| Returns: | |
| loss: scalar float. | |
| """ | |
| inputs = _inputs('train', pretrain=True) | |
| self.lm_inputs = inputs | |
| return self._lm_loss(inputs, compute_loss=compute_loss) | |
| def _lm_loss(self, | |
| inputs, | |
| emb_key='lm_embedded', | |
| lstm_layer='lstm', | |
| lm_loss_layer='lm_loss', | |
| loss_name='lm_loss', | |
| compute_loss=True): | |
| embedded = self.layers['embedding'](inputs.tokens) | |
| self.tensors[emb_key] = embedded | |
| lstm_out, next_state = self.layers[lstm_layer](embedded, inputs.state, | |
| inputs.length) | |
| if compute_loss: | |
| loss = self.layers[lm_loss_layer]( | |
| [lstm_out, inputs.labels, inputs.weights]) | |
| with tf.control_dependencies([inputs.save_state(next_state)]): | |
| loss = tf.identity(loss) | |
| tf.summary.scalar(loss_name, loss) | |
| return loss | |
| def eval_graph(self, dataset='test'): | |
| """Constructs classifier evaluation graph. | |
| Args: | |
| dataset: the labeled dataset to evaluate, {'train', 'test', 'valid'}. | |
| Returns: | |
| eval_ops: dict<metric name, tuple(value, update_op)> | |
| var_restore_dict: dict mapping variable restoration names to variables. | |
| Trainable variables will be mapped to their moving average names. | |
| """ | |
| inputs = _inputs(dataset, pretrain=False) | |
| embedded = self.layers['embedding'](inputs.tokens) | |
| _, next_state, logits, _ = self.cl_loss_from_embedding( | |
| embedded, inputs=inputs, return_intermediates=True) | |
| if FLAGS.single_label: | |
| indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1) | |
| labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1) | |
| weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1) | |
| else: | |
| labels = inputs.labels | |
| weights = inputs.weights | |
| eval_ops = { | |
| 'accuracy': | |
| tf.contrib.metrics.streaming_accuracy( | |
| layers_lib.predictions(logits), labels, weights) | |
| } | |
| with tf.control_dependencies([inputs.save_state(next_state)]): | |
| acc, acc_update = eval_ops['accuracy'] | |
| acc_update = tf.identity(acc_update) | |
| eval_ops['accuracy'] = (acc, acc_update) | |
| var_restore_dict = make_restore_average_vars_dict() | |
| return eval_ops, var_restore_dict | |
| def cl_loss_from_embedding(self, | |
| embedded, | |
| inputs=None, | |
| return_intermediates=False): | |
| """Compute classification loss from embedding. | |
| Args: | |
| embedded: 3-D float Tensor [batch_size, num_timesteps, embedding_dim] | |
| inputs: VatxtInput, defaults to self.cl_inputs. | |
| return_intermediates: bool, whether to return intermediate tensors or only | |
| the final loss. | |
| Returns: | |
| If return_intermediates is True: | |
| lstm_out, next_state, logits, loss | |
| Else: | |
| loss | |
| """ | |
| if inputs is None: | |
| inputs = self.cl_inputs | |
| lstm_out, next_state = self.layers['lstm'](embedded, inputs.state, | |
| inputs.length) | |
| if FLAGS.single_label: | |
| indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1) | |
| lstm_out = tf.expand_dims(tf.gather_nd(lstm_out, indices), 1) | |
| labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1) | |
| weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1) | |
| else: | |
| labels = inputs.labels | |
| weights = inputs.weights | |
| logits = self.layers['cl_logits'](lstm_out) | |
| loss = layers_lib.classification_loss(logits, labels, weights) | |
| if return_intermediates: | |
| return lstm_out, next_state, logits, loss | |
| else: | |
| return loss | |
| def adversarial_loss(self): | |
| """Compute adversarial loss based on FLAGS.adv_training_method.""" | |
| def random_perturbation_loss(): | |
| return adv_lib.random_perturbation_loss(self.tensors['cl_embedded'], | |
| self.cl_inputs.length, | |
| self.cl_loss_from_embedding) | |
| def adversarial_loss(): | |
| return adv_lib.adversarial_loss(self.tensors['cl_embedded'], | |
| self.tensors['cl_loss'], | |
| self.cl_loss_from_embedding) | |
| def virtual_adversarial_loss(): | |
| """Computes virtual adversarial loss. | |
| Uses lm_inputs and constructs the language model graph if it hasn't yet | |
| been constructed. | |
| Also ensures that the LM input states are saved for LSTM state-saving | |
| BPTT. | |
| Returns: | |
| loss: float scalar. | |
| """ | |
| if self.lm_inputs is None: | |
| self.language_model_graph(compute_loss=False) | |
| def logits_from_embedding(embedded, return_next_state=False): | |
| _, next_state, logits, _ = self.cl_loss_from_embedding( | |
| embedded, inputs=self.lm_inputs, return_intermediates=True) | |
| if return_next_state: | |
| return next_state, logits | |
| else: | |
| return logits | |
| next_state, lm_cl_logits = logits_from_embedding( | |
| self.tensors['lm_embedded'], return_next_state=True) | |
| va_loss = adv_lib.virtual_adversarial_loss( | |
| lm_cl_logits, self.tensors['lm_embedded'], self.lm_inputs, | |
| logits_from_embedding) | |
| with tf.control_dependencies([self.lm_inputs.save_state(next_state)]): | |
| va_loss = tf.identity(va_loss) | |
| return va_loss | |
| def combo_loss(): | |
| return adversarial_loss() + virtual_adversarial_loss() | |
| adv_training_methods = { | |
| # Random perturbation | |
| 'rp': random_perturbation_loss, | |
| # Adversarial training | |
| 'at': adversarial_loss, | |
| # Virtual adversarial training | |
| 'vat': virtual_adversarial_loss, | |
| # Both at and vat | |
| 'atvat': combo_loss, | |
| '': lambda: tf.constant(0.), | |
| None: lambda: tf.constant(0.), | |
| } | |
| with tf.name_scope('adversarial_loss'): | |
| return adv_training_methods[FLAGS.adv_training_method]() | |
| class VatxtBidirModel(VatxtModel): | |
| """Extension of VatxtModel that supports bidirectional input.""" | |
| def __init__(self): | |
| super(VatxtBidirModel, | |
| self).__init__(cl_logits_input_dim=FLAGS.rnn_cell_size * 2) | |
| # Reverse LSTM and LM loss for bidirectional models | |
| self.layers['lstm_reverse'] = layers_lib.LSTM( | |
| FLAGS.rnn_cell_size, | |
| FLAGS.rnn_num_layers, | |
| FLAGS.keep_prob_lstm_out, | |
| name='LSTM_Reverse') | |
| self.layers['lm_loss_reverse'] = layers_lib.SoftmaxLoss( | |
| FLAGS.vocab_size, | |
| FLAGS.num_candidate_samples, | |
| self.vocab_freqs, | |
| name='LM_loss_reverse') | |
| def pretrained_variables(self): | |
| variables = super(VatxtBidirModel, self).pretrained_variables | |
| variables.extend(self.layers['lstm_reverse'].trainable_weights) | |
| return variables | |
| def classifier_graph(self): | |
| """Constructs classifier graph from inputs to classifier loss. | |
| * Caches the VatxtInput objects in `self.cl_inputs` | |
| * Caches tensors: `cl_embedded` (tuple of forward and reverse), `cl_logits`, | |
| `cl_loss` | |
| Returns: | |
| loss: scalar float. | |
| """ | |
| inputs = _inputs('train', pretrain=False, bidir=True) | |
| self.cl_inputs = inputs | |
| f_inputs, _ = inputs | |
| # Embed both forward and reverse with a shared embedding | |
| embedded = [self.layers['embedding'](inp.tokens) for inp in inputs] | |
| self.tensors['cl_embedded'] = embedded | |
| _, next_states, logits, loss = self.cl_loss_from_embedding( | |
| embedded, return_intermediates=True) | |
| tf.summary.scalar('classification_loss', loss) | |
| self.tensors['cl_logits'] = logits | |
| self.tensors['cl_loss'] = loss | |
| acc = layers_lib.accuracy(logits, f_inputs.labels, f_inputs.weights) | |
| tf.summary.scalar('accuracy', acc) | |
| adv_loss = (self.adversarial_loss() * tf.constant( | |
| FLAGS.adv_reg_coeff, name='adv_reg_coeff')) | |
| tf.summary.scalar('adversarial_loss', adv_loss) | |
| total_loss = loss + adv_loss | |
| saves = [inp.save_state(state) for (inp, state) in zip(inputs, next_states)] | |
| with tf.control_dependencies(saves): | |
| total_loss = tf.identity(total_loss) | |
| tf.summary.scalar('total_classification_loss', total_loss) | |
| return total_loss | |
| def language_model_graph(self, compute_loss=True): | |
| """Constructs forward and reverse LM graphs from inputs to LM losses. | |
| * Caches the VatxtInput objects in `self.lm_inputs` | |
| * Caches tensors: `lm_embedded`, `lm_embedded_reverse` | |
| Args: | |
| compute_loss: bool, whether to compute and return the loss or stop after | |
| the LSTM computation. | |
| Returns: | |
| loss: scalar float, sum of forward and reverse losses. | |
| """ | |
| inputs = _inputs('train', pretrain=True, bidir=True) | |
| self.lm_inputs = inputs | |
| f_inputs, r_inputs = inputs | |
| f_loss = self._lm_loss(f_inputs, compute_loss=compute_loss) | |
| r_loss = self._lm_loss( | |
| r_inputs, | |
| emb_key='lm_embedded_reverse', | |
| lstm_layer='lstm_reverse', | |
| lm_loss_layer='lm_loss_reverse', | |
| loss_name='lm_loss_reverse', | |
| compute_loss=compute_loss) | |
| if compute_loss: | |
| return f_loss + r_loss | |
| def eval_graph(self, dataset='test'): | |
| """Constructs classifier evaluation graph. | |
| Args: | |
| dataset: the labeled dataset to evaluate, {'train', 'test', 'valid'}. | |
| Returns: | |
| eval_ops: dict<metric name, tuple(value, update_op)> | |
| var_restore_dict: dict mapping variable restoration names to variables. | |
| Trainable variables will be mapped to their moving average names. | |
| """ | |
| inputs = _inputs(dataset, pretrain=False, bidir=True) | |
| embedded = [self.layers['embedding'](inp.tokens) for inp in inputs] | |
| _, next_states, logits, _ = self.cl_loss_from_embedding( | |
| embedded, inputs=inputs, return_intermediates=True) | |
| f_inputs, _ = inputs | |
| eval_ops = { | |
| 'accuracy': | |
| tf.contrib.metrics.streaming_accuracy( | |
| layers_lib.predictions(logits), f_inputs.labels, | |
| f_inputs.weights) | |
| } | |
| # Save states on accuracy update | |
| saves = [inp.save_state(state) for (inp, state) in zip(inputs, next_states)] | |
| with tf.control_dependencies(saves): | |
| acc, acc_update = eval_ops['accuracy'] | |
| acc_update = tf.identity(acc_update) | |
| eval_ops['accuracy'] = (acc, acc_update) | |
| var_restore_dict = make_restore_average_vars_dict() | |
| return eval_ops, var_restore_dict | |
| def cl_loss_from_embedding(self, | |
| embedded, | |
| inputs=None, | |
| return_intermediates=False): | |
| """Compute classification loss from embedding. | |
| Args: | |
| embedded: Length 2 tuple of 3-D float Tensor | |
| [batch_size, num_timesteps, embedding_dim]. | |
| inputs: Length 2 tuple of VatxtInput, defaults to self.cl_inputs. | |
| return_intermediates: bool, whether to return intermediate tensors or only | |
| the final loss. | |
| Returns: | |
| If return_intermediates is True: | |
| lstm_out, next_states, logits, loss | |
| Else: | |
| loss | |
| """ | |
| if inputs is None: | |
| inputs = self.cl_inputs | |
| out = [] | |
| for (layer_name, emb, inp) in zip(['lstm', 'lstm_reverse'], embedded, | |
| inputs): | |
| out.append(self.layers[layer_name](emb, inp.state, inp.length)) | |
| lstm_outs, next_states = zip(*out) | |
| # Concatenate output of forward and reverse LSTMs | |
| lstm_out = tf.concat(lstm_outs, 1) | |
| logits = self.layers['cl_logits'](lstm_out) | |
| f_inputs, _ = inputs # pylint: disable=unpacking-non-sequence | |
| loss = layers_lib.classification_loss(logits, f_inputs.labels, | |
| f_inputs.weights) | |
| if return_intermediates: | |
| return lstm_out, next_states, logits, loss | |
| else: | |
| return loss | |
| def adversarial_loss(self): | |
| """Compute adversarial loss based on FLAGS.adv_training_method.""" | |
| def random_perturbation_loss(): | |
| return adv_lib.random_perturbation_loss_bidir(self.tensors['cl_embedded'], | |
| self.cl_inputs[0].length, | |
| self.cl_loss_from_embedding) | |
| def adversarial_loss(): | |
| return adv_lib.adversarial_loss_bidir(self.tensors['cl_embedded'], | |
| self.tensors['cl_loss'], | |
| self.cl_loss_from_embedding) | |
| def virtual_adversarial_loss(): | |
| """Computes virtual adversarial loss. | |
| Uses lm_inputs and constructs the language model graph if it hasn't yet | |
| been constructed. | |
| Also ensures that the LM input states are saved for LSTM state-saving | |
| BPTT. | |
| Returns: | |
| loss: float scalar. | |
| """ | |
| if self.lm_inputs is None: | |
| self.language_model_graph(compute_loss=False) | |
| def logits_from_embedding(embedded, return_next_state=False): | |
| _, next_states, logits, _ = self.cl_loss_from_embedding( | |
| embedded, inputs=self.lm_inputs, return_intermediates=True) | |
| if return_next_state: | |
| return next_states, logits | |
| else: | |
| return logits | |
| lm_embedded = (self.tensors['lm_embedded'], | |
| self.tensors['lm_embedded_reverse']) | |
| next_states, lm_cl_logits = logits_from_embedding( | |
| lm_embedded, return_next_state=True) | |
| va_loss = adv_lib.virtual_adversarial_loss_bidir( | |
| lm_cl_logits, lm_embedded, self.lm_inputs, logits_from_embedding) | |
| saves = [ | |
| inp.save_state(state) | |
| for (inp, state) in zip(self.lm_inputs, next_states) | |
| ] | |
| with tf.control_dependencies(saves): | |
| va_loss = tf.identity(va_loss) | |
| return va_loss | |
| def combo_loss(): | |
| return adversarial_loss() + virtual_adversarial_loss() | |
| adv_training_methods = { | |
| # Random perturbation | |
| 'rp': random_perturbation_loss, | |
| # Adversarial training | |
| 'at': adversarial_loss, | |
| # Virtual adversarial training | |
| 'vat': virtual_adversarial_loss, | |
| # Both at and vat | |
| 'atvat': combo_loss, | |
| '': lambda: tf.constant(0.), | |
| None: lambda: tf.constant(0.), | |
| } | |
| with tf.name_scope('adversarial_loss'): | |
| return adv_training_methods[FLAGS.adv_training_method]() | |
| def _inputs(dataset='train', pretrain=False, bidir=False): | |
| return inputs_lib.inputs( | |
| data_dir=FLAGS.data_dir, | |
| phase=dataset, | |
| bidir=bidir, | |
| pretrain=pretrain, | |
| use_seq2seq=pretrain and FLAGS.use_seq2seq_autoencoder, | |
| state_size=FLAGS.rnn_cell_size, | |
| num_layers=FLAGS.rnn_num_layers, | |
| batch_size=FLAGS.batch_size, | |
| unroll_steps=FLAGS.num_timesteps, | |
| eos_id=FLAGS.vocab_size - 1) | |
| def _get_vocab_freqs(): | |
| """Returns vocab frequencies. | |
| Returns: | |
| List of integers, length=FLAGS.vocab_size. | |
| Raises: | |
| ValueError: if the length of the frequency file is not equal to the vocab | |
| size, or if the file is not found. | |
| """ | |
| path = FLAGS.vocab_freq_path or os.path.join(FLAGS.data_dir, 'vocab_freq.txt') | |
| if tf.gfile.Exists(path): | |
| with tf.gfile.Open(path) as f: | |
| # Get pre-calculated frequencies of words. | |
| reader = csv.reader(f, quoting=csv.QUOTE_NONE) | |
| freqs = [int(row[-1]) for row in reader] | |
| if len(freqs) != FLAGS.vocab_size: | |
| raise ValueError('Frequency file length %d != vocab size %d' % | |
| (len(freqs), FLAGS.vocab_size)) | |
| else: | |
| if FLAGS.vocab_freq_path: | |
| raise ValueError('vocab_freq_path not found') | |
| freqs = [1] * FLAGS.vocab_size | |
| return freqs | |
| def make_restore_average_vars_dict(): | |
| """Returns dict mapping moving average names to variables.""" | |
| var_restore_dict = {} | |
| variable_averages = tf.train.ExponentialMovingAverage(0.999) | |
| for v in tf.global_variables(): | |
| if v in tf.trainable_variables(): | |
| name = variable_averages.average_name(v) | |
| else: | |
| name = v.op.name | |
| var_restore_dict[name] = v | |
| return var_restore_dict | |
| def optimize(loss, global_step): | |
| return layers_lib.optimize( | |
| loss, global_step, FLAGS.max_grad_norm, FLAGS.learning_rate, | |
| FLAGS.learning_rate_decay_factor, FLAGS.sync_replicas, | |
| FLAGS.replicas_to_aggregate, FLAGS.task) | |