Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| r"""""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| from collections import namedtuple | |
| import os | |
| import time | |
| import tensorflow as tf | |
| import gin.tf | |
| flags = tf.app.flags | |
| flags.DEFINE_multi_string('config_file', None, | |
| 'List of paths to the config files.') | |
| flags.DEFINE_multi_string('params', None, | |
| 'Newline separated list of Gin parameter bindings.') | |
| flags.DEFINE_string('train_dir', None, | |
| 'Directory for writing logs/summaries during training.') | |
| flags.DEFINE_string('master', 'local', | |
| 'BNS name of the TensorFlow master to use.') | |
| flags.DEFINE_integer('task', 0, 'task id') | |
| flags.DEFINE_integer('save_interval_secs', 300, 'The frequency at which ' | |
| 'checkpoints are saved, in seconds.') | |
| flags.DEFINE_integer('save_summaries_secs', 30, 'The frequency at which ' | |
| 'summaries are saved, in seconds.') | |
| flags.DEFINE_boolean('summarize_gradients', False, | |
| 'Whether to generate gradient summaries.') | |
| FLAGS = flags.FLAGS | |
| TrainOps = namedtuple('TrainOps', | |
| ['train_op', 'meta_train_op', 'collect_experience_op']) | |
| class TrainStep(object): | |
| """Handles training step.""" | |
| def __init__(self, | |
| max_number_of_steps=0, | |
| num_updates_per_observation=1, | |
| num_collect_per_update=1, | |
| num_collect_per_meta_update=1, | |
| log_every_n_steps=1, | |
| policy_save_fn=None, | |
| save_policy_every_n_steps=0, | |
| should_stop_early=None): | |
| """Returns a function that is executed at each step of slim training. | |
| Args: | |
| max_number_of_steps: Optional maximum number of train steps to take. | |
| num_updates_per_observation: Number of updates per observation. | |
| log_every_n_steps: The frequency, in terms of global steps, that the loss | |
| and global step and logged. | |
| policy_save_fn: A tf.Saver().save function to save the policy. | |
| save_policy_every_n_steps: How frequently to save the policy. | |
| should_stop_early: Optional hook to report whether training should stop. | |
| Raises: | |
| ValueError: If policy_save_fn is not provided when | |
| save_policy_every_n_steps > 0. | |
| """ | |
| if save_policy_every_n_steps and policy_save_fn is None: | |
| raise ValueError( | |
| 'policy_save_fn is required when save_policy_every_n_steps > 0') | |
| self.max_number_of_steps = max_number_of_steps | |
| self.num_updates_per_observation = num_updates_per_observation | |
| self.num_collect_per_update = num_collect_per_update | |
| self.num_collect_per_meta_update = num_collect_per_meta_update | |
| self.log_every_n_steps = log_every_n_steps | |
| self.policy_save_fn = policy_save_fn | |
| self.save_policy_every_n_steps = save_policy_every_n_steps | |
| self.should_stop_early = should_stop_early | |
| self.last_global_step_val = 0 | |
| self.train_op_fn = None | |
| self.collect_and_train_fn = None | |
| tf.logging.info('Training for %d max_number_of_steps', | |
| self.max_number_of_steps) | |
| def train_step(self, sess, train_ops, global_step, _): | |
| """This function will be called at each step of training. | |
| This represents one step of the DDPG algorithm and can include: | |
| 1. collect a <state, action, reward, next_state> transition | |
| 2. update the target network | |
| 3. train the actor | |
| 4. train the critic | |
| Args: | |
| sess: A Tensorflow session. | |
| train_ops: A DdpgTrainOps tuple of train ops to run. | |
| global_step: The global step. | |
| Returns: | |
| A scalar total loss. | |
| A boolean should stop. | |
| """ | |
| start_time = time.time() | |
| if self.train_op_fn is None: | |
| self.train_op_fn = sess.make_callable([train_ops.train_op, global_step]) | |
| self.meta_train_op_fn = sess.make_callable([train_ops.meta_train_op, global_step]) | |
| self.collect_fn = sess.make_callable([train_ops.collect_experience_op, global_step]) | |
| self.collect_and_train_fn = sess.make_callable( | |
| [train_ops.train_op, global_step, train_ops.collect_experience_op]) | |
| self.collect_and_meta_train_fn = sess.make_callable( | |
| [train_ops.meta_train_op, global_step, train_ops.collect_experience_op]) | |
| for _ in range(self.num_collect_per_update - 1): | |
| self.collect_fn() | |
| for _ in range(self.num_updates_per_observation - 1): | |
| self.train_op_fn() | |
| total_loss, global_step_val, _ = self.collect_and_train_fn() | |
| if (global_step_val // self.num_collect_per_meta_update != | |
| self.last_global_step_val // self.num_collect_per_meta_update): | |
| self.meta_train_op_fn() | |
| time_elapsed = time.time() - start_time | |
| should_stop = False | |
| if self.max_number_of_steps: | |
| should_stop = global_step_val >= self.max_number_of_steps | |
| if global_step_val != self.last_global_step_val: | |
| if (self.save_policy_every_n_steps and | |
| global_step_val // self.save_policy_every_n_steps != | |
| self.last_global_step_val // self.save_policy_every_n_steps): | |
| self.policy_save_fn(sess) | |
| if (self.log_every_n_steps and | |
| global_step_val % self.log_every_n_steps == 0): | |
| tf.logging.info( | |
| 'global step %d: loss = %.4f (%.3f sec/step) (%d steps/sec)', | |
| global_step_val, total_loss, time_elapsed, 1 / time_elapsed) | |
| self.last_global_step_val = global_step_val | |
| stop_early = bool(self.should_stop_early and self.should_stop_early()) | |
| return total_loss, should_stop or stop_early | |
| def create_counter_summaries(counters): | |
| """Add named summaries to counters, a list of tuples (name, counter).""" | |
| if counters: | |
| with tf.name_scope('Counters/'): | |
| for name, counter in counters: | |
| tf.summary.scalar(name, counter) | |
| def gen_debug_batch_summaries(batch): | |
| """Generates summaries for the sampled replay batch.""" | |
| states, actions, rewards, _, next_states = batch | |
| with tf.name_scope('batch'): | |
| for s in range(states.get_shape()[-1]): | |
| tf.summary.histogram('states_%d' % s, states[:, s]) | |
| for s in range(states.get_shape()[-1]): | |
| tf.summary.histogram('next_states_%d' % s, next_states[:, s]) | |
| for a in range(actions.get_shape()[-1]): | |
| tf.summary.histogram('actions_%d' % a, actions[:, a]) | |
| tf.summary.histogram('rewards', rewards) | |