Spaces:
Runtime error
Runtime error
| import bisect | |
| import warnings | |
| from torch._utils import _accumulate | |
| from torch import randperm | |
| class Dataset(object): | |
| """An abstract class representing a Dataset. | |
| All other datasets should subclass it. All subclasses should override | |
| ``__len__``, that provides the size of the dataset, and ``__getitem__``, | |
| supporting integer indexing in range from 0 to len(self) exclusive. | |
| """ | |
| def __getitem__(self, index): | |
| raise NotImplementedError | |
| def __len__(self): | |
| raise NotImplementedError | |
| def __add__(self, other): | |
| return ConcatDataset([self, other]) | |
| class TensorDataset(Dataset): | |
| """Dataset wrapping data and target tensors. | |
| Each sample will be retrieved by indexing both tensors along the first | |
| dimension. | |
| Arguments: | |
| data_tensor (Tensor): contains sample data. | |
| target_tensor (Tensor): contains sample targets (labels). | |
| """ | |
| def __init__(self, data_tensor, target_tensor): | |
| assert data_tensor.size(0) == target_tensor.size(0) | |
| self.data_tensor = data_tensor | |
| self.target_tensor = target_tensor | |
| def __getitem__(self, index): | |
| return self.data_tensor[index], self.target_tensor[index] | |
| def __len__(self): | |
| return self.data_tensor.size(0) | |
| class ConcatDataset(Dataset): | |
| """ | |
| Dataset to concatenate multiple datasets. | |
| Purpose: useful to assemble different existing datasets, possibly | |
| large-scale datasets as the concatenation operation is done in an | |
| on-the-fly manner. | |
| Arguments: | |
| datasets (iterable): List of datasets to be concatenated | |
| """ | |
| def cumsum(sequence): | |
| r, s = [], 0 | |
| for e in sequence: | |
| l = len(e) | |
| r.append(l + s) | |
| s += l | |
| return r | |
| def __init__(self, datasets): | |
| super(ConcatDataset, self).__init__() | |
| assert len(datasets) > 0, 'datasets should not be an empty iterable' | |
| self.datasets = list(datasets) | |
| self.cumulative_sizes = self.cumsum(self.datasets) | |
| def __len__(self): | |
| return self.cumulative_sizes[-1] | |
| def __getitem__(self, idx): | |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
| if dataset_idx == 0: | |
| sample_idx = idx | |
| else: | |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
| return self.datasets[dataset_idx][sample_idx] | |
| def cummulative_sizes(self): | |
| warnings.warn("cummulative_sizes attribute is renamed to " | |
| "cumulative_sizes", DeprecationWarning, stacklevel=2) | |
| return self.cumulative_sizes | |
| class Subset(Dataset): | |
| def __init__(self, dataset, indices): | |
| self.dataset = dataset | |
| self.indices = indices | |
| def __getitem__(self, idx): | |
| return self.dataset[self.indices[idx]] | |
| def __len__(self): | |
| return len(self.indices) | |
| def random_split(dataset, lengths): | |
| """ | |
| Randomly split a dataset into non-overlapping new datasets of given lengths | |
| ds | |
| Arguments: | |
| dataset (Dataset): Dataset to be split | |
| lengths (iterable): lengths of splits to be produced | |
| """ | |
| if sum(lengths) != len(dataset): | |
| raise ValueError("Sum of input lengths does not equal the length of the input dataset!") | |
| indices = randperm(sum(lengths)) | |
| return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] | |