Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| r"""Train RL agent on coding tasks.""" | |
| import contextlib | |
| import cPickle | |
| import cProfile | |
| import marshal | |
| import os | |
| import time | |
| from absl import flags | |
| from absl import logging | |
| import tensorflow as tf | |
| # internal session lib import | |
| from single_task import data # brain coder | |
| from single_task import defaults # brain coder | |
| from single_task import pg_agent as agent_lib # brain coder | |
| from single_task import results_lib # brain coder | |
| FLAGS = flags.FLAGS | |
| flags.DEFINE_string( | |
| 'master', '', | |
| 'URL of the TensorFlow master to use.') | |
| flags.DEFINE_integer( | |
| 'ps_tasks', 0, | |
| 'Number of parameter server tasks. Only set to 0 for ' | |
| 'single worker training.') | |
| flags.DEFINE_integer( | |
| 'summary_interval', 10, | |
| 'How often to write summaries.') | |
| flags.DEFINE_integer( | |
| 'summary_tasks', 16, | |
| 'If greater than 0 only tasks 0 through summary_tasks - 1 ' | |
| 'will write summaries. If 0, all tasks will write ' | |
| 'summaries.') | |
| flags.DEFINE_bool( | |
| 'stop_on_success', True, | |
| 'If True, training will stop as soon as a solution is found. ' | |
| 'If False, training will continue indefinitely until another ' | |
| 'stopping condition is reached.') | |
| flags.DEFINE_bool( | |
| 'do_profiling', False, | |
| 'If True, cProfile profiler will run and results will be ' | |
| 'written to logdir. WARNING: Results will not be written if ' | |
| 'the code crashes. Make sure it exists successfully.') | |
| flags.DEFINE_integer('model_v', 0, 'Model verbosity level.') | |
| flags.DEFINE_bool( | |
| 'delayed_graph_cleanup', True, | |
| 'If true, container for n-th run will not be reset until the (n+1)-th run ' | |
| 'is complete. This greatly reduces the chance that a worker is still ' | |
| 'using the n-th container when it is cleared.') | |
| def define_tuner_hparam_space(hparam_space_type): | |
| """Define tunable hparams for grid search.""" | |
| if hparam_space_type not in ('pg', 'pg-topk', 'topk', 'is'): | |
| raise ValueError('Hparam space is not valid: "%s"' % hparam_space_type) | |
| # Discrete hparam space is stored as a dict from hparam name to discrete | |
| # values. | |
| hparam_space = {} | |
| if hparam_space_type in ('pg', 'pg-topk', 'is'): | |
| # Add a floating point parameter named learning rate. | |
| hparam_space['lr'] = [1e-5, 1e-4, 1e-3] | |
| hparam_space['entropy_beta'] = [0.005, 0.01, 0.05, 0.10] | |
| else: # 'topk' | |
| # Add a floating point parameter named learning rate. | |
| hparam_space['lr'] = [1e-5, 1e-4, 1e-3] | |
| hparam_space['entropy_beta'] = [0.0, 0.005, 0.01, 0.05, 0.10] | |
| if hparam_space_type in ('topk', 'pg-topk'): | |
| # topk tuning will be enabled. | |
| hparam_space['topk'] = [10] | |
| hparam_space['topk_loss_hparam'] = [1.0, 10.0, 50.0, 200.0] | |
| elif hparam_space_type == 'is': | |
| # importance sampling tuning will be enabled. | |
| hparam_space['replay_temperature'] = [0.25, 0.5, 1.0, 2.0] | |
| hparam_space['alpha'] = [0.5, 0.75, 63/64.] | |
| return hparam_space | |
| def write_hparams_to_config(config, hparams, hparam_space_type): | |
| """Write hparams given by the tuner into the Config object.""" | |
| if hparam_space_type not in ('pg', 'pg-topk', 'topk', 'is'): | |
| raise ValueError('Hparam space is not valid: "%s"' % hparam_space_type) | |
| config.agent.lr = hparams.lr | |
| config.agent.entropy_beta = hparams.entropy_beta | |
| if hparam_space_type in ('topk', 'pg-topk'): | |
| # topk tuning will be enabled. | |
| config.agent.topk = hparams.topk | |
| config.agent.topk_loss_hparam = hparams.topk_loss_hparam | |
| elif hparam_space_type == 'is': | |
| # importance sampling tuning will be enabled. | |
| config.agent.replay_temperature = hparams.replay_temperature | |
| config.agent.alpha = hparams.alpha | |
| def make_initialized_variable(value, name, shape=None, dtype=tf.float32): | |
| """Create a tf.Variable with a constant initializer. | |
| Args: | |
| value: Constant value to initialize the variable with. This is the value | |
| that the variable starts with. | |
| name: Name of the variable in the TF graph. | |
| shape: Shape of the variable. If None, variable will be a scalar. | |
| dtype: Data type of the variable. Should be a TF dtype. Defaults to | |
| tf.float32. | |
| Returns: | |
| tf.Variable instance. | |
| """ | |
| if shape is None: | |
| shape = [] | |
| return tf.get_variable( | |
| name=name, shape=shape, initializer=tf.constant_initializer(value), | |
| dtype=dtype, trainable=False) | |
| class AsyncTrainer(object): | |
| """Manages graph creation and training. | |
| This async trainer creates a global model on the parameter server, and a local | |
| model (for this worker). Gradient updates are sent to the global model, and | |
| the updated weights are synced to the local copy. | |
| """ | |
| def __init__(self, config, task_id, ps_tasks, num_workers, is_chief=True, | |
| summary_writer=None, | |
| dtype=tf.float32, | |
| summary_interval=1, | |
| run_number=0, | |
| logging_dir='/tmp', model_v=0): | |
| self.config = config | |
| self.data_manager = data.DataManager( | |
| config, run_number=run_number, | |
| do_code_simplification=not FLAGS.stop_on_success) | |
| self.task_id = task_id | |
| self.ps_tasks = ps_tasks | |
| self.is_chief = is_chief | |
| if ps_tasks == 0: | |
| assert task_id == 0, 'No parameter servers specified. Expecting 1 task.' | |
| assert num_workers == 1, ( | |
| 'No parameter servers specified. Expecting 1 task.') | |
| worker_device = '/job:localhost/replica:%d/task:0/cpu:0' % task_id | |
| # worker_device = '/cpu:0' | |
| # ps_device = '/cpu:0' | |
| else: | |
| assert num_workers > 0, 'There must be at least 1 training worker.' | |
| worker_device = '/job:worker/replica:%d/task:0/cpu:0' % task_id | |
| # ps_device = '/job:ps/replica:0/task:0/cpu:0' | |
| logging.info('worker_device: %s', worker_device) | |
| logging_file = os.path.join( | |
| logging_dir, 'solutions_%d.txt' % task_id) | |
| experience_replay_file = os.path.join( | |
| logging_dir, 'replay_buffer_%d.pickle' % task_id) | |
| self.topk_file = os.path.join( | |
| logging_dir, 'topk_buffer_%d.pickle' % task_id) | |
| tf.get_variable_scope().set_use_resource(True) | |
| # global model | |
| with tf.device(tf.train.replica_device_setter(ps_tasks, | |
| ps_device='/job:ps/replica:0', | |
| worker_device=worker_device)): | |
| with tf.variable_scope('global'): | |
| global_model = agent_lib.LMAgent(config, dtype=dtype, is_local=False) | |
| global_params_dict = {p.name: p | |
| for p in global_model.sync_variables} | |
| self.global_model = global_model | |
| self.global_step = make_initialized_variable( | |
| 0, 'global_step', dtype=tf.int64) | |
| self.global_best_reward = make_initialized_variable( | |
| -10.0, 'global_best_reward', dtype=tf.float64) | |
| self.is_best_model = make_initialized_variable( | |
| False, 'is_best_model', dtype=tf.bool) | |
| self.reset_is_best_model = self.is_best_model.assign(False) | |
| self.global_best_reward_placeholder = tf.placeholder( | |
| tf.float64, [], name='global_best_reward_placeholder') | |
| self.assign_global_best_reward_op = tf.group( | |
| self.global_best_reward.assign( | |
| self.global_best_reward_placeholder), | |
| self.is_best_model.assign(True)) | |
| def assign_global_best_reward_fn(session, reward): | |
| reward = round(reward, 10) | |
| best_reward = round(session.run(self.global_best_reward), 10) | |
| is_best = reward > best_reward | |
| if is_best: | |
| session.run(self.assign_global_best_reward_op, | |
| {self.global_best_reward_placeholder: reward}) | |
| return is_best | |
| self.assign_global_best_reward_fn = assign_global_best_reward_fn | |
| # Any worker will set to true when it finds a solution. | |
| self.found_solution_flag = make_initialized_variable( | |
| False, 'found_solution_flag', dtype=tf.bool) | |
| self.found_solution_op = self.found_solution_flag.assign(True) | |
| self.run_number = make_initialized_variable( | |
| run_number, 'run_number', dtype=tf.int32) | |
| # Store a solution when found. | |
| self.code_solution_variable = tf.get_variable( | |
| 'code_solution', [], tf.string, | |
| initializer=tf.constant_initializer('')) | |
| self.code_solution_ph = tf.placeholder( | |
| tf.string, [], name='code_solution_ph') | |
| self.code_solution_assign_op = self.code_solution_variable.assign( | |
| self.code_solution_ph) | |
| def assign_code_solution_fn(session, code_solution_string): | |
| session.run(self.code_solution_assign_op, | |
| {self.code_solution_ph: code_solution_string}) | |
| self.assign_code_solution_fn = assign_code_solution_fn | |
| # Count all programs sampled from policy. This does not include | |
| # programs sampled from replay buffer. | |
| # This equals NPE (number of programs executed). Only programs sampled | |
| # from the policy need to be executed. | |
| self.program_count = make_initialized_variable( | |
| 0, 'program_count', dtype=tf.int64) | |
| # local model | |
| with tf.device(worker_device): | |
| with tf.variable_scope('local'): | |
| self.model = model = agent_lib.LMAgent( | |
| config, | |
| task_id=task_id, | |
| logging_file=logging_file, | |
| experience_replay_file=experience_replay_file, | |
| dtype=dtype, | |
| global_best_reward_fn=self.assign_global_best_reward_fn, | |
| found_solution_op=self.found_solution_op, | |
| assign_code_solution_fn=self.assign_code_solution_fn, | |
| program_count=self.program_count, | |
| stop_on_success=FLAGS.stop_on_success, | |
| verbose_level=model_v) | |
| local_params = model.trainable_variables | |
| local_params_dict = {p.name: p for p in local_params} | |
| # Pull global params to local model. | |
| def _global_to_local_scope(name): | |
| assert name.startswith('global/') | |
| return 'local' + name[6:] | |
| sync_dict = { | |
| local_params_dict[_global_to_local_scope(p_name)]: p | |
| for p_name, p in global_params_dict.items()} | |
| self.sync_op = tf.group(*[v_local.assign(v_global) | |
| for v_local, v_global | |
| in sync_dict.items()]) | |
| # Pair local gradients with global params. | |
| grad_var_dict = { | |
| gradient: sync_dict[local_var] | |
| for local_var, gradient in model.gradients_dict.items()} | |
| # local model | |
| model.make_summary_ops() # Don't put summaries under 'local' scope. | |
| with tf.variable_scope('local'): | |
| self.train_op = model.optimizer.apply_gradients( | |
| grad_var_dict.items(), global_step=self.global_step) | |
| self.local_init_op = tf.variables_initializer( | |
| tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, | |
| tf.get_variable_scope().name)) | |
| self.local_step = 0 | |
| self.last_summary_time = time.time() | |
| self.summary_interval = summary_interval | |
| self.summary_writer = summary_writer | |
| self.cached_global_step = -1 | |
| self.cached_global_npe = -1 | |
| logging.info('summary_interval: %d', self.summary_interval) | |
| # Load top-k buffer. | |
| if self.model.top_episodes is not None and tf.gfile.Exists(self.topk_file): | |
| try: | |
| with tf.gfile.FastGFile(self.topk_file, 'r') as f: | |
| self.model.top_episodes = cPickle.loads(f.read()) | |
| logging.info( | |
| 'Loaded top-k buffer from disk with %d items. Location: "%s"', | |
| len(self.model.top_episodes), self.topk_file) | |
| except (cPickle.UnpicklingError, EOFError) as e: | |
| logging.warn( | |
| 'Failed to load existing top-k buffer from disk. Removing bad file.' | |
| '\nLocation: "%s"\nException: %s', self.topk_file, str(e)) | |
| tf.gfile.Remove(self.topk_file) | |
| def initialize(self, session): | |
| """Run initialization ops.""" | |
| session.run(self.local_init_op) | |
| session.run(self.sync_op) | |
| self.cached_global_step, self.cached_global_npe = session.run( | |
| [self.global_step, self.program_count]) | |
| def update_global_model(self, session): | |
| """Run an update step. | |
| 1) Asynchronously copy global weights to local model. | |
| 2) Call into local model's update_step method, which does the following: | |
| a) Sample batch of programs from policy. | |
| b) Compute rewards. | |
| c) Compute gradients and update the global model asynchronously. | |
| 3) Write tensorboard summaries to disk. | |
| Args: | |
| session: tf.Session instance. | |
| """ | |
| session.run(self.sync_op) # Copy weights from global to local. | |
| with session.as_default(): | |
| result = self.model.update_step( | |
| session, self.data_manager.sample_rl_batch(), self.train_op, | |
| self.global_step) | |
| global_step = result.global_step | |
| global_npe = result.global_npe | |
| summaries = result.summaries_list | |
| self.cached_global_step = global_step | |
| self.cached_global_npe = global_npe | |
| self.local_step += 1 | |
| if self.summary_writer and self.local_step % self.summary_interval == 0: | |
| if not isinstance(summaries, (tuple, list)): | |
| summaries = [summaries] | |
| summaries.append(self._local_step_summary()) | |
| if self.is_chief: | |
| (global_best_reward, | |
| found_solution_flag, | |
| program_count) = session.run( | |
| [self.global_best_reward, | |
| self.found_solution_flag, | |
| self.program_count]) | |
| summaries.append( | |
| tf.Summary( | |
| value=[tf.Summary.Value( | |
| tag='model/best_reward', | |
| simple_value=global_best_reward)])) | |
| summaries.append( | |
| tf.Summary( | |
| value=[tf.Summary.Value( | |
| tag='model/solution_found', | |
| simple_value=int(found_solution_flag))])) | |
| summaries.append( | |
| tf.Summary( | |
| value=[tf.Summary.Value( | |
| tag='model/program_count', | |
| simple_value=program_count)])) | |
| for s in summaries: | |
| self.summary_writer.add_summary(s, global_step) | |
| self.last_summary_time = time.time() | |
| def _local_step_summary(self): | |
| """Compute number of local steps per time increment.""" | |
| dt = time.time() - self.last_summary_time | |
| steps_per_time = self.summary_interval / float(dt) | |
| return tf.Summary(value=[ | |
| tf.Summary.Value( | |
| tag='local_step/per_sec', | |
| simple_value=steps_per_time), | |
| tf.Summary.Value( | |
| tag='local_step/step', | |
| simple_value=self.local_step)]) | |
| def maybe_save_best_model(self, session, saver, checkpoint_file): | |
| """Check if this model got the highest reward and save to disk if so.""" | |
| if self.is_chief and session.run(self.is_best_model): | |
| logging.info('Saving best model to "%s"', checkpoint_file) | |
| saver.save(session, checkpoint_file) | |
| session.run(self.reset_is_best_model) | |
| def save_replay_buffer(self): | |
| """Save replay buffer to disk. | |
| Call this periodically so that training can recover if jobs go down. | |
| """ | |
| if self.model.experience_replay is not None: | |
| logging.info('Saving experience replay buffer to "%s".', | |
| self.model.experience_replay.save_file) | |
| self.model.experience_replay.incremental_save(True) | |
| def delete_replay_buffer(self): | |
| """Delete replay buffer from disk. | |
| Call this at the end of training to clean up. Replay buffer can get very | |
| large. | |
| """ | |
| if self.model.experience_replay is not None: | |
| logging.info('Deleting experience replay buffer at "%s".', | |
| self.model.experience_replay.save_file) | |
| tf.gfile.Remove(self.model.experience_replay.save_file) | |
| def save_topk_buffer(self): | |
| """Save top-k buffer to disk. | |
| Call this periodically so that training can recover if jobs go down. | |
| """ | |
| if self.model.top_episodes is not None: | |
| logging.info('Saving top-k buffer to "%s".', self.topk_file) | |
| # Overwrite previous data each time. | |
| with tf.gfile.FastGFile(self.topk_file, 'w') as f: | |
| f.write(cPickle.dumps(self.model.top_episodes)) | |
| def managed_session(sv, master='', config=None, | |
| start_standard_services=True, | |
| close_summary_writer=True, | |
| max_wait_secs=7200): | |
| # Same as Supervisor.managed_session, but with configurable timeout. | |
| try: | |
| sess = sv.prepare_or_wait_for_session( | |
| master=master, config=config, | |
| start_standard_services=start_standard_services, | |
| max_wait_secs=max_wait_secs) | |
| yield sess | |
| except tf.errors.DeadlineExceededError: | |
| raise | |
| except Exception as e: # pylint: disable=broad-except | |
| sv.request_stop(e) | |
| finally: | |
| try: | |
| # Request all the threads to stop and wait for them to do so. Any | |
| # exception raised by the threads is raised again from stop(). | |
| # Passing stop_grace_period_secs is for blocked enqueue/dequeue | |
| # threads which are not checking for `should_stop()`. They | |
| # will be stopped when we close the session further down. | |
| sv.stop(close_summary_writer=close_summary_writer) | |
| finally: | |
| # Close the session to finish up all pending calls. We do not care | |
| # about exceptions raised when closing. This takes care of | |
| # blocked enqueue/dequeue calls. | |
| try: | |
| sess.close() | |
| except Exception: # pylint: disable=broad-except | |
| # Silently ignore exceptions raised by close(). | |
| pass | |
| def train(config, is_chief, tuner=None, run_dir=None, run_number=0, | |
| results_writer=None): | |
| """Run training loop. | |
| Args: | |
| config: config_lib.Config instance containing global config (agent and env). | |
| is_chief: True if this worker is chief. Chief worker manages writing some | |
| data to disk and initialization of the global model. | |
| tuner: A tuner instance. If not tuning, leave as None. | |
| run_dir: Directory where all data for this run will be written. If None, | |
| run_dir = FLAGS.logdir. Set this argument when doing multiple runs. | |
| run_number: Which run is this. | |
| results_writer: Managest writing training results to disk. Results are a | |
| dict of metric names and values. | |
| Returns: | |
| The trainer object used to run training updates. | |
| """ | |
| logging.info('Will run asynchronous training.') | |
| if run_dir is None: | |
| run_dir = FLAGS.logdir | |
| train_dir = os.path.join(run_dir, 'train') | |
| best_model_checkpoint = os.path.join(train_dir, 'best.ckpt') | |
| events_dir = '%s/events_%d' % (run_dir, FLAGS.task_id) | |
| logging.info('Events directory: %s', events_dir) | |
| logging_dir = os.path.join(run_dir, 'logs') | |
| if not tf.gfile.Exists(logging_dir): | |
| tf.gfile.MakeDirs(logging_dir) | |
| status_file = os.path.join(logging_dir, 'status.txt') | |
| if FLAGS.summary_tasks and FLAGS.task_id < FLAGS.summary_tasks: | |
| summary_writer = tf.summary.FileWriter(events_dir) | |
| else: | |
| summary_writer = None | |
| # Only profile task 0. | |
| if FLAGS.do_profiling: | |
| logging.info('Profiling enabled') | |
| profiler = cProfile.Profile() | |
| profiler.enable() | |
| else: | |
| profiler = None | |
| trainer = AsyncTrainer( | |
| config, FLAGS.task_id, FLAGS.ps_tasks, FLAGS.num_workers, | |
| is_chief=is_chief, | |
| summary_interval=FLAGS.summary_interval, | |
| summary_writer=summary_writer, | |
| logging_dir=logging_dir, | |
| run_number=run_number, | |
| model_v=FLAGS.model_v) | |
| variables_to_save = [v for v in tf.global_variables() | |
| if v.name.startswith('global')] | |
| global_init_op = tf.variables_initializer(variables_to_save) | |
| saver = tf.train.Saver(variables_to_save) | |
| var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, | |
| tf.get_variable_scope().name) | |
| logging.info('Trainable vars:') | |
| for v in var_list: | |
| logging.info(' %s, %s, %s', v.name, v.device, v.get_shape()) | |
| logging.info('All vars:') | |
| for v in tf.global_variables(): | |
| logging.info(' %s, %s, %s', v.name, v.device, v.get_shape()) | |
| def init_fn(unused_sess): | |
| logging.info('No checkpoint found. Initialized global params.') | |
| sv = tf.train.Supervisor(is_chief=is_chief, | |
| logdir=train_dir, | |
| saver=saver, | |
| summary_op=None, | |
| init_op=global_init_op, | |
| init_fn=init_fn, | |
| summary_writer=summary_writer, | |
| ready_op=tf.report_uninitialized_variables( | |
| variables_to_save), | |
| ready_for_local_init_op=None, | |
| global_step=trainer.global_step, | |
| save_model_secs=30, | |
| save_summaries_secs=30) | |
| # Add a thread that periodically checks if this Trial should stop | |
| # based on an early stopping policy. | |
| if tuner: | |
| sv.Loop(60, tuner.check_for_stop, (sv.coord,)) | |
| last_replay_save_time = time.time() | |
| global_step = -1 | |
| logging.info( | |
| 'Starting session. ' | |
| 'If this hangs, we\'re mostly likely waiting to connect ' | |
| 'to the parameter server. One common cause is that the parameter ' | |
| 'server DNS name isn\'t resolving yet, or is misspecified.') | |
| should_retry = True | |
| supervisor_deadline_exceeded = False | |
| while should_retry: | |
| try: | |
| with managed_session( | |
| sv, FLAGS.master, max_wait_secs=60) as session, session.as_default(): | |
| should_retry = False | |
| do_training = True | |
| try: | |
| trainer.initialize(session) | |
| if session.run(trainer.run_number) != run_number: | |
| # If we loaded existing model from disk, and the saved run number is | |
| # different, throw an exception. | |
| raise RuntimeError( | |
| 'Expecting to be on run %d, but is actually on run %d. ' | |
| 'run_dir: "%s"' | |
| % (run_number, session.run(trainer.run_number), run_dir)) | |
| global_step = trainer.cached_global_step | |
| logging.info('Starting training at step=%d', global_step) | |
| while do_training: | |
| trainer.update_global_model(session) | |
| if is_chief: | |
| trainer.maybe_save_best_model( | |
| session, saver, best_model_checkpoint) | |
| global_step = trainer.cached_global_step | |
| global_npe = trainer.cached_global_npe | |
| if time.time() - last_replay_save_time >= 30: | |
| trainer.save_replay_buffer() | |
| trainer.save_topk_buffer() | |
| last_replay_save_time = time.time() | |
| # Stopping conditions. | |
| if tuner and tuner.should_trial_stop(): | |
| logging.info('Tuner requested early stopping. Finishing.') | |
| do_training = False | |
| if is_chief and FLAGS.stop_on_success: | |
| found_solution = session.run(trainer.found_solution_flag) | |
| if found_solution: | |
| do_training = False | |
| logging.info('Solution found. Finishing.') | |
| if FLAGS.max_npe and global_npe >= FLAGS.max_npe: | |
| # Max NPE (number of programs executed) reached. | |
| logging.info('Max NPE reached. Finishing.') | |
| do_training = False | |
| if sv.should_stop(): | |
| logging.info('Supervisor issued stop. Finishing.') | |
| do_training = False | |
| except tf.errors.NotFoundError: | |
| # Catch "Error while reading resource variable". | |
| # The chief worker likely destroyed the container, so do not retry. | |
| logging.info('Caught NotFoundError. Quitting.') | |
| do_training = False | |
| should_retry = False | |
| break | |
| except tf.errors.InternalError as e: | |
| # Catch "Invalid variable reference." | |
| if str(e).startswith('Invalid variable reference.'): | |
| # The chief worker likely destroyed the container, so do not | |
| # retry. | |
| logging.info( | |
| 'Caught "InternalError: Invalid variable reference.". ' | |
| 'Quitting.') | |
| do_training = False | |
| should_retry = False | |
| break | |
| else: | |
| # Pass exception through. | |
| raise | |
| # Exited training loop. Write results to disk. | |
| if is_chief and results_writer: | |
| assert not should_retry | |
| with tf.gfile.FastGFile(status_file, 'w') as f: | |
| f.write('done') | |
| (program_count, | |
| found_solution, | |
| code_solution, | |
| best_reward, | |
| global_step) = session.run( | |
| [trainer.program_count, | |
| trainer.found_solution_flag, | |
| trainer.code_solution_variable, | |
| trainer.global_best_reward, | |
| trainer.global_step]) | |
| results_dict = { | |
| 'max_npe': FLAGS.max_npe, | |
| 'batch_size': config.batch_size, | |
| 'max_batches': FLAGS.max_npe // config.batch_size, | |
| 'npe': program_count, | |
| 'max_global_repetitions': FLAGS.num_repetitions, | |
| 'max_local_repetitions': FLAGS.num_repetitions, | |
| 'code_solution': code_solution, | |
| 'best_reward': best_reward, | |
| 'num_batches': global_step, | |
| 'found_solution': found_solution, | |
| 'task': trainer.data_manager.task_name, | |
| 'global_rep': run_number} | |
| logging.info('results_dict: %s', results_dict) | |
| results_writer.append(results_dict) | |
| except tf.errors.AbortedError: | |
| # Catch "Graph handle is not found" error due to preempted jobs. | |
| logging.info('Caught AbortedError. Retying.') | |
| should_retry = True | |
| except tf.errors.DeadlineExceededError: | |
| supervisor_deadline_exceeded = True | |
| should_retry = False | |
| if is_chief: | |
| logging.info('This is chief worker. Stopping all workers.') | |
| sv.stop() | |
| if supervisor_deadline_exceeded: | |
| logging.info('Supervisor timed out. Quitting.') | |
| else: | |
| logging.info('Reached %s steps. Worker stopped.', global_step) | |
| # Dump profiling. | |
| """ | |
| How to use profiling data. | |
| Download the profiler dump to your local machine, say to PROF_FILE_PATH. | |
| In a separate script, run something like the following: | |
| import pstats | |
| p = pstats.Stats(PROF_FILE_PATH) | |
| p.strip_dirs().sort_stats('cumtime').print_stats() | |
| This will sort by 'cumtime', which "is the cumulative time spent in this and | |
| all subfunctions (from invocation till exit)." | |
| https://docs.python.org/2/library/profile.html#instant-user-s-manual | |
| """ # pylint: disable=pointless-string-statement | |
| if profiler: | |
| prof_file = os.path.join(run_dir, 'task_%d.prof' % FLAGS.task_id) | |
| logging.info('Done profiling.\nDumping to "%s".', prof_file) | |
| profiler.create_stats() | |
| with tf.gfile.Open(prof_file, 'w') as f: | |
| f.write(marshal.dumps(profiler.stats)) | |
| return trainer | |
| def run_training(config=None, tuner=None, logdir=None, trial_name=None, | |
| is_chief=True): | |
| """Do all training runs. | |
| This is the top level training function for policy gradient based models. | |
| Run this from the main function. | |
| Args: | |
| config: config_lib.Config instance containing global config (agent and | |
| environment hparams). If None, config will be parsed from FLAGS.config. | |
| tuner: A tuner instance. Leave as None if not tuning. | |
| logdir: Parent directory where all data from all runs will be written. If | |
| None, FLAGS.logdir will be used. | |
| trial_name: If tuning, set this to a unique string that identifies this | |
| trial. If `tuner` is not None, this also must be set. | |
| is_chief: True if this worker is the chief. | |
| Returns: | |
| List of results dicts which were written to disk. Each training run gets a | |
| results dict. Results dict contains metrics, i.e. (name, value) pairs which | |
| give information about the training run. | |
| Raises: | |
| ValueError: If results dicts read from disk contain invalid data. | |
| """ | |
| if not config: | |
| # If custom config is not given, get it from flags. | |
| config = defaults.default_config_with_updates(FLAGS.config) | |
| if not logdir: | |
| logdir = FLAGS.logdir | |
| if not tf.gfile.Exists(logdir): | |
| tf.gfile.MakeDirs(logdir) | |
| assert FLAGS.num_repetitions > 0 | |
| results = results_lib.Results(logdir) | |
| results_list, _ = results.read_all() | |
| logging.info('Starting experiment. Directory: "%s"', logdir) | |
| if results_list: | |
| if results_list[0]['max_npe'] != FLAGS.max_npe: | |
| raise ValueError( | |
| 'Cannot resume training. Max-NPE changed. Was %s, now %s', | |
| results_list[0]['max_npe'], FLAGS.max_npe) | |
| if results_list[0]['max_global_repetitions'] != FLAGS.num_repetitions: | |
| raise ValueError( | |
| 'Cannot resume training. Number of repetitions changed. Was %s, ' | |
| 'now %s', | |
| results_list[0]['max_global_repetitions'], | |
| FLAGS.num_repetitions) | |
| while len(results_list) < FLAGS.num_repetitions: | |
| run_number = len(results_list) | |
| rep_container_name = trial_name if trial_name else 'container' | |
| if FLAGS.num_repetitions > 1: | |
| rep_dir = os.path.join(logdir, 'run_%d' % run_number) | |
| rep_container_name = rep_container_name + '_run_' + str(run_number) | |
| else: | |
| rep_dir = logdir | |
| logging.info( | |
| 'Starting repetition %d (%d out of %d)', run_number, run_number + 1, | |
| FLAGS.num_repetitions) | |
| # Train will write result to disk. | |
| with tf.container(rep_container_name): | |
| trainer = train(config, is_chief, tuner, rep_dir, run_number, results) | |
| logging.info('Done training.') | |
| if is_chief: | |
| # Destroy current container immediately (clears current graph). | |
| logging.info('Clearing shared variables.') | |
| tf.Session.reset(FLAGS.master, containers=[rep_container_name]) | |
| logging.info('Shared variables cleared.') | |
| # Delete replay buffer on disk. | |
| assert trainer | |
| trainer.delete_replay_buffer() | |
| else: | |
| # Give chief worker time to clean up. | |
| sleep_sec = 30.0 | |
| logging.info('Sleeping for %s sec.', sleep_sec) | |
| time.sleep(sleep_sec) | |
| tf.reset_default_graph() | |
| logging.info('Default graph reset.') | |
| # Expecting that train wrote new result to disk before returning. | |
| results_list, _ = results.read_all() | |
| return results_list | |