Spaces:
Runtime error
Runtime error
| from __future__ import print_function, division | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| import torch | |
| class BaseDataset(Dataset): | |
| def __init__(self, opt): | |
| self.crop_size = 512 | |
| self.debug_mode = opt.debug_mode | |
| self.data_path = opt.data_path # dataset path. e.g., ./data/ | |
| self.camera_name = opt.camera | |
| self.gamma = opt.gamma | |
| def norm_img(self, img, max_value): | |
| img = img / float(max_value) | |
| return img | |
| def pack_raw(self, raw): | |
| # pack Bayer image to 4 channels | |
| im = np.expand_dims(raw, axis=2) | |
| H, W = raw.shape[0], raw.shape[1] | |
| # RGBG | |
| out = np.concatenate( | |
| ( | |
| im[0:H:2, 0:W:2, :], | |
| im[0:H:2, 1:W:2, :], | |
| im[1:H:2, 1:W:2, :], | |
| im[1:H:2, 0:W:2, :], | |
| ), | |
| axis=2, | |
| ) | |
| return out | |
| def np2tensor(self, array): | |
| return torch.Tensor(array).permute(2, 0, 1) | |
| def center_crop(self, img, crop_size=None): | |
| H = img.shape[0] | |
| W = img.shape[1] | |
| if crop_size is not None: | |
| th, tw = crop_size[0], crop_size[1] | |
| else: | |
| th, tw = self.crop_size, self.crop_size | |
| x1_img = int(round((W - tw) / 2.0)) | |
| y1_img = int(round((H - th) / 2.0)) | |
| if img.ndim == 3: | |
| input_patch = img[y1_img : y1_img + th, x1_img : x1_img + tw, :] | |
| else: | |
| input_patch = img[y1_img : y1_img + th, x1_img : x1_img + tw] | |
| return input_patch | |
| def load(self, is_train=True): | |
| # ./data | |
| # ./data/NIKON D700/RAW, ./data/NIKON D700/RGB | |
| # ./data/Canon EOS 5D/RAW, ./data/Canon EOS 5D/RGB | |
| # ./data/NIKON D700_train.txt, ./data/NIKON D700_test.txt | |
| # ./data/NIKON D700_train.txt: a0016, ... | |
| input_RAWs_WBs = [] | |
| target_RGBs = [] | |
| data_path = self.data_path # ./data/ | |
| if is_train: | |
| txt_path = data_path + self.camera_name + "_train.txt" | |
| else: | |
| txt_path = data_path + self.camera_name + "_test.txt" | |
| with open(txt_path, "r") as f_read: | |
| # valid_camera_list = [os.path.basename(line.strip()).split('.')[0] for line in f_read.readlines()] | |
| valid_camera_list = [line.strip() for line in f_read.readlines()] | |
| if self.debug_mode: | |
| valid_camera_list = valid_camera_list[:10] | |
| for i, name in enumerate(valid_camera_list): | |
| full_name = data_path + self.camera_name | |
| input_RAWs_WBs.append(full_name + "/RAW/" + name + ".npz") | |
| target_RGBs.append(full_name + "/RGB/" + name + ".jpg") | |
| return input_RAWs_WBs, target_RGBs | |
| def __len__(self): | |
| return 0 | |
| def __getitem__(self, idx): | |
| return None | |