Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Detection model trainer. | |
| This file provides a generic training method that can be used to train a | |
| DetectionModel. | |
| """ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import functools | |
| import six | |
| from six.moves import range | |
| import tensorflow.compat.v1 as tf | |
| import tf_slim as slim | |
| from object_detection.builders import optimizer_builder | |
| from object_detection.core import standard_fields as fields | |
| from object_detection.utils import ops as util_ops | |
| from object_detection.utils import variables_helper | |
| from deployment import model_deploy | |
| def create_input_queue(create_tensor_dict_fn): | |
| """Sets up reader, prefetcher and returns input queue. | |
| Args: | |
| create_tensor_dict_fn: function to create tensor dictionary. | |
| Returns: | |
| all_dict: A dictionary holds tensors for images, boxes, and targets. | |
| """ | |
| tensor_dict = create_tensor_dict_fn() | |
| all_dict = {} | |
| num_images = len(tensor_dict[fields.InputDataFields.image]) | |
| all_dict['batch'] = tensor_dict['batch'] | |
| del tensor_dict['batch'] | |
| for i in range(num_images): | |
| suffix = str(i) | |
| for key, val in tensor_dict.items(): | |
| all_dict[key + suffix] = val[i] | |
| all_dict[fields.InputDataFields.image + suffix] = tf.to_float( | |
| tf.expand_dims(all_dict[fields.InputDataFields.image + suffix], 0)) | |
| return all_dict | |
| def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False): | |
| """Dequeues batch and constructs inputs to object detection model. | |
| Args: | |
| input_queue: BatchQueue object holding enqueued tensor_dicts. | |
| num_classes: Number of classes. | |
| merge_multiple_label_boxes: Whether to merge boxes with multiple labels | |
| or not. Defaults to false. Merged boxes are represented with a single | |
| box and a k-hot encoding of the multiple labels associated with the | |
| boxes. | |
| Returns: | |
| images: a list of 3-D float tensor of images. | |
| image_keys: a list of string keys for the images. | |
| locations: a list of tensors of shape [num_boxes, 4] containing the corners | |
| of the groundtruth boxes. | |
| classes: a list of padded one-hot tensors containing target classes. | |
| masks: a list of 3-D float tensors of shape [num_boxes, image_height, | |
| image_width] containing instance masks for objects if present in the | |
| input_queue. Else returns None. | |
| keypoints: a list of 3-D float tensors of shape [num_boxes, num_keypoints, | |
| 2] containing keypoints for objects if present in the | |
| input queue. Else returns None. | |
| """ | |
| read_data_list = input_queue | |
| label_id_offset = 1 | |
| def extract_images_and_targets(read_data): | |
| """Extract images and targets from the input dict.""" | |
| suffix = 0 | |
| images = [] | |
| keys = [] | |
| locations = [] | |
| classes = [] | |
| masks = [] | |
| keypoints = [] | |
| while fields.InputDataFields.image + str(suffix) in read_data: | |
| image = read_data[fields.InputDataFields.image + str(suffix)] | |
| key = '' | |
| if fields.InputDataFields.source_id in read_data: | |
| key = read_data[fields.InputDataFields.source_id + str(suffix)] | |
| location_gt = ( | |
| read_data[fields.InputDataFields.groundtruth_boxes + str(suffix)]) | |
| classes_gt = tf.cast( | |
| read_data[fields.InputDataFields.groundtruth_classes + str(suffix)], | |
| tf.int32) | |
| classes_gt -= label_id_offset | |
| masks_gt = read_data.get( | |
| fields.InputDataFields.groundtruth_instance_masks + str(suffix)) | |
| keypoints_gt = read_data.get( | |
| fields.InputDataFields.groundtruth_keypoints + str(suffix)) | |
| if merge_multiple_label_boxes: | |
| location_gt, classes_gt, _ = util_ops.merge_boxes_with_multiple_labels( | |
| location_gt, classes_gt, num_classes) | |
| else: | |
| classes_gt = util_ops.padded_one_hot_encoding( | |
| indices=classes_gt, depth=num_classes, left_pad=0) | |
| # Batch read input data and groundtruth. Images and locations, classes by | |
| # default should have the same number of items. | |
| images.append(image) | |
| keys.append(key) | |
| locations.append(location_gt) | |
| classes.append(classes_gt) | |
| masks.append(masks_gt) | |
| keypoints.append(keypoints_gt) | |
| suffix += 1 | |
| return (images, keys, locations, classes, masks, keypoints) | |
| return extract_images_and_targets(read_data_list) | |
| def _create_losses(input_queue, create_model_fn, train_config): | |
| """Creates loss function for a DetectionModel. | |
| Args: | |
| input_queue: BatchQueue object holding enqueued tensor_dicts. | |
| create_model_fn: A function to create the DetectionModel. | |
| train_config: a train_pb2.TrainConfig protobuf. | |
| """ | |
| detection_model = create_model_fn() | |
| (images, _, groundtruth_boxes_list, groundtruth_classes_list, | |
| groundtruth_masks_list, groundtruth_keypoints_list) = get_inputs( | |
| input_queue, detection_model.num_classes, | |
| train_config.merge_multiple_label_boxes) | |
| preprocessed_images = [] | |
| true_image_shapes = [] | |
| for image in images: | |
| resized_image, true_image_shape = detection_model.preprocess(image) | |
| preprocessed_images.append(resized_image) | |
| true_image_shapes.append(true_image_shape) | |
| images = tf.concat(preprocessed_images, 0) | |
| true_image_shapes = tf.concat(true_image_shapes, 0) | |
| if any(mask is None for mask in groundtruth_masks_list): | |
| groundtruth_masks_list = None | |
| if any(keypoints is None for keypoints in groundtruth_keypoints_list): | |
| groundtruth_keypoints_list = None | |
| detection_model.provide_groundtruth( | |
| groundtruth_boxes_list, groundtruth_classes_list, groundtruth_masks_list, | |
| groundtruth_keypoints_list) | |
| prediction_dict = detection_model.predict(images, true_image_shapes, | |
| input_queue['batch']) | |
| losses_dict = detection_model.loss(prediction_dict, true_image_shapes) | |
| for loss_tensor in losses_dict.values(): | |
| tf.losses.add_loss(loss_tensor) | |
| def get_restore_checkpoint_ops(restore_checkpoints, detection_model, | |
| train_config): | |
| """Restore checkpoint from saved checkpoints. | |
| Args: | |
| restore_checkpoints: loaded checkpoints. | |
| detection_model: Object detection model built from config file. | |
| train_config: a train_pb2.TrainConfig protobuf. | |
| Returns: | |
| restorers: A list ops to init the model from checkpoints. | |
| """ | |
| restorers = [] | |
| vars_restored = [] | |
| for restore_checkpoint in restore_checkpoints: | |
| var_map = detection_model.restore_map( | |
| fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type) | |
| available_var_map = ( | |
| variables_helper.get_variables_available_in_checkpoint( | |
| var_map, restore_checkpoint)) | |
| for var_name, var in six.iteritems(available_var_map): | |
| if var in vars_restored: | |
| tf.logging.info('Variable %s contained in multiple checkpoints', | |
| var.op.name) | |
| del available_var_map[var_name] | |
| else: | |
| vars_restored.append(var) | |
| # Initialize from ExponentialMovingAverages if possible. | |
| available_ema_var_map = {} | |
| ckpt_reader = tf.train.NewCheckpointReader(restore_checkpoint) | |
| ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map() | |
| for var_name, var in six.iteritems(available_var_map): | |
| var_name_ema = var_name + '/ExponentialMovingAverage' | |
| if var_name_ema in ckpt_vars_to_shape_map: | |
| available_ema_var_map[var_name_ema] = var | |
| else: | |
| available_ema_var_map[var_name] = var | |
| available_var_map = available_ema_var_map | |
| init_saver = tf.train.Saver(available_var_map) | |
| if list(available_var_map.keys()): | |
| restorers.append(init_saver) | |
| else: | |
| tf.logging.info('WARNING: Checkpoint %s has no restorable variables', | |
| restore_checkpoint) | |
| return restorers | |
| def train(create_tensor_dict_fn, | |
| create_model_fn, | |
| train_config, | |
| master, | |
| task, | |
| num_clones, | |
| worker_replicas, | |
| clone_on_cpu, | |
| ps_tasks, | |
| worker_job_name, | |
| is_chief, | |
| train_dir, | |
| graph_hook_fn=None): | |
| """Training function for detection models. | |
| Args: | |
| create_tensor_dict_fn: a function to create a tensor input dictionary. | |
| create_model_fn: a function that creates a DetectionModel and generates | |
| losses. | |
| train_config: a train_pb2.TrainConfig protobuf. | |
| master: BNS name of the TensorFlow master to use. | |
| task: The task id of this training instance. | |
| num_clones: The number of clones to run per machine. | |
| worker_replicas: The number of work replicas to train with. | |
| clone_on_cpu: True if clones should be forced to run on CPU. | |
| ps_tasks: Number of parameter server tasks. | |
| worker_job_name: Name of the worker job. | |
| is_chief: Whether this replica is the chief replica. | |
| train_dir: Directory to write checkpoints and training summaries to. | |
| graph_hook_fn: Optional function that is called after the training graph is | |
| completely built. This is helpful to perform additional changes to the | |
| training graph such as optimizing batchnorm. The function should modify | |
| the default graph. | |
| """ | |
| detection_model = create_model_fn() | |
| with tf.Graph().as_default(): | |
| # Build a configuration specifying multi-GPU and multi-replicas. | |
| deploy_config = model_deploy.DeploymentConfig( | |
| num_clones=num_clones, | |
| clone_on_cpu=clone_on_cpu, | |
| replica_id=task, | |
| num_replicas=worker_replicas, | |
| num_ps_tasks=ps_tasks, | |
| worker_job_name=worker_job_name) | |
| # Place the global step on the device storing the variables. | |
| with tf.device(deploy_config.variables_device()): | |
| global_step = slim.create_global_step() | |
| with tf.device(deploy_config.inputs_device()): | |
| input_queue = create_input_queue(create_tensor_dict_fn) | |
| # Gather initial summaries. | |
| # TODO(rathodv): See if summaries can be added/extracted from global tf | |
| # collections so that they don't have to be passed around. | |
| summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) | |
| global_summaries = set([]) | |
| model_fn = functools.partial( | |
| _create_losses, | |
| create_model_fn=create_model_fn, | |
| train_config=train_config) | |
| clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) | |
| first_clone_scope = clones[0].scope | |
| # Gather update_ops from the first clone. These contain, for example, | |
| # the updates for the batch_norm variables created by model_fn. | |
| update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) | |
| with tf.device(deploy_config.optimizer_device()): | |
| training_optimizer, optimizer_summary_vars = optimizer_builder.build( | |
| train_config.optimizer) | |
| for var in optimizer_summary_vars: | |
| tf.summary.scalar(var.op.name, var) | |
| sync_optimizer = None | |
| if train_config.sync_replicas: | |
| training_optimizer = tf.train.SyncReplicasOptimizer( | |
| training_optimizer, | |
| replicas_to_aggregate=train_config.replicas_to_aggregate, | |
| total_num_replicas=train_config.worker_replicas) | |
| sync_optimizer = training_optimizer | |
| # Create ops required to initialize the model from a given checkpoint. | |
| init_fn = None | |
| if train_config.fine_tune_checkpoint: | |
| restore_checkpoints = [ | |
| path.strip() for path in train_config.fine_tune_checkpoint.split(',') | |
| ] | |
| restorers = get_restore_checkpoint_ops(restore_checkpoints, | |
| detection_model, train_config) | |
| def initializer_fn(sess): | |
| for i, restorer in enumerate(restorers): | |
| restorer.restore(sess, restore_checkpoints[i]) | |
| init_fn = initializer_fn | |
| with tf.device(deploy_config.optimizer_device()): | |
| regularization_losses = ( | |
| None if train_config.add_regularization_loss else []) | |
| total_loss, grads_and_vars = model_deploy.optimize_clones( | |
| clones, | |
| training_optimizer, | |
| regularization_losses=regularization_losses) | |
| total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') | |
| # Optionally multiply bias gradients by train_config.bias_grad_multiplier. | |
| if train_config.bias_grad_multiplier: | |
| biases_regex_list = ['.*/biases'] | |
| grads_and_vars = variables_helper.multiply_gradients_matching_regex( | |
| grads_and_vars, | |
| biases_regex_list, | |
| multiplier=train_config.bias_grad_multiplier) | |
| # Optionally clip gradients | |
| if train_config.gradient_clipping_by_norm > 0: | |
| with tf.name_scope('clip_grads'): | |
| grads_and_vars = slim.learning.clip_gradient_norms( | |
| grads_and_vars, train_config.gradient_clipping_by_norm) | |
| moving_average_variables = slim.get_model_variables() | |
| variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step) | |
| update_ops.append(variable_averages.apply(moving_average_variables)) | |
| # Create gradient updates. | |
| grad_updates = training_optimizer.apply_gradients( | |
| grads_and_vars, global_step=global_step) | |
| update_ops.append(grad_updates) | |
| update_op = tf.group(*update_ops, name='update_barrier') | |
| with tf.control_dependencies([update_op]): | |
| train_tensor = tf.identity(total_loss, name='train_op') | |
| if graph_hook_fn: | |
| with tf.device(deploy_config.variables_device()): | |
| graph_hook_fn() | |
| # Add summaries. | |
| for model_var in slim.get_model_variables(): | |
| global_summaries.add(tf.summary.histogram(model_var.op.name, model_var)) | |
| for loss_tensor in tf.losses.get_losses(): | |
| global_summaries.add(tf.summary.scalar(loss_tensor.op.name, loss_tensor)) | |
| global_summaries.add( | |
| tf.summary.scalar('TotalLoss', tf.losses.get_total_loss())) | |
| # Add the summaries from the first clone. These contain the summaries | |
| # created by model_fn and either optimize_clones() or _gather_clone_loss(). | |
| summaries |= set( | |
| tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) | |
| summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, 'critic_loss')) | |
| summaries |= global_summaries | |
| # Merge all summaries together. | |
| summary_op = tf.summary.merge(list(summaries), name='summary_op') | |
| # Soft placement allows placing on CPU ops without GPU implementation. | |
| session_config = tf.ConfigProto( | |
| allow_soft_placement=True, log_device_placement=False) | |
| # Save checkpoints regularly. | |
| keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours | |
| saver = tf.train.Saver( | |
| keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) | |
| slim.learning.train( | |
| train_tensor, | |
| logdir=train_dir, | |
| master=master, | |
| is_chief=is_chief, | |
| session_config=session_config, | |
| startup_delay_steps=train_config.startup_delay_steps, | |
| init_fn=init_fn, | |
| summary_op=summary_op, | |
| number_of_steps=(train_config.num_steps | |
| if train_config.num_steps else None), | |
| save_summaries_secs=120, | |
| sync_optimizer=sync_optimizer, | |
| saver=saver) | |