Spaces:
Runtime error
Runtime error
| # Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Pretraining functions.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| # Dependency imports | |
| import numpy as np | |
| import tensorflow as tf | |
| from data import imdb_loader | |
| from data import ptb_loader | |
| # Data. | |
| from model_utils import model_utils | |
| from models import evaluation_utils | |
| tf.app.flags.DEFINE_integer( | |
| 'gen_pretrain_steps', None, | |
| 'The number of steps to pretrain the generator with cross entropy loss.') | |
| tf.app.flags.DEFINE_integer( | |
| 'dis_pretrain_steps', None, | |
| 'The number of steps to pretrain the discriminator.') | |
| FLAGS = tf.app.flags.FLAGS | |
| def pretrain_generator(sv, sess, model, data, log, id_to_word, | |
| data_ngram_counts, is_chief): | |
| """Pretrain the generator with classic language modeling training.""" | |
| print('\nPretraining generator for %d steps.' % FLAGS.gen_pretrain_steps) | |
| log.write( | |
| '\nPretraining generator for %d steps.\n' % FLAGS.gen_pretrain_steps) | |
| is_pretraining = True | |
| while is_pretraining: | |
| costs = 0. | |
| iters = 0 | |
| if FLAGS.data_set == 'ptb': | |
| iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size, | |
| FLAGS.sequence_length, | |
| FLAGS.epoch_size_override) | |
| elif FLAGS.data_set == 'imdb': | |
| iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size, | |
| FLAGS.sequence_length) | |
| for x, y, _ in iterator: | |
| # For pretraining with cross entropy loss, we have all tokens in the | |
| # forward sequence present (all True). | |
| model_utils.assign_percent_real(sess, model.percent_real_update, | |
| model.new_rate, 1.0) | |
| p = np.ones(shape=[FLAGS.batch_size, FLAGS.sequence_length], dtype=bool) | |
| pretrain_feed = {model.inputs: x, model.targets: y, model.present: p} | |
| [losses, cost_eval, _, step] = sess.run( | |
| [ | |
| model.fake_cross_entropy_losses, model.avg_log_perplexity, | |
| model.gen_pretrain_op, model.global_step | |
| ], | |
| feed_dict=pretrain_feed) | |
| costs += cost_eval | |
| iters += FLAGS.sequence_length | |
| # Calulate rolling perplexity. | |
| perplexity = np.exp(costs / iters) | |
| # Summaries. | |
| if is_chief and step % FLAGS.summaries_every == 0: | |
| # Graph summaries. | |
| summary_str = sess.run( | |
| model.merge_summaries_op, feed_dict=pretrain_feed) | |
| sv.SummaryComputed(sess, summary_str) | |
| # Additional summary. | |
| for n, data_ngram_count in data_ngram_counts.iteritems(): | |
| avg_percent_captured = evaluation_utils.sequence_ngram_evaluation( | |
| sess, model.fake_sequence, log, pretrain_feed, data_ngram_count, | |
| int(n)) | |
| summary_percent_str = tf.Summary(value=[ | |
| tf.Summary.Value( | |
| tag='general/%s-grams_percent_correct' % n, | |
| simple_value=avg_percent_captured) | |
| ]) | |
| sv.SummaryComputed(sess, summary_percent_str, global_step=step) | |
| summary_perplexity_str = tf.Summary(value=[ | |
| tf.Summary.Value(tag='general/perplexity', simple_value=perplexity) | |
| ]) | |
| sv.SummaryComputed(sess, summary_perplexity_str, global_step=step) | |
| # Printing and logging | |
| if is_chief and step % FLAGS.print_every == 0: | |
| print('global_step: %d' % step) | |
| print(' generator loss: %.3f' % np.mean(losses)) | |
| print(' perplexity: %.3f' % perplexity) | |
| log.write('global_step: %d\n' % step) | |
| log.write(' generator loss: %.3f\n' % np.mean(losses)) | |
| log.write(' perplexity: %.3f\n' % perplexity) | |
| for n, data_ngram_count in data_ngram_counts.iteritems(): | |
| avg_percent_captured = evaluation_utils.sequence_ngram_evaluation( | |
| sess, model.fake_sequence, log, pretrain_feed, data_ngram_count, | |
| int(n)) | |
| print(' percent of %s-grams captured: %.3f.\n' % | |
| (n, avg_percent_captured)) | |
| log.write(' percent of %s-grams captured: %.3f.\n\n' % | |
| (n, avg_percent_captured)) | |
| evaluation_utils.generate_logs(sess, model, log, id_to_word, | |
| pretrain_feed) | |
| if step >= FLAGS.gen_pretrain_steps: | |
| is_pretraining = False | |
| break | |
| return | |
| def pretrain_discriminator(sv, sess, model, data, log, id_to_word, | |
| data_ngram_counts, is_chief): | |
| print('\nPretraining discriminator for %d steps.' % FLAGS.dis_pretrain_steps) | |
| log.write( | |
| '\nPretraining discriminator for %d steps.\n' % FLAGS.dis_pretrain_steps) | |
| is_pretraining = True | |
| while is_pretraining: | |
| cumulative_costs = 0. | |
| iters = 0 | |
| if FLAGS.data_set == 'ptb': | |
| iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size, | |
| FLAGS.sequence_length, | |
| FLAGS.epoch_size_override) | |
| elif FLAGS.data_set == 'imdb': | |
| iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size, | |
| FLAGS.sequence_length) | |
| for x, y, _ in iterator: | |
| is_present_rate = FLAGS.is_present_rate | |
| # is_present_rate = np.random.uniform(low=0.0, high=1.0) | |
| model_utils.assign_percent_real(sess, model.percent_real_update, | |
| model.new_rate, is_present_rate) | |
| # Randomly mask out tokens. | |
| p = model_utils.generate_mask() | |
| pretrain_feed = {model.inputs: x, model.targets: y, model.present: p} | |
| [_, dis_loss_eval, gen_log_perplexity_eval, step] = sess.run( | |
| [ | |
| model.dis_pretrain_op, model.dis_loss, model.avg_log_perplexity, | |
| model.global_step | |
| ], | |
| feed_dict=pretrain_feed) | |
| cumulative_costs += gen_log_perplexity_eval | |
| iters += 1 | |
| # Calulate rolling perplexity. | |
| perplexity = np.exp(cumulative_costs / iters) | |
| # Summaries. | |
| if is_chief and step % FLAGS.summaries_every == 0: | |
| # Graph summaries. | |
| summary_str = sess.run( | |
| model.merge_summaries_op, feed_dict=pretrain_feed) | |
| sv.SummaryComputed(sess, summary_str) | |
| # Additional summary. | |
| for n, data_ngram_count in data_ngram_counts.iteritems(): | |
| avg_percent_captured = evaluation_utils.sequence_ngram_evaluation( | |
| sess, model.fake_sequence, log, pretrain_feed, data_ngram_count, | |
| int(n)) | |
| summary_percent_str = tf.Summary(value=[ | |
| tf.Summary.Value( | |
| tag='general/%s-grams_percent_correct' % n, | |
| simple_value=avg_percent_captured) | |
| ]) | |
| sv.SummaryComputed(sess, summary_percent_str, global_step=step) | |
| summary_perplexity_str = tf.Summary(value=[ | |
| tf.Summary.Value(tag='general/perplexity', simple_value=perplexity) | |
| ]) | |
| sv.SummaryComputed(sess, summary_perplexity_str, global_step=step) | |
| # Printing and logging | |
| if is_chief and step % FLAGS.print_every == 0: | |
| print('global_step: %d' % step) | |
| print(' discriminator loss: %.3f' % dis_loss_eval) | |
| print(' perplexity: %.3f' % perplexity) | |
| log.write('global_step: %d\n' % step) | |
| log.write(' discriminator loss: %.3f\n' % dis_loss_eval) | |
| log.write(' perplexity: %.3f\n' % perplexity) | |
| for n, data_ngram_count in data_ngram_counts.iteritems(): | |
| avg_percent_captured = evaluation_utils.sequence_ngram_evaluation( | |
| sess, model.fake_sequence, log, pretrain_feed, data_ngram_count, | |
| int(n)) | |
| print(' percent of %s-grams captured: %.3f.\n' % | |
| (n, avg_percent_captured)) | |
| log.write(' percent of %s-grams captured: %.3f.\n\n' % | |
| (n, avg_percent_captured)) | |
| evaluation_utils.generate_logs(sess, model, log, id_to_word, | |
| pretrain_feed) | |
| if step >= FLAGS.dis_pretrain_steps + int(FLAGS.gen_pretrain_steps or 0): | |
| is_pretraining = False | |
| break | |
| return | |