Spaces:
Runtime error
Runtime error
| # Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Contains evaluation plan for the Im2vox model.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import os | |
| import tensorflow as tf | |
| from tensorflow import app | |
| import model_ptn | |
| flags = tf.app.flags | |
| slim = tf.contrib.slim | |
| flags.DEFINE_string('inp_dir', | |
| '', | |
| 'Directory path containing the input data (tfrecords).') | |
| flags.DEFINE_string( | |
| 'dataset_name', 'shapenet_chair', | |
| 'Dataset name that is to be used for training and evaluation.') | |
| flags.DEFINE_integer('z_dim', 512, '') | |
| flags.DEFINE_integer('f_dim', 64, '') | |
| flags.DEFINE_integer('fc_dim', 1024, '') | |
| flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.') | |
| flags.DEFINE_integer('image_size', 64, | |
| 'Input images dimension (pixels) - width & height.') | |
| flags.DEFINE_integer('vox_size', 32, 'Voxel prediction dimension.') | |
| flags.DEFINE_integer('step_size', 24, '') | |
| flags.DEFINE_integer('batch_size', 1, 'Batch size while training.') | |
| flags.DEFINE_float('focal_length', 0.866, '') | |
| flags.DEFINE_float('focal_range', 1.732, '') | |
| flags.DEFINE_string('encoder_name', 'ptn_encoder', | |
| 'Name of the encoder network being used.') | |
| flags.DEFINE_string('decoder_name', 'ptn_vox_decoder', | |
| 'Name of the decoder network being used.') | |
| flags.DEFINE_string('projector_name', 'ptn_projector', | |
| 'Name of the projector network being used.') | |
| # Save options | |
| flags.DEFINE_string('checkpoint_dir', '/tmp/ptn/eval/', | |
| 'Directory path for saving trained models and other data.') | |
| flags.DEFINE_string('model_name', 'ptn_proj', | |
| 'Name of the model used in naming the TF job. Must be different for each run.') | |
| flags.DEFINE_string('eval_set', 'val', 'Data partition to form evaluation on.') | |
| # Optimization | |
| flags.DEFINE_float('proj_weight', 10, 'Weighting factor for projection loss.') | |
| flags.DEFINE_float('volume_weight', 0, 'Weighting factor for volume loss.') | |
| flags.DEFINE_float('viewpoint_weight', 1, | |
| 'Weighting factor for viewpoint loss.') | |
| flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.') | |
| flags.DEFINE_float('weight_decay', 0.001, '') | |
| flags.DEFINE_float('clip_gradient_norm', 0, '') | |
| # Summary | |
| flags.DEFINE_integer('save_summaries_secs', 15, '') | |
| flags.DEFINE_integer('eval_interval_secs', 60 * 5, '') | |
| # Distribution | |
| flags.DEFINE_string('master', '', '') | |
| FLAGS = flags.FLAGS | |
| def main(argv=()): | |
| del argv # Unused. | |
| eval_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train') | |
| log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, | |
| 'eval_%s' % FLAGS.eval_set) | |
| if not os.path.exists(eval_dir): | |
| os.makedirs(eval_dir) | |
| if not os.path.exists(log_dir): | |
| os.makedirs(log_dir) | |
| g = tf.Graph() | |
| with g.as_default(): | |
| eval_params = FLAGS | |
| eval_params.batch_size = 1 | |
| eval_params.step_size = FLAGS.num_views | |
| ########### | |
| ## model ## | |
| ########### | |
| model = model_ptn.model_PTN(eval_params) | |
| ########## | |
| ## data ## | |
| ########## | |
| eval_data = model.get_inputs( | |
| FLAGS.inp_dir, | |
| FLAGS.dataset_name, | |
| eval_params.eval_set, | |
| eval_params.batch_size, | |
| eval_params.image_size, | |
| eval_params.vox_size, | |
| is_training=False) | |
| inputs = model.preprocess_with_all_views(eval_data) | |
| ############## | |
| ## model_fn ## | |
| ############## | |
| model_fn = model.get_model_fn(is_training=False, run_projection=False) | |
| outputs = model_fn(inputs) | |
| ############# | |
| ## metrics ## | |
| ############# | |
| names_to_values, names_to_updates = model.get_metrics(inputs, outputs) | |
| del names_to_values | |
| ################ | |
| ## evaluation ## | |
| ################ | |
| num_batches = eval_data['num_samples'] | |
| slim.evaluation.evaluation_loop( | |
| master=FLAGS.master, | |
| checkpoint_dir=eval_dir, | |
| logdir=log_dir, | |
| num_evals=num_batches, | |
| eval_op=names_to_updates.values(), | |
| eval_interval_secs=FLAGS.eval_interval_secs) | |
| if __name__ == '__main__': | |
| app.run() | |