Spaces:
Runtime error
Runtime error
| # Lint as: python2, python3 | |
| # Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Segmentation results visualization on a given set of images. | |
| See model.py for more details and usage. | |
| """ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import os.path | |
| import time | |
| import numpy as np | |
| from six.moves import range | |
| import tensorflow as tf | |
| from tensorflow.contrib import quantize as contrib_quantize | |
| from tensorflow.contrib import training as contrib_training | |
| from deeplab import common | |
| from deeplab import model | |
| from deeplab.datasets import data_generator | |
| from deeplab.utils import save_annotation | |
| flags = tf.app.flags | |
| FLAGS = flags.FLAGS | |
| flags.DEFINE_string('master', '', 'BNS name of the tensorflow server') | |
| # Settings for log directories. | |
| flags.DEFINE_string('vis_logdir', None, 'Where to write the event logs.') | |
| flags.DEFINE_string('checkpoint_dir', None, 'Directory of model checkpoints.') | |
| # Settings for visualizing the model. | |
| flags.DEFINE_integer('vis_batch_size', 1, | |
| 'The number of images in each batch during evaluation.') | |
| flags.DEFINE_list('vis_crop_size', '513,513', | |
| 'Crop size [height, width] for visualization.') | |
| flags.DEFINE_integer('eval_interval_secs', 60 * 5, | |
| 'How often (in seconds) to run evaluation.') | |
| # For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or | |
| # rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note | |
| # one could use different atrous_rates/output_stride during training/evaluation. | |
| flags.DEFINE_multi_integer('atrous_rates', None, | |
| 'Atrous rates for atrous spatial pyramid pooling.') | |
| flags.DEFINE_integer('output_stride', 16, | |
| 'The ratio of input to output spatial resolution.') | |
| # Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale test. | |
| flags.DEFINE_multi_float('eval_scales', [1.0], | |
| 'The scales to resize images for evaluation.') | |
| # Change to True for adding flipped images during test. | |
| flags.DEFINE_bool('add_flipped_images', False, | |
| 'Add flipped images for evaluation or not.') | |
| flags.DEFINE_integer( | |
| 'quantize_delay_step', -1, | |
| 'Steps to start quantized training. If < 0, will not quantize model.') | |
| # Dataset settings. | |
| flags.DEFINE_string('dataset', 'pascal_voc_seg', | |
| 'Name of the segmentation dataset.') | |
| flags.DEFINE_string('vis_split', 'val', | |
| 'Which split of the dataset used for visualizing results') | |
| flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.') | |
| flags.DEFINE_enum('colormap_type', 'pascal', ['pascal', 'cityscapes', 'ade20k'], | |
| 'Visualization colormap type.') | |
| flags.DEFINE_boolean('also_save_raw_predictions', False, | |
| 'Also save raw predictions.') | |
| flags.DEFINE_integer('max_number_of_iterations', 0, | |
| 'Maximum number of visualization iterations. Will loop ' | |
| 'indefinitely upon nonpositive values.') | |
| # The folder where semantic segmentation predictions are saved. | |
| _SEMANTIC_PREDICTION_SAVE_FOLDER = 'segmentation_results' | |
| # The folder where raw semantic segmentation predictions are saved. | |
| _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER = 'raw_segmentation_results' | |
| # The format to save image. | |
| _IMAGE_FORMAT = '%06d_image' | |
| # The format to save prediction | |
| _PREDICTION_FORMAT = '%06d_prediction' | |
| # To evaluate Cityscapes results on the evaluation server, the labels used | |
| # during training should be mapped to the labels for evaluation. | |
| _CITYSCAPES_TRAIN_ID_TO_EVAL_ID = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, | |
| 23, 24, 25, 26, 27, 28, 31, 32, 33] | |
| def _convert_train_id_to_eval_id(prediction, train_id_to_eval_id): | |
| """Converts the predicted label for evaluation. | |
| There are cases where the training labels are not equal to the evaluation | |
| labels. This function is used to perform the conversion so that we could | |
| evaluate the results on the evaluation server. | |
| Args: | |
| prediction: Semantic segmentation prediction. | |
| train_id_to_eval_id: A list mapping from train id to evaluation id. | |
| Returns: | |
| Semantic segmentation prediction whose labels have been changed. | |
| """ | |
| converted_prediction = prediction.copy() | |
| for train_id, eval_id in enumerate(train_id_to_eval_id): | |
| converted_prediction[prediction == train_id] = eval_id | |
| return converted_prediction | |
| def _process_batch(sess, original_images, semantic_predictions, image_names, | |
| image_heights, image_widths, image_id_offset, save_dir, | |
| raw_save_dir, train_id_to_eval_id=None): | |
| """Evaluates one single batch qualitatively. | |
| Args: | |
| sess: TensorFlow session. | |
| original_images: One batch of original images. | |
| semantic_predictions: One batch of semantic segmentation predictions. | |
| image_names: Image names. | |
| image_heights: Image heights. | |
| image_widths: Image widths. | |
| image_id_offset: Image id offset for indexing images. | |
| save_dir: The directory where the predictions will be saved. | |
| raw_save_dir: The directory where the raw predictions will be saved. | |
| train_id_to_eval_id: A list mapping from train id to eval id. | |
| """ | |
| (original_images, | |
| semantic_predictions, | |
| image_names, | |
| image_heights, | |
| image_widths) = sess.run([original_images, semantic_predictions, | |
| image_names, image_heights, image_widths]) | |
| num_image = semantic_predictions.shape[0] | |
| for i in range(num_image): | |
| image_height = np.squeeze(image_heights[i]) | |
| image_width = np.squeeze(image_widths[i]) | |
| original_image = np.squeeze(original_images[i]) | |
| semantic_prediction = np.squeeze(semantic_predictions[i]) | |
| crop_semantic_prediction = semantic_prediction[:image_height, :image_width] | |
| # Save image. | |
| save_annotation.save_annotation( | |
| original_image, save_dir, _IMAGE_FORMAT % (image_id_offset + i), | |
| add_colormap=False) | |
| # Save prediction. | |
| save_annotation.save_annotation( | |
| crop_semantic_prediction, save_dir, | |
| _PREDICTION_FORMAT % (image_id_offset + i), add_colormap=True, | |
| colormap_type=FLAGS.colormap_type) | |
| if FLAGS.also_save_raw_predictions: | |
| image_filename = os.path.basename(image_names[i]) | |
| if train_id_to_eval_id is not None: | |
| crop_semantic_prediction = _convert_train_id_to_eval_id( | |
| crop_semantic_prediction, | |
| train_id_to_eval_id) | |
| save_annotation.save_annotation( | |
| crop_semantic_prediction, raw_save_dir, image_filename, | |
| add_colormap=False) | |
| def main(unused_argv): | |
| tf.logging.set_verbosity(tf.logging.INFO) | |
| # Get dataset-dependent information. | |
| dataset = data_generator.Dataset( | |
| dataset_name=FLAGS.dataset, | |
| split_name=FLAGS.vis_split, | |
| dataset_dir=FLAGS.dataset_dir, | |
| batch_size=FLAGS.vis_batch_size, | |
| crop_size=[int(sz) for sz in FLAGS.vis_crop_size], | |
| min_resize_value=FLAGS.min_resize_value, | |
| max_resize_value=FLAGS.max_resize_value, | |
| resize_factor=FLAGS.resize_factor, | |
| model_variant=FLAGS.model_variant, | |
| is_training=False, | |
| should_shuffle=False, | |
| should_repeat=False) | |
| train_id_to_eval_id = None | |
| if dataset.dataset_name == data_generator.get_cityscapes_dataset_name(): | |
| tf.logging.info('Cityscapes requires converting train_id to eval_id.') | |
| train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID | |
| # Prepare for visualization. | |
| tf.gfile.MakeDirs(FLAGS.vis_logdir) | |
| save_dir = os.path.join(FLAGS.vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER) | |
| tf.gfile.MakeDirs(save_dir) | |
| raw_save_dir = os.path.join( | |
| FLAGS.vis_logdir, _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER) | |
| tf.gfile.MakeDirs(raw_save_dir) | |
| tf.logging.info('Visualizing on %s set', FLAGS.vis_split) | |
| with tf.Graph().as_default(): | |
| samples = dataset.get_one_shot_iterator().get_next() | |
| model_options = common.ModelOptions( | |
| outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_of_classes}, | |
| crop_size=[int(sz) for sz in FLAGS.vis_crop_size], | |
| atrous_rates=FLAGS.atrous_rates, | |
| output_stride=FLAGS.output_stride) | |
| if tuple(FLAGS.eval_scales) == (1.0,): | |
| tf.logging.info('Performing single-scale test.') | |
| predictions = model.predict_labels( | |
| samples[common.IMAGE], | |
| model_options=model_options, | |
| image_pyramid=FLAGS.image_pyramid) | |
| else: | |
| tf.logging.info('Performing multi-scale test.') | |
| if FLAGS.quantize_delay_step >= 0: | |
| raise ValueError( | |
| 'Quantize mode is not supported with multi-scale test.') | |
| predictions = model.predict_labels_multi_scale( | |
| samples[common.IMAGE], | |
| model_options=model_options, | |
| eval_scales=FLAGS.eval_scales, | |
| add_flipped_images=FLAGS.add_flipped_images) | |
| predictions = predictions[common.OUTPUT_TYPE] | |
| if FLAGS.min_resize_value and FLAGS.max_resize_value: | |
| # Only support batch_size = 1, since we assume the dimensions of original | |
| # image after tf.squeeze is [height, width, 3]. | |
| assert FLAGS.vis_batch_size == 1 | |
| # Reverse the resizing and padding operations performed in preprocessing. | |
| # First, we slice the valid regions (i.e., remove padded region) and then | |
| # we resize the predictions back. | |
| original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE]) | |
| original_image_shape = tf.shape(original_image) | |
| predictions = tf.slice( | |
| predictions, | |
| [0, 0, 0], | |
| [1, original_image_shape[0], original_image_shape[1]]) | |
| resized_shape = tf.to_int32([tf.squeeze(samples[common.HEIGHT]), | |
| tf.squeeze(samples[common.WIDTH])]) | |
| predictions = tf.squeeze( | |
| tf.image.resize_images(tf.expand_dims(predictions, 3), | |
| resized_shape, | |
| method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, | |
| align_corners=True), 3) | |
| tf.train.get_or_create_global_step() | |
| if FLAGS.quantize_delay_step >= 0: | |
| contrib_quantize.create_eval_graph() | |
| num_iteration = 0 | |
| max_num_iteration = FLAGS.max_number_of_iterations | |
| checkpoints_iterator = contrib_training.checkpoints_iterator( | |
| FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs) | |
| for checkpoint_path in checkpoints_iterator: | |
| num_iteration += 1 | |
| tf.logging.info( | |
| 'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', | |
| time.gmtime())) | |
| tf.logging.info('Visualizing with model %s', checkpoint_path) | |
| scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer()) | |
| session_creator = tf.train.ChiefSessionCreator( | |
| scaffold=scaffold, | |
| master=FLAGS.master, | |
| checkpoint_filename_with_path=checkpoint_path) | |
| with tf.train.MonitoredSession( | |
| session_creator=session_creator, hooks=None) as sess: | |
| batch = 0 | |
| image_id_offset = 0 | |
| while not sess.should_stop(): | |
| tf.logging.info('Visualizing batch %d', batch + 1) | |
| _process_batch(sess=sess, | |
| original_images=samples[common.ORIGINAL_IMAGE], | |
| semantic_predictions=predictions, | |
| image_names=samples[common.IMAGE_NAME], | |
| image_heights=samples[common.HEIGHT], | |
| image_widths=samples[common.WIDTH], | |
| image_id_offset=image_id_offset, | |
| save_dir=save_dir, | |
| raw_save_dir=raw_save_dir, | |
| train_id_to_eval_id=train_id_to_eval_id) | |
| image_id_offset += FLAGS.vis_batch_size | |
| batch += 1 | |
| tf.logging.info( | |
| 'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', | |
| time.gmtime())) | |
| if max_num_iteration > 0 and num_iteration >= max_num_iteration: | |
| break | |
| if __name__ == '__main__': | |
| flags.mark_flag_as_required('checkpoint_dir') | |
| flags.mark_flag_as_required('vis_logdir') | |
| flags.mark_flag_as_required('dataset_dir') | |
| tf.app.run() | |