Spaces:
Runtime error
Runtime error
| # Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Launch example: | |
| [IMDB] | |
| python train_mask_gan.py --data_dir | |
| /tmp/imdb --data_set imdb --batch_size 128 | |
| --sequence_length 20 --base_directory /tmp/maskGAN_v0.01 | |
| --hparams="gen_rnn_size=650,gen_num_layers=2,dis_rnn_size=650,dis_num_layers=2 | |
| ,critic_learning_rate=0.0009756,dis_learning_rate=0.0000585, | |
| dis_train_iterations=8,gen_learning_rate=0.0016624, | |
| gen_full_learning_rate_steps=1e9,gen_learning_rate_decay=0.999999, | |
| rl_discount_rate=0.8835659" --mode TRAIN --max_steps 1000000 | |
| --generator_model seq2seq_vd --discriminator_model seq2seq_vd | |
| --is_present_rate 0.5 --summaries_every 25 --print_every 25 | |
| --max_num_to_print=3 --generator_optimizer=adam | |
| --seq2seq_share_embedding=True --baseline_method=critic | |
| --attention_option=luong --n_gram_eval=4 --mask_strategy=contiguous | |
| --gen_training_strategy=reinforce --dis_pretrain_steps=100 | |
| --perplexity_threshold=1000000 | |
| --dis_share_embedding=True --maskgan_ckpt | |
| /tmp/model.ckpt-171091 | |
| """ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import collections | |
| from functools import partial | |
| import os | |
| import time | |
| # Dependency imports | |
| import numpy as np | |
| from six.moves import xrange | |
| import tensorflow as tf | |
| import pretrain_mask_gan | |
| from data import imdb_loader | |
| from data import ptb_loader | |
| from model_utils import helper | |
| from model_utils import model_construction | |
| from model_utils import model_losses | |
| from model_utils import model_optimization | |
| # Data. | |
| from model_utils import model_utils | |
| from model_utils import n_gram | |
| from models import evaluation_utils | |
| from models import rollout | |
| np.set_printoptions(precision=3) | |
| np.set_printoptions(suppress=True) | |
| MODE_TRAIN = 'TRAIN' | |
| MODE_TRAIN_EVAL = 'TRAIN_EVAL' | |
| MODE_VALIDATION = 'VALIDATION' | |
| MODE_TEST = 'TEST' | |
| ## Binary and setup FLAGS. | |
| tf.app.flags.DEFINE_enum( | |
| 'mode', 'TRAIN', [MODE_TRAIN, MODE_VALIDATION, MODE_TEST, MODE_TRAIN_EVAL], | |
| 'What this binary will do.') | |
| tf.app.flags.DEFINE_string('master', '', | |
| """Name of the TensorFlow master to use.""") | |
| tf.app.flags.DEFINE_string('eval_master', '', | |
| """Name prefix of the Tensorflow eval master.""") | |
| tf.app.flags.DEFINE_integer('task', 0, | |
| """Task id of the replica running the training.""") | |
| tf.app.flags.DEFINE_integer('ps_tasks', 0, """Number of tasks in the ps job. | |
| If 0 no ps job is used.""") | |
| ## General FLAGS. | |
| tf.app.flags.DEFINE_string( | |
| 'hparams', '', 'Comma separated list of name=value hyperparameter pairs.') | |
| tf.app.flags.DEFINE_integer('batch_size', 20, 'The batch size.') | |
| tf.app.flags.DEFINE_integer('vocab_size', 10000, 'The vocabulary size.') | |
| tf.app.flags.DEFINE_integer('sequence_length', 20, 'The sequence length.') | |
| tf.app.flags.DEFINE_integer('max_steps', 1000000, | |
| 'Maximum number of steps to run.') | |
| tf.app.flags.DEFINE_string( | |
| 'mask_strategy', 'random', 'Strategy for masking the words. Determine the ' | |
| 'characterisitics of how the words are dropped out. One of ' | |
| "['contiguous', 'random'].") | |
| tf.app.flags.DEFINE_float('is_present_rate', 0.5, | |
| 'Percent of tokens present in the forward sequence.') | |
| tf.app.flags.DEFINE_float('is_present_rate_decay', None, 'Decay rate for the ' | |
| 'percent of words that are real (are present).') | |
| tf.app.flags.DEFINE_string( | |
| 'generator_model', 'seq2seq', | |
| "Type of Generator model. One of ['rnn', 'seq2seq', 'seq2seq_zaremba'," | |
| "'rnn_zaremba', 'rnn_nas', 'seq2seq_nas']") | |
| tf.app.flags.DEFINE_string( | |
| 'attention_option', None, | |
| "Attention mechanism. One of [None, 'luong', 'bahdanau']") | |
| tf.app.flags.DEFINE_string( | |
| 'discriminator_model', 'bidirectional', | |
| "Type of Discriminator model. One of ['cnn', 'rnn', 'bidirectional', " | |
| "'rnn_zaremba', 'bidirectional_zaremba', 'rnn_nas', 'rnn_vd', 'seq2seq_vd']" | |
| ) | |
| tf.app.flags.DEFINE_boolean('seq2seq_share_embedding', False, | |
| 'Whether to share the ' | |
| 'embeddings between the encoder and decoder.') | |
| tf.app.flags.DEFINE_boolean( | |
| 'dis_share_embedding', False, 'Whether to share the ' | |
| 'embeddings between the generator and discriminator.') | |
| tf.app.flags.DEFINE_boolean('dis_update_share_embedding', False, 'Whether the ' | |
| 'discriminator should update the shared embedding.') | |
| tf.app.flags.DEFINE_boolean('use_gen_mode', False, | |
| 'Use the mode of the generator ' | |
| 'to produce samples.') | |
| tf.app.flags.DEFINE_boolean('critic_update_dis_vars', False, | |
| 'Whether the critic ' | |
| 'updates the discriminator variables.') | |
| ## Training FLAGS. | |
| tf.app.flags.DEFINE_string( | |
| 'gen_training_strategy', 'reinforce', | |
| "Method for training the Generator. One of ['cross_entropy', 'reinforce']") | |
| tf.app.flags.DEFINE_string( | |
| 'generator_optimizer', 'adam', | |
| "Type of Generator optimizer. One of ['sgd', 'adam']") | |
| tf.app.flags.DEFINE_float('grad_clipping', 10., 'Norm for gradient clipping.') | |
| tf.app.flags.DEFINE_float('advantage_clipping', 5., 'Clipping for advantages.') | |
| tf.app.flags.DEFINE_string( | |
| 'baseline_method', None, | |
| "Approach for baseline. One of ['critic', 'dis_batch', 'ema', None]") | |
| tf.app.flags.DEFINE_float('perplexity_threshold', 15000, | |
| 'Limit for perplexity before terminating job.') | |
| tf.app.flags.DEFINE_float('zoneout_drop_prob', 0.1, | |
| 'Probability for dropping parameter for zoneout.') | |
| tf.app.flags.DEFINE_float('keep_prob', 0.5, | |
| 'Probability for keeping parameter for dropout.') | |
| ## Logging and evaluation FLAGS. | |
| tf.app.flags.DEFINE_integer('print_every', 250, | |
| 'Frequency to print and log the ' | |
| 'outputs of the model.') | |
| tf.app.flags.DEFINE_integer('max_num_to_print', 5, | |
| 'Number of samples to log/print.') | |
| tf.app.flags.DEFINE_boolean('print_verbose', False, 'Whether to print in full.') | |
| tf.app.flags.DEFINE_integer('summaries_every', 100, | |
| 'Frequency to compute summaries.') | |
| tf.app.flags.DEFINE_boolean('eval_language_model', False, | |
| 'Whether to evaluate on ' | |
| 'all words as in language modeling.') | |
| tf.app.flags.DEFINE_float('eval_interval_secs', 60, | |
| 'Delay for evaluating model.') | |
| tf.app.flags.DEFINE_integer( | |
| 'n_gram_eval', 4, """The degree of the n-grams to use for evaluation.""") | |
| tf.app.flags.DEFINE_integer( | |
| 'epoch_size_override', None, | |
| 'If an integer, this dictates the size of the epochs and will potentially ' | |
| 'not iterate over all the data.') | |
| tf.app.flags.DEFINE_integer('eval_epoch_size_override', None, | |
| 'Number of evaluation steps.') | |
| ## Directories and checkpoints. | |
| tf.app.flags.DEFINE_string('base_directory', '/tmp/maskGAN_v0.00', | |
| 'Base directory for the logging, events and graph.') | |
| tf.app.flags.DEFINE_string('data_set', 'ptb', 'Data set to operate on. One of' | |
| "['ptb', 'imdb']") | |
| tf.app.flags.DEFINE_string('data_dir', '/tmp/data/ptb', | |
| 'Directory for the training data.') | |
| tf.app.flags.DEFINE_string( | |
| 'language_model_ckpt_dir', None, | |
| 'Directory storing checkpoints to initialize the model. Pretrained models' | |
| 'are stored at /tmp/maskGAN/pretrained/') | |
| tf.app.flags.DEFINE_string( | |
| 'language_model_ckpt_dir_reversed', None, | |
| 'Directory storing checkpoints of reversed models to initialize the model.' | |
| 'Pretrained models stored at' | |
| 'are stored at /tmp/PTB/pretrained_reversed') | |
| tf.app.flags.DEFINE_string( | |
| 'maskgan_ckpt', None, | |
| 'Override which checkpoint file to use to restore the ' | |
| 'model. A pretrained seq2seq_zaremba model is stored at ' | |
| '/tmp/maskGAN/pretrain/seq2seq_zaremba/train/model.ckpt-64912') | |
| tf.app.flags.DEFINE_boolean('wasserstein_objective', False, | |
| '(DEPRECATED) Whether to use the WGAN training.') | |
| tf.app.flags.DEFINE_integer('num_rollouts', 1, | |
| 'The number of rolled out predictions to make.') | |
| tf.app.flags.DEFINE_float('c_lower', -0.01, 'Lower bound for weights.') | |
| tf.app.flags.DEFINE_float('c_upper', 0.01, 'Upper bound for weights.') | |
| FLAGS = tf.app.flags.FLAGS | |
| def create_hparams(): | |
| """Create the hparams object for generic training hyperparameters.""" | |
| hparams = tf.contrib.training.HParams( | |
| gen_num_layers=2, | |
| dis_num_layers=2, | |
| gen_rnn_size=740, | |
| dis_rnn_size=740, | |
| gen_learning_rate=5e-4, | |
| dis_learning_rate=5e-3, | |
| critic_learning_rate=5e-3, | |
| dis_train_iterations=1, | |
| gen_learning_rate_decay=1.0, | |
| gen_full_learning_rate_steps=1e7, | |
| baseline_decay=0.999999, | |
| rl_discount_rate=0.9, | |
| gen_vd_keep_prob=0.5, | |
| dis_vd_keep_prob=0.5, | |
| dis_pretrain_learning_rate=5e-3, | |
| dis_num_filters=128, | |
| dis_hidden_dim=128, | |
| gen_nas_keep_prob_0=0.85, | |
| gen_nas_keep_prob_1=0.55, | |
| dis_nas_keep_prob_0=0.85, | |
| dis_nas_keep_prob_1=0.55) | |
| # Command line flags override any of the preceding hyperparameter values. | |
| if FLAGS.hparams: | |
| hparams = hparams.parse(FLAGS.hparams) | |
| return hparams | |
| def create_MaskGAN(hparams, is_training): | |
| """Create the MaskGAN model. | |
| Args: | |
| hparams: Hyperparameters for the MaskGAN. | |
| is_training: Boolean indicating operational mode (train/inference). | |
| evaluated with a teacher forcing regime. | |
| Return: | |
| model: Namedtuple for specifying the MaskGAN. | |
| """ | |
| global_step = tf.Variable(0, name='global_step', trainable=False) | |
| new_learning_rate = tf.placeholder(tf.float32, [], name='new_learning_rate') | |
| learning_rate = tf.Variable(0.0, name='learning_rate', trainable=False) | |
| learning_rate_update = tf.assign(learning_rate, new_learning_rate) | |
| new_rate = tf.placeholder(tf.float32, [], name='new_rate') | |
| percent_real_var = tf.Variable(0.0, trainable=False) | |
| percent_real_update = tf.assign(percent_real_var, new_rate) | |
| ## Placeholders. | |
| inputs = tf.placeholder( | |
| tf.int32, shape=[FLAGS.batch_size, FLAGS.sequence_length]) | |
| targets = tf.placeholder( | |
| tf.int32, shape=[FLAGS.batch_size, FLAGS.sequence_length]) | |
| present = tf.placeholder( | |
| tf.bool, shape=[FLAGS.batch_size, FLAGS.sequence_length]) | |
| # TODO(adai): Placeholder for IMDB label. | |
| ## Real Sequence is the targets. | |
| real_sequence = targets | |
| ## Fakse Sequence from the Generator. | |
| # TODO(adai): Generator must have IMDB labels placeholder. | |
| (fake_sequence, fake_logits, fake_log_probs, fake_gen_initial_state, | |
| fake_gen_final_state, _) = model_construction.create_generator( | |
| hparams, | |
| inputs, | |
| targets, | |
| present, | |
| is_training=is_training, | |
| is_validating=False) | |
| (_, eval_logits, _, eval_initial_state, eval_final_state, | |
| _) = model_construction.create_generator( | |
| hparams, | |
| inputs, | |
| targets, | |
| present, | |
| is_training=False, | |
| is_validating=True, | |
| reuse=True) | |
| ## Discriminator. | |
| fake_predictions = model_construction.create_discriminator( | |
| hparams, | |
| fake_sequence, | |
| is_training=is_training, | |
| inputs=inputs, | |
| present=present) | |
| real_predictions = model_construction.create_discriminator( | |
| hparams, | |
| real_sequence, | |
| is_training=is_training, | |
| reuse=True, | |
| inputs=inputs, | |
| present=present) | |
| ## Critic. | |
| # The critic will be used to estimate the forward rewards to the Generator. | |
| if FLAGS.baseline_method == 'critic': | |
| est_state_values = model_construction.create_critic( | |
| hparams, fake_sequence, is_training=is_training) | |
| else: | |
| est_state_values = None | |
| ## Discriminator Loss. | |
| [dis_loss, dis_loss_fake, dis_loss_real] = model_losses.create_dis_loss( | |
| fake_predictions, real_predictions, present) | |
| ## Average log-perplexity for only missing words. However, to do this, | |
| # the logits are still computed using teacher forcing, that is, the ground | |
| # truth tokens are fed in at each time point to be valid. | |
| avg_log_perplexity = model_losses.calculate_log_perplexity( | |
| eval_logits, targets, present) | |
| ## Generator Objective. | |
| # 1. Cross Entropy losses on missing tokens. | |
| fake_cross_entropy_losses = model_losses.create_masked_cross_entropy_loss( | |
| targets, present, fake_logits) | |
| # 2. GAN REINFORCE losses. | |
| [ | |
| fake_RL_loss, fake_log_probs, fake_rewards, fake_advantages, | |
| fake_baselines, fake_averages_op, critic_loss, cumulative_rewards | |
| ] = model_losses.calculate_reinforce_objective( | |
| hparams, fake_log_probs, fake_predictions, present, est_state_values) | |
| ## Pre-training. | |
| if FLAGS.gen_pretrain_steps: | |
| raise NotImplementedError | |
| # # TODO(liamfedus): Rewrite this. | |
| # fwd_cross_entropy_loss = tf.reduce_mean(fwd_cross_entropy_losses) | |
| # gen_pretrain_op = model_optimization.create_gen_pretrain_op( | |
| # hparams, fwd_cross_entropy_loss, global_step) | |
| else: | |
| gen_pretrain_op = None | |
| if FLAGS.dis_pretrain_steps: | |
| dis_pretrain_op = model_optimization.create_dis_pretrain_op( | |
| hparams, dis_loss, global_step) | |
| else: | |
| dis_pretrain_op = None | |
| ## Generator Train Op. | |
| # 1. Cross-Entropy. | |
| if FLAGS.gen_training_strategy == 'cross_entropy': | |
| gen_loss = tf.reduce_mean(fake_cross_entropy_losses) | |
| [gen_train_op, gen_grads, | |
| gen_vars] = model_optimization.create_gen_train_op( | |
| hparams, learning_rate, gen_loss, global_step, mode='MINIMIZE') | |
| # 2. GAN (REINFORCE) | |
| elif FLAGS.gen_training_strategy == 'reinforce': | |
| gen_loss = fake_RL_loss | |
| [gen_train_op, gen_grads, | |
| gen_vars] = model_optimization.create_reinforce_gen_train_op( | |
| hparams, learning_rate, gen_loss, fake_averages_op, global_step) | |
| else: | |
| raise NotImplementedError | |
| ## Discriminator Train Op. | |
| dis_train_op, dis_grads, dis_vars = model_optimization.create_dis_train_op( | |
| hparams, dis_loss, global_step) | |
| ## Critic Train Op. | |
| if critic_loss is not None: | |
| [critic_train_op, _, _] = model_optimization.create_critic_train_op( | |
| hparams, critic_loss, global_step) | |
| dis_train_op = tf.group(dis_train_op, critic_train_op) | |
| ## Summaries. | |
| with tf.name_scope('general'): | |
| tf.summary.scalar('percent_real', percent_real_var) | |
| tf.summary.scalar('learning_rate', learning_rate) | |
| with tf.name_scope('generator_objectives'): | |
| tf.summary.scalar('gen_objective', tf.reduce_mean(gen_loss)) | |
| tf.summary.scalar('gen_loss_cross_entropy', | |
| tf.reduce_mean(fake_cross_entropy_losses)) | |
| with tf.name_scope('REINFORCE'): | |
| with tf.name_scope('objective'): | |
| tf.summary.scalar('fake_RL_loss', tf.reduce_mean(fake_RL_loss)) | |
| with tf.name_scope('rewards'): | |
| helper.variable_summaries(cumulative_rewards, 'rewards') | |
| with tf.name_scope('advantages'): | |
| helper.variable_summaries(fake_advantages, 'advantages') | |
| with tf.name_scope('baselines'): | |
| helper.variable_summaries(fake_baselines, 'baselines') | |
| with tf.name_scope('log_probs'): | |
| helper.variable_summaries(fake_log_probs, 'log_probs') | |
| with tf.name_scope('discriminator_losses'): | |
| tf.summary.scalar('dis_loss', dis_loss) | |
| tf.summary.scalar('dis_loss_fake_sequence', dis_loss_fake) | |
| tf.summary.scalar('dis_loss_prob_fake_sequence', tf.exp(-dis_loss_fake)) | |
| tf.summary.scalar('dis_loss_real_sequence', dis_loss_real) | |
| tf.summary.scalar('dis_loss_prob_real_sequence', tf.exp(-dis_loss_real)) | |
| if critic_loss is not None: | |
| with tf.name_scope('critic_losses'): | |
| tf.summary.scalar('critic_loss', critic_loss) | |
| with tf.name_scope('logits'): | |
| helper.variable_summaries(fake_logits, 'fake_logits') | |
| for v, g in zip(gen_vars, gen_grads): | |
| helper.variable_summaries(v, v.op.name) | |
| helper.variable_summaries(g, 'grad/' + v.op.name) | |
| for v, g in zip(dis_vars, dis_grads): | |
| helper.variable_summaries(v, v.op.name) | |
| helper.variable_summaries(g, 'grad/' + v.op.name) | |
| merge_summaries_op = tf.summary.merge_all() | |
| text_summary_placeholder = tf.placeholder(tf.string) | |
| text_summary_op = tf.summary.text('Samples', text_summary_placeholder) | |
| # Model saver. | |
| saver = tf.train.Saver(keep_checkpoint_every_n_hours=1, max_to_keep=5) | |
| # Named tuple that captures elements of the MaskGAN model. | |
| Model = collections.namedtuple('Model', [ | |
| 'inputs', 'targets', 'present', 'percent_real_update', 'new_rate', | |
| 'fake_sequence', 'fake_logits', 'fake_rewards', 'fake_baselines', | |
| 'fake_advantages', 'fake_log_probs', 'fake_predictions', | |
| 'real_predictions', 'fake_cross_entropy_losses', 'fake_gen_initial_state', | |
| 'fake_gen_final_state', 'eval_initial_state', 'eval_final_state', | |
| 'avg_log_perplexity', 'dis_loss', 'gen_loss', 'critic_loss', | |
| 'cumulative_rewards', 'dis_train_op', 'gen_train_op', 'gen_pretrain_op', | |
| 'dis_pretrain_op', 'merge_summaries_op', 'global_step', | |
| 'new_learning_rate', 'learning_rate_update', 'saver', 'text_summary_op', | |
| 'text_summary_placeholder' | |
| ]) | |
| model = Model( | |
| inputs, targets, present, percent_real_update, new_rate, fake_sequence, | |
| fake_logits, fake_rewards, fake_baselines, fake_advantages, | |
| fake_log_probs, fake_predictions, real_predictions, | |
| fake_cross_entropy_losses, fake_gen_initial_state, fake_gen_final_state, | |
| eval_initial_state, eval_final_state, avg_log_perplexity, dis_loss, | |
| gen_loss, critic_loss, cumulative_rewards, dis_train_op, gen_train_op, | |
| gen_pretrain_op, dis_pretrain_op, merge_summaries_op, global_step, | |
| new_learning_rate, learning_rate_update, saver, text_summary_op, | |
| text_summary_placeholder) | |
| return model | |
| def compute_geometric_average(percent_captured): | |
| """Compute the geometric average of the n-gram metrics.""" | |
| res = 1. | |
| for _, n_gram_percent in percent_captured.iteritems(): | |
| res *= n_gram_percent | |
| return np.power(res, 1. / float(len(percent_captured))) | |
| def compute_arithmetic_average(percent_captured): | |
| """Compute the arithmetic average of the n-gram metrics.""" | |
| N = len(percent_captured) | |
| res = 0. | |
| for _, n_gram_percent in percent_captured.iteritems(): | |
| res += n_gram_percent | |
| return res / float(N) | |
| def get_iterator(data): | |
| """Return the data iterator.""" | |
| if FLAGS.data_set == 'ptb': | |
| iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size, | |
| FLAGS.sequence_length, | |
| FLAGS.epoch_size_override) | |
| elif FLAGS.data_set == 'imdb': | |
| iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size, | |
| FLAGS.sequence_length) | |
| return iterator | |
| def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts): | |
| """Train model. | |
| Args: | |
| hparams: Hyperparameters for the MaskGAN. | |
| data: Data to evaluate. | |
| log_dir: Directory to save checkpoints. | |
| log: Readable log for the experiment. | |
| id_to_word: Dictionary of indices to words. | |
| data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the | |
| data_set. | |
| """ | |
| print('Training model.') | |
| tf.logging.info('Training model.') | |
| # Boolean indicating operational mode. | |
| is_training = True | |
| # Write all the information to the logs. | |
| log.write('hparams\n') | |
| log.write(str(hparams)) | |
| log.flush() | |
| is_chief = FLAGS.task == 0 | |
| with tf.Graph().as_default(): | |
| with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): | |
| container_name = '' | |
| with tf.container(container_name): | |
| # Construct the model. | |
| if FLAGS.num_rollouts == 1: | |
| model = create_MaskGAN(hparams, is_training) | |
| elif FLAGS.num_rollouts > 1: | |
| model = rollout.create_rollout_MaskGAN(hparams, is_training) | |
| else: | |
| raise ValueError | |
| print('\nTrainable Variables in Graph:') | |
| for v in tf.trainable_variables(): | |
| print(v) | |
| ## Retrieve the initial savers. | |
| init_savers = model_utils.retrieve_init_savers(hparams) | |
| ## Initial saver function to supervisor. | |
| init_fn = partial(model_utils.init_fn, init_savers) | |
| # Create the supervisor. It will take care of initialization, | |
| # summaries, checkpoints, and recovery. | |
| sv = tf.train.Supervisor( | |
| logdir=log_dir, | |
| is_chief=is_chief, | |
| saver=model.saver, | |
| global_step=model.global_step, | |
| save_model_secs=60, | |
| recovery_wait_secs=30, | |
| summary_op=None, | |
| init_fn=init_fn) | |
| # Get an initialized, and possibly recovered session. Launch the | |
| # services: Checkpointing, Summaries, step counting. | |
| # | |
| # When multiple replicas of this program are running the services are | |
| # only launched by the 'chief' replica. | |
| with sv.managed_session(FLAGS.master) as sess: | |
| ## Pretrain the generator. | |
| if FLAGS.gen_pretrain_steps: | |
| pretrain_mask_gan.pretrain_generator(sv, sess, model, data, log, | |
| id_to_word, data_ngram_counts, | |
| is_chief) | |
| ## Pretrain the discriminator. | |
| if FLAGS.dis_pretrain_steps: | |
| pretrain_mask_gan.pretrain_discriminator( | |
| sv, sess, model, data, log, id_to_word, data_ngram_counts, | |
| is_chief) | |
| # Initial indicators for printing and summarizing. | |
| print_step_division = -1 | |
| summary_step_division = -1 | |
| # Run iterative computation in a loop. | |
| while not sv.ShouldStop(): | |
| is_present_rate = FLAGS.is_present_rate | |
| if FLAGS.is_present_rate_decay is not None: | |
| is_present_rate *= (1. - FLAGS.is_present_rate_decay) | |
| model_utils.assign_percent_real(sess, model.percent_real_update, | |
| model.new_rate, is_present_rate) | |
| # GAN training. | |
| avg_epoch_gen_loss, avg_epoch_dis_loss = [], [] | |
| cumulative_costs = 0. | |
| gen_iters = 0 | |
| # Generator and Discriminator statefulness initial evaluation. | |
| # TODO(liamfedus): Throughout the code I am implicitly assuming | |
| # that the Generator and Discriminator are equal sized. | |
| [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( | |
| [model.eval_initial_state, model.fake_gen_initial_state]) | |
| dis_initial_state_eval = fake_gen_initial_state_eval | |
| # Save zeros state to reset later. | |
| zeros_state = fake_gen_initial_state_eval | |
| ## Offset Discriminator. | |
| if FLAGS.ps_tasks == 0: | |
| dis_offset = 1 | |
| else: | |
| dis_offset = FLAGS.task * 1000 + 1 | |
| dis_iterator = get_iterator(data) | |
| for i in range(dis_offset): | |
| try: | |
| dis_x, dis_y, _ = next(dis_iterator) | |
| except StopIteration: | |
| dis_iterator = get_iterator(data) | |
| dis_initial_state_eval = zeros_state | |
| dis_x, dis_y, _ = next(dis_iterator) | |
| p = model_utils.generate_mask() | |
| # Construct the train feed. | |
| train_feed = { | |
| model.inputs: dis_x, | |
| model.targets: dis_y, | |
| model.present: p | |
| } | |
| if FLAGS.data_set == 'ptb': | |
| # Statefulness of the Generator being used for Discriminator. | |
| for i, (c, h) in enumerate(model.fake_gen_initial_state): | |
| train_feed[c] = dis_initial_state_eval[i].c | |
| train_feed[h] = dis_initial_state_eval[i].h | |
| # Determine the state had the Generator run over real data. We | |
| # use this state for the Discriminator. | |
| [dis_initial_state_eval] = sess.run( | |
| [model.fake_gen_final_state], train_feed) | |
| ## Training loop. | |
| iterator = get_iterator(data) | |
| gen_initial_state_eval = zeros_state | |
| if FLAGS.ps_tasks > 0: | |
| gen_offset = FLAGS.task * 1000 + 1 | |
| for i in range(gen_offset): | |
| try: | |
| next(iterator) | |
| except StopIteration: | |
| dis_iterator = get_iterator(data) | |
| dis_initial_state_eval = zeros_state | |
| next(dis_iterator) | |
| for x, y, _ in iterator: | |
| for _ in xrange(hparams.dis_train_iterations): | |
| try: | |
| dis_x, dis_y, _ = next(dis_iterator) | |
| except StopIteration: | |
| dis_iterator = get_iterator(data) | |
| dis_initial_state_eval = zeros_state | |
| dis_x, dis_y, _ = next(dis_iterator) | |
| if FLAGS.data_set == 'ptb': | |
| [dis_initial_state_eval] = sess.run( | |
| [model.fake_gen_initial_state]) | |
| p = model_utils.generate_mask() | |
| # Construct the train feed. | |
| train_feed = { | |
| model.inputs: dis_x, | |
| model.targets: dis_y, | |
| model.present: p | |
| } | |
| # Statefulness for the Discriminator. | |
| if FLAGS.data_set == 'ptb': | |
| for i, (c, h) in enumerate(model.fake_gen_initial_state): | |
| train_feed[c] = dis_initial_state_eval[i].c | |
| train_feed[h] = dis_initial_state_eval[i].h | |
| _, dis_loss_eval, step = sess.run( | |
| [model.dis_train_op, model.dis_loss, model.global_step], | |
| feed_dict=train_feed) | |
| # Determine the state had the Generator run over real data. | |
| # Use this state for the Discriminator. | |
| [dis_initial_state_eval] = sess.run( | |
| [model.fake_gen_final_state], train_feed) | |
| # Randomly mask out tokens. | |
| p = model_utils.generate_mask() | |
| # Construct the train feed. | |
| train_feed = {model.inputs: x, model.targets: y, model.present: p} | |
| # Statefulness for Generator. | |
| if FLAGS.data_set == 'ptb': | |
| tf.logging.info('Generator is stateful.') | |
| print('Generator is stateful.') | |
| # Statefulness for *evaluation* Generator. | |
| for i, (c, h) in enumerate(model.eval_initial_state): | |
| train_feed[c] = gen_initial_state_eval[i].c | |
| train_feed[h] = gen_initial_state_eval[i].h | |
| # Statefulness for Generator. | |
| for i, (c, h) in enumerate(model.fake_gen_initial_state): | |
| train_feed[c] = fake_gen_initial_state_eval[i].c | |
| train_feed[h] = fake_gen_initial_state_eval[i].h | |
| # Determine whether to decay learning rate. | |
| lr_decay = hparams.gen_learning_rate_decay**max( | |
| step + 1 - hparams.gen_full_learning_rate_steps, 0.0) | |
| # Assign learning rate. | |
| gen_learning_rate = hparams.gen_learning_rate * lr_decay | |
| model_utils.assign_learning_rate(sess, model.learning_rate_update, | |
| model.new_learning_rate, | |
| gen_learning_rate) | |
| [_, gen_loss_eval, gen_log_perplexity_eval, step] = sess.run( | |
| [ | |
| model.gen_train_op, model.gen_loss, | |
| model.avg_log_perplexity, model.global_step | |
| ], | |
| feed_dict=train_feed) | |
| cumulative_costs += gen_log_perplexity_eval | |
| gen_iters += 1 | |
| # Determine the state had the Generator run over real data. | |
| [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( | |
| [model.eval_final_state, | |
| model.fake_gen_final_state], train_feed) | |
| avg_epoch_dis_loss.append(dis_loss_eval) | |
| avg_epoch_gen_loss.append(gen_loss_eval) | |
| ## Summaries. | |
| # Calulate rolling perplexity. | |
| perplexity = np.exp(cumulative_costs / gen_iters) | |
| if is_chief and (step / FLAGS.summaries_every > | |
| summary_step_division): | |
| summary_step_division = step / FLAGS.summaries_every | |
| # Confirm perplexity is not infinite. | |
| if (not np.isfinite(perplexity) or | |
| perplexity >= FLAGS.perplexity_threshold): | |
| print('Training raising FloatingPoinError.') | |
| raise FloatingPointError( | |
| 'Training infinite perplexity: %.3f' % perplexity) | |
| # Graph summaries. | |
| summary_str = sess.run( | |
| model.merge_summaries_op, feed_dict=train_feed) | |
| sv.SummaryComputed(sess, summary_str) | |
| # Summary: n-gram | |
| avg_percent_captured = {'2': 0., '3': 0., '4': 0.} | |
| for n, data_ngram_count in data_ngram_counts.iteritems(): | |
| batch_percent_captured = evaluation_utils.sequence_ngram_evaluation( | |
| sess, model.fake_sequence, log, train_feed, | |
| data_ngram_count, int(n)) | |
| summary_percent_str = tf.Summary(value=[ | |
| tf.Summary.Value( | |
| tag='general/%s-grams_percent_correct' % n, | |
| simple_value=batch_percent_captured) | |
| ]) | |
| sv.SummaryComputed( | |
| sess, summary_percent_str, global_step=step) | |
| # Summary: geometric_avg | |
| geometric_avg = compute_geometric_average(avg_percent_captured) | |
| summary_geometric_avg_str = tf.Summary(value=[ | |
| tf.Summary.Value( | |
| tag='general/geometric_avg', simple_value=geometric_avg) | |
| ]) | |
| sv.SummaryComputed( | |
| sess, summary_geometric_avg_str, global_step=step) | |
| # Summary: arithmetic_avg | |
| arithmetic_avg = compute_arithmetic_average( | |
| avg_percent_captured) | |
| summary_arithmetic_avg_str = tf.Summary(value=[ | |
| tf.Summary.Value( | |
| tag='general/arithmetic_avg', | |
| simple_value=arithmetic_avg) | |
| ]) | |
| sv.SummaryComputed( | |
| sess, summary_arithmetic_avg_str, global_step=step) | |
| # Summary: perplexity | |
| summary_perplexity_str = tf.Summary(value=[ | |
| tf.Summary.Value( | |
| tag='general/perplexity', simple_value=perplexity) | |
| ]) | |
| sv.SummaryComputed( | |
| sess, summary_perplexity_str, global_step=step) | |
| ## Printing and logging | |
| if is_chief and (step / FLAGS.print_every > print_step_division): | |
| print_step_division = (step / FLAGS.print_every) | |
| print('global_step: %d' % step) | |
| print(' perplexity: %.3f' % perplexity) | |
| print(' gen_learning_rate: %.6f' % gen_learning_rate) | |
| log.write('global_step: %d\n' % step) | |
| log.write(' perplexity: %.3f\n' % perplexity) | |
| log.write(' gen_learning_rate: %.6f' % gen_learning_rate) | |
| # Average percent captured for each of the n-grams. | |
| avg_percent_captured = {'2': 0., '3': 0., '4': 0.} | |
| for n, data_ngram_count in data_ngram_counts.iteritems(): | |
| batch_percent_captured = evaluation_utils.sequence_ngram_evaluation( | |
| sess, model.fake_sequence, log, train_feed, | |
| data_ngram_count, int(n)) | |
| avg_percent_captured[n] = batch_percent_captured | |
| print(' percent of %s-grams captured: %.3f.' % | |
| (n, batch_percent_captured)) | |
| log.write(' percent of %s-grams captured: %.3f.\n' % | |
| (n, batch_percent_captured)) | |
| geometric_avg = compute_geometric_average(avg_percent_captured) | |
| print(' geometric_avg: %.3f.' % geometric_avg) | |
| log.write(' geometric_avg: %.3f.' % geometric_avg) | |
| arithmetic_avg = compute_arithmetic_average( | |
| avg_percent_captured) | |
| print(' arithmetic_avg: %.3f.' % arithmetic_avg) | |
| log.write(' arithmetic_avg: %.3f.' % arithmetic_avg) | |
| evaluation_utils.print_and_log_losses( | |
| log, step, is_present_rate, avg_epoch_dis_loss, | |
| avg_epoch_gen_loss) | |
| if FLAGS.gen_training_strategy == 'reinforce': | |
| evaluation_utils.generate_RL_logs(sess, model, log, | |
| id_to_word, train_feed) | |
| else: | |
| evaluation_utils.generate_logs(sess, model, log, id_to_word, | |
| train_feed) | |
| log.flush() | |
| log.close() | |
| def evaluate_once(data, sv, model, sess, train_dir, log, id_to_word, | |
| data_ngram_counts, eval_saver): | |
| """Evaluate model for a number of steps. | |
| Args: | |
| data: Dataset. | |
| sv: Supervisor. | |
| model: The GAN model we have just built. | |
| sess: A session to use. | |
| train_dir: Path to a directory containing checkpoints. | |
| log: Evaluation log for evaluation. | |
| id_to_word: Dictionary of indices to words. | |
| data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the | |
| data_set. | |
| eval_saver: Evaluation saver.r. | |
| """ | |
| tf.logging.info('Evaluate Once.') | |
| # Load the last model checkpoint, or initialize the graph. | |
| model_save_path = tf.latest_checkpoint(train_dir) | |
| if not model_save_path: | |
| tf.logging.warning('No checkpoint yet in: %s', train_dir) | |
| return | |
| tf.logging.info('Starting eval of: %s' % model_save_path) | |
| tf.logging.info('Only restoring trainable variables.') | |
| eval_saver.restore(sess, model_save_path) | |
| # Run the requested number of evaluation steps | |
| avg_epoch_gen_loss, avg_epoch_dis_loss = [], [] | |
| cumulative_costs = 0. | |
| # Average percent captured for each of the n-grams. | |
| avg_percent_captured = {'2': 0., '3': 0., '4': 0.} | |
| # Set a random seed to keep fixed mask. | |
| np.random.seed(0) | |
| gen_iters = 0 | |
| # Generator statefulness over the epoch. | |
| # TODO(liamfedus): Check this. | |
| [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( | |
| [model.eval_initial_state, model.fake_gen_initial_state]) | |
| if FLAGS.eval_language_model: | |
| is_present_rate = 0. | |
| tf.logging.info('Overriding is_present_rate=0. for evaluation.') | |
| print('Overriding is_present_rate=0. for evaluation.') | |
| iterator = get_iterator(data) | |
| for x, y, _ in iterator: | |
| if FLAGS.eval_language_model: | |
| is_present_rate = 0. | |
| else: | |
| is_present_rate = FLAGS.is_present_rate | |
| tf.logging.info('Evaluating on is_present_rate=%.3f.' % is_present_rate) | |
| model_utils.assign_percent_real(sess, model.percent_real_update, | |
| model.new_rate, is_present_rate) | |
| # Randomly mask out tokens. | |
| p = model_utils.generate_mask() | |
| eval_feed = {model.inputs: x, model.targets: y, model.present: p} | |
| if FLAGS.data_set == 'ptb': | |
| # Statefulness for *evaluation* Generator. | |
| for i, (c, h) in enumerate(model.eval_initial_state): | |
| eval_feed[c] = gen_initial_state_eval[i].c | |
| eval_feed[h] = gen_initial_state_eval[i].h | |
| # Statefulness for the Generator. | |
| for i, (c, h) in enumerate(model.fake_gen_initial_state): | |
| eval_feed[c] = fake_gen_initial_state_eval[i].c | |
| eval_feed[h] = fake_gen_initial_state_eval[i].h | |
| [ | |
| gen_log_perplexity_eval, dis_loss_eval, gen_loss_eval, | |
| gen_initial_state_eval, fake_gen_initial_state_eval, step | |
| ] = sess.run( | |
| [ | |
| model.avg_log_perplexity, model.dis_loss, model.gen_loss, | |
| model.eval_final_state, model.fake_gen_final_state, | |
| model.global_step | |
| ], | |
| feed_dict=eval_feed) | |
| for n, data_ngram_count in data_ngram_counts.iteritems(): | |
| batch_percent_captured = evaluation_utils.sequence_ngram_evaluation( | |
| sess, model.fake_sequence, log, eval_feed, data_ngram_count, int(n)) | |
| avg_percent_captured[n] += batch_percent_captured | |
| cumulative_costs += gen_log_perplexity_eval | |
| avg_epoch_dis_loss.append(dis_loss_eval) | |
| avg_epoch_gen_loss.append(gen_loss_eval) | |
| gen_iters += 1 | |
| # Calulate rolling metrics. | |
| perplexity = np.exp(cumulative_costs / gen_iters) | |
| for n, _ in avg_percent_captured.iteritems(): | |
| avg_percent_captured[n] /= gen_iters | |
| # Confirm perplexity is not infinite. | |
| if not np.isfinite(perplexity) or perplexity >= FLAGS.perplexity_threshold: | |
| print('Evaluation raising FloatingPointError.') | |
| raise FloatingPointError( | |
| 'Evaluation infinite perplexity: %.3f' % perplexity) | |
| ## Printing and logging. | |
| evaluation_utils.print_and_log_losses(log, step, is_present_rate, | |
| avg_epoch_dis_loss, avg_epoch_gen_loss) | |
| print(' perplexity: %.3f' % perplexity) | |
| log.write(' perplexity: %.3f\n' % perplexity) | |
| for n, n_gram_percent in avg_percent_captured.iteritems(): | |
| n = int(n) | |
| print(' percent of %d-grams captured: %.3f.' % (n, n_gram_percent)) | |
| log.write(' percent of %d-grams captured: %.3f.\n' % (n, n_gram_percent)) | |
| samples = evaluation_utils.generate_logs(sess, model, log, id_to_word, | |
| eval_feed) | |
| ## Summaries. | |
| summary_str = sess.run(model.merge_summaries_op, feed_dict=eval_feed) | |
| sv.SummaryComputed(sess, summary_str) | |
| # Summary: text | |
| summary_str = sess.run(model.text_summary_op, | |
| {model.text_summary_placeholder: '\n\n'.join(samples)}) | |
| sv.SummaryComputed(sess, summary_str, global_step=step) | |
| # Summary: n-gram | |
| for n, n_gram_percent in avg_percent_captured.iteritems(): | |
| n = int(n) | |
| summary_percent_str = tf.Summary(value=[ | |
| tf.Summary.Value( | |
| tag='general/%d-grams_percent_correct' % n, | |
| simple_value=n_gram_percent) | |
| ]) | |
| sv.SummaryComputed(sess, summary_percent_str, global_step=step) | |
| # Summary: geometric_avg | |
| geometric_avg = compute_geometric_average(avg_percent_captured) | |
| summary_geometric_avg_str = tf.Summary(value=[ | |
| tf.Summary.Value(tag='general/geometric_avg', simple_value=geometric_avg) | |
| ]) | |
| sv.SummaryComputed(sess, summary_geometric_avg_str, global_step=step) | |
| # Summary: arithmetic_avg | |
| arithmetic_avg = compute_arithmetic_average(avg_percent_captured) | |
| summary_arithmetic_avg_str = tf.Summary(value=[ | |
| tf.Summary.Value( | |
| tag='general/arithmetic_avg', simple_value=arithmetic_avg) | |
| ]) | |
| sv.SummaryComputed(sess, summary_arithmetic_avg_str, global_step=step) | |
| # Summary: perplexity | |
| summary_perplexity_str = tf.Summary(value=[ | |
| tf.Summary.Value(tag='general/perplexity', simple_value=perplexity) | |
| ]) | |
| sv.SummaryComputed(sess, summary_perplexity_str, global_step=step) | |
| def evaluate_model(hparams, data, train_dir, log, id_to_word, | |
| data_ngram_counts): | |
| """Evaluate MaskGAN model. | |
| Args: | |
| hparams: Hyperparameters for the MaskGAN. | |
| data: Data to evaluate. | |
| train_dir: Path to a directory containing checkpoints. | |
| id_to_word: Dictionary of indices to words. | |
| data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the | |
| data_set. | |
| """ | |
| tf.logging.error('Evaluate model.') | |
| # Boolean indicating operational mode. | |
| is_training = False | |
| if FLAGS.mode == MODE_VALIDATION: | |
| logdir = FLAGS.base_directory + '/validation' | |
| elif FLAGS.mode == MODE_TRAIN_EVAL: | |
| logdir = FLAGS.base_directory + '/train_eval' | |
| elif FLAGS.mode == MODE_TEST: | |
| logdir = FLAGS.base_directory + '/test' | |
| else: | |
| raise NotImplementedError | |
| # Wait for a checkpoint to exist. | |
| print(train_dir) | |
| print(tf.train.latest_checkpoint(train_dir)) | |
| while not tf.train.latest_checkpoint(train_dir): | |
| tf.logging.error('Waiting for checkpoint...') | |
| print('Waiting for checkpoint...') | |
| time.sleep(10) | |
| with tf.Graph().as_default(): | |
| # Use a separate container for each trial | |
| container_name = '' | |
| with tf.container(container_name): | |
| # Construct the model. | |
| if FLAGS.num_rollouts == 1: | |
| model = create_MaskGAN(hparams, is_training) | |
| elif FLAGS.num_rollouts > 1: | |
| model = rollout.create_rollout_MaskGAN(hparams, is_training) | |
| else: | |
| raise ValueError | |
| # Create the supervisor. It will take care of initialization, summaries, | |
| # checkpoints, and recovery. We only pass the trainable variables | |
| # to load since things like baselines keep batch_size which may not | |
| # match between training and evaluation. | |
| evaluation_variables = tf.trainable_variables() | |
| evaluation_variables.append(model.global_step) | |
| eval_saver = tf.train.Saver(var_list=evaluation_variables) | |
| sv = tf.Supervisor(logdir=logdir) | |
| sess = sv.PrepareSession(FLAGS.eval_master, start_standard_services=False) | |
| tf.logging.info('Before sv.Loop.') | |
| sv.Loop(FLAGS.eval_interval_secs, evaluate_once, | |
| (data, sv, model, sess, train_dir, log, id_to_word, | |
| data_ngram_counts, eval_saver)) | |
| sv.WaitForStop() | |
| tf.logging.info('sv.Stop().') | |
| sv.Stop() | |
| def main(_): | |
| hparams = create_hparams() | |
| train_dir = FLAGS.base_directory + '/train' | |
| # Load data set. | |
| if FLAGS.data_set == 'ptb': | |
| raw_data = ptb_loader.ptb_raw_data(FLAGS.data_dir) | |
| train_data, valid_data, test_data, _ = raw_data | |
| valid_data_flat = valid_data | |
| elif FLAGS.data_set == 'imdb': | |
| raw_data = imdb_loader.imdb_raw_data(FLAGS.data_dir) | |
| # TODO(liamfedus): Get an IMDB test partition. | |
| train_data, valid_data = raw_data | |
| valid_data_flat = [word for review in valid_data for word in review] | |
| else: | |
| raise NotImplementedError | |
| if FLAGS.mode == MODE_TRAIN or FLAGS.mode == MODE_TRAIN_EVAL: | |
| data_set = train_data | |
| elif FLAGS.mode == MODE_VALIDATION: | |
| data_set = valid_data | |
| elif FLAGS.mode == MODE_TEST: | |
| data_set = test_data | |
| else: | |
| raise NotImplementedError | |
| # Dictionary and reverse dictionry. | |
| if FLAGS.data_set == 'ptb': | |
| word_to_id = ptb_loader.build_vocab( | |
| os.path.join(FLAGS.data_dir, 'ptb.train.txt')) | |
| elif FLAGS.data_set == 'imdb': | |
| word_to_id = imdb_loader.build_vocab( | |
| os.path.join(FLAGS.data_dir, 'vocab.txt')) | |
| id_to_word = {v: k for k, v in word_to_id.iteritems()} | |
| # Dictionary of Training Set n-gram counts. | |
| bigram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=2) | |
| trigram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=3) | |
| fourgram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=4) | |
| bigram_counts = n_gram.construct_ngrams_dict(bigram_tuples) | |
| trigram_counts = n_gram.construct_ngrams_dict(trigram_tuples) | |
| fourgram_counts = n_gram.construct_ngrams_dict(fourgram_tuples) | |
| print('Unique %d-grams: %d' % (2, len(bigram_counts))) | |
| print('Unique %d-grams: %d' % (3, len(trigram_counts))) | |
| print('Unique %d-grams: %d' % (4, len(fourgram_counts))) | |
| data_ngram_counts = { | |
| '2': bigram_counts, | |
| '3': trigram_counts, | |
| '4': fourgram_counts | |
| } | |
| # TODO(liamfedus): This was necessary because there was a problem with our | |
| # originally trained IMDB models. The EOS_INDEX was off by one, which means, | |
| # two words were mapping to index 86933. The presence of '</s>' is going | |
| # to throw and out of vocabulary error. | |
| FLAGS.vocab_size = len(id_to_word) | |
| print('Vocab size: %d' % FLAGS.vocab_size) | |
| tf.gfile.MakeDirs(FLAGS.base_directory) | |
| if FLAGS.mode == MODE_TRAIN: | |
| log = tf.gfile.GFile( | |
| os.path.join(FLAGS.base_directory, 'train-log.txt'), mode='w') | |
| elif FLAGS.mode == MODE_VALIDATION: | |
| log = tf.gfile.GFile( | |
| os.path.join(FLAGS.base_directory, 'validation-log.txt'), mode='w') | |
| elif FLAGS.mode == MODE_TRAIN_EVAL: | |
| log = tf.gfile.GFile( | |
| os.path.join(FLAGS.base_directory, 'train_eval-log.txt'), mode='w') | |
| else: | |
| log = tf.gfile.GFile( | |
| os.path.join(FLAGS.base_directory, 'test-log.txt'), mode='w') | |
| if FLAGS.mode == MODE_TRAIN: | |
| train_model(hparams, data_set, train_dir, log, id_to_word, | |
| data_ngram_counts) | |
| elif FLAGS.mode == MODE_VALIDATION: | |
| evaluate_model(hparams, data_set, train_dir, log, id_to_word, | |
| data_ngram_counts) | |
| elif FLAGS.mode == MODE_TRAIN_EVAL: | |
| evaluate_model(hparams, data_set, train_dir, log, id_to_word, | |
| data_ngram_counts) | |
| elif FLAGS.mode == MODE_TEST: | |
| evaluate_model(hparams, data_set, train_dir, log, id_to_word, | |
| data_ngram_counts) | |
| else: | |
| raise NotImplementedError | |
| if __name__ == '__main__': | |
| tf.app.run() | |