Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| """Tests for pg_train. | |
| These tests excersize code paths available through configuration options. | |
| Training will be run for just a few steps with the goal being to check that | |
| nothing crashes. | |
| """ | |
| from absl import flags | |
| import tensorflow as tf | |
| from single_task import defaults # brain coder | |
| from single_task import run # brain coder | |
| FLAGS = flags.FLAGS | |
| class TrainTest(tf.test.TestCase): | |
| def RunTrainingSteps(self, config_string, num_steps=10): | |
| """Run a few training steps with the given config. | |
| Just check that nothing crashes. | |
| Args: | |
| config_string: Config encoded in a string. See | |
| $REPO_PATH/common/config_lib.py | |
| num_steps: Number of training steps to run. Defaults to 10. | |
| """ | |
| config = defaults.default_config_with_updates(config_string) | |
| FLAGS.master = '' | |
| FLAGS.max_npe = num_steps * config.batch_size | |
| FLAGS.summary_interval = 1 | |
| FLAGS.logdir = tf.test.get_temp_dir() | |
| FLAGS.config = config_string | |
| tf.reset_default_graph() | |
| run.main(None) | |
| def testVanillaPolicyGradient(self): | |
| self.RunTrainingSteps( | |
| 'env=c(task="reverse"),' | |
| 'agent=c(algorithm="pg"),' | |
| 'timestep_limit=90,batch_size=64') | |
| def testVanillaPolicyGradient_VariableLengthSequences(self): | |
| self.RunTrainingSteps( | |
| 'env=c(task="reverse"),' | |
| 'agent=c(algorithm="pg",eos_token=False),' | |
| 'timestep_limit=90,batch_size=64') | |
| def testVanillaActorCritic(self): | |
| self.RunTrainingSteps( | |
| 'env=c(task="reverse"),' | |
| 'agent=c(algorithm="pg",ema_baseline_decay=0.0),' | |
| 'timestep_limit=90,batch_size=64') | |
| def testPolicyGradientWithTopK(self): | |
| self.RunTrainingSteps( | |
| 'env=c(task="reverse"),' | |
| 'agent=c(algorithm="pg",topk_loss_hparam=1.0,topk=10),' | |
| 'timestep_limit=90,batch_size=64') | |
| def testVanillaActorCriticWithTopK(self): | |
| self.RunTrainingSteps( | |
| 'env=c(task="reverse"),' | |
| 'agent=c(algorithm="pg",ema_baseline_decay=0.0,topk_loss_hparam=1.0,' | |
| 'topk=10),' | |
| 'timestep_limit=90,batch_size=64') | |
| def testPolicyGradientWithTopK_VariableLengthSequences(self): | |
| self.RunTrainingSteps( | |
| 'env=c(task="reverse"),' | |
| 'agent=c(algorithm="pg",topk_loss_hparam=1.0,topk=10,eos_token=False),' | |
| 'timestep_limit=90,batch_size=64') | |
| def testPolicyGradientWithImportanceSampling(self): | |
| self.RunTrainingSteps( | |
| 'env=c(task="reverse"),' | |
| 'agent=c(algorithm="pg",alpha=0.5),' | |
| 'timestep_limit=90,batch_size=64') | |
| if __name__ == '__main__': | |
| tf.test.main() | |