Spaces:
Sleeping
Sleeping
| # -- coding: utf-8 -- | |
| import importlib | |
| import torch.utils.data | |
| from data.base_data_loader import BaseDataLoader | |
| from data.base_dataset import BaseDataset | |
| def find_dataset_using_name(dataset_name): | |
| # Given the option --dataset_mode [datasetname], | |
| # the file "data/datasetname_dataset.py" | |
| # will be imported. | |
| dataset_filename = "data." + dataset_name + "_dataset" | |
| datasetlib = importlib.import_module(dataset_filename) | |
| # In the file, the class called DatasetNameDataset() will | |
| # be instantiated. It has to be a subclass of BaseDataset, | |
| # and it is case-insensitive. | |
| dataset = None | |
| target_dataset_name = dataset_name.replace('_', '') + 'dataset' | |
| for name, cls in datasetlib.__dict__.items(): | |
| if name.lower() == target_dataset_name.lower() \ | |
| and issubclass(cls, BaseDataset): | |
| dataset = cls | |
| if dataset is None: | |
| print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) | |
| exit(0) | |
| return dataset | |
| def get_option_setter(dataset_name): | |
| dataset_class = find_dataset_using_name(dataset_name) | |
| return dataset_class.modify_commandline_options | |
| def create_dataset(opt): | |
| dataset = find_dataset_using_name(opt.dataset_mode) | |
| instance = dataset() | |
| instance.initialize(opt) | |
| print("dataset [%s] was created" % (instance.name())) | |
| return instance | |
| def CreateDataLoader(opt): | |
| data_loader = CustomDatasetDataLoader() | |
| data_loader.initialize(opt) | |
| return data_loader | |
| # Wrapper class of Dataset class that performs | |
| # multi-threaded data loading | |
| class CustomDatasetDataLoader(BaseDataLoader): | |
| def name(self): | |
| return 'CustomDatasetDataLoader' | |
| def initialize(self, opt): | |
| BaseDataLoader.initialize(self, opt) | |
| self.dataset = create_dataset(opt) | |
| self.dataloader = torch.utils.data.DataLoader( | |
| self.dataset, | |
| batch_size=opt.batchSize, | |
| shuffle=not opt.serial_batches, | |
| num_workers=int(opt.nThreads)) | |
| # DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False) | |
| # 加载的数据集,DataSet对象,是否打乱,样本抽样,使用多线程加载的进程数,0表示不使用多线程,如何将多样本数据拼接成一个batch,是否将数据保存到pin memory,dataset种数据可数可能不是\ | |
| # 一个batch_size的整数倍,drop_last 为True将多出来不足一个batch的数据丢弃 | |
| def load_data(self): | |
| return self | |
| def __len__(self): | |
| return min(len(self.dataset), self.opt.max_dataset_size) | |
| def __iter__(self): | |
| for i, data in enumerate(self.dataloader): | |
| if i * self.opt.batchSize >= self.opt.max_dataset_size: | |
| break | |
| yield data | |