Spaces:
Running
on
Zero
Running
on
Zero
| import random | |
| import numpy as np | |
| import torch | |
| def reset_numpy_seed(worker_id): | |
| """ | |
| https://github.com/pytorch/pytorch/issues/5059 | |
| https://discuss.pytorch.org/t/dataloader-multi-threading-random-number/27719 | |
| with torch.initial_seed(), each worker is initialized with this number + worker_id as seed. It also | |
| changes at each epoch | |
| Args: | |
| worker_id: | |
| Returns: | |
| """ | |
| np.random.seed(int(torch.initial_seed()) % (2**32-1)) | |
| random.seed(int(torch.initial_seed()) % (2**32-1)) | |
| class Loader(torch.utils.data.DataLoader): | |
| """ | |
| Data loader. Combines a dataset and a sampler, and provides | |
| single- or multi-process iterators over the dataset. | |
| Note: The only difference with default pytorch DataLoader is that an additional option stack_dim is available to | |
| select along which dimension the data should be stacked to form a batch. | |
| Arguments: | |
| dataset (Dataset): dataset from which to load the data. | |
| batch_size (int, optional): how many samples per batch to load | |
| (default: 1). | |
| shuffle (bool, optional): set to ``True`` to have the data reshuffled | |
| at every epoch (default: False). | |
| sampler (Sampler, optional): defines the strategy to draw samples from | |
| the dataset. If specified, ``shuffle`` must be False. | |
| batch_sampler (Sampler, optional): like sampler, but returns a batch of | |
| indices at a time. Mutually exclusive with batch_size, shuffle, | |
| sampler, and drop_last. | |
| num_workers (int, optional): how many subprocesses to use for data | |
| loading. 0 means that the data will be loaded in the main process. | |
| (default: 0) | |
| collate_fn (callable, optional): merges a list of samples to form a mini-batch. | |
| stack_dim (int): Dimension along which to stack to form the batch. (default: 0) | |
| pin_memory (bool, optional): If ``True``, the data loader will copy tensors | |
| into CUDA pinned memory before returning them. | |
| drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, | |
| if the dataset load_size is not divisible by the batch load_size. If ``False`` and | |
| the load_size of dataset is not divisible by the batch load_size, then the last batch | |
| will be smaller. (default: False) | |
| timeout (numeric, optional): if positive, the timeout value for collecting a batch | |
| from workers. Should always be non-negative. (default: 0) | |
| worker_init_fn (callable, optional): If not None, this will be called on each | |
| worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as | |
| input, after seeding and before data loading. (default: None) | |
| .. note:: By default, each worker will have its PyTorch seed set to | |
| ``base_seed + worker_id``, where ``base_seed`` is a long generated | |
| by main process using its RNG. However, seeds for other libraies | |
| may be duplicated upon initializing workers (w.g., NumPy), causing | |
| each worker to return identical random numbers. (See | |
| :ref:`dataloader-workers-random-seed` section in FAQ.) You may | |
| use ``torch.initial_seed()`` to access the PyTorch seed for each | |
| worker in :attr:`worker_init_fn`, and use it to set other seeds | |
| before data loading. | |
| .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an | |
| unpicklable object, e.g., a lambda function. | |
| """ | |
| __initialized = False | |
| def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, | |
| num_workers=0, epoch_interval=1, collate_fn=None, pin_memory=False, drop_last=True, | |
| timeout=0, worker_init_fn=reset_numpy_seed): | |
| super().__init__(dataset, batch_size, shuffle, drop_last=drop_last, sampler=sampler, | |
| batch_sampler=batch_sampler, num_workers=num_workers, | |
| pin_memory=pin_memory, timeout=timeout, worker_init_fn=worker_init_fn) #collate_fn=collate_fn, | |
| self.name = name | |
| self.training = training | |
| self.epoch_interval = epoch_interval |