Spaces:
Running
Running
| import torch | |
| from torch.utils.data import DistributedSampler | |
| from torch.utils.data import Dataset, Sampler | |
| from torch.utils.data import RandomSampler | |
| from operator import itemgetter | |
| from typing import List, Union, Iterator, Optional | |
| class DatasetFromSampler(Dataset): | |
| """Dataset to create indexes from `Sampler`. From catalyst library. | |
| Args: | |
| sampler: PyTorch sampler | |
| """ | |
| def __init__(self, sampler: Sampler): | |
| """Initialisation for DatasetFromSampler.""" | |
| self.sampler = sampler | |
| self.sampler_list = None | |
| def __getitem__(self, index: int): | |
| """Gets element of the dataset. | |
| Args: | |
| index: index of the element in the dataset | |
| Returns: | |
| Single element by index | |
| """ | |
| if self.sampler_list is None: | |
| self.sampler_list = list(self.sampler) | |
| return self.sampler_list[index] | |
| def __len__(self) -> int: | |
| """ | |
| Returns: | |
| int: length of the dataset | |
| """ | |
| return len(self.sampler) | |
| class DistributedSamplerWrapper(DistributedSampler): | |
| """ | |
| Wrapper over `Sampler` for distributed training. | |
| Allows you to use any sampler in distributed mode. | |
| From https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py | |
| It is especially useful in conjunction with | |
| `torch.nn.parallel.DistributedDataParallel`. In such case, each | |
| process can pass a DistributedSamplerWrapper instance as a DataLoader | |
| sampler, and load a subset of subsampled data of the original dataset | |
| that is exclusive to it. | |
| .. note:: | |
| Sampler is assumed to be of constant size. | |
| """ | |
| def __init__( | |
| self, | |
| sampler, | |
| num_replicas: Optional[int] = None, | |
| rank: Optional[int] = None, | |
| shuffle: bool = True, | |
| ): | |
| """ | |
| Args: | |
| sampler: Sampler used for subsampling | |
| num_replicas (int, optional): Number of processes participating in | |
| distributed training | |
| rank (int, optional): Rank of the current process | |
| within ``num_replicas`` | |
| shuffle (bool, optional): If true (default), | |
| sampler will shuffle the indices | |
| """ | |
| super(DistributedSamplerWrapper, self).__init__( | |
| DatasetFromSampler(sampler), | |
| num_replicas=num_replicas, | |
| rank=rank, | |
| shuffle=shuffle, | |
| ) | |
| self.sampler = sampler | |
| def __iter__(self) -> Iterator[int]: | |
| """Iterate over sampler. | |
| Returns: | |
| python iterator | |
| """ | |
| self.dataset = DatasetFromSampler(self.sampler) | |
| indexes_of_indexes = super().__iter__() | |
| subsampler_indexes = self.dataset | |
| return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) | |
| class UnimaxSampler(Sampler): | |
| # Initialize the sampler with the character counts for each language, | |
| # the total character budget, and the number of epochs per language. | |
| def __init__(self, language_character_counts: List[int], total_character_budget: int, | |
| num_epochs: int) -> None: | |
| self.language_character_counts = torch.tensor(language_character_counts) | |
| self.total_character_budget = total_character_budget | |
| self.num_epochs = num_epochs | |
| # Compute the sampling distribution p. | |
| self.p = self._unimax() | |
| # Define how to iterate over the data. We'll use PyTorch's multinomial | |
| # function to generate indices according to the distribution p. | |
| def __iter__(self) -> iter: | |
| return iter(torch.multinomial(self.p, len(self.p), replacement=True).tolist()) | |
| # Define the length of the sampler as the number of languages. | |
| def __len__(self) -> int: | |
| return len(self.p) | |
| # Implement the UNIMAX algorithm to compute the sampling distribution p. | |
| def _unimax(self) -> torch.Tensor: | |
| # Sort languages by character count. | |
| L, indices = torch.sort(self.language_character_counts) | |
| # Initialize the remaining budget to the total character budget. | |
| B = float(self.total_character_budget) | |
| i = 0 | |
| # Initialize the budget per language. | |
| U = torch.zeros_like(L) | |
| # For each language... | |
| for idx in indices: | |
| # Compute the remaining budget per-language. | |
| bl = B / (len(L) - i) | |
| cl = L[idx] | |
| # If per-language budget exceeds N epochs of the language, use N epochs. | |
| if bl > cl * self.num_epochs: | |
| Ul = cl * self.num_epochs | |
| # Otherwise use uniform per-language budget. | |
| else: | |
| Ul = bl | |
| # Store the computed budget. | |
| U[idx] = Ul | |
| # Update the remaining budget. | |
| B -= Ul | |
| # Move to the next language. | |
| i += 1 | |
| # Normalize the budget to create a distribution. | |
| p = U / U.sum() | |
| # Return the computed distribution. | |
| return p | |
| class DistributedUnimaxSampler(UnimaxSampler): | |
| def __init__(self, | |
| language_character_counts: List[int], | |
| total_character_budget: int, | |
| num_epochs: int, | |
| num_replicas: Optional[int] = None, | |
| rank: Optional[int] = None, | |
| shuffle: bool = True) -> None: | |
| super().__init__(language_character_counts, total_character_budget, num_epochs) | |
| self.distributed_sampler = DistributedSamplerWrapper(self, num_replicas, rank, shuffle) | |
| def __iter__(self): | |
| return iter(self.distributed_sampler) | |
| def __len__(self): | |
| return len(self.distributed_sampler) | |
| def set_epoch(self, epoch): | |
| self.distributed_sampler.set_epoch(epoch) |