Spaces:
Runtime error
Runtime error
| # Lint as: python2, python3 | |
| # Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Cell structure used by NAS.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import functools | |
| from six.moves import range | |
| from six.moves import zip | |
| import tensorflow as tf | |
| from tensorflow.contrib import framework as contrib_framework | |
| from tensorflow.contrib import slim as contrib_slim | |
| from deeplab.core import xception as xception_utils | |
| from deeplab.core.utils import resize_bilinear | |
| from deeplab.core.utils import scale_dimension | |
| from tensorflow.contrib.slim.nets import resnet_utils | |
| arg_scope = contrib_framework.arg_scope | |
| slim = contrib_slim | |
| separable_conv2d_same = functools.partial(xception_utils.separable_conv2d_same, | |
| regularize_depthwise=True) | |
| class NASBaseCell(object): | |
| """NASNet Cell class that is used as a 'layer' in image architectures.""" | |
| def __init__(self, num_conv_filters, operations, used_hiddenstates, | |
| hiddenstate_indices, drop_path_keep_prob, total_num_cells, | |
| total_training_steps, batch_norm_fn=slim.batch_norm): | |
| """Init function. | |
| For more details about NAS cell, see | |
| https://arxiv.org/abs/1707.07012 and https://arxiv.org/abs/1712.00559. | |
| Args: | |
| num_conv_filters: The number of filters for each convolution operation. | |
| operations: List of operations that are performed in the NASNet Cell in | |
| order. | |
| used_hiddenstates: Binary array that signals if the hiddenstate was used | |
| within the cell. This is used to determine what outputs of the cell | |
| should be concatenated together. | |
| hiddenstate_indices: Determines what hiddenstates should be combined | |
| together with the specified operations to create the NASNet cell. | |
| drop_path_keep_prob: Float, drop path keep probability. | |
| total_num_cells: Integer, total number of cells. | |
| total_training_steps: Integer, total training steps. | |
| batch_norm_fn: Function, batch norm function. Defaults to | |
| slim.batch_norm. | |
| """ | |
| if len(hiddenstate_indices) != len(operations): | |
| raise ValueError( | |
| 'Number of hiddenstate_indices and operations should be the same.') | |
| if len(operations) % 2: | |
| raise ValueError('Number of operations should be even.') | |
| self._num_conv_filters = num_conv_filters | |
| self._operations = operations | |
| self._used_hiddenstates = used_hiddenstates | |
| self._hiddenstate_indices = hiddenstate_indices | |
| self._drop_path_keep_prob = drop_path_keep_prob | |
| self._total_num_cells = total_num_cells | |
| self._total_training_steps = total_training_steps | |
| self._batch_norm_fn = batch_norm_fn | |
| def __call__(self, net, scope, filter_scaling, stride, prev_layer, cell_num): | |
| """Runs the conv cell.""" | |
| self._cell_num = cell_num | |
| self._filter_scaling = filter_scaling | |
| self._filter_size = int(self._num_conv_filters * filter_scaling) | |
| with tf.variable_scope(scope): | |
| net = self._cell_base(net, prev_layer) | |
| for i in range(len(self._operations) // 2): | |
| with tf.variable_scope('comb_iter_{}'.format(i)): | |
| h1 = net[self._hiddenstate_indices[i * 2]] | |
| h2 = net[self._hiddenstate_indices[i * 2 + 1]] | |
| with tf.variable_scope('left'): | |
| h1 = self._apply_conv_operation( | |
| h1, self._operations[i * 2], stride, | |
| self._hiddenstate_indices[i * 2] < 2) | |
| with tf.variable_scope('right'): | |
| h2 = self._apply_conv_operation( | |
| h2, self._operations[i * 2 + 1], stride, | |
| self._hiddenstate_indices[i * 2 + 1] < 2) | |
| with tf.variable_scope('combine'): | |
| h = h1 + h2 | |
| net.append(h) | |
| with tf.variable_scope('cell_output'): | |
| net = self._combine_unused_states(net) | |
| return net | |
| def _cell_base(self, net, prev_layer): | |
| """Runs the beginning of the conv cell before the chosen ops are run.""" | |
| filter_size = self._filter_size | |
| if prev_layer is None: | |
| prev_layer = net | |
| else: | |
| if net.shape[2] != prev_layer.shape[2]: | |
| prev_layer = resize_bilinear( | |
| prev_layer, tf.shape(net)[1:3], prev_layer.dtype) | |
| if filter_size != prev_layer.shape[3]: | |
| prev_layer = tf.nn.relu(prev_layer) | |
| prev_layer = slim.conv2d(prev_layer, filter_size, 1, scope='prev_1x1') | |
| prev_layer = self._batch_norm_fn(prev_layer, scope='prev_bn') | |
| net = tf.nn.relu(net) | |
| net = slim.conv2d(net, filter_size, 1, scope='1x1') | |
| net = self._batch_norm_fn(net, scope='beginning_bn') | |
| net = tf.split(axis=3, num_or_size_splits=1, value=net) | |
| net.append(prev_layer) | |
| return net | |
| def _apply_conv_operation(self, net, operation, stride, | |
| is_from_original_input): | |
| """Applies the predicted conv operation to net.""" | |
| if stride > 1 and not is_from_original_input: | |
| stride = 1 | |
| input_filters = net.shape[3] | |
| filter_size = self._filter_size | |
| if 'separable' in operation: | |
| num_layers = int(operation.split('_')[-1]) | |
| kernel_size = int(operation.split('x')[0][-1]) | |
| for layer_num in range(num_layers): | |
| net = tf.nn.relu(net) | |
| net = separable_conv2d_same( | |
| net, | |
| filter_size, | |
| kernel_size, | |
| depth_multiplier=1, | |
| scope='separable_{0}x{0}_{1}'.format(kernel_size, layer_num + 1), | |
| stride=stride) | |
| net = self._batch_norm_fn( | |
| net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1)) | |
| stride = 1 | |
| elif 'atrous' in operation: | |
| kernel_size = int(operation.split('x')[0][-1]) | |
| net = tf.nn.relu(net) | |
| if stride == 2: | |
| scaled_height = scale_dimension(tf.shape(net)[1], 0.5) | |
| scaled_width = scale_dimension(tf.shape(net)[2], 0.5) | |
| net = resize_bilinear(net, [scaled_height, scaled_width], net.dtype) | |
| net = resnet_utils.conv2d_same( | |
| net, filter_size, kernel_size, rate=1, stride=1, | |
| scope='atrous_{0}x{0}'.format(kernel_size)) | |
| else: | |
| net = resnet_utils.conv2d_same( | |
| net, filter_size, kernel_size, rate=2, stride=1, | |
| scope='atrous_{0}x{0}'.format(kernel_size)) | |
| net = self._batch_norm_fn(net, scope='bn_atr_{0}x{0}'.format(kernel_size)) | |
| elif operation in ['none']: | |
| if stride > 1 or (input_filters != filter_size): | |
| net = tf.nn.relu(net) | |
| net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1') | |
| net = self._batch_norm_fn(net, scope='bn_1') | |
| elif 'pool' in operation: | |
| pooling_type = operation.split('_')[0] | |
| pooling_shape = int(operation.split('_')[-1].split('x')[0]) | |
| if pooling_type == 'avg': | |
| net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding='SAME') | |
| elif pooling_type == 'max': | |
| net = slim.max_pool2d(net, pooling_shape, stride=stride, padding='SAME') | |
| else: | |
| raise ValueError('Unimplemented pooling type: ', pooling_type) | |
| if input_filters != filter_size: | |
| net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1') | |
| net = self._batch_norm_fn(net, scope='bn_1') | |
| else: | |
| raise ValueError('Unimplemented operation', operation) | |
| if operation != 'none': | |
| net = self._apply_drop_path(net) | |
| return net | |
| def _combine_unused_states(self, net): | |
| """Concatenates the unused hidden states of the cell.""" | |
| used_hiddenstates = self._used_hiddenstates | |
| states_to_combine = ([ | |
| h for h, is_used in zip(net, used_hiddenstates) if not is_used]) | |
| net = tf.concat(values=states_to_combine, axis=3) | |
| return net | |
| def _apply_drop_path(self, net): | |
| """Apply drop_path regularization.""" | |
| drop_path_keep_prob = self._drop_path_keep_prob | |
| if drop_path_keep_prob < 1.0: | |
| # Scale keep prob by layer number. | |
| assert self._cell_num != -1 | |
| layer_ratio = (self._cell_num + 1) / float(self._total_num_cells) | |
| drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob) | |
| # Decrease keep prob over time. | |
| current_step = tf.cast(tf.train.get_or_create_global_step(), tf.float32) | |
| current_ratio = tf.minimum(1.0, current_step / self._total_training_steps) | |
| drop_path_keep_prob = (1 - current_ratio * (1 - drop_path_keep_prob)) | |
| # Drop path. | |
| noise_shape = [tf.shape(net)[0], 1, 1, 1] | |
| random_tensor = drop_path_keep_prob | |
| random_tensor += tf.random_uniform(noise_shape, dtype=tf.float32) | |
| binary_tensor = tf.cast(tf.floor(random_tensor), net.dtype) | |
| keep_prob_inv = tf.cast(1.0 / drop_path_keep_prob, net.dtype) | |
| net = net * keep_prob_inv * binary_tensor | |
| return net | |