Spaces:
Running
Running
| from .mvtec import MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT | |
| from .visa import VISA_CLS_NAMES, VisaDataset, VISA_ROOT | |
| from .mpdd import MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT | |
| from .btad import BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT | |
| from .sdd import SDD_CLS_NAMES, SDDDataset, SDD_ROOT | |
| from .dagm import DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT | |
| from .dtd import DTD_CLS_NAMES,DTDDataset,DTD_ROOT | |
| from .isic import ISIC_CLS_NAMES,ISICDataset,ISIC_ROOT | |
| from .colondb import ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT | |
| from .clinicdb import ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT | |
| from .tn3k import TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT | |
| from .headct import HEADCT_CLS_NAMES,HEADCTDataset,HEADCT_ROOT | |
| from .brain_mri import BrainMRI_CLS_NAMES,BrainMRIDataset,BrainMRI_ROOT | |
| from .br35h import Br35h_CLS_NAMES,Br35hDataset,Br35h_ROOT | |
| from torch.utils.data import ConcatDataset | |
| dataset_dict = { | |
| 'br35h': (Br35h_CLS_NAMES, Br35hDataset, Br35h_ROOT), | |
| 'brain_mri': (BrainMRI_CLS_NAMES, BrainMRIDataset, BrainMRI_ROOT), | |
| 'btad': (BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT), | |
| 'clinicdb': (ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT), | |
| 'colondb': (ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT), | |
| 'dagm': (DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT), | |
| 'dtd': (DTD_CLS_NAMES, DTDDataset, DTD_ROOT), | |
| 'headct': (HEADCT_CLS_NAMES, HEADCTDataset, HEADCT_ROOT), | |
| 'isic': (ISIC_CLS_NAMES, ISICDataset, ISIC_ROOT), | |
| 'mpdd': (MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT), | |
| 'mvtec': (MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT), | |
| 'sdd': (SDD_CLS_NAMES, SDDDataset, SDD_ROOT), | |
| 'tn3k': (TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT), | |
| 'visa': (VISA_CLS_NAMES, VisaDataset, VISA_ROOT), | |
| } | |
| def get_data(dataset_type_list, transform, target_transform, training): | |
| if not isinstance(dataset_type_list, list): | |
| dataset_type_list = [dataset_type_list] | |
| dataset_cls_names_list = [] | |
| dataset_instance_list = [] | |
| dataset_root_list = [] | |
| for dataset_type in dataset_type_list: | |
| if dataset_dict.get(dataset_type, ''): | |
| dataset_cls_names, dataset_instance, dataset_root = dataset_dict[dataset_type] | |
| dataset_instance = dataset_instance( | |
| clsnames=dataset_cls_names, | |
| transform=transform, | |
| target_transform=target_transform, | |
| training=training | |
| ) | |
| dataset_cls_names_list.append(dataset_cls_names) | |
| dataset_instance_list.append(dataset_instance) | |
| dataset_root_list.append(dataset_root) | |
| else: | |
| print(f'Only support {list(dataset_dict.keys())}, but entered {dataset_type}...') | |
| raise NotImplementedError | |
| if len(dataset_type_list) > 1: | |
| dataset_instance = ConcatDataset(dataset_instance_list) | |
| dataset_cls_names = dataset_cls_names_list | |
| dataset_root = dataset_root_list | |
| else: | |
| dataset_instance = dataset_instance_list[0] | |
| dataset_cls_names = dataset_cls_names_list[0] | |
| dataset_root = dataset_root_list[0] | |
| return dataset_cls_names, dataset_instance, dataset_root |