Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import bisect | |
| from torch.utils.data.dataset import ConcatDataset as _ConcatDataset | |
| class ConcatDataset(_ConcatDataset): | |
| """ | |
| Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra | |
| method for querying the sizes of the image | |
| """ | |
| def get_idxs(self, idx): | |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
| if dataset_idx == 0: | |
| sample_idx = idx | |
| else: | |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
| return dataset_idx, sample_idx | |
| def get_img_info(self, idx): | |
| dataset_idx, sample_idx = self.get_idxs(idx) | |
| return self.datasets[dataset_idx].get_img_info(sample_idx) | |