Spaces:
Runtime error
Runtime error
| import cv2 | |
| import os | |
| import glob | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| import random | |
| import numpy as np | |
| import math | |
| from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels | |
| from basicsr.data.transforms import augment | |
| from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor | |
| from basicsr.utils.registry import DATASET_REGISTRY | |
| from PIL import Image | |
| class DAPEDataset(Dataset): | |
| def __init__(self, opt, image_size=384): | |
| self.opt = opt | |
| self.root = opt['root'] | |
| exts = opt['ext'] | |
| gt_lists = [] | |
| lr_lists = [] | |
| for idx_dir, root_dir in enumerate(self.root): | |
| gt_path = os.path.join(root_dir, 'gt') | |
| lr_path = os.path.join(root_dir, 'sr_bicubic') | |
| print(f'gt_path: {gt_path}') | |
| for ext in exts: | |
| gt_list = glob.glob(os.path.join(gt_path, ext)) | |
| lr_list = glob.glob(os.path.join(lr_path, ext)) | |
| gt_lists += gt_list | |
| lr_lists += lr_list | |
| self.lr_lists = lr_lists | |
| self.gt_lists = gt_lists | |
| assert len(self.gt_lists) == len(self.lr_lists) | |
| print(f'=========================Dataset Length {len(self.gt_lists)}=========================') | |
| self.img_preproc = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((512, 512)), | |
| ]) | |
| self.ram_preproc = transforms.Compose([ | |
| transforms.Resize((384, 384)), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def __getitem__(self, index): | |
| gt_image = Image.open(self.gt_lists[index]).convert('RGB') | |
| lr_image = Image.open(self.lr_lists[index]).convert('RGB') | |
| lr_image, gt_image = self.img_preproc(lr_image), self.img_preproc(gt_image) | |
| lr_image_ram, gt_image_ram = self.ram_preproc(lr_image), self.ram_preproc(gt_image) | |
| return_d = {'gt': gt_image, 'lq': lr_image, 'gt_ram': gt_image_ram, 'lq_ram': lr_image_ram, 'lq_path':self.lr_lists[index]} | |
| return return_d | |
| def __len__(self): | |
| return len(self.gt_lists) | |