Spaces:
Runtime error
Runtime error
| # Copyright 2017 Google Inc. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # ============================================================================== | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| from lfads import LFADS | |
| import numpy as np | |
| import os | |
| import tensorflow as tf | |
| import re | |
| import utils | |
| import sys | |
| MAX_INT = sys.maxsize | |
| # Lots of hyperparameters, but most are pretty insensitive. The | |
| # explanation of these hyperparameters is found below, in the flags | |
| # session. | |
| CHECKPOINT_PB_LOAD_NAME = "checkpoint" | |
| CHECKPOINT_NAME = "lfads_vae" | |
| CSV_LOG = "fitlog" | |
| OUTPUT_FILENAME_STEM = "" | |
| DEVICE = "gpu:0" # "cpu:0", or other gpus, e.g. "gpu:1" | |
| MAX_CKPT_TO_KEEP = 5 | |
| MAX_CKPT_TO_KEEP_LVE = 5 | |
| PS_NEXAMPLES_TO_PROCESS = MAX_INT # if larger than number of examples, process all | |
| EXT_INPUT_DIM = 0 | |
| IC_DIM = 64 | |
| FACTORS_DIM = 50 | |
| IC_ENC_DIM = 128 | |
| GEN_DIM = 200 | |
| GEN_CELL_INPUT_WEIGHT_SCALE = 1.0 | |
| GEN_CELL_REC_WEIGHT_SCALE = 1.0 | |
| CELL_WEIGHT_SCALE = 1.0 | |
| BATCH_SIZE = 128 | |
| LEARNING_RATE_INIT = 0.01 | |
| LEARNING_RATE_DECAY_FACTOR = 0.95 | |
| LEARNING_RATE_STOP = 0.00001 | |
| LEARNING_RATE_N_TO_COMPARE = 6 | |
| INJECT_EXT_INPUT_TO_GEN = False | |
| DO_TRAIN_IO_ONLY = False | |
| DO_TRAIN_ENCODER_ONLY = False | |
| DO_RESET_LEARNING_RATE = False | |
| FEEDBACK_FACTORS_OR_RATES = "factors" | |
| DO_TRAIN_READIN = True | |
| # Calibrated just above the average value for the rnn synthetic data. | |
| MAX_GRAD_NORM = 200.0 | |
| CELL_CLIP_VALUE = 5.0 | |
| KEEP_PROB = 0.95 | |
| TEMPORAL_SPIKE_JITTER_WIDTH = 0 | |
| OUTPUT_DISTRIBUTION = 'poisson' # 'poisson' or 'gaussian' | |
| NUM_STEPS_FOR_GEN_IC = MAX_INT # set to num_steps if greater than num_steps | |
| DATA_DIR = "/tmp/rnn_synth_data_v1.0/" | |
| DATA_FILENAME_STEM = "chaotic_rnn_inputs_g1p5" | |
| LFADS_SAVE_DIR = "/tmp/lfads_chaotic_rnn_inputs_g1p5/" | |
| CO_DIM = 1 | |
| DO_CAUSAL_CONTROLLER = False | |
| DO_FEED_FACTORS_TO_CONTROLLER = True | |
| CONTROLLER_INPUT_LAG = 1 | |
| PRIOR_AR_AUTOCORRELATION = 10.0 | |
| PRIOR_AR_PROCESS_VAR = 0.1 | |
| DO_TRAIN_PRIOR_AR_ATAU = True | |
| DO_TRAIN_PRIOR_AR_NVAR = True | |
| CI_ENC_DIM = 128 | |
| CON_DIM = 128 | |
| CO_PRIOR_VAR_SCALE = 0.1 | |
| KL_INCREASE_STEPS = 2000 | |
| L2_INCREASE_STEPS = 2000 | |
| L2_GEN_SCALE = 2000.0 | |
| L2_CON_SCALE = 0.0 | |
| # scale of regularizer on time correlation of inferred inputs | |
| CO_MEAN_CORR_SCALE = 0.0 | |
| KL_IC_WEIGHT = 1.0 | |
| KL_CO_WEIGHT = 1.0 | |
| KL_START_STEP = 0 | |
| L2_START_STEP = 0 | |
| IC_PRIOR_VAR_MIN = 0.1 | |
| IC_PRIOR_VAR_SCALE = 0.1 | |
| IC_PRIOR_VAR_MAX = 0.1 | |
| IC_POST_VAR_MIN = 0.0001 # protection from KL blowing up | |
| flags = tf.app.flags | |
| flags.DEFINE_string("kind", "train", | |
| "Type of model to build {train, \ | |
| posterior_sample_and_average, \ | |
| posterior_push_mean, \ | |
| prior_sample, write_model_params") | |
| flags.DEFINE_string("output_dist", OUTPUT_DISTRIBUTION, | |
| "Type of output distribution, 'poisson' or 'gaussian'") | |
| flags.DEFINE_boolean("allow_gpu_growth", False, | |
| "If true, only allocate amount of memory needed for \ | |
| Session. Otherwise, use full GPU memory.") | |
| # DATA | |
| flags.DEFINE_string("data_dir", DATA_DIR, "Data for training") | |
| flags.DEFINE_string("data_filename_stem", DATA_FILENAME_STEM, | |
| "Filename stem for data dictionaries.") | |
| flags.DEFINE_string("lfads_save_dir", LFADS_SAVE_DIR, "model save dir") | |
| flags.DEFINE_string("checkpoint_pb_load_name", CHECKPOINT_PB_LOAD_NAME, | |
| "Name of checkpoint files, use 'checkpoint_lve' for best \ | |
| error") | |
| flags.DEFINE_string("checkpoint_name", CHECKPOINT_NAME, | |
| "Name of checkpoint files (.ckpt appended)") | |
| flags.DEFINE_string("output_filename_stem", OUTPUT_FILENAME_STEM, | |
| "Name of output file (postfix will be added)") | |
| flags.DEFINE_string("device", DEVICE, | |
| "Which device to use (default: \"gpu:0\", can also be \ | |
| \"cpu:0\", \"gpu:1\", etc)") | |
| flags.DEFINE_string("csv_log", CSV_LOG, | |
| "Name of file to keep running log of fit likelihoods, \ | |
| etc (.csv appended)") | |
| flags.DEFINE_integer("max_ckpt_to_keep", MAX_CKPT_TO_KEEP, | |
| "Max # of checkpoints to keep (rolling)") | |
| flags.DEFINE_integer("ps_nexamples_to_process", PS_NEXAMPLES_TO_PROCESS, | |
| "Number of examples to process for posterior sample and \ | |
| average (not number of samples to average over).") | |
| flags.DEFINE_integer("max_ckpt_to_keep_lve", MAX_CKPT_TO_KEEP_LVE, | |
| "Max # of checkpoints to keep for lowest validation error \ | |
| models (rolling)") | |
| flags.DEFINE_integer("ext_input_dim", EXT_INPUT_DIM, "Dimension of external \ | |
| inputs") | |
| flags.DEFINE_integer("num_steps_for_gen_ic", NUM_STEPS_FOR_GEN_IC, | |
| "Number of steps to train the generator initial conditon.") | |
| # If there are observed inputs, there are two ways to add that observed | |
| # input to the model. The first is by treating as something to be | |
| # inferred, and thus encoding the observed input via the encoders, and then | |
| # input to the generator via the "inferred inputs" channel. Second, one | |
| # can input the input directly into the generator. This has the downside | |
| # of making the generation process strictly dependent on knowing the | |
| # observed input for any generated trial. | |
| flags.DEFINE_boolean("inject_ext_input_to_gen", | |
| INJECT_EXT_INPUT_TO_GEN, | |
| "Should observed inputs be input to model via encoders, \ | |
| or injected directly into generator?") | |
| # CELL | |
| # The combined recurrent and input weights of the encoder and | |
| # controller cells are by default set to scale at ws/sqrt(#inputs), | |
| # with ws=1.0. You can change this scaling with this parameter. | |
| flags.DEFINE_float("cell_weight_scale", CELL_WEIGHT_SCALE, | |
| "Input scaling for input weights in generator.") | |
| # GENERATION | |
| # Note that the dimension of the initial conditions is separated from the | |
| # dimensions of the generator initial conditions (and a linear matrix will | |
| # adapt the shapes if necessary). This is just another way to control | |
| # complexity. In all likelihood, setting the ic dims to the size of the | |
| # generator hidden state is just fine. | |
| flags.DEFINE_integer("ic_dim", IC_DIM, "Dimension of h0") | |
| # Setting the dimensions of the factors to something smaller than the data | |
| # dimension is a way to get a reduced dimensionality representation of your | |
| # data. | |
| flags.DEFINE_integer("factors_dim", FACTORS_DIM, | |
| "Number of factors from generator") | |
| flags.DEFINE_integer("ic_enc_dim", IC_ENC_DIM, | |
| "Cell hidden size, encoder of h0") | |
| # Controlling the size of the generator is one way to control complexity of | |
| # the dynamics (there is also l2, which will squeeze out unnecessary | |
| # dynamics also). The modern deep learning approach is to make these cells | |
| # as large as tolerable (from a waiting perspective), and then regularize | |
| # them to death with drop out or whatever. I don't know if this is correct | |
| # for the LFADS application or not. | |
| flags.DEFINE_integer("gen_dim", GEN_DIM, | |
| "Cell hidden size, generator.") | |
| # The weights of the generator cell by default set to scale at | |
| # ws/sqrt(#inputs), with ws=1.0. You can change ws for | |
| # the input weights or the recurrent weights with these hyperparameters. | |
| flags.DEFINE_float("gen_cell_input_weight_scale", GEN_CELL_INPUT_WEIGHT_SCALE, | |
| "Input scaling for input weights in generator.") | |
| flags.DEFINE_float("gen_cell_rec_weight_scale", GEN_CELL_REC_WEIGHT_SCALE, | |
| "Input scaling for rec weights in generator.") | |
| # KL DISTRIBUTIONS | |
| # If you don't know what you are donig here, please leave alone, the | |
| # defaults should be fine for most cases, irregardless of other parameters. | |
| # | |
| # If you don't want the prior variance to be learned, set the | |
| # following values to the same thing: ic_prior_var_min, | |
| # ic_prior_var_scale, ic_prior_var_max. The prior mean will be | |
| # learned regardless. | |
| flags.DEFINE_float("ic_prior_var_min", IC_PRIOR_VAR_MIN, | |
| "Minimum variance in posterior h0 codes.") | |
| flags.DEFINE_float("ic_prior_var_scale", IC_PRIOR_VAR_SCALE, | |
| "Variance of ic prior distribution") | |
| flags.DEFINE_float("ic_prior_var_max", IC_PRIOR_VAR_MAX, | |
| "Maximum variance of IC prior distribution.") | |
| # If you really want to limit the information from encoder to decoder, | |
| # Increase ic_post_var_min above 0.0. | |
| flags.DEFINE_float("ic_post_var_min", IC_POST_VAR_MIN, | |
| "Minimum variance of IC posterior distribution.") | |
| flags.DEFINE_float("co_prior_var_scale", CO_PRIOR_VAR_SCALE, | |
| "Variance of control input prior distribution.") | |
| flags.DEFINE_float("prior_ar_atau", PRIOR_AR_AUTOCORRELATION, | |
| "Initial autocorrelation of AR(1) priors.") | |
| flags.DEFINE_float("prior_ar_nvar", PRIOR_AR_PROCESS_VAR, | |
| "Initial noise variance for AR(1) priors.") | |
| flags.DEFINE_boolean("do_train_prior_ar_atau", DO_TRAIN_PRIOR_AR_ATAU, | |
| "Is the value for atau an init, or the constant value?") | |
| flags.DEFINE_boolean("do_train_prior_ar_nvar", DO_TRAIN_PRIOR_AR_NVAR, | |
| "Is the value for noise variance an init, or the constant \ | |
| value?") | |
| # CONTROLLER | |
| # This parameter critically controls whether or not there is a controller | |
| # (along with controller encoders placed into the LFADS graph. If CO_DIM > | |
| # 1, that means there is a 1 dimensional controller outputs, if equal to 0, | |
| # then no controller. | |
| flags.DEFINE_integer("co_dim", CO_DIM, | |
| "Number of control net outputs (>0 builds that graph).") | |
| # The controller will be more powerful if it can see the encoding of the entire | |
| # trial. However, this allows the controller to create inferred inputs that are | |
| # acausal with respect to the actual data generation process. E.g. the data | |
| # generator could have an input at time t, but the controller, after seeing the | |
| # entirety of the trial could infer that the input is coming a little before | |
| # time t, because there are no restrictions on the data the controller sees. | |
| # One can force the controller to be causal (with respect to perturbations in | |
| # the data generator) so that it only sees forward encodings of the data at time | |
| # t that originate at times before or at time t. One can also control the data | |
| # the controller sees by using an input lag (forward encoding at time [t-tlag] | |
| # for controller input at time t. The same can be done in the reverse direction | |
| # (controller input at time t from reverse encoding at time [t+tlag], in the | |
| # case of an acausal controller). Setting this lag > 0 (even lag=1) can be a | |
| # powerful way of avoiding very spiky decodes. Finally, one can manually control | |
| # whether the factors at time t-1 are fed to the controller at time t. | |
| # | |
| # If you don't care about any of this, and just want to smooth your data, set | |
| # do_causal_controller = False | |
| # do_feed_factors_to_controller = True | |
| # causal_input_lag = 0 | |
| flags.DEFINE_boolean("do_causal_controller", | |
| DO_CAUSAL_CONTROLLER, | |
| "Restrict the controller create only causal inferred \ | |
| inputs?") | |
| # Strictly speaking, feeding either the factors or the rates to the controller | |
| # violates causality, since the g0 gets to see all the data. This may or may not | |
| # be only a theoretical concern. | |
| flags.DEFINE_boolean("do_feed_factors_to_controller", | |
| DO_FEED_FACTORS_TO_CONTROLLER, | |
| "Should factors[t-1] be input to controller at time t?") | |
| flags.DEFINE_string("feedback_factors_or_rates", FEEDBACK_FACTORS_OR_RATES, | |
| "Feedback the factors or the rates to the controller? \ | |
| Acceptable values: 'factors' or 'rates'.") | |
| flags.DEFINE_integer("controller_input_lag", CONTROLLER_INPUT_LAG, | |
| "Time lag on the encoding to controller t-lag for \ | |
| forward, t+lag for reverse.") | |
| flags.DEFINE_integer("ci_enc_dim", CI_ENC_DIM, | |
| "Cell hidden size, encoder of control inputs") | |
| flags.DEFINE_integer("con_dim", CON_DIM, | |
| "Cell hidden size, controller") | |
| # OPTIMIZATION | |
| flags.DEFINE_integer("batch_size", BATCH_SIZE, | |
| "Batch size to use during training.") | |
| flags.DEFINE_float("learning_rate_init", LEARNING_RATE_INIT, | |
| "Learning rate initial value") | |
| flags.DEFINE_float("learning_rate_decay_factor", LEARNING_RATE_DECAY_FACTOR, | |
| "Learning rate decay, decay by this fraction every so \ | |
| often.") | |
| flags.DEFINE_float("learning_rate_stop", LEARNING_RATE_STOP, | |
| "The lr is adaptively reduced, stop training at this value.") | |
| # Rather put the learning rate on an exponentially decreasiong schedule, | |
| # the current algorithm pays attention to the learning rate, and if it | |
| # isn't regularly decreasing, it will decrease the learning rate. So far, | |
| # it works fine, though it is not perfect. | |
| flags.DEFINE_integer("learning_rate_n_to_compare", LEARNING_RATE_N_TO_COMPARE, | |
| "Number of previous costs current cost has to be worse \ | |
| than, to lower learning rate.") | |
| # This sets a value, above which, the gradients will be clipped. This hp | |
| # is extremely useful to avoid an infrequent, but highly pathological | |
| # problem whereby the gradient is so large that it destroys the | |
| # optimziation by setting parameters too large, leading to a vicious cycle | |
| # that ends in NaNs. If it's too large, it's useless, if it's too small, | |
| # it essentially becomes the learning rate. It's pretty insensitive, though. | |
| flags.DEFINE_float("max_grad_norm", MAX_GRAD_NORM, | |
| "Max norm of gradient before clipping.") | |
| # If your optimizations start "NaN-ing out", reduce this value so that | |
| # the values of the network don't grow out of control. Typically, once | |
| # this parameter is set to a reasonable value, one stops having numerical | |
| # problems. | |
| flags.DEFINE_float("cell_clip_value", CELL_CLIP_VALUE, | |
| "Max value recurrent cell can take before being clipped.") | |
| # This flag is used for an experiment where one sees if training a model with | |
| # many days data can be used to learn the dynamics from a held-out days data. | |
| # If you don't care about that particular experiment, this flag should always be | |
| # false. | |
| flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY, | |
| "Train only the input (readin) and output (readout) \ | |
| affine functions.") | |
| # This flag is used for an experiment where one wants to know if the dynamics | |
| # learned by the generator generalize across conditions. In that case, you might | |
| # train up a model on one set of data, and then only further train the encoder | |
| # on another set of data (the conditions to be tested) so that the model is | |
| # forced to use the same dynamics to describe that data. If you don't care about | |
| # that particular experiment, this flag should always be false. | |
| flags.DEFINE_boolean("do_train_encoder_only", DO_TRAIN_ENCODER_ONLY, | |
| "Train only the encoder weights.") | |
| flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE, | |
| "Reset the learning rate to initial value.") | |
| # for multi-session "stitching" models, the per-session readin matrices map from | |
| # neurons to input factors which are fed into the shared encoder. These are | |
| # initialized by alignment_matrix_cxf and alignment_bias_c in the input .h5 | |
| # files. They can be fixed or made trainable. | |
| flags.DEFINE_boolean("do_train_readin", DO_TRAIN_READIN, "Whether to train the \ | |
| readin matrices and bias vectors. False leaves them fixed \ | |
| at their initial values specified by the alignment \ | |
| matrices and vectors.") | |
| # OVERFITTING | |
| # Dropout is done on the input data, on controller inputs (from | |
| # encoder), on outputs from generator to factors. | |
| flags.DEFINE_float("keep_prob", KEEP_PROB, "Dropout keep probability.") | |
| # It appears that the system will happily fit spikes (blessing or | |
| # curse, depending). You may not want this. Jittering the spikes a | |
| # bit will help (-/+ bin size, as specified here). | |
| flags.DEFINE_integer("temporal_spike_jitter_width", | |
| TEMPORAL_SPIKE_JITTER_WIDTH, | |
| "Shuffle spikes around this window.") | |
| # General note about helping ascribe controller inputs vs dynamics: | |
| # | |
| # If controller is heavily penalized, then it won't have any output. | |
| # If dynamics are heavily penalized, then generator won't make | |
| # dynamics. Note this l2 penalty is only on the recurrent portion of | |
| # the RNNs, as dropout is also available, penalizing the feed-forward | |
| # connections. | |
| flags.DEFINE_float("l2_gen_scale", L2_GEN_SCALE, | |
| "L2 regularization cost for the generator only.") | |
| flags.DEFINE_float("l2_con_scale", L2_CON_SCALE, | |
| "L2 regularization cost for the controller only.") | |
| flags.DEFINE_float("co_mean_corr_scale", CO_MEAN_CORR_SCALE, | |
| "Cost of correlation (thru time)in the means of \ | |
| controller output.") | |
| # UNDERFITTING | |
| # If the primary task of LFADS is "filtering" of data and not | |
| # generation, then it is possible that the KL penalty is too strong. | |
| # Empirically, we have found this to be the case. So we add a | |
| # hyperparameter in front of the the two KL terms (one for the initial | |
| # conditions to the generator, the other for the controller outputs). | |
| # You should always think of the the default values as 1.0, and that | |
| # leads to a standard VAE formulation whereby the numbers that are | |
| # optimized are a lower-bound on the log-likelihood of the data. When | |
| # these 2 HPs deviate from 1.0, one cannot make any statement about | |
| # what those LL lower bounds mean anymore, and they cannot be compared | |
| # (AFAIK). | |
| flags.DEFINE_float("kl_ic_weight", KL_IC_WEIGHT, | |
| "Strength of KL weight on initial conditions KL penatly.") | |
| flags.DEFINE_float("kl_co_weight", KL_CO_WEIGHT, | |
| "Strength of KL weight on controller output KL penalty.") | |
| # Sometimes the task can be sufficiently hard to learn that the | |
| # optimizer takes the 'easy route', and simply minimizes the KL | |
| # divergence, setting it to near zero, and the optimization gets | |
| # stuck. These two parameters will help avoid that by by getting the | |
| # optimization to 'latch' on to the main optimization, and only | |
| # turning in the regularizers later. | |
| flags.DEFINE_integer("kl_start_step", KL_START_STEP, | |
| "Start increasing weight after this many steps.") | |
| # training passes, not epochs, increase by 0.5 every kl_increase_steps | |
| flags.DEFINE_integer("kl_increase_steps", KL_INCREASE_STEPS, | |
| "Increase weight of kl cost to avoid local minimum.") | |
| # Same story for l2 regularizer. One wants a simple generator, for scientific | |
| # reasons, but not at the expense of hosing the optimization. | |
| flags.DEFINE_integer("l2_start_step", L2_START_STEP, | |
| "Start increasing l2 weight after this many steps.") | |
| flags.DEFINE_integer("l2_increase_steps", L2_INCREASE_STEPS, | |
| "Increase weight of l2 cost to avoid local minimum.") | |
| FLAGS = flags.FLAGS | |
| def build_model(hps, kind="train", datasets=None): | |
| """Builds a model from either random initialization, or saved parameters. | |
| Args: | |
| hps: The hyper parameters for the model. | |
| kind: (optional) The kind of model to build. Training vs inference require | |
| different graphs. | |
| datasets: The datasets structure (see top of lfads.py). | |
| Returns: | |
| an LFADS model. | |
| """ | |
| build_kind = kind | |
| if build_kind == "write_model_params": | |
| build_kind = "train" | |
| with tf.variable_scope("LFADS", reuse=None): | |
| model = LFADS(hps, kind=build_kind, datasets=datasets) | |
| if not os.path.exists(hps.lfads_save_dir): | |
| print("Save directory %s does not exist, creating it." % hps.lfads_save_dir) | |
| os.makedirs(hps.lfads_save_dir) | |
| cp_pb_ln = hps.checkpoint_pb_load_name | |
| cp_pb_ln = 'checkpoint' if cp_pb_ln == "" else cp_pb_ln | |
| if cp_pb_ln == 'checkpoint': | |
| print("Loading latest training checkpoint in: ", hps.lfads_save_dir) | |
| saver = model.seso_saver | |
| elif cp_pb_ln == 'checkpoint_lve': | |
| print("Loading lowest validation checkpoint in: ", hps.lfads_save_dir) | |
| saver = model.lve_saver | |
| else: | |
| print("Loading checkpoint: ", cp_pb_ln, ", in: ", hps.lfads_save_dir) | |
| saver = model.seso_saver | |
| ckpt = tf.train.get_checkpoint_state(hps.lfads_save_dir, | |
| latest_filename=cp_pb_ln) | |
| session = tf.get_default_session() | |
| print("ckpt: ", ckpt) | |
| if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): | |
| print("Reading model parameters from %s" % ckpt.model_checkpoint_path) | |
| saver.restore(session, ckpt.model_checkpoint_path) | |
| else: | |
| print("Created model with fresh parameters.") | |
| if kind in ["posterior_sample_and_average", "posterior_push_mean", | |
| "prior_sample", "write_model_params"]: | |
| print("Possible error!!! You are running ", kind, " on a newly \ | |
| initialized model!") | |
| # cannot print ckpt.model_check_point path if no ckpt | |
| print("Are you sure you sure a checkpoint in ", hps.lfads_save_dir, | |
| " exists?") | |
| tf.global_variables_initializer().run() | |
| if ckpt: | |
| train_step_str = re.search('-[0-9]+$', ckpt.model_checkpoint_path).group() | |
| else: | |
| train_step_str = '-0' | |
| fname = 'hyperparameters' + train_step_str + '.txt' | |
| hp_fname = os.path.join(hps.lfads_save_dir, fname) | |
| hps_for_saving = jsonify_dict(hps) | |
| utils.write_data(hp_fname, hps_for_saving, use_json=True) | |
| return model | |
| def jsonify_dict(d): | |
| """Turns python booleans into strings so hps dict can be written in json. | |
| Creates a shallow-copied dictionary first, then accomplishes string | |
| conversion. | |
| Args: | |
| d: hyperparameter dictionary | |
| Returns: hyperparameter dictionary with bool's as strings | |
| """ | |
| d2 = d.copy() # shallow copy is fine by assumption of d being shallow | |
| def jsonify_bool(boolean_value): | |
| if boolean_value: | |
| return "true" | |
| else: | |
| return "false" | |
| for key in d2.keys(): | |
| if isinstance(d2[key], bool): | |
| d2[key] = jsonify_bool(d2[key]) | |
| return d2 | |
| def build_hyperparameter_dict(flags): | |
| """Simple script for saving hyper parameters. Under the hood the | |
| flags structure isn't a dictionary, so it has to be simplified since we | |
| want to be able to view file as text. | |
| Args: | |
| flags: From tf.app.flags | |
| Returns: | |
| dictionary of hyper parameters (ignoring other flag types). | |
| """ | |
| d = {} | |
| # Data | |
| d['output_dist'] = flags.output_dist | |
| d['data_dir'] = flags.data_dir | |
| d['lfads_save_dir'] = flags.lfads_save_dir | |
| d['checkpoint_pb_load_name'] = flags.checkpoint_pb_load_name | |
| d['checkpoint_name'] = flags.checkpoint_name | |
| d['output_filename_stem'] = flags.output_filename_stem | |
| d['max_ckpt_to_keep'] = flags.max_ckpt_to_keep | |
| d['max_ckpt_to_keep_lve'] = flags.max_ckpt_to_keep_lve | |
| d['ps_nexamples_to_process'] = flags.ps_nexamples_to_process | |
| d['ext_input_dim'] = flags.ext_input_dim | |
| d['data_filename_stem'] = flags.data_filename_stem | |
| d['device'] = flags.device | |
| d['csv_log'] = flags.csv_log | |
| d['num_steps_for_gen_ic'] = flags.num_steps_for_gen_ic | |
| d['inject_ext_input_to_gen'] = flags.inject_ext_input_to_gen | |
| # Cell | |
| d['cell_weight_scale'] = flags.cell_weight_scale | |
| # Generation | |
| d['ic_dim'] = flags.ic_dim | |
| d['factors_dim'] = flags.factors_dim | |
| d['ic_enc_dim'] = flags.ic_enc_dim | |
| d['gen_dim'] = flags.gen_dim | |
| d['gen_cell_input_weight_scale'] = flags.gen_cell_input_weight_scale | |
| d['gen_cell_rec_weight_scale'] = flags.gen_cell_rec_weight_scale | |
| # KL distributions | |
| d['ic_prior_var_min'] = flags.ic_prior_var_min | |
| d['ic_prior_var_scale'] = flags.ic_prior_var_scale | |
| d['ic_prior_var_max'] = flags.ic_prior_var_max | |
| d['ic_post_var_min'] = flags.ic_post_var_min | |
| d['co_prior_var_scale'] = flags.co_prior_var_scale | |
| d['prior_ar_atau'] = flags.prior_ar_atau | |
| d['prior_ar_nvar'] = flags.prior_ar_nvar | |
| d['do_train_prior_ar_atau'] = flags.do_train_prior_ar_atau | |
| d['do_train_prior_ar_nvar'] = flags.do_train_prior_ar_nvar | |
| # Controller | |
| d['do_causal_controller'] = flags.do_causal_controller | |
| d['controller_input_lag'] = flags.controller_input_lag | |
| d['do_feed_factors_to_controller'] = flags.do_feed_factors_to_controller | |
| d['feedback_factors_or_rates'] = flags.feedback_factors_or_rates | |
| d['co_dim'] = flags.co_dim | |
| d['ci_enc_dim'] = flags.ci_enc_dim | |
| d['con_dim'] = flags.con_dim | |
| d['co_mean_corr_scale'] = flags.co_mean_corr_scale | |
| # Optimization | |
| d['batch_size'] = flags.batch_size | |
| d['learning_rate_init'] = flags.learning_rate_init | |
| d['learning_rate_decay_factor'] = flags.learning_rate_decay_factor | |
| d['learning_rate_stop'] = flags.learning_rate_stop | |
| d['learning_rate_n_to_compare'] = flags.learning_rate_n_to_compare | |
| d['max_grad_norm'] = flags.max_grad_norm | |
| d['cell_clip_value'] = flags.cell_clip_value | |
| d['do_train_io_only'] = flags.do_train_io_only | |
| d['do_train_encoder_only'] = flags.do_train_encoder_only | |
| d['do_reset_learning_rate'] = flags.do_reset_learning_rate | |
| d['do_train_readin'] = flags.do_train_readin | |
| # Overfitting | |
| d['keep_prob'] = flags.keep_prob | |
| d['temporal_spike_jitter_width'] = flags.temporal_spike_jitter_width | |
| d['l2_gen_scale'] = flags.l2_gen_scale | |
| d['l2_con_scale'] = flags.l2_con_scale | |
| # Underfitting | |
| d['kl_ic_weight'] = flags.kl_ic_weight | |
| d['kl_co_weight'] = flags.kl_co_weight | |
| d['kl_start_step'] = flags.kl_start_step | |
| d['kl_increase_steps'] = flags.kl_increase_steps | |
| d['l2_start_step'] = flags.l2_start_step | |
| d['l2_increase_steps'] = flags.l2_increase_steps | |
| d['_clip_value'] = 80 # bounds the tf.exp to avoid INF | |
| return d | |
| class hps_dict_to_obj(dict): | |
| """Helper class allowing us to access hps dictionary more easily.""" | |
| def __getattr__(self, key): | |
| if key in self: | |
| return self[key] | |
| else: | |
| assert False, ("%s does not exist." % key) | |
| def __setattr__(self, key, value): | |
| self[key] = value | |
| def train(hps, datasets): | |
| """Train the LFADS model. | |
| Args: | |
| hps: The dictionary of hyperparameters. | |
| datasets: A dictionary of data dictionaries. The dataset dict is simply a | |
| name(string)-> data dictionary mapping (See top of lfads.py). | |
| """ | |
| model = build_model(hps, kind="train", datasets=datasets) | |
| if hps.do_reset_learning_rate: | |
| sess = tf.get_default_session() | |
| sess.run(model.learning_rate.initializer) | |
| model.train_model(datasets) | |
| def write_model_runs(hps, datasets, output_fname=None, push_mean=False): | |
| """Run the model on the data in data_dict, and save the computed values. | |
| LFADS generates a number of outputs for each examples, and these are all | |
| saved. They are: | |
| The mean and variance of the prior of g0. | |
| The mean and variance of approximate posterior of g0. | |
| The control inputs (if enabled) | |
| The initial conditions, g0, for all examples. | |
| The generator states for all time. | |
| The factors for all time. | |
| The rates for all time. | |
| Args: | |
| hps: The dictionary of hyperparameters. | |
| datasets: A dictionary of data dictionaries. The dataset dict is simply a | |
| name(string)-> data dictionary mapping (See top of lfads.py). | |
| output_fname (optional): output filename stem to write the model runs. | |
| push_mean: if False (default), generates batch_size samples for each trial | |
| and averages the results. if True, runs each trial once without noise, | |
| pushing the posterior mean initial conditions and control inputs through | |
| the trained model. False is used for posterior_sample_and_average, True | |
| is used for posterior_push_mean. | |
| """ | |
| model = build_model(hps, kind=hps.kind, datasets=datasets) | |
| model.write_model_runs(datasets, output_fname, push_mean) | |
| def write_model_samples(hps, datasets, dataset_name=None, output_fname=None): | |
| """Use the prior distribution to generate samples from the model. | |
| Generates batch_size number of samples (set through FLAGS). | |
| LFADS generates a number of outputs for each examples, and these are all | |
| saved. They are: | |
| The mean and variance of the prior of g0. | |
| The control inputs (if enabled) | |
| The initial conditions, g0, for all examples. | |
| The generator states for all time. | |
| The factors for all time. | |
| The output distribution parameters (e.g. rates) for all time. | |
| Args: | |
| hps: The dictionary of hyperparameters. | |
| datasets: A dictionary of data dictionaries. The dataset dict is simply a | |
| name(string)-> data dictionary mapping (See top of lfads.py). | |
| dataset_name: The name of the dataset to grab the factors -> rates | |
| alignment matrices from. Only a concern with models trained on | |
| multi-session data. By default, uses the first dataset in the data dict. | |
| output_fname: The name prefix of the file in which to save the generated | |
| samples. | |
| """ | |
| if not output_fname: | |
| output_fname = "model_runs_" + hps.kind | |
| else: | |
| output_fname = output_fname + "model_runs_" + hps.kind | |
| if not dataset_name: | |
| dataset_name = datasets.keys()[0] | |
| else: | |
| if dataset_name not in datasets.keys(): | |
| raise ValueError("Invalid dataset name '%s'."%(dataset_name)) | |
| model = build_model(hps, kind=hps.kind, datasets=datasets) | |
| model.write_model_samples(dataset_name, output_fname) | |
| def write_model_parameters(hps, output_fname=None, datasets=None): | |
| """Save all the model parameters | |
| Save all the parameters to hps.lfads_save_dir. | |
| Args: | |
| hps: The dictionary of hyperparameters. | |
| output_fname: The prefix of the file in which to save the generated | |
| samples. | |
| datasets: A dictionary of data dictionaries. The dataset dict is simply a | |
| name(string)-> data dictionary mapping (See top of lfads.py). | |
| """ | |
| if not output_fname: | |
| output_fname = "model_params" | |
| else: | |
| output_fname = output_fname + "_model_params" | |
| fname = os.path.join(hps.lfads_save_dir, output_fname) | |
| print("Writing model parameters to: ", fname) | |
| # save the optimizer params as well | |
| model = build_model(hps, kind="write_model_params", datasets=datasets) | |
| model_params = model.eval_model_parameters(use_nested=False, | |
| include_strs="LFADS") | |
| utils.write_data(fname, model_params, compression=None) | |
| print("Done.") | |
| def clean_data_dict(data_dict): | |
| """Add some key/value pairs to the data dict, if they are missing. | |
| Args: | |
| data_dict - dictionary containing data for LFADS | |
| Returns: | |
| data_dict with some keys filled in, if they are absent. | |
| """ | |
| keys = ['train_truth', 'train_ext_input', 'valid_data', | |
| 'valid_truth', 'valid_ext_input', 'valid_train'] | |
| for k in keys: | |
| if k not in data_dict: | |
| data_dict[k] = None | |
| return data_dict | |
| def load_datasets(data_dir, data_filename_stem): | |
| """Load the datasets from a specified directory. | |
| Example files look like | |
| >data_dir/my_dataset_first_day | |
| >data_dir/my_dataset_second_day | |
| If my_dataset (filename) stem is in the directory, the read routine will try | |
| and load it. The datasets dictionary will then look like | |
| dataset['first_day'] -> (first day data dictionary) | |
| dataset['second_day'] -> (first day data dictionary) | |
| Args: | |
| data_dir: The directory from which to load the datasets. | |
| data_filename_stem: The stem of the filename for the datasets. | |
| Returns: | |
| datasets: a dataset dictionary, with one name->data dictionary pair for | |
| each dataset file. | |
| """ | |
| print("Reading data from ", data_dir) | |
| datasets = utils.read_datasets(data_dir, data_filename_stem) | |
| for k, data_dict in datasets.items(): | |
| datasets[k] = clean_data_dict(data_dict) | |
| train_total_size = len(data_dict['train_data']) | |
| if train_total_size == 0: | |
| print("Did not load training set.") | |
| else: | |
| print("Found training set with number examples: ", train_total_size) | |
| valid_total_size = len(data_dict['valid_data']) | |
| if valid_total_size == 0: | |
| print("Did not load validation set.") | |
| else: | |
| print("Found validation set with number examples: ", valid_total_size) | |
| return datasets | |
| def main(_): | |
| """Get this whole shindig off the ground.""" | |
| d = build_hyperparameter_dict(FLAGS) | |
| hps = hps_dict_to_obj(d) # hyper parameters | |
| kind = FLAGS.kind | |
| # Read the data, if necessary. | |
| train_set = valid_set = None | |
| if kind in ["train", "posterior_sample_and_average", "posterior_push_mean", | |
| "prior_sample", "write_model_params"]: | |
| datasets = load_datasets(hps.data_dir, hps.data_filename_stem) | |
| else: | |
| raise ValueError('Kind {} is not supported.'.format(kind)) | |
| # infer the dataset names and dataset dimensions from the loaded files | |
| hps.kind = kind # needs to be added here, cuz not saved as hyperparam | |
| hps.dataset_names = [] | |
| hps.dataset_dims = {} | |
| for key in datasets: | |
| hps.dataset_names.append(key) | |
| hps.dataset_dims[key] = datasets[key]['data_dim'] | |
| # also store down the dimensionality of the data | |
| # - just pull from one set, required to be same for all sets | |
| hps.num_steps = datasets.values()[0]['num_steps'] | |
| hps.ndatasets = len(hps.dataset_names) | |
| if hps.num_steps_for_gen_ic > hps.num_steps: | |
| hps.num_steps_for_gen_ic = hps.num_steps | |
| # Build and run the model, for varying purposes. | |
| config = tf.ConfigProto(allow_soft_placement=True, | |
| log_device_placement=False) | |
| if FLAGS.allow_gpu_growth: | |
| config.gpu_options.allow_growth = True | |
| sess = tf.Session(config=config) | |
| with sess.as_default(): | |
| with tf.device(hps.device): | |
| if kind == "train": | |
| train(hps, datasets) | |
| elif kind == "posterior_sample_and_average": | |
| write_model_runs(hps, datasets, hps.output_filename_stem, | |
| push_mean=False) | |
| elif kind == "posterior_push_mean": | |
| write_model_runs(hps, datasets, hps.output_filename_stem, | |
| push_mean=True) | |
| elif kind == "prior_sample": | |
| write_model_samples(hps, datasets, hps.output_filename_stem) | |
| elif kind == "write_model_params": | |
| write_model_parameters(hps, hps.output_filename_stem, datasets) | |
| else: | |
| assert False, ("Kind %s is not implemented. " % kind) | |
| if __name__ == "__main__": | |
| tf.app.run() | |