| from torch.utils.data import Sampler | |
| import numpy as np | |
| class ConcatDatasetBatchSampler(Sampler): | |
| """This sampler is built to work with a standard Pytorch ConcatDataset. | |
| From SpeechBrain dataio see https://github.com/speechbrain/ | |
| It is used to retrieve elements from the different concatenated datasets placing them in the same batch | |
| with proportion specified by batch_sizes, e.g 8, 16 means each batch will | |
| be of 24 elements with the first 8 belonging to the first dataset in ConcatDataset | |
| object and the last 16 to the second. | |
| More than two datasets are supported, in that case you need to provide 3 batch | |
| sizes. | |
| Note | |
| ---- | |
| Batched are drawn from the datasets till the one with smallest length is exhausted. | |
| Thus number of examples in your training epoch is dictated by the dataset | |
| whose length is the smallest. | |
| Arguments | |
| --------- | |
| samplers : int | |
| The base seed to use for the random number generator. It is recommended | |
| to use a value which has a good mix of 0 and 1 bits. | |
| batch_sizes: list | |
| Batch sizes. | |
| epoch : int | |
| The epoch to start at. | |
| """ | |
| def __init__(self, samplers, batch_sizes: (tuple, list), epoch=0) -> None: | |
| if not isinstance(samplers, (list, tuple)): | |
| raise ValueError( | |
| "samplers should be a list or tuple of Pytorch Samplers, " | |
| "but got samplers={}".format(batch_sizes) | |
| ) | |
| if not isinstance(batch_sizes, (list, tuple)): | |
| raise ValueError( | |
| "batch_sizes should be a list or tuple of integers, " | |
| "but got batch_sizes={}".format(batch_sizes) | |
| ) | |
| if not len(batch_sizes) == len(samplers): | |
| raise ValueError("batch_sizes and samplers should be have same length") | |
| self.batch_sizes = batch_sizes | |
| self.samplers = samplers | |
| self.offsets = [0] + np.cumsum([len(x) for x in self.samplers]).tolist()[:-1] | |
| self.epoch = epoch | |
| self.set_epoch(self.epoch) | |
| def _iter_one_dataset(self, c_batch_size, c_sampler, c_offset): | |
| batch = [] | |
| for idx in c_sampler: | |
| batch.append(c_offset + idx) | |
| if len(batch) == c_batch_size: | |
| yield batch | |
| def set_epoch(self, epoch): | |
| if hasattr(self.samplers[0], "epoch"): | |
| for s in self.samplers: | |
| s.set_epoch(epoch) | |
| def __iter__(self): | |
| iterators = [iter(i) for i in self.samplers] | |
| tot_batch = [] | |
| for b_num in range(len(self)): | |
| for samp_idx in range(len(self.samplers)): | |
| c_batch = [] | |
| while len(c_batch) < self.batch_sizes[samp_idx]: | |
| c_batch.append(self.offsets[samp_idx] + next(iterators[samp_idx])) | |
| tot_batch.extend(c_batch) | |
| yield tot_batch | |
| tot_batch = [] | |
| def __len__(self): | |
| min_len = float("inf") | |
| for idx, sampler in enumerate(self.samplers): | |
| c_len = (len(sampler)) // self.batch_sizes[idx] | |
| min_len = min(c_len, min_len) | |
| return min_len | |