Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Data utils for CIFAR-10 and CIFAR-100.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import copy | |
| import cPickle | |
| import os | |
| import augmentation_transforms | |
| import numpy as np | |
| import policies as found_policies | |
| import tensorflow as tf | |
| # pylint:disable=logging-format-interpolation | |
| class DataSet(object): | |
| """Dataset object that produces augmented training and eval data.""" | |
| def __init__(self, hparams): | |
| self.hparams = hparams | |
| self.epochs = 0 | |
| self.curr_train_index = 0 | |
| all_labels = [] | |
| self.good_policies = found_policies.good_policies() | |
| # Determine how many databatched to load | |
| num_data_batches_to_load = 5 | |
| total_batches_to_load = num_data_batches_to_load | |
| train_batches_to_load = total_batches_to_load | |
| assert hparams.train_size + hparams.validation_size <= 50000 | |
| if hparams.eval_test: | |
| total_batches_to_load += 1 | |
| # Determine how many images we have loaded | |
| total_dataset_size = 10000 * num_data_batches_to_load | |
| train_dataset_size = total_dataset_size | |
| if hparams.eval_test: | |
| total_dataset_size += 10000 | |
| if hparams.dataset == 'cifar10': | |
| all_data = np.empty((total_batches_to_load, 10000, 3072), dtype=np.uint8) | |
| elif hparams.dataset == 'cifar100': | |
| assert num_data_batches_to_load == 5 | |
| all_data = np.empty((1, 50000, 3072), dtype=np.uint8) | |
| if hparams.eval_test: | |
| test_data = np.empty((1, 10000, 3072), dtype=np.uint8) | |
| if hparams.dataset == 'cifar10': | |
| tf.logging.info('Cifar10') | |
| datafiles = [ | |
| 'data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', | |
| 'data_batch_5'] | |
| datafiles = datafiles[:train_batches_to_load] | |
| if hparams.eval_test: | |
| datafiles.append('test_batch') | |
| num_classes = 10 | |
| elif hparams.dataset == 'cifar100': | |
| datafiles = ['train'] | |
| if hparams.eval_test: | |
| datafiles.append('test') | |
| num_classes = 100 | |
| else: | |
| raise NotImplementedError('Unimplemented dataset: ', hparams.dataset) | |
| if hparams.dataset != 'test': | |
| for file_num, f in enumerate(datafiles): | |
| d = unpickle(os.path.join(hparams.data_path, f)) | |
| if f == 'test': | |
| test_data[0] = copy.deepcopy(d['data']) | |
| all_data = np.concatenate([all_data, test_data], axis=1) | |
| else: | |
| all_data[file_num] = copy.deepcopy(d['data']) | |
| if hparams.dataset == 'cifar10': | |
| labels = np.array(d['labels']) | |
| else: | |
| labels = np.array(d['fine_labels']) | |
| nsamples = len(labels) | |
| for idx in range(nsamples): | |
| all_labels.append(labels[idx]) | |
| all_data = all_data.reshape(total_dataset_size, 3072) | |
| all_data = all_data.reshape(-1, 3, 32, 32) | |
| all_data = all_data.transpose(0, 2, 3, 1).copy() | |
| all_data = all_data / 255.0 | |
| mean = augmentation_transforms.MEANS | |
| std = augmentation_transforms.STDS | |
| tf.logging.info('mean:{} std: {}'.format(mean, std)) | |
| all_data = (all_data - mean) / std | |
| all_labels = np.eye(num_classes)[np.array(all_labels, dtype=np.int32)] | |
| assert len(all_data) == len(all_labels) | |
| tf.logging.info( | |
| 'In CIFAR10 loader, number of images: {}'.format(len(all_data))) | |
| # Break off test data | |
| if hparams.eval_test: | |
| self.test_images = all_data[train_dataset_size:] | |
| self.test_labels = all_labels[train_dataset_size:] | |
| # Shuffle the rest of the data | |
| all_data = all_data[:train_dataset_size] | |
| all_labels = all_labels[:train_dataset_size] | |
| np.random.seed(0) | |
| perm = np.arange(len(all_data)) | |
| np.random.shuffle(perm) | |
| all_data = all_data[perm] | |
| all_labels = all_labels[perm] | |
| # Break into train and val | |
| train_size, val_size = hparams.train_size, hparams.validation_size | |
| assert 50000 >= train_size + val_size | |
| self.train_images = all_data[:train_size] | |
| self.train_labels = all_labels[:train_size] | |
| self.val_images = all_data[train_size:train_size + val_size] | |
| self.val_labels = all_labels[train_size:train_size + val_size] | |
| self.num_train = self.train_images.shape[0] | |
| def next_batch(self): | |
| """Return the next minibatch of augmented data.""" | |
| next_train_index = self.curr_train_index + self.hparams.batch_size | |
| if next_train_index > self.num_train: | |
| # Increase epoch number | |
| epoch = self.epochs + 1 | |
| self.reset() | |
| self.epochs = epoch | |
| batched_data = ( | |
| self.train_images[self.curr_train_index: | |
| self.curr_train_index + self.hparams.batch_size], | |
| self.train_labels[self.curr_train_index: | |
| self.curr_train_index + self.hparams.batch_size]) | |
| final_imgs = [] | |
| images, labels = batched_data | |
| for data in images: | |
| epoch_policy = self.good_policies[np.random.choice( | |
| len(self.good_policies))] | |
| final_img = augmentation_transforms.apply_policy( | |
| epoch_policy, data) | |
| final_img = augmentation_transforms.random_flip( | |
| augmentation_transforms.zero_pad_and_crop(final_img, 4)) | |
| # Apply cutout | |
| final_img = augmentation_transforms.cutout_numpy(final_img) | |
| final_imgs.append(final_img) | |
| batched_data = (np.array(final_imgs, np.float32), labels) | |
| self.curr_train_index += self.hparams.batch_size | |
| return batched_data | |
| def reset(self): | |
| """Reset training data and index into the training data.""" | |
| self.epochs = 0 | |
| # Shuffle the training data | |
| perm = np.arange(self.num_train) | |
| np.random.shuffle(perm) | |
| assert self.num_train == self.train_images.shape[ | |
| 0], 'Error incorrect shuffling mask' | |
| self.train_images = self.train_images[perm] | |
| self.train_labels = self.train_labels[perm] | |
| self.curr_train_index = 0 | |
| def unpickle(f): | |
| tf.logging.info('loading file: {}'.format(f)) | |
| fo = tf.gfile.Open(f, 'r') | |
| d = cPickle.load(fo) | |
| fo.close() | |
| return d | |