Spaces:
Runtime error
Runtime error
| # Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Generate samples from the MaskGAN. | |
| Launch command: | |
| python generate_samples.py | |
| --data_dir=/tmp/data/imdb --data_set=imdb | |
| --batch_size=256 --sequence_length=20 --base_directory=/tmp/imdb | |
| --hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2, | |
| gen_vd_keep_prob=1.0" --generator_model=seq2seq_vd | |
| --discriminator_model=seq2seq_vd --is_present_rate=0.5 | |
| --maskgan_ckpt=/tmp/model.ckpt-45494 | |
| --seq2seq_share_embedding=True --dis_share_embedding=True | |
| --attention_option=luong --mask_strategy=contiguous --baseline_method=critic | |
| --number_epochs=4 | |
| """ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| from functools import partial | |
| import os | |
| # Dependency imports | |
| import numpy as np | |
| from six.moves import xrange | |
| import tensorflow as tf | |
| import train_mask_gan | |
| from data import imdb_loader | |
| from data import ptb_loader | |
| # Data. | |
| from model_utils import helper | |
| from model_utils import model_utils | |
| SAMPLE_TRAIN = 'TRAIN' | |
| SAMPLE_VALIDATION = 'VALIDATION' | |
| ## Sample Generation. | |
| ## Binary and setup FLAGS. | |
| tf.app.flags.DEFINE_enum('sample_mode', 'TRAIN', | |
| [SAMPLE_TRAIN, SAMPLE_VALIDATION], | |
| 'Dataset to sample from.') | |
| tf.app.flags.DEFINE_string('output_path', '/tmp', 'Model output directory.') | |
| tf.app.flags.DEFINE_boolean( | |
| 'output_masked_logs', False, | |
| 'Whether to display for human evaluation (show masking).') | |
| tf.app.flags.DEFINE_integer('number_epochs', 1, | |
| 'The number of epochs to produce.') | |
| FLAGS = tf.app.flags.FLAGS | |
| def get_iterator(data): | |
| """Return the data iterator.""" | |
| if FLAGS.data_set == 'ptb': | |
| iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size, | |
| FLAGS.sequence_length, | |
| FLAGS.epoch_size_override) | |
| elif FLAGS.data_set == 'imdb': | |
| iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size, | |
| FLAGS.sequence_length) | |
| return iterator | |
| def convert_to_human_readable(id_to_word, arr, p, max_num_to_print): | |
| """Convert a np.array of indices into words using id_to_word dictionary. | |
| Return max_num_to_print results. | |
| """ | |
| assert arr.ndim == 2 | |
| samples = [] | |
| for sequence_id in xrange(min(len(arr), max_num_to_print)): | |
| sample = [] | |
| for i, index in enumerate(arr[sequence_id, :]): | |
| if p[sequence_id, i] == 1: | |
| sample.append(str(id_to_word[index])) | |
| else: | |
| sample.append('*' + str(id_to_word[index])) | |
| buffer_str = ' '.join(sample) | |
| samples.append(buffer_str) | |
| return samples | |
| def write_unmasked_log(log, id_to_word, sequence_eval): | |
| """Helper function for logging evaluated sequences without mask.""" | |
| indices_arr = np.asarray(sequence_eval) | |
| samples = helper.convert_to_human_readable(id_to_word, indices_arr, | |
| FLAGS.batch_size) | |
| for sample in samples: | |
| log.write(sample + '\n') | |
| log.flush() | |
| return samples | |
| def write_masked_log(log, id_to_word, sequence_eval, present_eval): | |
| indices_arr = np.asarray(sequence_eval) | |
| samples = convert_to_human_readable(id_to_word, indices_arr, present_eval, | |
| FLAGS.batch_size) | |
| for sample in samples: | |
| log.write(sample + '\n') | |
| log.flush() | |
| return samples | |
| def generate_logs(sess, model, log, id_to_word, feed): | |
| """Impute Sequences using the model for a particular feed and send it to | |
| logs. | |
| """ | |
| # Impute Sequences. | |
| [p, inputs_eval, sequence_eval] = sess.run( | |
| [model.present, model.inputs, model.fake_sequence], feed_dict=feed) | |
| # Add the 0th time-step for coherence. | |
| first_token = np.expand_dims(inputs_eval[:, 0], axis=1) | |
| sequence_eval = np.concatenate((first_token, sequence_eval), axis=1) | |
| # 0th token always present. | |
| p = np.concatenate((np.ones((FLAGS.batch_size, 1)), p), axis=1) | |
| if FLAGS.output_masked_logs: | |
| samples = write_masked_log(log, id_to_word, sequence_eval, p) | |
| else: | |
| samples = write_unmasked_log(log, id_to_word, sequence_eval) | |
| return samples | |
| def generate_samples(hparams, data, id_to_word, log_dir, output_file): | |
| """"Generate samples. | |
| Args: | |
| hparams: Hyperparameters for the MaskGAN. | |
| data: Data to evaluate. | |
| id_to_word: Dictionary of indices to words. | |
| log_dir: Log directory. | |
| output_file: Output file for the samples. | |
| """ | |
| # Boolean indicating operational mode. | |
| is_training = False | |
| # Set a random seed to keep fixed mask. | |
| np.random.seed(0) | |
| with tf.Graph().as_default(): | |
| # Construct the model. | |
| model = train_mask_gan.create_MaskGAN(hparams, is_training) | |
| ## Retrieve the initial savers. | |
| init_savers = model_utils.retrieve_init_savers(hparams) | |
| ## Initial saver function to supervisor. | |
| init_fn = partial(model_utils.init_fn, init_savers) | |
| is_chief = FLAGS.task == 0 | |
| # Create the supervisor. It will take care of initialization, summaries, | |
| # checkpoints, and recovery. | |
| sv = tf.Supervisor( | |
| logdir=log_dir, | |
| is_chief=is_chief, | |
| saver=model.saver, | |
| global_step=model.global_step, | |
| recovery_wait_secs=30, | |
| summary_op=None, | |
| init_fn=init_fn) | |
| # Get an initialized, and possibly recovered session. Launch the | |
| # services: Checkpointing, Summaries, step counting. | |
| # | |
| # When multiple replicas of this program are running the services are | |
| # only launched by the 'chief' replica. | |
| with sv.managed_session( | |
| FLAGS.master, start_standard_services=False) as sess: | |
| # Generator statefulness over the epoch. | |
| [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run( | |
| [model.eval_initial_state, model.fake_gen_initial_state]) | |
| for n in xrange(FLAGS.number_epochs): | |
| print('Epoch number: %d' % n) | |
| # print('Percent done: %.2f' % float(n) / float(FLAGS.number_epochs)) | |
| iterator = get_iterator(data) | |
| for x, y, _ in iterator: | |
| if FLAGS.eval_language_model: | |
| is_present_rate = 0. | |
| else: | |
| is_present_rate = FLAGS.is_present_rate | |
| tf.logging.info( | |
| 'Evaluating on is_present_rate=%.3f.' % is_present_rate) | |
| model_utils.assign_percent_real(sess, model.percent_real_update, | |
| model.new_rate, is_present_rate) | |
| # Randomly mask out tokens. | |
| p = model_utils.generate_mask() | |
| eval_feed = {model.inputs: x, model.targets: y, model.present: p} | |
| if FLAGS.data_set == 'ptb': | |
| # Statefulness for *evaluation* Generator. | |
| for i, (c, h) in enumerate(model.eval_initial_state): | |
| eval_feed[c] = gen_initial_state_eval[i].c | |
| eval_feed[h] = gen_initial_state_eval[i].h | |
| # Statefulness for the Generator. | |
| for i, (c, h) in enumerate(model.fake_gen_initial_state): | |
| eval_feed[c] = fake_gen_initial_state_eval[i].c | |
| eval_feed[h] = fake_gen_initial_state_eval[i].h | |
| [gen_initial_state_eval, fake_gen_initial_state_eval, _] = sess.run( | |
| [ | |
| model.eval_final_state, model.fake_gen_final_state, | |
| model.global_step | |
| ], | |
| feed_dict=eval_feed) | |
| generate_logs(sess, model, output_file, id_to_word, eval_feed) | |
| output_file.close() | |
| print('Closing output_file.') | |
| return | |
| def main(_): | |
| hparams = train_mask_gan.create_hparams() | |
| log_dir = FLAGS.base_directory | |
| tf.gfile.MakeDirs(FLAGS.output_path) | |
| output_file = tf.gfile.GFile( | |
| os.path.join(FLAGS.output_path, 'reviews.txt'), mode='w') | |
| # Load data set. | |
| if FLAGS.data_set == 'ptb': | |
| raw_data = ptb_loader.ptb_raw_data(FLAGS.data_dir) | |
| train_data, valid_data, _, _ = raw_data | |
| elif FLAGS.data_set == 'imdb': | |
| raw_data = imdb_loader.imdb_raw_data(FLAGS.data_dir) | |
| train_data, valid_data = raw_data | |
| else: | |
| raise NotImplementedError | |
| # Generating more data on train set. | |
| if FLAGS.sample_mode == SAMPLE_TRAIN: | |
| data_set = train_data | |
| elif FLAGS.sample_mode == SAMPLE_VALIDATION: | |
| data_set = valid_data | |
| else: | |
| raise NotImplementedError | |
| # Dictionary and reverse dictionry. | |
| if FLAGS.data_set == 'ptb': | |
| word_to_id = ptb_loader.build_vocab( | |
| os.path.join(FLAGS.data_dir, 'ptb.train.txt')) | |
| elif FLAGS.data_set == 'imdb': | |
| word_to_id = imdb_loader.build_vocab( | |
| os.path.join(FLAGS.data_dir, 'vocab.txt')) | |
| id_to_word = {v: k for k, v in word_to_id.iteritems()} | |
| FLAGS.vocab_size = len(id_to_word) | |
| print('Vocab size: %d' % FLAGS.vocab_size) | |
| generate_samples(hparams, data_set, id_to_word, log_dir, output_file) | |
| if __name__ == '__main__': | |
| tf.app.run() | |