Spaces:
Runtime error
Runtime error
| from __future__ import print_function, division | |
| import os, random, time | |
| import torch | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms, utils | |
| import rawpy | |
| from glob import glob | |
| from PIL import Image as PILImage | |
| import numbers | |
| from scipy.misc import imread | |
| from .base_dataset import BaseDataset | |
| class FiveKDatasetTrain(BaseDataset): | |
| def __init__(self, opt): | |
| super().__init__(opt=opt) | |
| self.patch_size = 256 | |
| input_RAWs_WBs, target_RGBs = self.load(is_train=True) | |
| assert len(input_RAWs_WBs) == len(target_RGBs) | |
| self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs} | |
| def random_flip(self, input_raw, target_rgb): | |
| idx = np.random.randint(2) | |
| input_raw = np.flip(input_raw, axis=idx).copy() | |
| target_rgb = np.flip(target_rgb, axis=idx).copy() | |
| return input_raw, target_rgb | |
| def random_rotate(self, input_raw, target_rgb): | |
| idx = np.random.randint(4) | |
| input_raw = np.rot90(input_raw, k=idx) | |
| target_rgb = np.rot90(target_rgb, k=idx) | |
| return input_raw, target_rgb | |
| def random_crop(self, patch_size, input_raw, target_rgb, flow=False, demos=False): | |
| H, W, _ = input_raw.shape | |
| rnd_h = random.randint(0, max(0, H - patch_size)) | |
| rnd_w = random.randint(0, max(0, W - patch_size)) | |
| patch_input_raw = input_raw[ | |
| rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, : | |
| ] | |
| if flow or demos: | |
| patch_target_rgb = target_rgb[ | |
| rnd_h : rnd_h + patch_size, rnd_w : rnd_w + patch_size, : | |
| ] | |
| else: | |
| patch_target_rgb = target_rgb[ | |
| rnd_h * 2 : rnd_h * 2 + patch_size * 2, | |
| rnd_w * 2 : rnd_w * 2 + patch_size * 2, | |
| :, | |
| ] | |
| return patch_input_raw, patch_target_rgb | |
| def aug(self, patch_size, input_raw, target_rgb, flow=False, demos=False): | |
| input_raw, target_rgb = self.random_crop( | |
| patch_size, input_raw, target_rgb, flow=flow, demos=demos | |
| ) | |
| input_raw, target_rgb = self.random_rotate(input_raw, target_rgb) | |
| input_raw, target_rgb = self.random_flip(input_raw, target_rgb) | |
| return input_raw, target_rgb | |
| def __len__(self): | |
| return len(self.data["input_RAWs_WBs"]) | |
| def __getitem__(self, idx): | |
| input_raw_wb_path = self.data["input_RAWs_WBs"][idx] | |
| target_rgb_path = self.data["target_RGBs"][idx] | |
| target_rgb_img = imread(target_rgb_path) | |
| input_raw_wb = np.load(input_raw_wb_path) | |
| input_raw_img = input_raw_wb["raw"] | |
| wb = input_raw_wb["wb"] | |
| wb = wb / wb.max() | |
| input_raw_img = input_raw_img * wb[:-1] | |
| self.patch_size = 256 | |
| input_raw_img, target_rgb_img = self.aug( | |
| self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True | |
| ) | |
| if self.gamma: | |
| norm_value = ( | |
| np.power(4095, 1 / 2.2) | |
| if self.camera_name == "Canon_EOS_5D" | |
| else np.power(16383, 1 / 2.2) | |
| ) | |
| input_raw_img = np.power(input_raw_img, 1 / 2.2) | |
| else: | |
| norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383 | |
| target_rgb_img = self.norm_img(target_rgb_img, max_value=255) | |
| input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) | |
| target_raw_img = input_raw_img.copy() | |
| input_raw_img = self.np2tensor(input_raw_img).float() | |
| target_rgb_img = self.np2tensor(target_rgb_img).float() | |
| target_raw_img = self.np2tensor(target_raw_img).float() | |
| sample = { | |
| "input_raw": input_raw_img, | |
| "target_rgb": target_rgb_img, | |
| "target_raw": target_raw_img, | |
| "file_name": input_raw_wb_path.split("/")[-1].split(".")[0], | |
| } | |
| return sample | |
| class FiveKDatasetTest(BaseDataset): | |
| def __init__(self, opt): | |
| super().__init__(opt=opt) | |
| self.patch_size = 256 | |
| input_RAWs_WBs, target_RGBs = self.load(is_train=False) | |
| assert len(input_RAWs_WBs) == len(target_RGBs) | |
| self.data = {"input_RAWs_WBs": input_RAWs_WBs, "target_RGBs": target_RGBs} | |
| def __len__(self): | |
| return len(self.data["input_RAWs_WBs"]) | |
| def __getitem__(self, idx): | |
| input_raw_wb_path = self.data["input_RAWs_WBs"][idx] | |
| target_rgb_path = self.data["target_RGBs"][idx] | |
| target_rgb_img = imread(target_rgb_path) | |
| input_raw_wb = np.load(input_raw_wb_path) | |
| input_raw_img = input_raw_wb["raw"] | |
| wb = input_raw_wb["wb"] | |
| wb = wb / wb.max() | |
| input_raw_img = input_raw_img * wb[:-1] | |
| if self.gamma: | |
| norm_value = ( | |
| np.power(4095, 1 / 2.2) | |
| if self.camera_name == "Canon_EOS_5D" | |
| else np.power(16383, 1 / 2.2) | |
| ) | |
| input_raw_img = np.power(input_raw_img, 1 / 2.2) | |
| else: | |
| norm_value = 4095 if self.camera_name == "Canon_EOS_5D" else 16383 | |
| target_rgb_img = self.norm_img(target_rgb_img, max_value=255) | |
| input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) | |
| target_raw_img = input_raw_img.copy() | |
| input_raw_img = self.np2tensor(input_raw_img).float() | |
| target_rgb_img = self.np2tensor(target_rgb_img).float() | |
| target_raw_img = self.np2tensor(target_raw_img).float() | |
| sample = { | |
| "input_raw": input_raw_img, | |
| "target_rgb": target_rgb_img, | |
| "target_raw": target_raw_img, | |
| "file_name": input_raw_wb_path.split("/")[-1].split(".")[0], | |
| } | |
| return sample | |