Spaces:
Runtime error
Runtime error
| """Original sampling logic of MQTTS. | |
| Copyright PolyAI Limited. | |
| """ | |
| import math | |
| import random | |
| import numpy as np | |
| from torch.utils import data | |
| def StandardSampler(dataset, shuffle, distributed=False, | |
| world_size=None, rank=None): | |
| if distributed: | |
| return data.distributed.DistributedSampler( | |
| dataset, shuffle=shuffle, num_replicas=world_size, rank=rank) | |
| if shuffle: | |
| return data.RandomSampler(dataset) | |
| return data.SequentialSampler(dataset) | |
| def RandomBucketSampler( | |
| nbuckets, length, batch_size, drop_last, distributed=False, | |
| world_size=None, rank=None): | |
| if distributed: | |
| return DistributedRandomBucketSampler( | |
| nbuckets, length, batch_size, drop_last, world_size, rank) | |
| return SingleRandomBucketSampler(nbuckets, length, batch_size, drop_last) | |
| class SingleRandomBucketSampler(data.Sampler): | |
| def __init__(self, nbuckets, length, batch_size, drop_last): | |
| self.length = length | |
| self.batch_size = batch_size | |
| self.drop_last = drop_last | |
| indices = np.argsort([-x for x in length]) | |
| split = len(indices) // nbuckets | |
| self.indices = [] | |
| for i in range(nbuckets): | |
| self.indices.append(indices[i*split:(i+1)*split]) | |
| if nbuckets * split < len(length): | |
| self.indices.append(indices[nbuckets*split:]) | |
| def __iter__(self): | |
| random.shuffle(self.indices) | |
| for x in self.indices: | |
| random.shuffle(x) | |
| idxs = [i for x in self.indices for i in x] | |
| batches, batch, sum_len, max_len = [], [], 0, 0 | |
| for idx in idxs: | |
| batch.append(idx) | |
| sum_len += self.length[idx] | |
| max_len = max(self.length[idx], max_len) | |
| if max_len * len(batch) > self.batch_size: | |
| batches.append(batch[:-1]) | |
| batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa | |
| if len(batch) > 0 and not self.drop_last: | |
| batches.append(batch) | |
| random.shuffle(batches) | |
| return iter(batches) | |
| class DistributedRandomBucketSampler(data.Sampler): | |
| def __init__(self, nbuckets, length, batch_size, | |
| drop_last, num_replicas, rank, seed=1234): | |
| if rank >= num_replicas or rank < 0: | |
| raise ValueError( | |
| "Invalid rank {}, rank should be in the interval" | |
| " [0, {}]".format(rank, num_replicas - 1)) | |
| indices = np.argsort(length) | |
| split = len(indices) // nbuckets | |
| self.length = length | |
| self.batch_size = batch_size | |
| self.drop_last = drop_last | |
| self.indices = [] | |
| for i in range(nbuckets): | |
| self.indices.append(indices[i*split:(i+1)*split]) | |
| if nbuckets * split < len(length): | |
| self.indices.append(indices[nbuckets*split:]) | |
| self.num_replicas = num_replicas | |
| self.rank = rank | |
| self.epoch = 0 | |
| self.seed = seed | |
| def __iter__(self): | |
| # Deterministic shuffling | |
| random.Random(self.epoch + self.seed).shuffle(self.indices) | |
| for i, x in enumerate(self.indices): | |
| seed = self.epoch + self.seed + i * 5 | |
| random.Random(seed).shuffle(x) | |
| indices = [i for x in self.indices for i in x] | |
| # Batching | |
| batches, batch, sum_len, max_len = [], [], 0, 0 | |
| for idx in indices: | |
| batch.append(idx) | |
| sum_len += self.length[idx] | |
| max_len = max(self.length[idx], max_len) | |
| if max_len * len(batch) > self.batch_size: | |
| batches.append(batch[:-1]) | |
| batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa | |
| # Subsample | |
| num_samples = math.ceil( | |
| (len(batches) - self.num_replicas) / self.num_replicas) | |
| total_size = num_samples * self.num_replicas | |
| batches = batches[:total_size] | |
| batches = batches[self.rank*num_samples: (self.rank+1)*num_samples] | |
| assert len(batches) == num_samples | |
| # Stochastic suffling | |
| random.shuffle(batches) | |
| return iter(batches) | |
| def set_epoch(self, epoch): | |
| self.epoch = epoch | |