Spaces:
Runtime error
Runtime error
| """ Dataset parser interface that wraps TFDS datasets | |
| Wraps many (most?) TFDS image-classification datasets | |
| from https://github.com/tensorflow/datasets | |
| https://www.tensorflow.org/datasets/catalog/overview#image_classification | |
| Hacked together by / Copyright 2020 Ross Wightman | |
| """ | |
| import os | |
| import io | |
| import math | |
| import torch | |
| import torch.distributed as dist | |
| from PIL import Image | |
| try: | |
| import tensorflow as tf | |
| tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) | |
| import tensorflow_datasets as tfds | |
| except ImportError as e: | |
| print(e) | |
| print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") | |
| exit(1) | |
| from .parser import Parser | |
| MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities | |
| SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue | |
| PREFETCH_SIZE = 2048 # samples to prefetch | |
| def even_split_indices(split, n, num_samples): | |
| partitions = [round(i * num_samples / n) for i in range(n + 1)] | |
| return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)] | |
| class ParserTfds(Parser): | |
| """ Wrap Tensorflow Datasets for use in PyTorch | |
| There several things to be aware of: | |
| * To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of | |
| dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last | |
| https://github.com/pytorch/pytorch/issues/33413 | |
| * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch | |
| from each worker could be a different size. For training this is worked around by option above, for | |
| validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced | |
| across replicas are of same size. This will slightly alter the results, distributed validation will not be | |
| 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse | |
| since there are up to N * J extra samples with IterableDatasets. | |
| * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of | |
| replicas and dataloader workers you can use. For really small datasets that only contain a few shards | |
| you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the | |
| benefit of distributed training or fast dataloading should be much less for small datasets. | |
| * This wrapper is currently configured to return individual, decompressed image samples from the TFDS | |
| dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible | |
| to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream | |
| components. | |
| """ | |
| def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0): | |
| super().__init__() | |
| self.root = root | |
| self.split = split | |
| self.shuffle = shuffle | |
| self.is_training = is_training | |
| if self.is_training: | |
| assert batch_size is not None,\ | |
| "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" | |
| self.batch_size = batch_size | |
| self.repeats = repeats | |
| self.subsplit = None | |
| self.builder = tfds.builder(name, data_dir=root) | |
| # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call | |
| # download_and_prepare() by default here as it's caused issues generating unwanted paths. | |
| self.num_samples = self.builder.info.splits[split].num_examples | |
| self.ds = None # initialized lazily on each dataloader worker process | |
| self.worker_info = None | |
| self.dist_rank = 0 | |
| self.dist_num_replicas = 1 | |
| if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: | |
| self.dist_rank = dist.get_rank() | |
| self.dist_num_replicas = dist.get_world_size() | |
| def _lazy_init(self): | |
| """ Lazily initialize the dataset. | |
| This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that | |
| will be using the dataset instance. The __init__ method is called on the main process, | |
| this will be called in a dataloader worker process. | |
| NOTE: There will be problems if you try to re-use this dataset across different loader/worker | |
| instances once it has been initialized. Do not call any dataset methods that can call _lazy_init | |
| before it is passed to dataloader. | |
| """ | |
| worker_info = torch.utils.data.get_worker_info() | |
| # setup input context to split dataset across distributed processes | |
| split = self.split | |
| num_workers = 1 | |
| if worker_info is not None: | |
| self.worker_info = worker_info | |
| num_workers = worker_info.num_workers | |
| global_num_workers = self.dist_num_replicas * num_workers | |
| worker_id = worker_info.id | |
| # FIXME I need to spend more time figuring out the best way to distribute/split data across | |
| # combo of distributed replicas + dataloader worker processes | |
| """ | |
| InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. | |
| My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) | |
| between the splits each iteration, but that understanding could be wrong. | |
| Possible split options include: | |
| * InputContext for both distributed & worker processes (current) | |
| * InputContext for distributed and sub-splits for worker processes | |
| * sub-splits for both | |
| """ | |
| # split_size = self.num_samples // num_workers | |
| # start = worker_id * split_size | |
| # if worker_id == num_workers - 1: | |
| # split = split + '[{}:]'.format(start) | |
| # else: | |
| # split = split + '[{}:{}]'.format(start, start + split_size) | |
| if not self.is_training and '[' not in self.split: | |
| # If not training, and split doesn't define a subsplit, manually split the dataset | |
| # for more even samples / worker | |
| self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[ | |
| self.dist_rank * num_workers + worker_id] | |
| if self.subsplit is None: | |
| input_context = tf.distribute.InputContext( | |
| num_input_pipelines=self.dist_num_replicas * num_workers, | |
| input_pipeline_id=self.dist_rank * num_workers + worker_id, | |
| num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? | |
| ) | |
| else: | |
| input_context = None | |
| read_config = tfds.ReadConfig( | |
| shuffle_seed=42, | |
| shuffle_reshuffle_each_iteration=True, | |
| input_context=input_context) | |
| ds = self.builder.as_dataset( | |
| split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) | |
| # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers | |
| options = tf.data.Options() | |
| options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) | |
| options.experimental_threading.max_intra_op_parallelism = 1 | |
| ds = ds.with_options(options) | |
| if self.is_training or self.repeats > 1: | |
| # to prevent excessive drop_last batch behaviour w/ IterableDatasets | |
| # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading | |
| ds = ds.repeat() # allow wrap around and break iteration manually | |
| if self.shuffle: | |
| ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0) | |
| ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) | |
| self.ds = tfds.as_numpy(ds) | |
| def __iter__(self): | |
| if self.ds is None: | |
| self._lazy_init() | |
| # compute a rounded up sample count that is used to: | |
| # 1. make batches even cross workers & replicas in distributed validation. | |
| # This adds extra samples and will slightly alter validation results. | |
| # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size | |
| # batches are produced (underlying tfds iter wraps around) | |
| target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines) | |
| if self.is_training: | |
| # round up to nearest batch_size per worker-replica | |
| target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size | |
| sample_count = 0 | |
| for sample in self.ds: | |
| img = Image.fromarray(sample['image'], mode='RGB') | |
| yield img, sample['label'] | |
| sample_count += 1 | |
| if self.is_training and sample_count >= target_sample_count: | |
| # Need to break out of loop when repeat() is enabled for training w/ oversampling | |
| # this results in extra samples per epoch but seems more desirable than dropping | |
| # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) | |
| break | |
| if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count: | |
| # Validation batch padding only done for distributed training where results are reduced across nodes. | |
| # For single process case, it won't matter if workers return different batch sizes. | |
| # FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this | |
| # approach is not optimal | |
| yield img, sample['label'] # yield prev sample again | |
| sample_count += 1 | |
| def _num_workers(self): | |
| return 1 if self.worker_info is None else self.worker_info.num_workers | |
| def _num_pipelines(self): | |
| return self._num_workers * self.dist_num_replicas | |
| def __len__(self): | |
| # this is just an estimate and does not factor in extra samples added to pad batches based on | |
| # complete worker & replica info (not available until init in dataloader). | |
| return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) | |
| def _filename(self, index, basename=False, absolute=False): | |
| assert False, "Not supported" # no random access to samples | |
| def filenames(self, basename=False, absolute=False): | |
| """ Return all filenames in dataset, overrides base""" | |
| if self.ds is None: | |
| self._lazy_init() | |
| names = [] | |
| for sample in self.ds: | |
| if len(names) > self.num_samples: | |
| break # safety for ds.repeat() case | |
| if 'file_name' in sample: | |
| name = sample['file_name'] | |
| elif 'filename' in sample: | |
| name = sample['filename'] | |
| elif 'id' in sample: | |
| name = sample['id'] | |
| else: | |
| assert False, "No supported name field present" | |
| names.append(name) | |
| return names | |