| from importlib import import_module | |
| from torch.utils.data import dataloader | |
| class Data: | |
| def __init__(self, args): | |
| self.loader_test = [] | |
| for d in args.data_test: | |
| if d in ['Set5', 'Set14', 'B100', 'Urban100']: | |
| m = import_module('data.benchmark') | |
| testset = getattr(m, 'Benchmark')(args, name=d) | |
| else: | |
| raise NotImplementedError | |
| self.loader_test.append( | |
| dataloader.DataLoader( | |
| testset, | |
| batch_size=1, | |
| shuffle=False, | |
| pin_memory=False, | |
| num_workers=args.n_threads, | |
| ) | |
| ) | |