Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Training script for the FEELVOS model. | |
| See model.py for more details and usage. | |
| """ | |
| import six | |
| import tensorflow as tf | |
| from feelvos import common | |
| from feelvos import model | |
| from feelvos.datasets import video_dataset | |
| from feelvos.utils import embedding_utils | |
| from feelvos.utils import train_utils | |
| from feelvos.utils import video_input_generator | |
| from deployment import model_deploy | |
| slim = tf.contrib.slim | |
| prefetch_queue = slim.prefetch_queue | |
| flags = tf.app.flags | |
| FLAGS = flags.FLAGS | |
| # Settings for multi-GPUs/multi-replicas training. | |
| flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy.') | |
| flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.') | |
| flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.') | |
| flags.DEFINE_integer('startup_delay_steps', 15, | |
| 'Number of training steps between replicas startup.') | |
| flags.DEFINE_integer('num_ps_tasks', 0, | |
| 'The number of parameter servers. If the value is 0, then ' | |
| 'the parameters are handled locally by the worker.') | |
| flags.DEFINE_string('master', '', 'BNS name of the tensorflow server') | |
| flags.DEFINE_integer('task', 0, 'The task ID.') | |
| # Settings for logging. | |
| flags.DEFINE_string('train_logdir', None, | |
| 'Where the checkpoint and logs are stored.') | |
| flags.DEFINE_integer('log_steps', 10, | |
| 'Display logging information at every log_steps.') | |
| flags.DEFINE_integer('save_interval_secs', 1200, | |
| 'How often, in seconds, we save the model to disk.') | |
| flags.DEFINE_integer('save_summaries_secs', 600, | |
| 'How often, in seconds, we compute the summaries.') | |
| # Settings for training strategy. | |
| flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'], | |
| 'Learning rate policy for training.') | |
| flags.DEFINE_float('base_learning_rate', 0.0007, | |
| 'The base learning rate for model training.') | |
| flags.DEFINE_float('learning_rate_decay_factor', 0.1, | |
| 'The rate to decay the base learning rate.') | |
| flags.DEFINE_integer('learning_rate_decay_step', 2000, | |
| 'Decay the base learning rate at a fixed step.') | |
| flags.DEFINE_float('learning_power', 0.9, | |
| 'The power value used in the poly learning policy.') | |
| flags.DEFINE_integer('training_number_of_steps', 200000, | |
| 'The number of steps used for training') | |
| flags.DEFINE_float('momentum', 0.9, 'The momentum value to use') | |
| flags.DEFINE_integer('train_batch_size', 6, | |
| 'The number of images in each batch during training.') | |
| flags.DEFINE_integer('train_num_frames_per_video', 3, | |
| 'The number of frames used per video during training') | |
| flags.DEFINE_float('weight_decay', 0.00004, | |
| 'The value of the weight decay for training.') | |
| flags.DEFINE_multi_integer('train_crop_size', [465, 465], | |
| 'Image crop size [height, width] during training.') | |
| flags.DEFINE_float('last_layer_gradient_multiplier', 1.0, | |
| 'The gradient multiplier for last layers, which is used to ' | |
| 'boost the gradient of last layers if the value > 1.') | |
| flags.DEFINE_boolean('upsample_logits', True, | |
| 'Upsample logits during training.') | |
| flags.DEFINE_integer('batch_capacity_factor', 16, 'Batch capacity factor.') | |
| flags.DEFINE_integer('num_readers', 1, 'Number of readers for data provider.') | |
| flags.DEFINE_integer('batch_num_threads', 1, 'Batch number of threads.') | |
| flags.DEFINE_integer('prefetch_queue_capacity_factor', 32, | |
| 'Prefetch queue capacity factor.') | |
| flags.DEFINE_integer('prefetch_queue_num_threads', 1, | |
| 'Prefetch queue number of threads.') | |
| flags.DEFINE_integer('train_max_neighbors_per_object', 1024, | |
| 'The maximum number of candidates for the nearest ' | |
| 'neighbor query per object after subsampling') | |
| # Settings for fine-tuning the network. | |
| flags.DEFINE_string('tf_initial_checkpoint', None, | |
| 'The initial checkpoint in tensorflow format.') | |
| flags.DEFINE_boolean('initialize_last_layer', False, | |
| 'Initialize the last layer.') | |
| flags.DEFINE_boolean('last_layers_contain_logits_only', False, | |
| 'Only consider logits as last layers or not.') | |
| flags.DEFINE_integer('slow_start_step', 0, | |
| 'Training model with small learning rate for few steps.') | |
| flags.DEFINE_float('slow_start_learning_rate', 1e-4, | |
| 'Learning rate employed during slow start.') | |
| flags.DEFINE_boolean('fine_tune_batch_norm', False, | |
| 'Fine tune the batch norm parameters or not.') | |
| flags.DEFINE_float('min_scale_factor', 1., | |
| 'Mininum scale factor for data augmentation.') | |
| flags.DEFINE_float('max_scale_factor', 1.3, | |
| 'Maximum scale factor for data augmentation.') | |
| flags.DEFINE_float('scale_factor_step_size', 0, | |
| 'Scale factor step size for data augmentation.') | |
| flags.DEFINE_multi_integer('atrous_rates', None, | |
| 'Atrous rates for atrous spatial pyramid pooling.') | |
| flags.DEFINE_integer('output_stride', 8, | |
| 'The ratio of input to output spatial resolution.') | |
| flags.DEFINE_boolean('sample_only_first_frame_for_finetuning', False, | |
| 'Whether to only sample the first frame during ' | |
| 'fine-tuning. This should be False when using lucid data, ' | |
| 'but True when fine-tuning on the first frame only. Only ' | |
| 'has an effect if first_frame_finetuning is True.') | |
| flags.DEFINE_multi_integer('first_frame_finetuning', [0], | |
| 'Whether to only sample the first frame for ' | |
| 'fine-tuning.') | |
| # Dataset settings. | |
| flags.DEFINE_multi_string('dataset', [], 'Name of the segmentation datasets.') | |
| flags.DEFINE_multi_float('dataset_sampling_probabilities', [], | |
| 'A list of probabilities to sample each of the ' | |
| 'datasets.') | |
| flags.DEFINE_string('train_split', 'train', | |
| 'Which split of the dataset to be used for training') | |
| flags.DEFINE_multi_string('dataset_dir', [], 'Where the datasets reside.') | |
| flags.DEFINE_multi_integer('three_frame_dataset', [0], | |
| 'Whether the dataset has exactly three frames per ' | |
| 'video of which the first is to be used as reference' | |
| ' and the two others are consecutive frames to be ' | |
| 'used as query frames.' | |
| 'Set true for pascal lucid data.') | |
| flags.DEFINE_boolean('damage_initial_previous_frame_mask', False, | |
| 'Whether to artificially damage the initial previous ' | |
| 'frame mask. Only has an effect if ' | |
| 'also_attend_to_previous_frame is True.') | |
| flags.DEFINE_float('top_k_percent_pixels', 0.15, 'Float in [0.0, 1.0].' | |
| 'When its value < 1.0, only compute the loss for the top k' | |
| 'percent pixels (e.g., the top 20% pixels). This is useful' | |
| 'for hard pixel mining.') | |
| flags.DEFINE_integer('hard_example_mining_step', 100000, | |
| 'The training step in which the hard exampling mining ' | |
| 'kicks off. Note that we gradually reduce the mining ' | |
| 'percent to the top_k_percent_pixels. For example, if ' | |
| 'hard_example_mining_step=100K and ' | |
| 'top_k_percent_pixels=0.25, then mining percent will ' | |
| 'gradually reduce from 100% to 25% until 100K steps ' | |
| 'after which we only mine top 25% pixels. Only has an ' | |
| 'effect if top_k_percent_pixels < 1.0') | |
| def _build_deeplab(inputs_queue_or_samples, outputs_to_num_classes, | |
| ignore_label): | |
| """Builds a clone of DeepLab. | |
| Args: | |
| inputs_queue_or_samples: A prefetch queue for images and labels, or | |
| directly a dict of the samples. | |
| outputs_to_num_classes: A map from output type to the number of classes. | |
| For example, for the task of semantic segmentation with 21 semantic | |
| classes, we would have outputs_to_num_classes['semantic'] = 21. | |
| ignore_label: Ignore label. | |
| Returns: | |
| A map of maps from output_type (e.g., semantic prediction) to a | |
| dictionary of multi-scale logits names to logits. For each output_type, | |
| the dictionary has keys which correspond to the scales and values which | |
| correspond to the logits. For example, if `scales` equals [1.0, 1.5], | |
| then the keys would include 'merged_logits', 'logits_1.00' and | |
| 'logits_1.50'. | |
| Raises: | |
| ValueError: If classification_loss is not softmax, softmax_with_attention, | |
| or triplet. | |
| """ | |
| if hasattr(inputs_queue_or_samples, 'dequeue'): | |
| samples = inputs_queue_or_samples.dequeue() | |
| else: | |
| samples = inputs_queue_or_samples | |
| train_crop_size = (None if 0 in FLAGS.train_crop_size else | |
| FLAGS.train_crop_size) | |
| model_options = common.VideoModelOptions( | |
| outputs_to_num_classes=outputs_to_num_classes, | |
| crop_size=train_crop_size, | |
| atrous_rates=FLAGS.atrous_rates, | |
| output_stride=FLAGS.output_stride) | |
| if model_options.classification_loss == 'softmax_with_attention': | |
| clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones | |
| # Create summaries of ground truth labels. | |
| for n in range(clone_batch_size): | |
| tf.summary.image( | |
| 'gt_label_%d' % n, | |
| tf.cast(samples[common.LABEL][ | |
| n * FLAGS.train_num_frames_per_video: | |
| (n + 1) * FLAGS.train_num_frames_per_video], | |
| tf.uint8) * 32, max_outputs=FLAGS.train_num_frames_per_video) | |
| if common.PRECEDING_FRAME_LABEL in samples: | |
| preceding_frame_label = samples[common.PRECEDING_FRAME_LABEL] | |
| init_softmax = [] | |
| for n in range(clone_batch_size): | |
| init_softmax_n = embedding_utils.create_initial_softmax_from_labels( | |
| preceding_frame_label[n, tf.newaxis], | |
| samples[common.LABEL][n * FLAGS.train_num_frames_per_video, | |
| tf.newaxis], | |
| common.parse_decoder_output_stride(), | |
| reduce_labels=True) | |
| init_softmax_n = tf.squeeze(init_softmax_n, axis=0) | |
| init_softmax.append(init_softmax_n) | |
| tf.summary.image('preceding_frame_label', | |
| tf.cast(preceding_frame_label[n, tf.newaxis], | |
| tf.uint8) * 32) | |
| else: | |
| init_softmax = None | |
| outputs_to_scales_to_logits = ( | |
| model.multi_scale_logits_with_nearest_neighbor_matching( | |
| samples[common.IMAGE], | |
| model_options=model_options, | |
| image_pyramid=FLAGS.image_pyramid, | |
| weight_decay=FLAGS.weight_decay, | |
| is_training=True, | |
| fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, | |
| reference_labels=samples[common.LABEL], | |
| clone_batch_size=FLAGS.train_batch_size // FLAGS.num_clones, | |
| num_frames_per_video=FLAGS.train_num_frames_per_video, | |
| embedding_dimension=FLAGS.embedding_dimension, | |
| max_neighbors_per_object=FLAGS.train_max_neighbors_per_object, | |
| k_nearest_neighbors=FLAGS.k_nearest_neighbors, | |
| use_softmax_feedback=FLAGS.use_softmax_feedback, | |
| initial_softmax_feedback=init_softmax, | |
| embedding_seg_feature_dimension= | |
| FLAGS.embedding_seg_feature_dimension, | |
| embedding_seg_n_layers=FLAGS.embedding_seg_n_layers, | |
| embedding_seg_kernel_size=FLAGS.embedding_seg_kernel_size, | |
| embedding_seg_atrous_rates=FLAGS.embedding_seg_atrous_rates, | |
| normalize_nearest_neighbor_distances= | |
| FLAGS.normalize_nearest_neighbor_distances, | |
| also_attend_to_previous_frame=FLAGS.also_attend_to_previous_frame, | |
| damage_initial_previous_frame_mask= | |
| FLAGS.damage_initial_previous_frame_mask, | |
| use_local_previous_frame_attention= | |
| FLAGS.use_local_previous_frame_attention, | |
| previous_frame_attention_window_size= | |
| FLAGS.previous_frame_attention_window_size, | |
| use_first_frame_matching=FLAGS.use_first_frame_matching | |
| )) | |
| else: | |
| outputs_to_scales_to_logits = model.multi_scale_logits_v2( | |
| samples[common.IMAGE], | |
| model_options=model_options, | |
| image_pyramid=FLAGS.image_pyramid, | |
| weight_decay=FLAGS.weight_decay, | |
| is_training=True, | |
| fine_tune_batch_norm=FLAGS.fine_tune_batch_norm) | |
| if model_options.classification_loss == 'softmax': | |
| for output, num_classes in six.iteritems(outputs_to_num_classes): | |
| train_utils.add_softmax_cross_entropy_loss_for_each_scale( | |
| outputs_to_scales_to_logits[output], | |
| samples[common.LABEL], | |
| num_classes, | |
| ignore_label, | |
| loss_weight=1.0, | |
| upsample_logits=FLAGS.upsample_logits, | |
| scope=output) | |
| elif model_options.classification_loss == 'triplet': | |
| for output, _ in six.iteritems(outputs_to_num_classes): | |
| train_utils.add_triplet_loss_for_each_scale( | |
| FLAGS.train_batch_size // FLAGS.num_clones, | |
| FLAGS.train_num_frames_per_video, | |
| FLAGS.embedding_dimension, outputs_to_scales_to_logits[output], | |
| samples[common.LABEL], scope=output) | |
| elif model_options.classification_loss == 'softmax_with_attention': | |
| labels = samples[common.LABEL] | |
| batch_size = FLAGS.train_batch_size // FLAGS.num_clones | |
| num_frames_per_video = FLAGS.train_num_frames_per_video | |
| h, w = train_utils.resolve_shape(labels)[1:3] | |
| labels = tf.reshape(labels, tf.stack( | |
| [batch_size, num_frames_per_video, h, w, 1])) | |
| # Strip the reference labels off. | |
| if FLAGS.also_attend_to_previous_frame or FLAGS.use_softmax_feedback: | |
| n_ref_frames = 2 | |
| else: | |
| n_ref_frames = 1 | |
| labels = labels[:, n_ref_frames:] | |
| # Merge batch and time dimensions. | |
| labels = tf.reshape(labels, tf.stack( | |
| [batch_size * (num_frames_per_video - n_ref_frames), h, w, 1])) | |
| for output, num_classes in six.iteritems(outputs_to_num_classes): | |
| train_utils.add_dynamic_softmax_cross_entropy_loss_for_each_scale( | |
| outputs_to_scales_to_logits[output], | |
| labels, | |
| ignore_label, | |
| loss_weight=1.0, | |
| upsample_logits=FLAGS.upsample_logits, | |
| scope=output, | |
| top_k_percent_pixels=FLAGS.top_k_percent_pixels, | |
| hard_example_mining_step=FLAGS.hard_example_mining_step) | |
| else: | |
| raise ValueError('Only support softmax, softmax_with_attention' | |
| ' or triplet for classification_loss.') | |
| return outputs_to_scales_to_logits | |
| def main(unused_argv): | |
| # Set up deployment (i.e., multi-GPUs and/or multi-replicas). | |
| config = model_deploy.DeploymentConfig( | |
| num_clones=FLAGS.num_clones, | |
| clone_on_cpu=FLAGS.clone_on_cpu, | |
| replica_id=FLAGS.task, | |
| num_replicas=FLAGS.num_replicas, | |
| num_ps_tasks=FLAGS.num_ps_tasks) | |
| with tf.Graph().as_default(): | |
| with tf.device(config.inputs_device()): | |
| train_crop_size = (None if 0 in FLAGS.train_crop_size else | |
| FLAGS.train_crop_size) | |
| assert FLAGS.dataset | |
| assert len(FLAGS.dataset) == len(FLAGS.dataset_dir) | |
| if len(FLAGS.first_frame_finetuning) == 1: | |
| first_frame_finetuning = (list(FLAGS.first_frame_finetuning) | |
| * len(FLAGS.dataset)) | |
| else: | |
| first_frame_finetuning = FLAGS.first_frame_finetuning | |
| if len(FLAGS.three_frame_dataset) == 1: | |
| three_frame_dataset = (list(FLAGS.three_frame_dataset) | |
| * len(FLAGS.dataset)) | |
| else: | |
| three_frame_dataset = FLAGS.three_frame_dataset | |
| assert len(FLAGS.dataset) == len(first_frame_finetuning) | |
| assert len(FLAGS.dataset) == len(three_frame_dataset) | |
| datasets, samples_list = zip( | |
| *[_get_dataset_and_samples(config, train_crop_size, dataset, | |
| dataset_dir, bool(first_frame_finetuning_), | |
| bool(three_frame_dataset_)) | |
| for dataset, dataset_dir, first_frame_finetuning_, | |
| three_frame_dataset_ in zip(FLAGS.dataset, FLAGS.dataset_dir, | |
| first_frame_finetuning, | |
| three_frame_dataset)]) | |
| # Note that this way of doing things is wasteful since it will evaluate | |
| # all branches but just use one of them. But let's do it anyway for now, | |
| # since it's easy and will probably be fast enough. | |
| dataset = datasets[0] | |
| if len(samples_list) == 1: | |
| samples = samples_list[0] | |
| else: | |
| probabilities = FLAGS.dataset_sampling_probabilities | |
| if probabilities: | |
| assert len(probabilities) == len(samples_list) | |
| else: | |
| # Default to uniform probabilities. | |
| probabilities = [1.0 / len(samples_list) for _ in samples_list] | |
| probabilities = tf.constant(probabilities) | |
| logits = tf.log(probabilities[tf.newaxis]) | |
| rand_idx = tf.squeeze(tf.multinomial(logits, 1, output_dtype=tf.int32), | |
| axis=[0, 1]) | |
| def wrap(x): | |
| def f(): | |
| return x | |
| return f | |
| samples = tf.case({tf.equal(rand_idx, idx): wrap(s) | |
| for idx, s in enumerate(samples_list)}, | |
| exclusive=True) | |
| # Prefetch_queue requires the shape to be known at graph creation time. | |
| # So we only use it if we crop to a fixed size. | |
| if train_crop_size is None: | |
| inputs_queue = samples | |
| else: | |
| inputs_queue = prefetch_queue.prefetch_queue( | |
| samples, | |
| capacity=FLAGS.prefetch_queue_capacity_factor*config.num_clones, | |
| num_threads=FLAGS.prefetch_queue_num_threads) | |
| # Create the global step on the device storing the variables. | |
| with tf.device(config.variables_device()): | |
| global_step = tf.train.get_or_create_global_step() | |
| # Define the model and create clones. | |
| model_fn = _build_deeplab | |
| if FLAGS.classification_loss == 'triplet': | |
| embedding_dim = FLAGS.embedding_dimension | |
| output_type_to_dim = {'embedding': embedding_dim} | |
| else: | |
| output_type_to_dim = {common.OUTPUT_TYPE: dataset.num_classes} | |
| model_args = (inputs_queue, output_type_to_dim, dataset.ignore_label) | |
| clones = model_deploy.create_clones(config, model_fn, args=model_args) | |
| # Gather update_ops from the first clone. These contain, for example, | |
| # the updates for the batch_norm variables created by model_fn. | |
| first_clone_scope = config.clone_scope(0) | |
| update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) | |
| # Gather initial summaries. | |
| summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) | |
| # Add summaries for model variables. | |
| for model_var in tf.contrib.framework.get_model_variables(): | |
| summaries.add(tf.summary.histogram(model_var.op.name, model_var)) | |
| # Add summaries for losses. | |
| for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): | |
| summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) | |
| # Build the optimizer based on the device specification. | |
| with tf.device(config.optimizer_device()): | |
| learning_rate = train_utils.get_model_learning_rate( | |
| FLAGS.learning_policy, | |
| FLAGS.base_learning_rate, | |
| FLAGS.learning_rate_decay_step, | |
| FLAGS.learning_rate_decay_factor, | |
| FLAGS.training_number_of_steps, | |
| FLAGS.learning_power, | |
| FLAGS.slow_start_step, | |
| FLAGS.slow_start_learning_rate) | |
| optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) | |
| summaries.add(tf.summary.scalar('learning_rate', learning_rate)) | |
| startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps | |
| with tf.device(config.variables_device()): | |
| total_loss, grads_and_vars = model_deploy.optimize_clones( | |
| clones, optimizer) | |
| total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') | |
| summaries.add(tf.summary.scalar('total_loss', total_loss)) | |
| # Modify the gradients for biases and last layer variables. | |
| last_layers = model.get_extra_layer_scopes( | |
| FLAGS.last_layers_contain_logits_only) | |
| grad_mult = train_utils.get_model_gradient_multipliers( | |
| last_layers, FLAGS.last_layer_gradient_multiplier) | |
| if grad_mult: | |
| grads_and_vars = slim.learning.multiply_gradients(grads_and_vars, | |
| grad_mult) | |
| with tf.name_scope('grad_clipping'): | |
| grads_and_vars = slim.learning.clip_gradient_norms(grads_and_vars, 5.0) | |
| # Create histogram summaries for the gradients. | |
| # We have too many summaries for mldash, so disable this one for now. | |
| # for grad, var in grads_and_vars: | |
| # summaries.add(tf.summary.histogram( | |
| # var.name.replace(':0', '_0') + '/gradient', grad)) | |
| # Create gradient update op. | |
| grad_updates = optimizer.apply_gradients(grads_and_vars, | |
| global_step=global_step) | |
| update_ops.append(grad_updates) | |
| update_op = tf.group(*update_ops) | |
| with tf.control_dependencies([update_op]): | |
| train_tensor = tf.identity(total_loss, name='train_op') | |
| # Add the summaries from the first clone. These contain the summaries | |
| # created by model_fn and either optimize_clones() or _gather_clone_loss(). | |
| summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, | |
| first_clone_scope)) | |
| # Merge all summaries together. | |
| summary_op = tf.summary.merge(list(summaries)) | |
| # Soft placement allows placing on CPU ops without GPU implementation. | |
| session_config = tf.ConfigProto(allow_soft_placement=True, | |
| log_device_placement=False) | |
| # Start the training. | |
| slim.learning.train( | |
| train_tensor, | |
| logdir=FLAGS.train_logdir, | |
| log_every_n_steps=FLAGS.log_steps, | |
| master=FLAGS.master, | |
| number_of_steps=FLAGS.training_number_of_steps, | |
| is_chief=(FLAGS.task == 0), | |
| session_config=session_config, | |
| startup_delay_steps=startup_delay_steps, | |
| init_fn=train_utils.get_model_init_fn(FLAGS.train_logdir, | |
| FLAGS.tf_initial_checkpoint, | |
| FLAGS.initialize_last_layer, | |
| last_layers, | |
| ignore_missing_vars=True), | |
| summary_op=summary_op, | |
| save_summaries_secs=FLAGS.save_summaries_secs, | |
| save_interval_secs=FLAGS.save_interval_secs) | |
| def _get_dataset_and_samples(config, train_crop_size, dataset_name, | |
| dataset_dir, first_frame_finetuning, | |
| three_frame_dataset): | |
| """Creates dataset object and samples dict of tensor. | |
| Args: | |
| config: A DeploymentConfig. | |
| train_crop_size: Integer, the crop size used for training. | |
| dataset_name: String, the name of the dataset. | |
| dataset_dir: String, the directory of the dataset. | |
| first_frame_finetuning: Boolean, whether the used dataset is a dataset | |
| for first frame fine-tuning. | |
| three_frame_dataset: Boolean, whether the dataset has exactly three frames | |
| per video of which the first is to be used as reference and the two | |
| others are consecutive frames to be used as query frames. | |
| Returns: | |
| dataset: An instance of slim Dataset. | |
| samples: A dictionary of tensors for semantic segmentation. | |
| """ | |
| # Split the batch across GPUs. | |
| assert FLAGS.train_batch_size % config.num_clones == 0, ( | |
| 'Training batch size not divisble by number of clones (GPUs).') | |
| clone_batch_size = FLAGS.train_batch_size / config.num_clones | |
| if first_frame_finetuning: | |
| train_split = 'val' | |
| else: | |
| train_split = FLAGS.train_split | |
| data_type = 'tf_sequence_example' | |
| # Get dataset-dependent information. | |
| dataset = video_dataset.get_dataset( | |
| dataset_name, | |
| train_split, | |
| dataset_dir=dataset_dir, | |
| data_type=data_type) | |
| tf.gfile.MakeDirs(FLAGS.train_logdir) | |
| tf.logging.info('Training on %s set', train_split) | |
| samples = video_input_generator.get( | |
| dataset, | |
| FLAGS.train_num_frames_per_video, | |
| train_crop_size, | |
| clone_batch_size, | |
| num_readers=FLAGS.num_readers, | |
| num_threads=FLAGS.batch_num_threads, | |
| min_resize_value=FLAGS.min_resize_value, | |
| max_resize_value=FLAGS.max_resize_value, | |
| resize_factor=FLAGS.resize_factor, | |
| min_scale_factor=FLAGS.min_scale_factor, | |
| max_scale_factor=FLAGS.max_scale_factor, | |
| scale_factor_step_size=FLAGS.scale_factor_step_size, | |
| dataset_split=FLAGS.train_split, | |
| is_training=True, | |
| model_variant=FLAGS.model_variant, | |
| batch_capacity_factor=FLAGS.batch_capacity_factor, | |
| decoder_output_stride=common.parse_decoder_output_stride(), | |
| first_frame_finetuning=first_frame_finetuning, | |
| sample_only_first_frame_for_finetuning= | |
| FLAGS.sample_only_first_frame_for_finetuning, | |
| sample_adjacent_and_consistent_query_frames= | |
| FLAGS.sample_adjacent_and_consistent_query_frames or | |
| FLAGS.use_softmax_feedback, | |
| remap_labels_to_reference_frame=True, | |
| three_frame_dataset=three_frame_dataset, | |
| add_prev_frame_label=not FLAGS.also_attend_to_previous_frame | |
| ) | |
| return dataset, samples | |
| if __name__ == '__main__': | |
| flags.mark_flag_as_required('train_logdir') | |
| tf.logging.set_verbosity(tf.logging.INFO) | |
| tf.app.run() | |