Spaces:
Build error
Build error
| import tensorflow as tf | |
| import tensorflow_datasets as tfds | |
| import jax | |
| import flax | |
| import numpy as np | |
| from PIL import Image | |
| import os | |
| from typing import Sequence | |
| from tqdm import tqdm | |
| import json | |
| from tqdm import tqdm | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def prefetch(dataset, n_prefetch): | |
| # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py | |
| ds_iter = iter(dataset) | |
| ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x), | |
| ds_iter) | |
| if n_prefetch: | |
| ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch) | |
| return ds_iter | |
| def get_data(data_dir, img_size, img_channels, num_classes, num_local_devices, batch_size, shuffle_buffer=1000): | |
| """ | |
| Args: | |
| data_dir (str): Root directory of the dataset. | |
| img_size (int): Image size for training. | |
| img_channels (int): Number of image channels. | |
| num_classes (int): Number of classes, 0 for no classes. | |
| num_local_devices (int): Number of devices. | |
| batch_size (int): Batch size (per device). | |
| shuffle_buffer (int): Buffer used for shuffling the dataset. | |
| Returns: | |
| (tf.data.Dataset): Dataset. | |
| """ | |
| def pre_process(serialized_example): | |
| feature = {'height': tf.io.FixedLenFeature([], tf.int64), | |
| 'width': tf.io.FixedLenFeature([], tf.int64), | |
| 'channels': tf.io.FixedLenFeature([], tf.int64), | |
| 'image': tf.io.FixedLenFeature([], tf.string), | |
| 'label': tf.io.FixedLenFeature([], tf.int64)} | |
| example = tf.io.parse_single_example(serialized_example, feature) | |
| height = tf.cast(example['height'], dtype=tf.int64) | |
| width = tf.cast(example['width'], dtype=tf.int64) | |
| channels = tf.cast(example['channels'], dtype=tf.int64) | |
| image = tf.io.decode_raw(example['image'], out_type=tf.uint8) | |
| image = tf.reshape(image, shape=[height, width, channels]) | |
| image = tf.cast(image, dtype='float32') | |
| image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True) | |
| image = tf.image.random_flip_left_right(image) | |
| image = (image - 127.5) / 127.5 | |
| label = tf.one_hot(example['label'], num_classes) | |
| return {'image': image, 'label': label} | |
| def shard(data): | |
| # Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C] | |
| # because the first dimension will be mapped across devices using jax.pmap | |
| data['image'] = tf.reshape(data['image'], [num_local_devices, -1, img_size, img_size, img_channels]) | |
| data['label'] = tf.reshape(data['label'], [num_local_devices, -1, num_classes]) | |
| return data | |
| logger.info('Loading TFRecord...') | |
| with tf.io.gfile.GFile(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin: | |
| dataset_info = json.load(fin) | |
| ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords')) | |
| ds = ds.shard(jax.process_count(), jax.process_index()) | |
| ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer)) | |
| ds = ds.map(pre_process, tf.data.AUTOTUNE) | |
| ds = ds.batch(batch_size * num_local_devices, drop_remainder=True) # uses per-worker batch size | |
| ds = ds.map(shard, tf.data.AUTOTUNE) | |
| ds = ds.prefetch(1) # prefetches the next batch | |
| return ds, dataset_info | |