Spaces:
Runtime error
Runtime error
| # Copyright 2018 Google, Inc. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """ Script that iteratively applies the unsupervised update rule and evaluates the | |
| meta-objective performance. | |
| """ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| from absl import flags | |
| from absl import app | |
| from learning_unsupervised_learning import evaluation | |
| from learning_unsupervised_learning import datasets | |
| from learning_unsupervised_learning import architectures | |
| from learning_unsupervised_learning import summary_utils | |
| from learning_unsupervised_learning import meta_objective | |
| import tensorflow as tf | |
| import sonnet as snt | |
| from tensorflow.contrib.framework.python.framework import checkpoint_utils | |
| flags.DEFINE_string("checkpoint_dir", None, "Dir to load pretrained update rule from") | |
| flags.DEFINE_string("train_log_dir", None, "Training log directory") | |
| FLAGS = flags.FLAGS | |
| def train(train_log_dir, checkpoint_dir, eval_every_n_steps=10, num_steps=3000): | |
| dataset_fn = datasets.mnist.TinyMnist | |
| w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner | |
| theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess | |
| meta_objectives = [] | |
| meta_objectives.append( | |
| meta_objective.linear_regression.LinearRegressionMetaObjective) | |
| meta_objectives.append(meta_objective.sklearn.LogisticRegression) | |
| checkpoint_vars, train_one_step_op, ( | |
| base_model, dataset) = evaluation.construct_evaluation_graph( | |
| theta_process_fn=theta_process_fn, | |
| w_learner_fn=w_learner_fn, | |
| dataset_fn=dataset_fn, | |
| meta_objectives=meta_objectives) | |
| batch = dataset() | |
| pre_logit, outputs = base_model(batch) | |
| global_step = tf.train.get_or_create_global_step() | |
| var_list = list( | |
| snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES)) | |
| tf.logging.info("all vars") | |
| for v in tf.all_variables(): | |
| tf.logging.info(" %s" % str(v)) | |
| global_step = tf.train.get_global_step() | |
| accumulate_global_step = global_step.assign_add(1) | |
| reset_global_step = global_step.assign(0) | |
| train_op = tf.group( | |
| train_one_step_op, accumulate_global_step, name="train_op") | |
| summary_op = tf.summary.merge_all() | |
| file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"]) | |
| if checkpoint_dir: | |
| str_var_list = checkpoint_utils.list_variables(checkpoint_dir) | |
| name_to_v_map = {v.op.name: v for v in tf.all_variables()} | |
| var_list = [ | |
| name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map | |
| ] | |
| saver = tf.train.Saver(var_list) | |
| missed_variables = [ | |
| v.op.name for v in set( | |
| snt.get_variables_in_scope("LocalWeightUpdateProcess", | |
| tf.GraphKeys.GLOBAL_VARIABLES)) - | |
| set(var_list) | |
| ] | |
| assert len(missed_variables) == 0, "Missed a theta variable." | |
| hooks = [] | |
| with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess: | |
| # global step should be restored from the evals job checkpoint or zero for fresh. | |
| step = sess.run(global_step) | |
| if step == 0 and checkpoint_dir: | |
| tf.logging.info("force restore") | |
| saver.restore(sess, checkpoint_dir) | |
| tf.logging.info("force restore done") | |
| sess.run(reset_global_step) | |
| step = sess.run(global_step) | |
| while step < num_steps: | |
| if step % eval_every_n_steps == 0: | |
| s, _, step = sess.run([summary_op, train_op, global_step]) | |
| file_writer.add_summary(s, step) | |
| else: | |
| _, step = sess.run([train_op, global_step]) | |
| def main(argv): | |
| train(FLAGS.train_log_dir, FLAGS.checkpoint_dir) | |
| if __name__ == "__main__": | |
| app.run(main) | |