Spaces:
Runtime error
Runtime error
| # Copyright 2017 Google Inc. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Download MNIST, Omniglot datasets for Rebar.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import urllib | |
| import gzip | |
| import os | |
| import config | |
| import struct | |
| import numpy as np | |
| import cPickle as pickle | |
| import datasets | |
| MNIST_URL = 'see README' | |
| MNIST_BINARIZED_URL = 'see README' | |
| OMNIGLOT_URL = 'see README' | |
| MNIST_FLOAT_TRAIN = 'train-images-idx3-ubyte' | |
| def load_mnist_float(local_filename): | |
| with open(local_filename, 'rb') as f: | |
| f.seek(4) | |
| nimages, rows, cols = struct.unpack('>iii', f.read(12)) | |
| dim = rows*cols | |
| images = np.fromfile(f, dtype=np.dtype(np.ubyte)) | |
| images = (images/255.0).astype('float32').reshape((nimages, dim)) | |
| return images | |
| if __name__ == '__main__': | |
| if not os.path.exists(config.DATA_DIR): | |
| os.makedirs(config.DATA_DIR) | |
| # Get MNIST and convert to npy file | |
| local_filename = os.path.join(config.DATA_DIR, MNIST_FLOAT_TRAIN) | |
| if not os.path.exists(local_filename): | |
| urllib.urlretrieve("%s/%s.gz" % (MNIST_URL, MNIST_FLOAT_TRAIN), local_filename+'.gz') | |
| with gzip.open(local_filename+'.gz', 'rb') as f: | |
| file_content = f.read() | |
| with open(local_filename, 'wb') as f: | |
| f.write(file_content) | |
| os.remove(local_filename+'.gz') | |
| mnist_float_train = load_mnist_float(local_filename)[:-10000] | |
| # save in a nice format | |
| np.save(os.path.join(config.DATA_DIR, config.MNIST_FLOAT), mnist_float_train) | |
| # Get binarized MNIST | |
| splits = ['train', 'valid', 'test'] | |
| mnist_binarized = [] | |
| for split in splits: | |
| filename = 'binarized_mnist_%s.amat' % split | |
| url = '%s/binarized_mnist_%s.amat' % (MNIST_BINARIZED_URL, split) | |
| local_filename = os.path.join(config.DATA_DIR, filename) | |
| if not os.path.exists(local_filename): | |
| urllib.urlretrieve(url, local_filename) | |
| with open(local_filename, 'rb') as f: | |
| mnist_binarized.append((np.array([map(int, line.split()) for line in f.readlines()]).astype('float32'), None)) | |
| # save in a nice format | |
| with open(os.path.join(config.DATA_DIR, config.MNIST_BINARIZED), 'w') as out: | |
| pickle.dump(mnist_binarized, out) | |
| # Get Omniglot | |
| local_filename = os.path.join(config.DATA_DIR, config.OMNIGLOT) | |
| if not os.path.exists(local_filename): | |
| urllib.urlretrieve(OMNIGLOT_URL, | |
| local_filename) | |