import os import cv2 import glob import random import timeit import numpy as np import skimage import yaml import torch import torchvision.transforms as transforms import torchvision.transforms.functional as TF from PIL import Image from torch.utils.data import Dataset from torch.distributions import Normal # from utils.utils import RGB2YCbCr class RandomGammaCorrection(object): def __init__(self, gamma=None): self.gamma = gamma def __call__(self, image): if self.gamma == None: # more chances of selecting 0 (original image) gammas = [0.5, 1, 2] self.gamma = random.choice(gammas) return TF.adjust_gamma(image, self.gamma, gain=1) elif isinstance(self.gamma, tuple): gamma = random.uniform(*self.gamma) return TF.adjust_gamma(image, gamma, gain=1) elif self.gamma == 0: return image else: return TF.adjust_gamma(image, self.gamma, gain=1) def remove_background(image): # the input of the image is PIL.Image form with [H,W,C] image = np.float32(np.array(image)) _EPS = 1e-7 rgb_max = np.max(image, (0, 1)) rgb_min = np.min(image, (0, 1)) image = (image - rgb_min) * rgb_max / (rgb_max - rgb_min + _EPS) image = torch.from_numpy(image) return image def glod_from_folder(folder_list, index_list): ext = ["png", "jpeg", "jpg", "bmp", "tif"] index_dict = {} for i, folder_name in enumerate(folder_list): data_list = [] [data_list.extend(glob.glob(folder_name + "/*." + e)) for e in ext] data_list.sort() index_dict[index_list[i]] = data_list return index_dict class Flare_Image_Loader(Dataset): def __init__(self, image_path, transform_base, transform_flare, mask_type=None): self.ext = ["png", "jpeg", "jpg", "bmp", "tif"] self.data_list = [] [self.data_list.extend(glob.glob(image_path + "/*." + e)) for e in self.ext] self.flare_dict = {} self.flare_list = [] self.flare_name_list = [] self.reflective_flag = False self.reflective_dict = {} self.reflective_list = [] self.reflective_name_list = [] self.light_flag = False self.light_dict = {} self.light_list = [] self.light_name_list = [] self.mask_type = ( mask_type # It is a str which may be None,"luminance" or "color" ) self.img_size = transform_base["img_size"] self.transform_base = transforms.Compose( [ transforms.RandomCrop( (self.img_size, self.img_size), pad_if_needed=True, padding_mode="reflect", ), transforms.RandomHorizontalFlip(), # transforms.RandomVerticalFlip(), ] ) self.transform_flare = transforms.Compose( [ transforms.RandomAffine( degrees=(0, 360), scale=(transform_flare["scale_min"], transform_flare["scale_max"]), translate=( transform_flare["translate"] / 1440, transform_flare["translate"] / 1440, ), shear=(-transform_flare["shear"], transform_flare["shear"]), ), transforms.CenterCrop((self.img_size, self.img_size)), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), ] ) self.normalize = transforms.Compose( [ transforms.Normalize([0.5], [0.5]), ] ) self.data_ratio = [] def lightsource_crop(self, matrix): """Find the largest rectangle of 1s in a binary matrix.""" def largestRectangleArea(heights): heights.append(0) stack = [-1] max_area = 0 max_rectangle = (0, 0, 0, 0) # (area, left, right, height) for i in range(len(heights)): while heights[i] < heights[stack[-1]]: h = heights[stack.pop()] w = i - stack[-1] - 1 area = h * w if area > max_area: max_area = area max_rectangle = (area, stack[-1] + 1, i - 1, h) stack.append(i) heights.pop() return max_rectangle max_area = 0 max_rectangle = [0, 0, 0, 0] # (left, right, top, bottom) heights = torch.zeros(matrix.shape[1]) for row in range(matrix.shape[0]): temp = 1 - matrix[row] heights = (heights + temp) * temp area, left, right, height = largestRectangleArea(heights.tolist()) if area > max_area: max_area = area max_rectangle = [int(left), int(right), int(row - height + 1), int(row)] return torch.tensor(max_rectangle) def __getitem__(self, index): # load base image img_path = self.data_list[index] base_img = Image.open(img_path).convert("RGB") gamma = np.random.uniform(1.8, 2.2) to_tensor = transforms.ToTensor() adjust_gamma = RandomGammaCorrection(gamma) adjust_gamma_reverse = RandomGammaCorrection(1 / gamma) color_jitter = transforms.ColorJitter(brightness=(0.8, 3), hue=0.0) if self.transform_base is not None: base_img = to_tensor(base_img) base_img = adjust_gamma(base_img) base_img = self.transform_base(base_img) else: base_img = to_tensor(base_img) base_img = adjust_gamma(base_img) sigma_chi = 0.01 * np.random.chisquare(df=1) base_img = Normal(base_img, sigma_chi).sample() gain = np.random.uniform(0.5, 1.2) flare_DC_offset = np.random.uniform(-0.02, 0.02) base_img = gain * base_img base_img = torch.clamp(base_img, min=0, max=1) choice_dataset = random.choices( [i for i in range(len(self.flare_list))], self.data_ratio )[0] choice_index = random.randint(0, len(self.flare_list[choice_dataset]) - 1) # load flare and light source image if self.light_flag: assert len(self.flare_list) == len( self.light_list ), "Error, number of light source and flares dataset no match!" for i in range(len(self.flare_list)): assert len(self.flare_list[i]) == len( self.light_list[i] ), f"Error, number of light source and flares no match in {i} dataset!" flare_path = self.flare_list[choice_dataset][choice_index] light_path = self.light_list[choice_dataset][choice_index] light_img = Image.open(light_path).convert("RGB") light_img = to_tensor(light_img) light_img = adjust_gamma(light_img) else: flare_path = self.flare_list[choice_dataset][choice_index] flare_img = Image.open(flare_path).convert("RGB") if self.reflective_flag: reflective_path_list = self.reflective_list[choice_dataset] if len(reflective_path_list) != 0: reflective_path = random.choice(reflective_path_list) reflective_img = Image.open(reflective_path).convert("RGB") else: reflective_img = None flare_img = to_tensor(flare_img) flare_img = adjust_gamma(flare_img) if self.reflective_flag and reflective_img is not None: reflective_img = to_tensor(reflective_img) reflective_img = adjust_gamma(reflective_img) flare_img = torch.clamp(flare_img + reflective_img, min=0, max=1) flare_img = remove_background(flare_img) if self.transform_flare is not None: if self.light_flag: flare_merge = torch.cat((flare_img, light_img), dim=0) flare_merge = self.transform_flare(flare_merge) else: flare_img = self.transform_flare(flare_img) # change color if self.light_flag: # flare_merge=color_jitter(flare_merge) flare_img, light_img = torch.split(flare_merge, 3, dim=0) else: flare_img = color_jitter(flare_img) # flare blur blur_transform = transforms.GaussianBlur(21, sigma=(0.1, 3.0)) flare_img = blur_transform(flare_img) # flare_img=flare_img+flare_DC_offset flare_img = torch.clamp(flare_img, min=0, max=1) # merge image merge_img = flare_img + base_img merge_img = torch.clamp(merge_img, min=0, max=1) if self.light_flag: base_img = base_img + light_img base_img = torch.clamp(base_img, min=0, max=1) flare_img = flare_img - light_img flare_img = torch.clamp(flare_img, min=0, max=1) flare_mask = None if self.mask_type == None: return { "gt": adjust_gamma_reverse(base_img), "flare": adjust_gamma_reverse(flare_img), "lq": adjust_gamma_reverse(merge_img), "gamma": gamma, } elif self.mask_type == "luminance": # calculate mask (the mask is 3 channel) one = torch.ones_like(base_img) zero = torch.zeros_like(base_img) luminance = 0.3 * flare_img[0] + 0.59 * flare_img[1] + 0.11 * flare_img[2] threshold_value = 0.99**gamma flare_mask = torch.where(luminance > threshold_value, one, zero) elif self.mask_type == "color": one = torch.ones_like(base_img) zero = torch.zeros_like(base_img) threshold_value = 0.99**gamma flare_mask = torch.where(merge_img > threshold_value, one, zero) elif self.mask_type == "flare": one = torch.ones_like(base_img) zero = torch.zeros_like(base_img) threshold_value = 0.7**gamma flare_mask = torch.where(flare_img > threshold_value, one, zero) elif self.mask_type == "light": # Depreciated: we dont need light mask anymore one = torch.ones_like(base_img) zero = torch.zeros_like(base_img) luminance = 0.3 * light_img[0] + 0.59 * light_img[1] + 0.11 * light_img[2] threshold_value = 0.01 flare_mask = torch.where(luminance > threshold_value, one, zero) light_source_cond = torch.zeros_like(flare_mask[0]) light_source_cond = (flare_mask[0] + flare_mask[1] + flare_mask[2]) > 0 light_source_cond = light_source_cond.float() light_source_cond = torch.repeat_interleave( light_source_cond[None, ...], 3, dim=0 ) # box = self.crop(light_source_cond[0]) box = self.lightsource_crop(light_source_cond[0]) # random int between -15 ~ 15 margin = random.randint(-15, 15) if box[0] - margin >= 0: box[0] -= margin if box[1] + margin < self.img_size: box[1] += margin if box[2] - margin >= 0: box[2] -= margin if box[3] + margin < self.img_size: box[3] += margin top, bottom, left, right = box[2], box[3], box[0], box[1] merge_img = adjust_gamma_reverse(merge_img) cropped_mask = torch.ones((self.img_size, self.img_size)) cropped_mask[top : bottom + 1, left : right + 1] = False cropped_mask = torch.repeat_interleave(cropped_mask[None, ...], 1, dim=0) channel3_mask = cropped_mask.repeat(3, 1, 1) masked_img = merge_img * (1 - channel3_mask) masked_img[channel3_mask == 1] = 0.5 return { # add "pixel_values": self.normalize(merge_img), "masks": cropped_mask, "masked_images": self.normalize(masked_img), "conditioning_pixel_values": light_source_cond, } def __len__(self): return len(self.data_list) def load_scattering_flare(self, flare_name, flare_path): flare_list = [] [flare_list.extend(glob.glob(flare_path + "/*." + e)) for e in self.ext] flare_list = sorted(flare_list) self.flare_name_list.append(flare_name) self.flare_dict[flare_name] = flare_list self.flare_list.append(flare_list) len_flare_list = len(self.flare_dict[flare_name]) if len_flare_list == 0: print("ERROR: scattering flare images are not loaded properly") else: print( "Scattering Flare Image:", flare_name, " is loaded successfully with examples", str(len_flare_list), ) # print("Now we have", len(self.flare_list), "scattering flare images") def load_light_source(self, light_name, light_path): # The number of the light source images should match the number of scattering flares light_list = [] [light_list.extend(glob.glob(light_path + "/*." + e)) for e in self.ext] light_list = sorted(light_list) self.flare_name_list.append(light_name) self.light_dict[light_name] = light_list self.light_list.append(light_list) len_light_list = len(self.light_dict[light_name]) if len_light_list == 0: print("ERROR: Light Source images are not loaded properly") else: self.light_flag = True print( "Light Source Image:", light_name, " is loaded successfully with examples", str(len_light_list), ) # print("Now we have", len(self.light_list), "light source images") def load_reflective_flare(self, reflective_name, reflective_path): if reflective_path is None: reflective_list = [] else: reflective_list = [] [ reflective_list.extend(glob.glob(reflective_path + "/*." + e)) for e in self.ext ] reflective_list = sorted(reflective_list) self.reflective_name_list.append(reflective_name) self.reflective_dict[reflective_name] = reflective_list self.reflective_list.append(reflective_list) len_reflective_list = len(self.reflective_dict[reflective_name]) if len_reflective_list == 0 and reflective_path is not None: print("ERROR: reflective flare images are not loaded properly") else: self.reflective_flag = True print( "Reflective Flare Image:", reflective_name, " is loaded successfully with examples", str(len_reflective_list), ) # print("Now we have", len(self.reflective_list), "refelctive flare images") class Flare7kpp_Pair_Loader(Flare_Image_Loader): def __init__(self, config): Flare_Image_Loader.__init__( self, config["image_path"], config["transform_base"], config["transform_flare"], config["mask_type"], ) scattering_dict = config["scattering_dict"] reflective_dict = config["reflective_dict"] light_dict = config["light_dict"] # defualt not use light mask if opt['use_light_mask'] is not declared if "data_ratio" not in config or len(config["data_ratio"]) == 0: self.data_ratio = [1] * len(scattering_dict) else: self.data_ratio = config["data_ratio"] if len(scattering_dict) != 0: for key in scattering_dict.keys(): self.load_scattering_flare(key, scattering_dict[key]) if len(reflective_dict) != 0: for key in reflective_dict.keys(): self.load_reflective_flare(key, reflective_dict[key]) if len(light_dict) != 0: for key in light_dict.keys(): self.load_light_source(key, light_dict[key]) class Lightsource_Regress_Loader(Flare7kpp_Pair_Loader): def __init__(self, config, num_lights=4): Flare7kpp_Pair_Loader.__init__(self, config) self.transform_flare = transforms.Compose( [ transforms.RandomAffine( degrees=(0, 360), scale=( config["transform_flare"]["scale_min"], config["transform_flare"]["scale_max"], ), shear=( -config["transform_flare"]["shear"], config["transform_flare"]["shear"], ), ), # transforms.CenterCrop((self.img_size, self.img_size)), ] ) self.mask_type = "light" self.num_lights = num_lights def __getitem__(self, index): # load base image img_path = self.data_list[index] base_img = Image.open(img_path).convert("RGB") gamma = np.random.uniform(1.8, 2.2) to_tensor = transforms.ToTensor() adjust_gamma = RandomGammaCorrection(gamma) adjust_gamma_reverse = RandomGammaCorrection(1 / gamma) color_jitter = transforms.ColorJitter(brightness=(0.8, 3), hue=0.0) base_img = to_tensor(base_img) base_img = adjust_gamma(base_img) if self.transform_base is not None: base_img = self.transform_base(base_img) sigma_chi = 0.01 * np.random.chisquare(df=1) base_img = Normal(base_img, sigma_chi).sample() gain = np.random.uniform(0.5, 1.2) base_img = gain * base_img base_img = torch.clamp(base_img, min=0, max=1) # init flare and light imgs flare_imgs = [] light_imgs = [] position = [ [[-224, 0], [-224, 0]], [[-224, 0], [0, 224]], [[0, 224], [-224, 0]], [[0, 224], [0, 224]], ] axis = random.sample(range(4), 4) axis[-1] = axis[0] flare_nums = int( random.random() * self.num_lights + 1 ) # random number of flares from 1 to 4 for fn in range(flare_nums): choice_dataset = random.choices( [i for i in range(len(self.flare_list))], self.data_ratio )[0] choice_index = random.randint(0, len(self.flare_list[choice_dataset]) - 1) flare_path = self.flare_list[choice_dataset][choice_index] flare_img = Image.open(flare_path).convert("RGB") flare_img = to_tensor(flare_img) flare_img = adjust_gamma(flare_img) flare_img = remove_background(flare_img) if self.light_flag: light_path = self.light_list[choice_dataset][choice_index] light_img = Image.open(light_path).convert("RGB") light_img = to_tensor(light_img) light_img = adjust_gamma(light_img) if self.transform_flare is not None: if self.light_flag: flare_merge = torch.cat((flare_img, light_img), dim=0) if flare_nums == 1: dx = random.randint(-224, 224) dy = random.randint(-224, 224) else: dx = random.randint( position[axis[fn]][0][0], position[axis[fn]][0][1] ) dy = random.randint( position[axis[fn]][1][0], position[axis[fn]][1][1] ) if -160 < dx < 160 and -160 < dy < 160: if random.random() < 0.5: dx = 160 if dx > 0 else -160 else: dy = 160 if dy > 0 else -160 flare_merge = self.transform_flare(flare_merge) flare_merge = TF.affine( flare_merge, angle=0, translate=(dx, dy), scale=1.0, shear=0 ) flare_merge = TF.center_crop( flare_merge, (self.img_size, self.img_size) ) else: flare_img = self.transform_flare(flare_img) # change color if self.light_flag: flare_img, light_img = torch.split(flare_merge, 3, dim=0) else: flare_img = color_jitter(flare_img) flare_imgs.append(flare_img) if self.light_flag: light_img = torch.clamp(light_img, min=0, max=1) light_imgs.append(light_img) flare_img = torch.sum(torch.stack(flare_imgs), dim=0) flare_img = torch.clamp(flare_img, min=0, max=1) # flare blur blur_transform = transforms.GaussianBlur(21, sigma=(0.1, 3.0)) flare_img = blur_transform(flare_img) flare_img = torch.clamp(flare_img, min=0, max=1) merge_img = torch.clamp(flare_img + base_img, min=0, max=1) if self.light_flag: light_img = torch.sum(torch.stack(light_imgs), dim=0) light_img = torch.clamp(light_img, min=0, max=1) base_img = torch.clamp(base_img + light_img, min=0, max=1) flare_img = torch.clamp(flare_img - light_img, min=0, max=1) flare_mask = None if self.mask_type == None: return { "gt": adjust_gamma_reverse(base_img), "flare": adjust_gamma_reverse(flare_img), "lq": adjust_gamma_reverse(merge_img), "gamma": gamma, } elif self.mask_type == "light": one = torch.ones_like(base_img) zero = torch.zeros_like(base_img) threshold_value = 0.01 # flare_masks_list = [] XYRs = torch.zeros((self.num_lights, 4)) for i in range(flare_nums): luminance = ( 0.3 * light_imgs[i][0] + 0.59 * light_imgs[i][1] + 0.11 * light_imgs[i][2] ) flare_mask = torch.where(luminance > threshold_value, one, zero) light_source_cond = (flare_mask.sum(dim=0) > 0).float() x, y, r = self.find_circle_properties(light_source_cond, i) XYRs[i] = torch.tensor([x, y, r, 1.0]) XYRs[:, :3] = XYRs[:, :3] / self.img_size luminance = 0.3 * light_img[0] + 0.59 * light_img[1] + 0.11 * light_img[2] flare_mask = torch.where(luminance > threshold_value, one, zero) light_source_cond = (flare_mask.sum(dim=0) > 0).float() light_source_cond = torch.repeat_interleave( light_source_cond[None, ...], 1, dim=0 ) # box = self.crop(light_source_cond[0]) box = self.lightsource_crop(light_source_cond[0]) # random int between 0 ~ 15 margin = random.randint(0, 15) if box[0] - margin >= 0: box[0] -= margin if box[1] + margin < self.img_size: box[1] += margin if box[2] - margin >= 0: box[2] -= margin if box[3] + margin < self.img_size: box[3] += margin top, bottom, left, right = box[2], box[3], box[0], box[1] merge_img = adjust_gamma_reverse(merge_img) cropped_mask = torch.full( (self.img_size, self.img_size), True, dtype=torch.bool ) cropped_mask[top : bottom + 1, left : right + 1] = False channel3_mask = cropped_mask.unsqueeze(0).expand(3, -1, -1) masked_img = merge_img * (1 - channel3_mask.float()) masked_img[channel3_mask] = 0.5 return { # add "input": self.normalize(masked_img), # normalize to [-1, 1] "light_masks": light_source_cond, "xyrs": XYRs, } def find_circle_properties(self, mask, i, method="minEnclosingCircle"): """ Find the properties of the light source circle in the mask. """ _mask = (mask.numpy() * 255).astype(np.uint8) _, binary_mask = cv2.threshold(_mask, 127, 255, cv2.THRESH_BINARY) contours, _ = cv2.findContours( binary_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE ) if len(contours) == 0: return 0.0, 0.0, 0.0 largest_contour = max(contours, key=cv2.contourArea) if method == "minEnclosingCircle": (x, y), radius = cv2.minEnclosingCircle(largest_contour) elif method == "area_based": M = cv2.moments(largest_contour) if M["m00"] == 0: # if the contour is too small return 0.0, 0.0, 0.0 x = M["m10"] / M["m00"] y = M["m01"] / M["m00"] area = cv2.contourArea(largest_contour) radius = np.sqrt(area / np.pi) # # draw # cv2.circle(_mask, (int(x), int(y)), int(radius), 128, 2) # cv2.imwrite(f"mask_{i}.png", _mask) return x, y, radius class Lightsource_3Maps_Loader(Lightsource_Regress_Loader): def __init__(self, config, num_lights=4): Lightsource_Regress_Loader.__init__(self, config, num_lights=num_lights) def build_gt_maps(self, coords, radii, H, W, kappa=0.4): yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij") prob_gt = torch.zeros((H, W)) rad_gt = torch.zeros((H, W)) eps = 1e-6 for x_i, y_i, r_i in zip(coords[:, 0], coords[:, 1], radii): if r_i < 1.0: continue sigma = kappa * r_i g = torch.exp(-((xx - x_i) ** 2 + (yy - y_i) ** 2) / (2 * sigma**2)) g_prime = torch.exp( -((xx - x_i) ** 2 + (yy - y_i) ** 2) / (2 * (sigma / 1.414) ** 2) ) prob_gt = torch.maximum(prob_gt, g) rad_gt = torch.maximum(rad_gt, g_prime * r_i) rad_gt = rad_gt / (prob_gt + eps) return prob_gt, rad_gt def __getitem__(self, index): # load base image img_path = self.data_list[index] base_img = Image.open(img_path).convert("RGB") gamma = np.random.uniform(1.8, 2.2) to_tensor = transforms.ToTensor() adjust_gamma = RandomGammaCorrection(gamma) adjust_gamma_reverse = RandomGammaCorrection(1 / gamma) color_jitter = transforms.ColorJitter(brightness=(0.8, 3), hue=0.0) base_img = to_tensor(base_img) base_img = adjust_gamma(base_img) if self.transform_base is not None: base_img = self.transform_base(base_img) sigma_chi = 0.01 * np.random.chisquare(df=1) base_img = Normal(base_img, sigma_chi).sample() gain = np.random.uniform(0.5, 1.2) base_img = gain * base_img base_img = torch.clamp(base_img, min=0, max=1) # init flare and light imgs flare_imgs = [] light_imgs = [] position = [ [[-224, 0], [-224, 0]], [[-224, 0], [0, 224]], [[0, 224], [-224, 0]], [[0, 224], [0, 224]], ] axis = random.sample(range(4), 4) axis[-1] = axis[0] flare_nums = int( random.random() * self.num_lights + 1 ) # random number of flares from 1 to 4 for fn in range(flare_nums): choice_dataset = random.choices( [i for i in range(len(self.flare_list))], self.data_ratio )[0] choice_index = random.randint(0, len(self.flare_list[choice_dataset]) - 1) flare_path = self.flare_list[choice_dataset][choice_index] flare_img = Image.open(flare_path).convert("RGB") flare_img = to_tensor(flare_img) flare_img = adjust_gamma(flare_img) flare_img = remove_background(flare_img) if self.light_flag: light_path = self.light_list[choice_dataset][choice_index] light_img = Image.open(light_path).convert("RGB") light_img = to_tensor(light_img) light_img = adjust_gamma(light_img) if self.transform_flare is not None: if self.light_flag: flare_merge = torch.cat((flare_img, light_img), dim=0) if flare_nums == 1: dx = random.randint(-224, 224) dy = random.randint(-224, 224) else: dx = random.randint( position[axis[fn]][0][0], position[axis[fn]][0][1] ) dy = random.randint( position[axis[fn]][1][0], position[axis[fn]][1][1] ) if -160 < dx < 160 and -160 < dy < 160: if random.random() < 0.5: dx = 160 if dx > 0 else -160 else: dy = 160 if dy > 0 else -160 flare_merge = self.transform_flare(flare_merge) flare_merge = TF.affine( flare_merge, angle=0, translate=(dx, dy), scale=1.0, shear=0 ) flare_merge = TF.center_crop( flare_merge, (self.img_size, self.img_size) ) else: flare_img = self.transform_flare(flare_img) # change color if self.light_flag: flare_img, light_img = torch.split(flare_merge, 3, dim=0) else: flare_img = color_jitter(flare_img) flare_imgs.append(flare_img) if self.light_flag: light_img = torch.clamp(light_img, min=0, max=1) light_imgs.append(light_img) flare_img = torch.sum(torch.stack(flare_imgs), dim=0) flare_img = torch.clamp(flare_img, min=0, max=1) # flare blur blur_transform = transforms.GaussianBlur(21, sigma=(0.1, 3.0)) flare_img = blur_transform(flare_img) flare_img = torch.clamp(flare_img, min=0, max=1) merge_img = torch.clamp(flare_img + base_img, min=0, max=1) if self.light_flag: light_img = torch.sum(torch.stack(light_imgs), dim=0) light_img = torch.clamp(light_img, min=0, max=1) base_img = torch.clamp(base_img + light_img, min=0, max=1) flare_img = torch.clamp(flare_img - light_img, min=0, max=1) flare_mask = None if self.mask_type == None: return { "gt": adjust_gamma_reverse(base_img), "flare": adjust_gamma_reverse(flare_img), "lq": adjust_gamma_reverse(merge_img), "gamma": gamma, } elif self.mask_type == "light": one = torch.ones_like(base_img) zero = torch.zeros_like(base_img) threshold_value = 0.01 # flare_masks_list = [] XYRs = torch.zeros((self.num_lights, 4)) for i in range(flare_nums): luminance = ( 0.3 * light_imgs[i][0] + 0.59 * light_imgs[i][1] + 0.11 * light_imgs[i][2] ) flare_mask = torch.where(luminance > threshold_value, one, zero) light_source_cond = (flare_mask.sum(dim=0) > 0).float() x, y, r = self.find_circle_properties(light_source_cond, i) XYRs[i] = torch.tensor([x, y, r, 1.0]) gt_prob, gt_rad = self.build_gt_maps( XYRs[:, :2], XYRs[:, 2], self.img_size, self.img_size ) gt_prob = gt_prob.unsqueeze(0) # shape: (1, H, W) gt_rad = gt_rad.unsqueeze(0) gt_rad /= self.img_size gt_maps = torch.cat((gt_prob, gt_rad), dim=0) # shape: (2, H, W) XYRs[:, :3] = XYRs[:, :3] / self.img_size luminance = 0.3 * light_img[0] + 0.59 * light_img[1] + 0.11 * light_img[2] flare_mask = torch.where(luminance > threshold_value, one, zero) light_source_cond = (flare_mask.sum(dim=0) > 0).float() light_source_cond = torch.repeat_interleave( light_source_cond[None, ...], 1, dim=0 ) # box = self.crop(light_source_cond[0]) box = self.lightsource_crop(light_source_cond[0]) # random int between 0 ~ 15 margin = random.randint(0, 15) if box[0] - margin >= 0: box[0] -= margin if box[1] + margin < self.img_size: box[1] += margin if box[2] - margin >= 0: box[2] -= margin if box[3] + margin < self.img_size: box[3] += margin top, bottom, left, right = box[2], box[3], box[0], box[1] merge_img = adjust_gamma_reverse(merge_img) cropped_mask = torch.full( (self.img_size, self.img_size), True, dtype=torch.bool ) cropped_mask[top : bottom + 1, left : right + 1] = False channel3_mask = cropped_mask.unsqueeze(0).expand(3, -1, -1) masked_img = merge_img * (1 - channel3_mask.float()) masked_img[channel3_mask] = 0.5 return { # add "input": self.normalize(masked_img), # normalize to [-1, 1] "light_masks": light_source_cond, "xyrs": gt_maps, } class TestImageLoader(Dataset): def __init__( self, dataroot_gt, dataroot_input, dataroot_mask, margin=0, img_size=512, noise_matching=False, ): super(TestImageLoader, self).__init__() self.gt_folder = dataroot_gt self.input_folder = dataroot_input self.mask_folder = dataroot_mask self.paths = glod_from_folder( [self.input_folder, self.gt_folder, self.mask_folder], ["input", "gt", "mask"], ) self.margin = margin self.img_size = img_size self.noise_matching = noise_matching def __len__(self): return len(self.paths["input"]) def __getitem__(self, index): img_name = self.paths["input"][index].split("/")[-1] num = img_name.split("_")[1].split(".")[0] # preprocess light source mask light_mask = np.array(Image.open(self.paths["mask"][index])) tmp_light_mask = np.zeros_like(light_mask[:, :, 0]) tmp_light_mask[light_mask[:, :, 2] > 0] = 255 cond = (light_mask[:, :, 0] > 0) & (light_mask[:, :, 1] > 0) tmp_light_mask[cond] = 0 light_mask = tmp_light_mask # img for controlnet input control_img = np.repeat(light_mask[:, :, None], 3, axis=2) # crop region box = self.lightsource_crop(light_mask) if box[0] - self.margin >= 0: box[0] -= self.margin if box[1] + self.margin < self.img_size: box[1] += self.margin if box[2] - self.margin >= 0: box[2] -= self.margin if box[3] + self.margin < self.img_size: box[3] += self.margin # input image to be outpainted input_img = np.array(Image.open(self.paths["input"][index])) cropped_region = np.ones((self.img_size, self.img_size), dtype=np.uint8) cropped_region[box[2] : box[3] + 1, box[0] : box[1] + 1] = 0 input_img[cropped_region == 1] = 128 # image for blip blip_img = input_img[box[2] : box[3] + 1, box[0] : box[1] + 1, :] # noise matching input_img_matching = None if self.noise_matching: np_src_img = input_img / 255.0 np_mask_rgb = np.repeat(cropped_region[:, :, None], 3, axis=2).astype( np.float32 ) matched_noise = self.get_matched_noise(np_src_img, np_mask_rgb) input_img_matching = (matched_noise * 255).astype(np.uint8) # mask image mask_img = (cropped_region * 255).astype(np.uint8) return { "blip_img": blip_img, "input_img": Image.fromarray(input_img), "input_img_matching": ( Image.fromarray(input_img_matching) if input_img_matching is not None else Image.fromarray(input_img) ), "mask_img": Image.fromarray(mask_img), "control_img": Image.fromarray(control_img), "box": box, "output_name": "output_" + num + ".png", } def lightsource_crop(self, matrix): """Find the largest rectangle of 1s in a binary matrix.""" def largestRectangleArea(heights): heights.append(0) stack = [-1] max_area = 0 max_rectangle = (0, 0, 0, 0) # (area, left, right, height) for i in range(len(heights)): while heights[i] < heights[stack[-1]]: h = heights[stack.pop()] w = i - stack[-1] - 1 area = h * w if area > max_area: max_area = area max_rectangle = (area, stack[-1] + 1, i - 1, h) stack.append(i) heights.pop() return max_rectangle max_area = 0 max_rectangle = [0, 0, 0, 0] # (left, right, top, bottom) heights = [0] * len(matrix[0]) for row in range(len(matrix)): for i, val in enumerate(matrix[row]): heights[i] = heights[i] + 1 if val == 0 else 0 area, left, right, height = largestRectangleArea(heights) if area > max_area: max_area = area max_rectangle = [int(left), int(right), int(row - height + 1), int(row)] return list(max_rectangle) # this function is taken from https://github.com/parlance-zz/g-diffuser-bot def get_matched_noise( self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05 ): # helper fft routines that keep ortho normalization and auto-shift before and after fft def _fft2(data): if data.ndim > 2: # has channels out_fft = np.zeros( (data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128 ) for c in range(data.shape[2]): c_data = data[:, :, c] out_fft[:, :, c] = np.fft.fft2( np.fft.fftshift(c_data), norm="ortho" ) out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c]) else: # one channel out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho") out_fft[:, :] = np.fft.ifftshift(out_fft[:, :]) return out_fft def _ifft2(data): if data.ndim > 2: # has channels out_ifft = np.zeros( (data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128 ) for c in range(data.shape[2]): c_data = data[:, :, c] out_ifft[:, :, c] = np.fft.ifft2( np.fft.fftshift(c_data), norm="ortho" ) out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c]) else: # one channel out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho") out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :]) return out_ifft def _get_gaussian_window(width, height, std=3.14, mode=0): window_scale_x = float(width / min(width, height)) window_scale_y = float(height / min(width, height)) window = np.zeros((width, height)) x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x for y in range(height): fy = (y / height * 2.0 - 1.0) * window_scale_y if mode == 0: window[:, y] = np.exp(-(x**2 + fy**2) * std) else: window[:, y] = (1 / ((x**2 + 1.0) * (fy**2 + 1.0))) ** ( std / 3.14 ) # hey wait a minute that's not gaussian return window def _get_masked_window_rgb(np_mask_grey, hardness=1.0): np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3)) if hardness != 1.0: hardened = np_mask_grey[:] ** hardness else: hardened = np_mask_grey[:] for c in range(3): np_mask_rgb[:, :, c] = hardened[:] return np_mask_rgb width = _np_src_image.shape[0] height = _np_src_image.shape[1] num_channels = _np_src_image.shape[2] _np_src_image[:] * (1.0 - np_mask_rgb) np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0 img_mask = np_mask_grey > 1e-6 ref_mask = np_mask_grey < 1e-3 windowed_image = _np_src_image * (1.0 - _get_masked_window_rgb(np_mask_grey)) windowed_image /= np.max(windowed_image) windowed_image += ( np.average(_np_src_image) * np_mask_rgb ) # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color src_fft = _fft2(windowed_image) # get feature statistics from masked src img src_dist = np.absolute(src_fft) src_phase = src_fft / src_dist # create a generator with a static seed to make outpainting deterministic / only follow global seed rng = np.random.default_rng(0) noise_window = _get_gaussian_window( width, height, mode=1 ) # start with simple gaussian noise noise_rgb = rng.random((width, height, num_channels)) noise_grey = np.sum(noise_rgb, axis=2) / 3.0 noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter for c in range(num_channels): noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey noise_fft = _fft2(noise_rgb) for c in range(num_channels): noise_fft[:, :, c] *= noise_window noise_rgb = np.real(_ifft2(noise_fft)) shaped_noise_fft = _fft2(noise_rgb) shaped_noise_fft[:, :, :] = ( np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist**noise_q) * src_phase ) # perform the actual shaping brightness_variation = 0.0 # color_variation # todo: temporarily tying brightness variation to color variation for now contrast_adjusted_np_src = ( _np_src_image[:] * (brightness_variation + 1.0) - brightness_variation * 2.0 ) # scikit-image is used for histogram matching, very convenient! shaped_noise = np.real(_ifft2(shaped_noise_fft)) shaped_noise -= np.min(shaped_noise) shaped_noise /= np.max(shaped_noise) shaped_noise[img_mask, :] = skimage.exposure.match_histograms( shaped_noise[img_mask, :] ** 1.0, contrast_adjusted_np_src[ref_mask, :], channel_axis=1, ) shaped_noise = ( _np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb ) matched_noise = shaped_noise[:] return np.clip(matched_noise, 0.0, 1.0) class CustomImageLoader(Dataset): def __init__( self, dataroot_input, left_outpaint, right_outpaint, up_outpaint, down_outpaint ): self.dataroot_input = dataroot_input self.left_outpaint = left_outpaint self.right_outpaint = right_outpaint self.up_outpaint = up_outpaint self.down_outpaint = down_outpaint self.H = 512 - (up_outpaint + down_outpaint) self.W = 512 - (left_outpaint + right_outpaint) self.img_size = 512 self.img_lists = [ os.path.join(dataroot_input, f) for f in os.listdir(dataroot_input) if f.endswith(".png") or f.endswith(".jpg") ] def __len__(self): return len(self.img_lists) def __getitem__(self, index): img_name = self.img_lists[index].split("/")[-1] # crop region box = [ self.left_outpaint, 511 - self.right_outpaint, self.up_outpaint, 511 - self.down_outpaint, ] # [left, right, top, bottom] # box = self.lightsource_crop(light_mask) # if box[0] - self.margin >= 0: # box[0] -= self.margin # if box[1] + self.margin < self.img_size: # box[1] += self.margin # if box[2] - self.margin >= 0: # box[2] -= self.margin # if box[3] + self.margin < self.img_size: # box[3] += self.margin # input image to be outpainted input_img = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8) paste_img = np.array( Image.open(self.img_lists[index]).resize((self.W, self.H), Image.LANCZOS) ) input_img[box[2] : box[3] + 1, box[0] : box[1] + 1, :] = paste_img cropped_region = np.ones((self.img_size, self.img_size), dtype=np.uint8) cropped_region[box[2] : box[3] + 1, box[0] : box[1] + 1] = 0 input_img[cropped_region == 1] = 128 # image for blip blip_img = np.array(Image.open(self.img_lists[index])) # # noise matching # input_img_matching = None # if self.noise_matching: # np_src_img = input_img / 255.0 # np_mask_rgb = np.repeat(cropped_region[:, :, None], 3, axis=2).astype( # np.float32 # ) # matched_noise = self.get_matched_noise(np_src_img, np_mask_rgb) # input_img_matching = (matched_noise * 255).astype(np.uint8) # mask image mask_img = (cropped_region * 255).astype(np.uint8) return { "blip_img": blip_img, "input_img": Image.fromarray(input_img), # "input_img": ( # Image.fromarray(input_img_matching) # if input_img_matching is not None # else Image.fromarray(input_img) # ), "mask_img": Image.fromarray(mask_img), "box": box, "output_name": img_name, } class HFCustomImageLoader(Dataset): def __init__( self, img_data, left_outpaint=64, right_outpaint=64, up_outpaint=64, down_outpaint=64 ): self.left_outpaint = left_outpaint self.right_outpaint = right_outpaint self.up_outpaint = up_outpaint self.down_outpaint = down_outpaint self.H = 512 - (up_outpaint + down_outpaint) self.W = 512 - (left_outpaint + right_outpaint) self.img_size = 512 self.img_lists = [img_data] def __len__(self): return len(self.img_lists) def __getitem__(self, index): # img_name = self.img_lists[index].split("/")[-1] # crop region box = [ self.left_outpaint, 511 - self.right_outpaint, self.up_outpaint, 511 - self.down_outpaint, ] # [left, right, top, bottom] # input image to be outpainted input_img = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8) paste_img = np.array(self.img_lists[index].resize((self.W, self.H), Image.LANCZOS)) input_img[box[2] : box[3] + 1, box[0] : box[1] + 1, :] = paste_img cropped_region = np.ones((self.img_size, self.img_size), dtype=np.uint8) cropped_region[box[2] : box[3] + 1, box[0] : box[1] + 1] = 0 input_img[cropped_region == 1] = 128 # image for blip blip_img = np.array(self.img_lists[index]) # # noise matching # input_img_matching = None # if self.noise_matching: # np_src_img = input_img / 255.0 # np_mask_rgb = np.repeat(cropped_region[:, :, None], 3, axis=2).astype( # np.float32 # ) # matched_noise = self.get_matched_noise(np_src_img, np_mask_rgb) # input_img_matching = (matched_noise * 255).astype(np.uint8) # mask image mask_img = (cropped_region * 255).astype(np.uint8) return { "blip_img": blip_img, "input_img": Image.fromarray(input_img), "mask_img": Image.fromarray(mask_img), "box": box, } if __name__ == "__main__": pass