Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| # domainbed/lib/fast_data_loader.py | |
| import torch | |
| from .datasets.ab_dataset import ABDataset | |
| class _InfiniteSampler(torch.utils.data.Sampler): | |
| """Wraps another Sampler to yield an infinite stream.""" | |
| def __init__(self, sampler): | |
| self.sampler = sampler | |
| def __iter__(self): | |
| while True: | |
| for batch in self.sampler: | |
| yield batch | |
| class InfiniteDataLoader: | |
| def __init__(self, dataset, weights, batch_size, num_workers, collate_fn=None): | |
| super().__init__() | |
| if weights: | |
| sampler = torch.utils.data.WeightedRandomSampler( | |
| weights, replacement=True, num_samples=batch_size | |
| ) | |
| else: | |
| sampler = torch.utils.data.RandomSampler(dataset, replacement=True) | |
| batch_sampler = torch.utils.data.BatchSampler( | |
| sampler, batch_size=batch_size, drop_last=True | |
| ) | |
| if collate_fn is not None: | |
| self._infinite_iterator = iter( | |
| torch.utils.data.DataLoader( | |
| dataset, | |
| num_workers=num_workers, | |
| batch_sampler=_InfiniteSampler(batch_sampler), | |
| pin_memory=False, | |
| collate_fn=collate_fn | |
| ) | |
| ) | |
| else: | |
| self._infinite_iterator = iter( | |
| torch.utils.data.DataLoader( | |
| dataset, | |
| num_workers=num_workers, | |
| batch_sampler=_InfiniteSampler(batch_sampler), | |
| pin_memory=False | |
| ) | |
| ) | |
| self.dataset = dataset | |
| def __iter__(self): | |
| while True: | |
| yield next(self._infinite_iterator) | |
| def __len__(self): | |
| raise ValueError | |
| class FastDataLoader: | |
| """ | |
| DataLoader wrapper with slightly improved speed by not respawning worker | |
| processes at every epoch. | |
| """ | |
| def __init__(self, dataset, batch_size, num_workers, shuffle=False, collate_fn=None): | |
| super().__init__() | |
| self.num_workers = num_workers | |
| if shuffle: | |
| sampler = torch.utils.data.RandomSampler(dataset, replacement=False) | |
| else: | |
| sampler = torch.utils.data.SequentialSampler(dataset) | |
| batch_sampler = torch.utils.data.BatchSampler( | |
| sampler, | |
| batch_size=batch_size, | |
| drop_last=False, | |
| ) | |
| if collate_fn is not None: | |
| self._infinite_iterator = iter( | |
| torch.utils.data.DataLoader( | |
| dataset, | |
| num_workers=num_workers, | |
| batch_sampler=_InfiniteSampler(batch_sampler), | |
| pin_memory=False, | |
| collate_fn=collate_fn | |
| ) | |
| ) | |
| else: | |
| self._infinite_iterator = iter( | |
| torch.utils.data.DataLoader( | |
| dataset, | |
| num_workers=num_workers, | |
| batch_sampler=_InfiniteSampler(batch_sampler), | |
| pin_memory=False, | |
| ) | |
| ) | |
| self.dataset = dataset | |
| self.batch_size = batch_size | |
| self._length = len(batch_sampler) | |
| def __iter__(self): | |
| for _ in range(len(self)): | |
| yield next(self._infinite_iterator) | |
| def __len__(self): | |
| return self._length | |
| def build_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool, collate_fn=None): | |
| assert batch_size <= len(dataset), len(dataset) | |
| if infinite: | |
| dataloader = InfiniteDataLoader( | |
| dataset, None, batch_size, num_workers=num_workers, collate_fn=collate_fn) | |
| else: | |
| dataloader = FastDataLoader( | |
| dataset, batch_size, num_workers, shuffle=shuffle_when_finite, collate_fn=collate_fn) | |
| return dataloader | |
| def get_a_batch_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool): | |
| pass | |