Spaces:
Runtime error
Runtime error
| # Copyright 2016 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Functions to create a DSN model and add the different losses to it. | |
| Specifically, in this file we define the: | |
| - Shared Encoding Similarity Loss Module, with: | |
| - The MMD Similarity method | |
| - The Correlation Similarity method | |
| - The Gradient Reversal (Domain-Adversarial) method | |
| - Difference Loss Module | |
| - Reconstruction Loss Module | |
| - Task Loss Module | |
| """ | |
| from functools import partial | |
| import tensorflow as tf | |
| import losses | |
| import models | |
| import utils | |
| slim = tf.contrib.slim | |
| ################################################################################ | |
| # HELPER FUNCTIONS | |
| ################################################################################ | |
| def dsn_loss_coefficient(params): | |
| """The global_step-dependent weight that specifies when to kick in DSN losses. | |
| Args: | |
| params: A dictionary of parameters. Expecting 'domain_separation_startpoint' | |
| Returns: | |
| A weight to that effectively enables or disables the DSN-related losses, | |
| i.e. similarity, difference, and reconstruction losses. | |
| """ | |
| return tf.where( | |
| tf.less(slim.get_or_create_global_step(), | |
| params['domain_separation_startpoint']), 1e-10, 1.0) | |
| ################################################################################ | |
| # MODEL CREATION | |
| ################################################################################ | |
| def create_model(source_images, source_labels, domain_selection_mask, | |
| target_images, target_labels, similarity_loss, params, | |
| basic_tower_name): | |
| """Creates a DSN model. | |
| Args: | |
| source_images: images from the source domain, a tensor of size | |
| [batch_size, height, width, channels] | |
| source_labels: a dictionary with the name, tensor pairs. 'classes' is one- | |
| hot for the number of classes. | |
| domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes | |
| the labeled images that belong to the source domain. | |
| target_images: images from the target domain, a tensor of size | |
| [batch_size, height width, channels]. | |
| target_labels: a dictionary with the name, tensor pairs. | |
| similarity_loss: The type of method to use for encouraging | |
| the codes from the shared encoder to be similar. | |
| params: A dictionary of parameters. Expecting 'weight_decay', | |
| 'layers_to_regularize', 'use_separation', 'domain_separation_startpoint', | |
| 'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name', | |
| 'decoder_name', 'encoder_name' | |
| basic_tower_name: the name of the tower to use for the shared encoder. | |
| Raises: | |
| ValueError: if the arch is not one of the available architectures. | |
| """ | |
| network = getattr(models, basic_tower_name) | |
| num_classes = source_labels['classes'].get_shape().as_list()[1] | |
| # Make sure we are using the appropriate number of classes. | |
| network = partial(network, num_classes=num_classes) | |
| # Add the classification/pose estimation loss to the source domain. | |
| source_endpoints = add_task_loss(source_images, source_labels, network, | |
| params) | |
| if similarity_loss == 'none': | |
| # No domain adaptation, we can stop here. | |
| return | |
| with tf.variable_scope('towers', reuse=True): | |
| target_logits, target_endpoints = network( | |
| target_images, weight_decay=params['weight_decay'], prefix='target') | |
| # Plot target accuracy of the train set. | |
| target_accuracy = utils.accuracy( | |
| tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1)) | |
| if 'quaternions' in target_labels: | |
| target_quaternion_loss = losses.log_quaternion_loss( | |
| target_labels['quaternions'], target_endpoints['quaternion_pred'], | |
| params) | |
| tf.summary.scalar('eval/Target quaternions', target_quaternion_loss) | |
| tf.summary.scalar('eval/Target accuracy', target_accuracy) | |
| source_shared = source_endpoints[params['layers_to_regularize']] | |
| target_shared = target_endpoints[params['layers_to_regularize']] | |
| # When using the semisupervised model we include labeled target data in the | |
| # source classifier. We do not want to include these target domain when | |
| # we use the similarity loss. | |
| indices = tf.range(0, source_shared.get_shape().as_list()[0]) | |
| indices = tf.boolean_mask(indices, domain_selection_mask) | |
| add_similarity_loss(similarity_loss, | |
| tf.gather(source_shared, indices), | |
| tf.gather(target_shared, indices), params) | |
| if params['use_separation']: | |
| add_autoencoders( | |
| source_images, | |
| source_shared, | |
| target_images, | |
| target_shared, | |
| params=params,) | |
| def add_similarity_loss(method_name, | |
| source_samples, | |
| target_samples, | |
| params, | |
| scope=None): | |
| """Adds a loss encouraging the shared encoding from each domain to be similar. | |
| Args: | |
| method_name: the name of the encoding similarity method to use. Valid | |
| options include `dann_loss', `mmd_loss' or `correlation_loss'. | |
| source_samples: a tensor of shape [num_samples, num_features]. | |
| target_samples: a tensor of shape [num_samples, num_features]. | |
| params: a dictionary of parameters. Expecting 'gamma_weight'. | |
| scope: optional name scope for summary tags. | |
| Raises: | |
| ValueError: if `method_name` is not recognized. | |
| """ | |
| weight = dsn_loss_coefficient(params) * params['gamma_weight'] | |
| method = getattr(losses, method_name) | |
| method(source_samples, target_samples, weight, scope) | |
| def add_reconstruction_loss(recon_loss_name, images, recons, weight, domain): | |
| """Adds a reconstruction loss. | |
| Args: | |
| recon_loss_name: The name of the reconstruction loss. | |
| images: A `Tensor` of size [batch_size, height, width, 3]. | |
| recons: A `Tensor` whose size matches `images`. | |
| weight: A scalar coefficient for the loss. | |
| domain: The name of the domain being reconstructed. | |
| Raises: | |
| ValueError: If `recon_loss_name` is not recognized. | |
| """ | |
| if recon_loss_name == 'sum_of_pairwise_squares': | |
| loss_fn = tf.contrib.losses.mean_pairwise_squared_error | |
| elif recon_loss_name == 'sum_of_squares': | |
| loss_fn = tf.contrib.losses.mean_squared_error | |
| else: | |
| raise ValueError('recon_loss_name value [%s] not recognized.' % | |
| recon_loss_name) | |
| loss = loss_fn(recons, images, weight) | |
| assert_op = tf.Assert(tf.is_finite(loss), [loss]) | |
| with tf.control_dependencies([assert_op]): | |
| tf.summary.scalar('losses/%s Recon Loss' % domain, loss) | |
| def add_autoencoders(source_data, source_shared, target_data, target_shared, | |
| params): | |
| """Adds the encoders/decoders for our domain separation model w/ incoherence. | |
| Args: | |
| source_data: images from the source domain, a tensor of size | |
| [batch_size, height, width, channels] | |
| source_shared: a tensor with first dimension batch_size | |
| target_data: images from the target domain, a tensor of size | |
| [batch_size, height, width, channels] | |
| target_shared: a tensor with first dimension batch_size | |
| params: A dictionary of parameters. Expecting 'layers_to_regularize', | |
| 'beta_weight', 'alpha_weight', 'recon_loss_name', 'decoder_name', | |
| 'encoder_name', 'weight_decay' | |
| """ | |
| def normalize_images(images): | |
| images -= tf.reduce_min(images) | |
| return images / tf.reduce_max(images) | |
| def concat_operation(shared_repr, private_repr): | |
| return shared_repr + private_repr | |
| mu = dsn_loss_coefficient(params) | |
| # The layer to concatenate the networks at. | |
| concat_layer = params['layers_to_regularize'] | |
| # The coefficient for modulating the private/shared difference loss. | |
| difference_loss_weight = params['beta_weight'] * mu | |
| # The reconstruction weight. | |
| recon_loss_weight = params['alpha_weight'] * mu | |
| # The reconstruction loss to use. | |
| recon_loss_name = params['recon_loss_name'] | |
| # The decoder/encoder to use. | |
| decoder_name = params['decoder_name'] | |
| encoder_name = params['encoder_name'] | |
| _, height, width, _ = source_data.get_shape().as_list() | |
| code_size = source_shared.get_shape().as_list()[-1] | |
| weight_decay = params['weight_decay'] | |
| encoder_fn = getattr(models, encoder_name) | |
| # Target Auto-encoding. | |
| with tf.variable_scope('source_encoder'): | |
| source_endpoints = encoder_fn( | |
| source_data, code_size, weight_decay=weight_decay) | |
| with tf.variable_scope('target_encoder'): | |
| target_endpoints = encoder_fn( | |
| target_data, code_size, weight_decay=weight_decay) | |
| decoder_fn = getattr(models, decoder_name) | |
| decoder = partial( | |
| decoder_fn, | |
| height=height, | |
| width=width, | |
| channels=source_data.get_shape().as_list()[-1], | |
| weight_decay=weight_decay) | |
| # Source Auto-encoding. | |
| source_private = source_endpoints[concat_layer] | |
| target_private = target_endpoints[concat_layer] | |
| with tf.variable_scope('decoder'): | |
| source_recons = decoder(concat_operation(source_shared, source_private)) | |
| with tf.variable_scope('decoder', reuse=True): | |
| source_private_recons = decoder( | |
| concat_operation(tf.zeros_like(source_private), source_private)) | |
| source_shared_recons = decoder( | |
| concat_operation(source_shared, tf.zeros_like(source_shared))) | |
| with tf.variable_scope('decoder', reuse=True): | |
| target_recons = decoder(concat_operation(target_shared, target_private)) | |
| target_shared_recons = decoder( | |
| concat_operation(target_shared, tf.zeros_like(target_shared))) | |
| target_private_recons = decoder( | |
| concat_operation(tf.zeros_like(target_private), target_private)) | |
| losses.difference_loss( | |
| source_private, | |
| source_shared, | |
| weight=difference_loss_weight, | |
| name='Source') | |
| losses.difference_loss( | |
| target_private, | |
| target_shared, | |
| weight=difference_loss_weight, | |
| name='Target') | |
| add_reconstruction_loss(recon_loss_name, source_data, source_recons, | |
| recon_loss_weight, 'source') | |
| add_reconstruction_loss(recon_loss_name, target_data, target_recons, | |
| recon_loss_weight, 'target') | |
| # Add summaries | |
| source_reconstructions = tf.concat( | |
| axis=2, | |
| values=map(normalize_images, [ | |
| source_data, source_recons, source_shared_recons, | |
| source_private_recons | |
| ])) | |
| target_reconstructions = tf.concat( | |
| axis=2, | |
| values=map(normalize_images, [ | |
| target_data, target_recons, target_shared_recons, | |
| target_private_recons | |
| ])) | |
| tf.summary.image( | |
| 'Source Images:Recons:RGB', | |
| source_reconstructions[:, :, :, :3], | |
| max_outputs=10) | |
| tf.summary.image( | |
| 'Target Images:Recons:RGB', | |
| target_reconstructions[:, :, :, :3], | |
| max_outputs=10) | |
| if source_reconstructions.get_shape().as_list()[3] == 4: | |
| tf.summary.image( | |
| 'Source Images:Recons:Depth', | |
| source_reconstructions[:, :, :, 3:4], | |
| max_outputs=10) | |
| tf.summary.image( | |
| 'Target Images:Recons:Depth', | |
| target_reconstructions[:, :, :, 3:4], | |
| max_outputs=10) | |
| def add_task_loss(source_images, source_labels, basic_tower, params): | |
| """Adds a classification and/or pose estimation loss to the model. | |
| Args: | |
| source_images: images from the source domain, a tensor of size | |
| [batch_size, height, width, channels] | |
| source_labels: labels from the source domain, a tensor of size [batch_size]. | |
| or a tuple of (quaternions, class_labels) | |
| basic_tower: a function that creates the single tower of the model. | |
| params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'. | |
| Returns: | |
| The source endpoints. | |
| Raises: | |
| RuntimeError: if basic tower does not support pose estimation. | |
| """ | |
| with tf.variable_scope('towers'): | |
| source_logits, source_endpoints = basic_tower( | |
| source_images, weight_decay=params['weight_decay'], prefix='Source') | |
| if 'quaternions' in source_labels: # We have pose estimation as well | |
| if 'quaternion_pred' not in source_endpoints: | |
| raise RuntimeError('Please use a model for estimation e.g. pose_mini') | |
| loss = losses.log_quaternion_loss(source_labels['quaternions'], | |
| source_endpoints['quaternion_pred'], | |
| params) | |
| assert_op = tf.Assert(tf.is_finite(loss), [loss]) | |
| with tf.control_dependencies([assert_op]): | |
| quaternion_loss = loss | |
| tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss) | |
| slim.losses.add_loss(quaternion_loss * params['pose_weight']) | |
| tf.summary.scalar('losses/quaternion_loss', quaternion_loss) | |
| classification_loss = tf.losses.softmax_cross_entropy( | |
| source_labels['classes'], source_logits) | |
| tf.summary.scalar('losses/classification_loss', classification_loss) | |
| return source_endpoints | |