Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """A script to run training for sequential latent variable models. | |
| """ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import tensorflow as tf | |
| from fivo import ghmm_runners | |
| from fivo import runners | |
| # Shared flags. | |
| tf.app.flags.DEFINE_enum("mode", "train", | |
| ["train", "eval", "sample"], | |
| "The mode of the binary.") | |
| tf.app.flags.DEFINE_enum("model", "vrnn", | |
| ["vrnn", "ghmm", "srnn"], | |
| "Model choice.") | |
| tf.app.flags.DEFINE_integer("latent_size", 64, | |
| "The size of the latent state of the model.") | |
| tf.app.flags.DEFINE_enum("dataset_type", "pianoroll", | |
| ["pianoroll", "speech", "pose"], | |
| "The type of dataset.") | |
| tf.app.flags.DEFINE_string("dataset_path", "", | |
| "Path to load the dataset from.") | |
| tf.app.flags.DEFINE_integer("data_dimension", None, | |
| "The dimension of each vector in the data sequence. " | |
| "Defaults to 88 for pianoroll datasets and 200 for speech " | |
| "datasets. Should not need to be changed except for " | |
| "testing.") | |
| tf.app.flags.DEFINE_integer("batch_size", 4, | |
| "Batch size.") | |
| tf.app.flags.DEFINE_integer("num_samples", 4, | |
| "The number of samples (or particles) for multisample " | |
| "algorithms.") | |
| tf.app.flags.DEFINE_string("logdir", "/tmp/smc_vi", | |
| "The directory to keep checkpoints and summaries in.") | |
| tf.app.flags.DEFINE_integer("random_seed", None, | |
| "A random seed for seeding the TensorFlow graph.") | |
| tf.app.flags.DEFINE_integer("parallel_iterations", 30, | |
| "The number of parallel iterations to use for the while " | |
| "loop that computes the bounds.") | |
| # Training flags. | |
| tf.app.flags.DEFINE_enum("bound", "fivo", | |
| ["elbo", "iwae", "fivo", "fivo-aux"], | |
| "The bound to optimize.") | |
| tf.app.flags.DEFINE_boolean("normalize_by_seq_len", True, | |
| "If true, normalize the loss by the number of timesteps " | |
| "per sequence.") | |
| tf.app.flags.DEFINE_float("learning_rate", 0.0002, | |
| "The learning rate for ADAM.") | |
| tf.app.flags.DEFINE_integer("max_steps", int(1e9), | |
| "The number of gradient update steps to train for.") | |
| tf.app.flags.DEFINE_integer("summarize_every", 50, | |
| "The number of steps between summaries.") | |
| tf.app.flags.DEFINE_enum("resampling_type", "multinomial", | |
| ["multinomial", "relaxed"], | |
| "The resampling strategy to use for training.") | |
| tf.app.flags.DEFINE_float("relaxed_resampling_temperature", 0.5, | |
| "The relaxation temperature for relaxed resampling.") | |
| tf.app.flags.DEFINE_enum("proposal_type", "filtering", | |
| ["prior", "filtering", "smoothing", | |
| "true-filtering", "true-smoothing"], | |
| "The type of proposal to use. true-filtering and true-smoothing " | |
| "are only available for the GHMM. The specific implementation " | |
| "of each proposal type is left to model-writers.") | |
| # Distributed training flags. | |
| tf.app.flags.DEFINE_string("master", "", | |
| "The BNS name of the TensorFlow master to use.") | |
| tf.app.flags.DEFINE_integer("task", 0, | |
| "Task id of the replica running the training.") | |
| tf.app.flags.DEFINE_integer("ps_tasks", 0, | |
| "Number of tasks in the ps job. If 0 no ps job is used.") | |
| tf.app.flags.DEFINE_boolean("stagger_workers", True, | |
| "If true, bring one worker online every 1000 steps.") | |
| # Evaluation flags. | |
| tf.app.flags.DEFINE_enum("split", "train", | |
| ["train", "test", "valid"], | |
| "Split to evaluate the model on.") | |
| # Sampling flags. | |
| tf.app.flags.DEFINE_integer("sample_length", 50, | |
| "The number of timesteps to sample for.") | |
| tf.app.flags.DEFINE_integer("prefix_length", 25, | |
| "The number of timesteps to condition the model on " | |
| "before sampling.") | |
| tf.app.flags.DEFINE_string("sample_out_dir", None, | |
| "The directory to write the samples to. " | |
| "Defaults to logdir.") | |
| # GHMM flags. | |
| tf.app.flags.DEFINE_float("variance", 0.1, | |
| "The variance of the ghmm.") | |
| tf.app.flags.DEFINE_integer("num_timesteps", 5, | |
| "The number of timesteps to run the gmp for.") | |
| FLAGS = tf.app.flags.FLAGS | |
| PIANOROLL_DEFAULT_DATA_DIMENSION = 88 | |
| SPEECH_DEFAULT_DATA_DIMENSION = 200 | |
| def main(unused_argv): | |
| tf.logging.set_verbosity(tf.logging.INFO) | |
| if FLAGS.model in ["vrnn", "srnn"]: | |
| if FLAGS.data_dimension is None: | |
| if FLAGS.dataset_type == "pianoroll": | |
| FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION | |
| elif FLAGS.dataset_type == "speech": | |
| FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION | |
| if FLAGS.mode == "train": | |
| runners.run_train(FLAGS) | |
| elif FLAGS.mode == "eval": | |
| runners.run_eval(FLAGS) | |
| elif FLAGS.mode == "sample": | |
| runners.run_sample(FLAGS) | |
| elif FLAGS.model == "ghmm": | |
| if FLAGS.mode == "train": | |
| ghmm_runners.run_train(FLAGS) | |
| elif FLAGS.mode == "eval": | |
| ghmm_runners.run_eval(FLAGS) | |
| if __name__ == "__main__": | |
| tf.app.run(main) | |