Spaces:
Runtime error
Runtime error
| import cv2 | |
| import os | |
| import glob | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| import random | |
| import numpy as np | |
| import math | |
| from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels | |
| from basicsr.data.transforms import augment | |
| from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor | |
| from PIL import Image | |
| class SimpleDataset(Dataset): | |
| def __init__(self, opt, fix_size=512): | |
| self.opt = opt | |
| self.image_root = opt['gt_path'] | |
| self.fix_size = fix_size | |
| exts = ['*.jpg', '*.png'] | |
| self.image_list = [] | |
| for image_root in self.image_root: | |
| for ext in exts: | |
| image_list = glob.glob(os.path.join(image_root, ext)) | |
| self.image_list += image_list | |
| # if add lsdir dataset | |
| image_list = glob.glob(os.path.join(image_root, '00*', ext)) | |
| self.image_list += image_list | |
| self.crop_preproc = transforms.Compose([ | |
| # transforms.CenterCrop(fix_size), | |
| transforms.Resize(fix_size) | |
| # transforms.RandomHorizontalFlip(), | |
| ]) | |
| self.img_preproc = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| # blur settings for the first degradation | |
| self.blur_kernel_size = opt['blur_kernel_size'] | |
| self.kernel_list = opt['kernel_list'] | |
| self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability | |
| self.blur_sigma = opt['blur_sigma'] | |
| self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels | |
| self.betap_range = opt['betap_range'] # betap used in plateau blur kernels | |
| self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters | |
| # blur settings for the second degradation | |
| self.blur_kernel_size2 = opt['blur_kernel_size2'] | |
| self.kernel_list2 = opt['kernel_list2'] | |
| self.kernel_prob2 = opt['kernel_prob2'] | |
| self.blur_sigma2 = opt['blur_sigma2'] | |
| self.betag_range2 = opt['betag_range2'] | |
| self.betap_range2 = opt['betap_range2'] | |
| self.sinc_prob2 = opt['sinc_prob2'] | |
| # a final sinc filter | |
| self.final_sinc_prob = opt['final_sinc_prob'] | |
| self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 | |
| # TODO: kernel range is now hard-coded, should be in the configure file | |
| self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect | |
| self.pulse_tensor[10, 10] = 1 | |
| print(f'The dataset length: {len(self.image_list)}') | |
| def __getitem__(self, index): | |
| image = Image.open(self.image_list[index]).convert('RGB') | |
| # width, height = image.size | |
| # if width > height: | |
| # width_after = self.fix_size | |
| # height_after = int(height*width_after/width) | |
| # elif height > width: | |
| # height_after = self.fix_size | |
| # width_after = int(width*height_after/height) | |
| # elif height == width: | |
| # height_after = self.fix_size | |
| # width_after = self.fix_size | |
| image = image.resize((self.fix_size, self.fix_size),Image.LANCZOS) | |
| # image = self.crop_preproc(image) | |
| image = self.img_preproc(image) | |
| # ------------------------ Generate kernels (used in the first degradation) ------------------------ # | |
| kernel_size = random.choice(self.kernel_range) | |
| if np.random.uniform() < self.opt['sinc_prob']: | |
| # this sinc filter setting is for kernels ranging from [7, 21] | |
| if kernel_size < 13: | |
| omega_c = np.random.uniform(np.pi / 3, np.pi) | |
| else: | |
| omega_c = np.random.uniform(np.pi / 5, np.pi) | |
| kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) | |
| else: | |
| kernel = random_mixed_kernels( | |
| self.kernel_list, | |
| self.kernel_prob, | |
| kernel_size, | |
| self.blur_sigma, | |
| self.blur_sigma, [-math.pi, math.pi], | |
| self.betag_range, | |
| self.betap_range, | |
| noise_range=None) | |
| # pad kernel | |
| pad_size = (21 - kernel_size) // 2 | |
| kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) | |
| # ------------------------ Generate kernels (used in the second degradation) ------------------------ # | |
| kernel_size = random.choice(self.kernel_range) | |
| if np.random.uniform() < self.opt['sinc_prob2']: | |
| if kernel_size < 13: | |
| omega_c = np.random.uniform(np.pi / 3, np.pi) | |
| else: | |
| omega_c = np.random.uniform(np.pi / 5, np.pi) | |
| kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) | |
| else: | |
| kernel2 = random_mixed_kernels( | |
| self.kernel_list2, | |
| self.kernel_prob2, | |
| kernel_size, | |
| self.blur_sigma2, | |
| self.blur_sigma2, [-math.pi, math.pi], | |
| self.betag_range2, | |
| self.betap_range2, | |
| noise_range=None) | |
| # pad kernel | |
| pad_size = (21 - kernel_size) // 2 | |
| kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) | |
| # ------------------------------------- the final sinc kernel ------------------------------------- # | |
| if np.random.uniform() < self.opt['final_sinc_prob']: | |
| kernel_size = random.choice(self.kernel_range) | |
| omega_c = np.random.uniform(np.pi / 3, np.pi) | |
| sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) | |
| sinc_kernel = torch.FloatTensor(sinc_kernel) | |
| else: | |
| sinc_kernel = self.pulse_tensor | |
| # BGR to RGB, HWC to CHW, numpy to tensor | |
| # img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] | |
| kernel = torch.FloatTensor(kernel) | |
| kernel2 = torch.FloatTensor(kernel2) | |
| return_d = {'gt': image, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'lq_path': self.image_list[index]} | |
| return return_d | |
| def __len__(self): | |
| return len(self.image_list) | |