Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Builds the Shake-Shake Model.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import math | |
| import custom_ops as ops | |
| import tensorflow as tf | |
| def round_int(x): | |
| """Rounds `x` and then converts to an int.""" | |
| return int(math.floor(x + 0.5)) | |
| def shortcut(x, output_filters, stride): | |
| """Applies strided avg pool or zero padding to make output_filters match x.""" | |
| num_filters = int(x.shape[3]) | |
| if stride == 2: | |
| x = ops.avg_pool(x, 2, stride=stride, padding='SAME') | |
| if num_filters != output_filters: | |
| diff = output_filters - num_filters | |
| assert diff > 0 | |
| # Zero padd diff zeros | |
| padding = [[0, 0], [0, 0], [0, 0], [0, diff]] | |
| x = tf.pad(x, padding) | |
| return x | |
| def calc_prob(curr_layer, total_layers, p_l): | |
| """Calculates drop prob depending on the current layer.""" | |
| return 1 - (float(curr_layer) / total_layers) * p_l | |
| def bottleneck_layer(x, n, stride, prob, is_training, alpha, beta): | |
| """Bottleneck layer for shake drop model.""" | |
| assert alpha[1] > alpha[0] | |
| assert beta[1] > beta[0] | |
| with tf.variable_scope('bottleneck_{}'.format(prob)): | |
| input_layer = x | |
| x = ops.batch_norm(x, scope='bn_1_pre') | |
| x = ops.conv2d(x, n, 1, scope='1x1_conv_contract') | |
| x = ops.batch_norm(x, scope='bn_1_post') | |
| x = tf.nn.relu(x) | |
| x = ops.conv2d(x, n, 3, stride=stride, scope='3x3') | |
| x = ops.batch_norm(x, scope='bn_2') | |
| x = tf.nn.relu(x) | |
| x = ops.conv2d(x, n * 4, 1, scope='1x1_conv_expand') | |
| x = ops.batch_norm(x, scope='bn_3') | |
| # Apply regularization here | |
| # Sample bernoulli with prob | |
| if is_training: | |
| batch_size = tf.shape(x)[0] | |
| bern_shape = [batch_size, 1, 1, 1] | |
| random_tensor = prob | |
| random_tensor += tf.random_uniform(bern_shape, dtype=tf.float32) | |
| binary_tensor = tf.floor(random_tensor) | |
| alpha_values = tf.random_uniform( | |
| [batch_size, 1, 1, 1], minval=alpha[0], maxval=alpha[1], | |
| dtype=tf.float32) | |
| beta_values = tf.random_uniform( | |
| [batch_size, 1, 1, 1], minval=beta[0], maxval=beta[1], | |
| dtype=tf.float32) | |
| rand_forward = ( | |
| binary_tensor + alpha_values - binary_tensor * alpha_values) | |
| rand_backward = ( | |
| binary_tensor + beta_values - binary_tensor * beta_values) | |
| x = x * rand_backward + tf.stop_gradient(x * rand_forward - | |
| x * rand_backward) | |
| else: | |
| expected_alpha = (alpha[1] + alpha[0])/2 | |
| # prob is the expectation of the bernoulli variable | |
| x = (prob + expected_alpha - prob * expected_alpha) * x | |
| res = shortcut(input_layer, n * 4, stride) | |
| return x + res | |
| def build_shake_drop_model(images, num_classes, is_training): | |
| """Builds the PyramidNet Shake-Drop model. | |
| Build the PyramidNet Shake-Drop model from https://arxiv.org/abs/1802.02375. | |
| Args: | |
| images: Tensor of images that will be fed into the Wide ResNet Model. | |
| num_classes: Number of classed that the model needs to predict. | |
| is_training: Is the model training or not. | |
| Returns: | |
| The logits of the PyramidNet Shake-Drop model. | |
| """ | |
| # ShakeDrop Hparams | |
| p_l = 0.5 | |
| alpha_shake = [-1, 1] | |
| beta_shake = [0, 1] | |
| # PyramidNet Hparams | |
| alpha = 200 | |
| depth = 272 | |
| # This is for the bottleneck architecture specifically | |
| n = int((depth - 2) / 9) | |
| start_channel = 16 | |
| add_channel = alpha / (3 * n) | |
| # Building the models | |
| x = images | |
| x = ops.conv2d(x, 16, 3, scope='init_conv') | |
| x = ops.batch_norm(x, scope='init_bn') | |
| layer_num = 1 | |
| total_layers = n * 3 | |
| start_channel += add_channel | |
| prob = calc_prob(layer_num, total_layers, p_l) | |
| x = bottleneck_layer( | |
| x, round_int(start_channel), 1, prob, is_training, alpha_shake, | |
| beta_shake) | |
| layer_num += 1 | |
| for _ in range(1, n): | |
| start_channel += add_channel | |
| prob = calc_prob(layer_num, total_layers, p_l) | |
| x = bottleneck_layer( | |
| x, round_int(start_channel), 1, prob, is_training, alpha_shake, | |
| beta_shake) | |
| layer_num += 1 | |
| start_channel += add_channel | |
| prob = calc_prob(layer_num, total_layers, p_l) | |
| x = bottleneck_layer( | |
| x, round_int(start_channel), 2, prob, is_training, alpha_shake, | |
| beta_shake) | |
| layer_num += 1 | |
| for _ in range(1, n): | |
| start_channel += add_channel | |
| prob = calc_prob(layer_num, total_layers, p_l) | |
| x = bottleneck_layer( | |
| x, round_int(start_channel), 1, prob, is_training, alpha_shake, | |
| beta_shake) | |
| layer_num += 1 | |
| start_channel += add_channel | |
| prob = calc_prob(layer_num, total_layers, p_l) | |
| x = bottleneck_layer( | |
| x, round_int(start_channel), 2, prob, is_training, alpha_shake, | |
| beta_shake) | |
| layer_num += 1 | |
| for _ in range(1, n): | |
| start_channel += add_channel | |
| prob = calc_prob(layer_num, total_layers, p_l) | |
| x = bottleneck_layer( | |
| x, round_int(start_channel), 1, prob, is_training, alpha_shake, | |
| beta_shake) | |
| layer_num += 1 | |
| assert layer_num - 1 == total_layers | |
| x = ops.batch_norm(x, scope='final_bn') | |
| x = tf.nn.relu(x) | |
| x = ops.global_avg_pool(x) | |
| # Fully connected | |
| logits = ops.fc(x, num_classes) | |
| return logits | |