Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Helper functions used for training AutoAugment models.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import numpy as np | |
| import tensorflow as tf | |
| def setup_loss(logits, labels): | |
| """Returns the cross entropy for the given `logits` and `labels`.""" | |
| predictions = tf.nn.softmax(logits) | |
| cost = tf.losses.softmax_cross_entropy(onehot_labels=labels, | |
| logits=logits) | |
| return predictions, cost | |
| def decay_weights(cost, weight_decay_rate): | |
| """Calculates the loss for l2 weight decay and adds it to `cost`.""" | |
| costs = [] | |
| for var in tf.trainable_variables(): | |
| costs.append(tf.nn.l2_loss(var)) | |
| cost += tf.multiply(weight_decay_rate, tf.add_n(costs)) | |
| return cost | |
| def eval_child_model(session, model, data_loader, mode): | |
| """Evaluates `model` on held out data depending on `mode`. | |
| Args: | |
| session: TensorFlow session the model will be run with. | |
| model: TensorFlow model that will be evaluated. | |
| data_loader: DataSet object that contains data that `model` will | |
| evaluate. | |
| mode: Will `model` either evaluate validation or test data. | |
| Returns: | |
| Accuracy of `model` when evaluated on the specified dataset. | |
| Raises: | |
| ValueError: if invalid dataset `mode` is specified. | |
| """ | |
| if mode == 'val': | |
| images = data_loader.val_images | |
| labels = data_loader.val_labels | |
| elif mode == 'test': | |
| images = data_loader.test_images | |
| labels = data_loader.test_labels | |
| else: | |
| raise ValueError('Not valid eval mode') | |
| assert len(images) == len(labels) | |
| tf.logging.info('model.batch_size is {}'.format(model.batch_size)) | |
| assert len(images) % model.batch_size == 0 | |
| eval_batches = int(len(images) / model.batch_size) | |
| for i in range(eval_batches): | |
| eval_images = images[i * model.batch_size:(i + 1) * model.batch_size] | |
| eval_labels = labels[i * model.batch_size:(i + 1) * model.batch_size] | |
| _ = session.run( | |
| model.eval_op, | |
| feed_dict={ | |
| model.images: eval_images, | |
| model.labels: eval_labels, | |
| }) | |
| return session.run(model.accuracy) | |
| def cosine_lr(learning_rate, epoch, iteration, batches_per_epoch, total_epochs): | |
| """Cosine Learning rate. | |
| Args: | |
| learning_rate: Initial learning rate. | |
| epoch: Current epoch we are one. This is one based. | |
| iteration: Current batch in this epoch. | |
| batches_per_epoch: Batches per epoch. | |
| total_epochs: Total epochs you are training for. | |
| Returns: | |
| The learning rate to be used for this current batch. | |
| """ | |
| t_total = total_epochs * batches_per_epoch | |
| t_cur = float(epoch * batches_per_epoch + iteration) | |
| return 0.5 * learning_rate * (1 + np.cos(np.pi * t_cur / t_total)) | |
| def get_lr(curr_epoch, hparams, iteration=None): | |
| """Returns the learning rate during training based on the current epoch.""" | |
| assert iteration is not None | |
| batches_per_epoch = int(hparams.train_size / hparams.batch_size) | |
| lr = cosine_lr(hparams.lr, curr_epoch, iteration, batches_per_epoch, | |
| hparams.num_epochs) | |
| return lr | |
| def run_epoch_training(session, model, data_loader, curr_epoch): | |
| """Runs one epoch of training for the model passed in. | |
| Args: | |
| session: TensorFlow session the model will be run with. | |
| model: TensorFlow model that will be evaluated. | |
| data_loader: DataSet object that contains data that `model` will | |
| evaluate. | |
| curr_epoch: How many of epochs of training have been done so far. | |
| Returns: | |
| The accuracy of 'model' on the training set | |
| """ | |
| steps_per_epoch = int(model.hparams.train_size / model.hparams.batch_size) | |
| tf.logging.info('steps per epoch: {}'.format(steps_per_epoch)) | |
| curr_step = session.run(model.global_step) | |
| assert curr_step % steps_per_epoch == 0 | |
| # Get the current learning rate for the model based on the current epoch | |
| curr_lr = get_lr(curr_epoch, model.hparams, iteration=0) | |
| tf.logging.info('lr of {} for epoch {}'.format(curr_lr, curr_epoch)) | |
| for step in xrange(steps_per_epoch): | |
| curr_lr = get_lr(curr_epoch, model.hparams, iteration=(step + 1)) | |
| # Update the lr rate variable to the current LR. | |
| model.lr_rate_ph.load(curr_lr, session=session) | |
| if step % 20 == 0: | |
| tf.logging.info('Training {}/{}'.format(step, steps_per_epoch)) | |
| train_images, train_labels = data_loader.next_batch() | |
| _, step, _ = session.run( | |
| [model.train_op, model.global_step, model.eval_op], | |
| feed_dict={ | |
| model.images: train_images, | |
| model.labels: train_labels, | |
| }) | |
| train_accuracy = session.run(model.accuracy) | |
| tf.logging.info('Train accuracy: {}'.format(train_accuracy)) | |
| return train_accuracy | |