Spaces:
Runtime error
Runtime error
| import math | |
| from typing import TypeVar, Optional, Iterator | |
| import torch | |
| from torch.utils.data import Sampler, Dataset | |
| import torch.distributed as dist | |
| import random | |
| import numpy as np | |
| import torch | |
| class DistributedSamplerChunkByNode(torch.utils.data.Sampler): | |
| def __init__(self, | |
| dataset, | |
| all_datasets, | |
| chunk_or_not, | |
| num_replicas: Optional[int] = None, | |
| rank: Optional[int] = None, | |
| shuffle: bool = True, | |
| seed: int = 0, | |
| drop_last: bool = False, | |
| node_rank=0, | |
| node_number=1, process_num_per_node=1, | |
| rank_within_local_node=0) -> None: | |
| if num_replicas is None: | |
| if not dist.is_available(): | |
| raise RuntimeError("Requires distributed package to be available") | |
| num_replicas = dist.get_world_size() | |
| if rank is None: | |
| if not dist.is_available(): | |
| raise RuntimeError("Requires distributed package to be available") | |
| rank = dist.get_rank() | |
| if rank >= num_replicas or rank < 0: | |
| raise ValueError( | |
| "Invalid rank {}, rank should be in the interval" | |
| " [0, {}]".format(rank, num_replicas - 1)) | |
| self.dataset = dataset | |
| self.num_replicas = num_replicas | |
| self.rank = rank | |
| self.epoch = 0 | |
| self.node_number = node_number | |
| self.node_rank = node_rank | |
| self.chunk_or_not = chunk_or_not | |
| self.process_num_per_node = process_num_per_node | |
| self.rank_within_local_node = rank_within_local_node | |
| assert (self.process_num_per_node * self.node_number == self.num_replicas) | |
| # 1. divide the datasets into two parts | |
| normal_datasets = [] | |
| chunked_datasets = [] | |
| for dataset_i, chunk_i in zip(all_datasets, chunk_or_not): | |
| if chunk_i: | |
| chunked_datasets.append(dataset_i) | |
| else: | |
| normal_datasets.append(dataset_i) | |
| # 2. calculate dataset sizes: | |
| self.normal_dataset_size = sum( | |
| [len(i) for i in normal_datasets]) # this part we follow the conventional distributed sampler | |
| # 3. Divide | |
| self.current_node_start_range = -1 | |
| self.current_node_end_range = -1 | |
| assert (len(chunked_datasets) >= self.node_number) | |
| chunk_size = len(chunked_datasets) // self.node_number | |
| current_example_num = self.normal_dataset_size | |
| for index in range(len(chunked_datasets)): | |
| if index == self.node_rank * chunk_size: | |
| self.current_node_start_range = current_example_num | |
| current_example_num += len(chunked_datasets[index]) | |
| if index == (self.node_rank + 1) * chunk_size - 1: | |
| self.current_node_end_range = current_example_num | |
| if self.current_node_end_range == -1: # boundary | |
| self.current_node_end_range = current_example_num | |
| self.drop_last = drop_last | |
| # If the dataset length is evenly divisible by # of replicas, then there | |
| # is no need to drop any data, since the dataset will be split equally. | |
| if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] | |
| # Split to nearest available length that is evenly divisible. | |
| # This is to ensure each rank receives the same amount of data when | |
| # using this Sampler. | |
| self.num_samples = math.ceil( | |
| # `type:ignore` is required because Dataset cannot provide a default __len__ | |
| # see NOTE in pytorch/torch/utils/data/sampler.py | |
| (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] | |
| ) | |
| else: | |
| self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] | |
| self.total_size = self.num_samples * self.num_replicas | |
| self.shuffle = shuffle | |
| self.seed = seed | |
| def __iter__(self): | |
| indices = self.generate_indices_within_range_with_rank( | |
| seed=self.seed, | |
| epoch=self.epoch, | |
| # NOTE: Distribute among all processes | |
| process_num=self.num_replicas, | |
| rank=self.rank, | |
| generate_length=-1, | |
| valid_indices=list(range(self.normal_dataset_size)), | |
| prefix="Normal " | |
| ) | |
| addition_indices = self.generate_indices_within_range_with_rank( | |
| seed=self.seed, | |
| epoch=self.epoch, | |
| # NOTE : very important arguments, distribute among local nodes | |
| process_num=self.process_num_per_node, | |
| rank=self.rank_within_local_node, | |
| generate_length=self.num_samples - len(indices), | |
| valid_indices=list(range(self.current_node_start_range, self.current_node_end_range)), | |
| prefix="Distribute " | |
| ) | |
| indices.extend(addition_indices) | |
| random.seed(self.seed + self.epoch + 10 * self.rank) # Set the seed to maximize randomness | |
| random.shuffle(indices) # Reshuffle | |
| assert len(indices) == self.num_samples | |
| return iter(indices) | |
| def generate_indices_within_range_with_rank(self, seed, epoch, process_num, generate_length, valid_indices, rank=-1, | |
| shuffle=True, prefix=""): | |
| ''' | |
| Use scenario : we want to sample 2500 examples from 10000 examples, while not sampling overlapping examples with other three process. | |
| Modified from DistributedSampler | |
| ''' | |
| dataset_size = len(valid_indices) | |
| if shuffle: | |
| # deterministically shuffle based on epoch and seed | |
| g = torch.Generator() | |
| g.manual_seed(seed + epoch) | |
| indices = torch.randperm(dataset_size, generator=g).tolist() # type: ignore[arg-type] | |
| else: | |
| indices = list(range(dataset_size)) # type: ignore[arg-type] | |
| indices = [valid_indices[i] for i in indices] | |
| num_samples_normal = math.ceil( | |
| (dataset_size - process_num) / process_num # type: ignore[arg-type] | |
| ) | |
| # remove tail of data to make it evenly divisible. | |
| indices = indices[:num_samples_normal * process_num] | |
| print("\n") | |
| print(prefix, | |
| "Global Rank {} Local Rank {} generate_length {} valid_indices {} process_num {} indices_before_subsample {} {}".format( | |
| self.rank, rank, generate_length, len(valid_indices), process_num, len(indices), indices[:10])) | |
| # subsample | |
| indices = indices[rank:num_samples_normal * process_num: process_num] | |
| print(prefix, | |
| "Global Rank {} Local Rank {} generate_length {} valid_indices {} process_num {} indices_after_subsample {} {}".format( | |
| self.rank, rank, generate_length, len(valid_indices), process_num, len(indices), indices[:10])) | |
| print("\n") | |
| if generate_length != -1: | |
| if len(indices) > generate_length: | |
| indices = indices[:generate_length] | |
| else: | |
| indices.extend(np.random.choice(valid_indices, generate_length - len(indices)).tolist()) | |
| return indices | |
| def __len__(self) -> int: | |
| return self.num_samples | |
| def set_epoch(self, epoch: int) -> None: | |
| r""" | |
| Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas | |
| use a different random ordering for each epoch. Otherwise, the next iteration of this | |
| sampler will yield the same ordering. | |
| Args: | |
| epoch (int): Epoch number. | |
| """ | |
| self.epoch = epoch | |