Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| import torch | |
| import numpy as np | |
| import os, time, random | |
| import argparse | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image as PILImage | |
| from glob import glob | |
| from tqdm import tqdm | |
| import rawpy | |
| import colour_demosaicing | |
| from .InvISP.model.model import InvISPNet | |
| from .utils.common import Notify | |
| from datasets.noise import ( | |
| camera_params, | |
| addGStarNoise, | |
| addPStarNoise, | |
| addQuantNoise, | |
| addRowNoise, | |
| sampleK, | |
| ) | |
| class NoiseSimulator: | |
| def __init__(self, device, ckpt_path="./datasets/InvISP/pretrained/canon.pth"): | |
| self.device = device | |
| # load Invertible ISP Network | |
| self.net = ( | |
| InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval() | |
| ) | |
| self.net.load_state_dict(torch.load(ckpt_path), strict=False) | |
| print( | |
| Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC | |
| ) | |
| # white balance parameters | |
| self.wb = np.array([2020.0, 1024.0, 1458.0, 1024.0]) | |
| # use Canon EOS 5D4 noise parameters provided by ELD | |
| self.camera_params = camera_params | |
| # random specify exposure time ratio from 50 to 150 | |
| self.ratio_min = 50 | |
| self.ratio_max = 150 | |
| pass | |
| # inverse demosaic | |
| # input: [H, W, 3] | |
| # output: [H, W] | |
| def invDemosaic(self, img): | |
| img_R = img[::2, ::2, 0] | |
| img_G1 = img[::2, 1::2, 1] | |
| img_G2 = img[1::2, ::2, 1] | |
| img_B = img[1::2, 1::2, 2] | |
| raw_img = np.ones(img.shape[:2]) | |
| raw_img[::2, ::2] = img_R | |
| raw_img[::2, 1::2] = img_G1 | |
| raw_img[1::2, ::2] = img_G2 | |
| raw_img[1::2, 1::2] = img_B | |
| return raw_img | |
| # demosaic - nearest ver | |
| # input: [H, W] | |
| # output: [H, W, 3] | |
| def demosaicNearest(self, img): | |
| raw = np.ones((img.shape[0], img.shape[1], 3)) | |
| raw[::2, ::2, 0] = img[::2, ::2] | |
| raw[::2, 1::2, 0] = img[::2, ::2] | |
| raw[1::2, ::2, 0] = img[::2, ::2] | |
| raw[1::2, 1::2, 0] = img[::2, ::2] | |
| raw[::2, ::2, 2] = img[1::2, 1::2] | |
| raw[::2, 1::2, 2] = img[1::2, 1::2] | |
| raw[1::2, ::2, 2] = img[1::2, 1::2] | |
| raw[1::2, 1::2, 2] = img[1::2, 1::2] | |
| raw[::2, ::2, 1] = img[::2, 1::2] | |
| raw[::2, 1::2, 1] = img[::2, 1::2] | |
| raw[1::2, ::2, 1] = img[1::2, ::2] | |
| raw[1::2, 1::2, 1] = img[1::2, ::2] | |
| return raw | |
| # demosaic | |
| # input: [H, W] | |
| # output: [H, W, 3] | |
| def demosaic(self, img): | |
| return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, "RGGB") | |
| # load rgb image | |
| def path2rgb(self, path): | |
| return torch.from_numpy(np.array(PILImage.open(path)) / 255.0) | |
| # InvISP | |
| # input: rgb image [H, W, 3] | |
| # output: raw image [H, W] | |
| def rgb2raw(self, rgb, batched=False): | |
| # 1. rgb -> invnet | |
| if not batched: | |
| rgb = rgb.unsqueeze(0) | |
| rgb = rgb.permute(0, 3, 1, 2).float().to(self.device) | |
| with torch.no_grad(): | |
| reconstruct_raw = self.net(rgb, rev=True) | |
| pred_raw = reconstruct_raw.detach().permute(0, 2, 3, 1) | |
| pred_raw = torch.clamp(pred_raw, 0, 1) | |
| if not batched: | |
| pred_raw = pred_raw[0, ...] | |
| pred_raw = pred_raw.cpu().numpy() | |
| # 2. -> inv gamma | |
| norm_value = np.power(16383, 1 / 2.2) | |
| pred_raw *= norm_value | |
| pred_raw = np.power(pred_raw, 2.2) | |
| # 3. -> inv white balance | |
| wb = self.wb / self.wb.max() | |
| pred_raw = pred_raw / wb[:-1] | |
| # 4. -> add black level | |
| pred_raw += self.camera_params["black_level"] | |
| # 5. -> inv demosaic | |
| if not batched: | |
| pred_raw = self.invDemosaic(pred_raw) | |
| else: | |
| preds = [] | |
| for i in range(pred_raw.shape[0]): | |
| preds.append(self.invDemosaic(pred_raw[i])) | |
| pred_raw = np.stack(preds, axis=0) | |
| return pred_raw | |
| def raw2noisyRaw(self, raw, ratio_dec=1, batched=False): | |
| if not batched: | |
| ratio = (random.uniform(self.ratio_min, self.ratio_max) - 1) * ratio_dec + 1 | |
| raw = raw.copy() / ratio | |
| K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"]) | |
| q = 1 / ( | |
| self.camera_params["max_value"] - self.camera_params["black_level"] | |
| ) | |
| raw = addPStarNoise(raw, K) | |
| raw = addGStarNoise( | |
| raw, | |
| K, | |
| self.camera_params["G_shape"], | |
| self.camera_params["Profile-1"]["G_scale"], | |
| ) | |
| raw = addRowNoise(raw, K, self.camera_params["Profile-1"]["R_scale"]) | |
| raw = addQuantNoise(raw, q) | |
| raw *= ratio | |
| return raw | |
| else: | |
| raw = raw.copy() | |
| for i in range(raw.shape[0]): | |
| ratio = random.uniform(self.ratio_min, self.ratio_max) | |
| raw[i] /= ratio | |
| K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"]) | |
| q = 1 / ( | |
| self.camera_params["max_value"] - self.camera_params["black_level"] | |
| ) | |
| raw[i] = addPStarNoise(raw[i], K) | |
| raw[i] = addGStarNoise( | |
| raw[i], | |
| K, | |
| self.camera_params["G_shape"], | |
| self.camera_params["Profile-1"]["G_scale"], | |
| ) | |
| raw[i] = addRowNoise( | |
| raw[i], K, self.camera_params["Profile-1"]["R_scale"] | |
| ) | |
| raw[i] = addQuantNoise(raw[i], q) | |
| raw[i] *= ratio | |
| return raw | |
| def raw2rgb(self, raw, batched=False): | |
| # 1. -> demosaic | |
| if not batched: | |
| raw = self.demosaic(raw) | |
| else: | |
| raws = [] | |
| for i in range(raw.shape[0]): | |
| raws.append(self.demosaic(raw[i])) | |
| raw = np.stack(raws, axis=0) | |
| # 2. -> substract black level | |
| raw -= self.camera_params["black_level"] | |
| raw = np.clip( | |
| raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"] | |
| ) | |
| # 3. -> white balance | |
| wb = self.wb / self.wb.max() | |
| raw = raw * wb[:-1] | |
| # 4. -> gamma | |
| norm_value = np.power(16383, 1 / 2.2) | |
| raw = np.power(raw, 1 / 2.2) | |
| raw /= norm_value | |
| # 5. -> ispnet | |
| if not batched: | |
| input_raw_img = ( | |
| torch.Tensor(raw) | |
| .permute(2, 0, 1) | |
| .float() | |
| .to(self.device)[np.newaxis, ...] | |
| ) | |
| else: | |
| input_raw_img = ( | |
| torch.Tensor(raw).permute(0, 3, 1, 2).float().to(self.device) | |
| ) | |
| with torch.no_grad(): | |
| reconstruct_rgb = self.net(input_raw_img) | |
| reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1) | |
| pred_rgb = reconstruct_rgb.detach().permute(0, 2, 3, 1) | |
| if not batched: | |
| pred_rgb = pred_rgb[0, ...] | |
| pred_rgb = pred_rgb.cpu().numpy() | |
| return pred_rgb | |
| def raw2packedRaw(self, raw, batched=False): | |
| # 1. -> substract black level | |
| raw -= self.camera_params["black_level"] | |
| raw = np.clip( | |
| raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"] | |
| ) | |
| raw /= self.camera_params["max_value"] | |
| # 2. pack | |
| if not batched: | |
| im = np.expand_dims(raw, axis=2) | |
| img_shape = im.shape | |
| H = img_shape[0] | |
| W = img_shape[1] | |
| out = np.concatenate( | |
| ( | |
| im[0:H:2, 0:W:2, :], | |
| im[0:H:2, 1:W:2, :], | |
| im[1:H:2, 1:W:2, :], | |
| im[1:H:2, 0:W:2, :], | |
| ), | |
| axis=2, | |
| ) | |
| else: | |
| im = np.expand_dims(raw, axis=3) | |
| img_shape = im.shape | |
| H = img_shape[1] | |
| W = img_shape[2] | |
| out = np.concatenate( | |
| ( | |
| im[:, 0:H:2, 0:W:2, :], | |
| im[:, 0:H:2, 1:W:2, :], | |
| im[:, 1:H:2, 1:W:2, :], | |
| im[:, 1:H:2, 0:W:2, :], | |
| ), | |
| axis=3, | |
| ) | |
| return out | |
| def raw2demosaicRaw(self, raw, batched=False): | |
| # 1. -> demosaic | |
| if not batched: | |
| raw = self.demosaic(raw) | |
| else: | |
| raws = [] | |
| for i in range(raw.shape[0]): | |
| raws.append(self.demosaic(raw[i])) | |
| raw = np.stack(raws, axis=0) | |
| # 2. -> substract black level | |
| raw -= self.camera_params["black_level"] | |
| raw = np.clip( | |
| raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"] | |
| ) | |
| raw /= self.camera_params["max_value"] | |
| return raw | |