Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """ | |
| Tracks and saves training progress (models and other data such as the current | |
| location in the lm1b corpus) for later reloading. | |
| """ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import tensorflow as tf | |
| from base import utils | |
| from corpus_processing import unlabeled_data | |
| class TrainingProgress(object): | |
| def __init__(self, config, sess, checkpoint_saver, best_model_saver, | |
| restore_if_possible=True): | |
| self.config = config | |
| self.checkpoint_saver = checkpoint_saver | |
| self.best_model_saver = best_model_saver | |
| tf.gfile.MakeDirs(config.checkpoints_dir) | |
| if restore_if_possible and tf.gfile.Exists(config.progress): | |
| history, current_file, current_line = utils.load_cpickle( | |
| config.progress, memoized=False) | |
| self.history = history | |
| self.unlabeled_data_reader = unlabeled_data.UnlabeledDataReader( | |
| config, current_file, current_line) | |
| utils.log("Continuing from global step", dict(self.history[-1])["step"], | |
| "(lm1b file {:}, line {:})".format(current_file, current_line)) | |
| self.checkpoint_saver.restore(sess, tf.train.latest_checkpoint( | |
| self.config.checkpoints_dir)) | |
| else: | |
| utils.log("No previous checkpoint found - starting from scratch") | |
| self.history = [] | |
| self.unlabeled_data_reader = ( | |
| unlabeled_data.UnlabeledDataReader(config)) | |
| def write(self, sess, global_step): | |
| self.checkpoint_saver.save(sess, self.config.checkpoint, | |
| global_step=global_step) | |
| utils.write_cpickle( | |
| (self.history, self.unlabeled_data_reader.current_file, | |
| self.unlabeled_data_reader.current_line), | |
| self.config.progress) | |
| def save_if_best_dev_model(self, sess, global_step): | |
| best_avg_score = 0 | |
| for i, results in enumerate(self.history): | |
| if any("train" in metric for metric, value in results): | |
| continue | |
| total, count = 0, 0 | |
| for metric, value in results: | |
| if "f1" in metric or "las" in metric or "accuracy" in metric: | |
| total += value | |
| count += 1 | |
| avg_score = total / count | |
| if avg_score >= best_avg_score: | |
| best_avg_score = avg_score | |
| if i == len(self.history) - 1: | |
| utils.log("New best model! Saving...") | |
| self.best_model_saver.save(sess, self.config.best_model_checkpoint, | |
| global_step=global_step) | |