Spaces:
Running
Running
| import importlib | |
| from typing import Type | |
| import torch | |
| from torch.utils.data import TensorDataset | |
| from torch.utils.data.dataloader import DataLoader | |
| from .datasets.ab_dataset import ABDataset | |
| from .datasets import * # import all datasets | |
| from .datasets.registery import static_dataset_registery | |
| def get_dataset(dataset_name, root_dir, split, transform=None, ignore_classes=[], idx_map=None) -> ABDataset: | |
| dataset_cls = static_dataset_registery[dataset_name][0] | |
| dataset = dataset_cls(root_dir, split, transform, ignore_classes, idx_map) | |
| return dataset | |
| def get_num_limited_dataset(dataset: ABDataset, num_samples: int, discard_label=True): | |
| dataloader = iter(DataLoader(dataset, num_samples // 2, shuffle=True)) | |
| x, y = [], [] | |
| cur_num_samples = 0 | |
| while True: | |
| batch = next(dataloader) | |
| cur_x, cur_y = batch[0], batch[1] | |
| x += [cur_x] | |
| y += [cur_y] | |
| cur_num_samples += cur_x.size(0) | |
| if cur_num_samples >= num_samples: | |
| break | |
| x, y = torch.cat(x)[0: num_samples], torch.cat(y)[0: num_samples] | |
| if discard_label: | |
| new_dataset = TensorDataset(x) | |
| else: | |
| new_dataset = TensorDataset(x, y) | |
| dataset.dataset = new_dataset | |
| return dataset | |