LightsOut-demo / utils /dataset.py
Ray-1026
update
a856109
import os
import cv2
import glob
import random
import timeit
import numpy as np
import skimage
import yaml
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from PIL import Image
from torch.utils.data import Dataset
from torch.distributions import Normal
# from utils.utils import RGB2YCbCr
class RandomGammaCorrection(object):
def __init__(self, gamma=None):
self.gamma = gamma
def __call__(self, image):
if self.gamma == None:
# more chances of selecting 0 (original image)
gammas = [0.5, 1, 2]
self.gamma = random.choice(gammas)
return TF.adjust_gamma(image, self.gamma, gain=1)
elif isinstance(self.gamma, tuple):
gamma = random.uniform(*self.gamma)
return TF.adjust_gamma(image, gamma, gain=1)
elif self.gamma == 0:
return image
else:
return TF.adjust_gamma(image, self.gamma, gain=1)
def remove_background(image):
# the input of the image is PIL.Image form with [H,W,C]
image = np.float32(np.array(image))
_EPS = 1e-7
rgb_max = np.max(image, (0, 1))
rgb_min = np.min(image, (0, 1))
image = (image - rgb_min) * rgb_max / (rgb_max - rgb_min + _EPS)
image = torch.from_numpy(image)
return image
def glod_from_folder(folder_list, index_list):
ext = ["png", "jpeg", "jpg", "bmp", "tif"]
index_dict = {}
for i, folder_name in enumerate(folder_list):
data_list = []
[data_list.extend(glob.glob(folder_name + "/*." + e)) for e in ext]
data_list.sort()
index_dict[index_list[i]] = data_list
return index_dict
class Flare_Image_Loader(Dataset):
def __init__(self, image_path, transform_base, transform_flare, mask_type=None):
self.ext = ["png", "jpeg", "jpg", "bmp", "tif"]
self.data_list = []
[self.data_list.extend(glob.glob(image_path + "/*." + e)) for e in self.ext]
self.flare_dict = {}
self.flare_list = []
self.flare_name_list = []
self.reflective_flag = False
self.reflective_dict = {}
self.reflective_list = []
self.reflective_name_list = []
self.light_flag = False
self.light_dict = {}
self.light_list = []
self.light_name_list = []
self.mask_type = (
mask_type # It is a str which may be None,"luminance" or "color"
)
self.img_size = transform_base["img_size"]
self.transform_base = transforms.Compose(
[
transforms.RandomCrop(
(self.img_size, self.img_size),
pad_if_needed=True,
padding_mode="reflect",
),
transforms.RandomHorizontalFlip(),
# transforms.RandomVerticalFlip(),
]
)
self.transform_flare = transforms.Compose(
[
transforms.RandomAffine(
degrees=(0, 360),
scale=(transform_flare["scale_min"], transform_flare["scale_max"]),
translate=(
transform_flare["translate"] / 1440,
transform_flare["translate"] / 1440,
),
shear=(-transform_flare["shear"], transform_flare["shear"]),
),
transforms.CenterCrop((self.img_size, self.img_size)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
]
)
self.normalize = transforms.Compose(
[
transforms.Normalize([0.5], [0.5]),
]
)
self.data_ratio = []
def lightsource_crop(self, matrix):
"""Find the largest rectangle of 1s in a binary matrix."""
def largestRectangleArea(heights):
heights.append(0)
stack = [-1]
max_area = 0
max_rectangle = (0, 0, 0, 0) # (area, left, right, height)
for i in range(len(heights)):
while heights[i] < heights[stack[-1]]:
h = heights[stack.pop()]
w = i - stack[-1] - 1
area = h * w
if area > max_area:
max_area = area
max_rectangle = (area, stack[-1] + 1, i - 1, h)
stack.append(i)
heights.pop()
return max_rectangle
max_area = 0
max_rectangle = [0, 0, 0, 0] # (left, right, top, bottom)
heights = torch.zeros(matrix.shape[1])
for row in range(matrix.shape[0]):
temp = 1 - matrix[row]
heights = (heights + temp) * temp
area, left, right, height = largestRectangleArea(heights.tolist())
if area > max_area:
max_area = area
max_rectangle = [int(left), int(right), int(row - height + 1), int(row)]
return torch.tensor(max_rectangle)
def __getitem__(self, index):
# load base image
img_path = self.data_list[index]
base_img = Image.open(img_path).convert("RGB")
gamma = np.random.uniform(1.8, 2.2)
to_tensor = transforms.ToTensor()
adjust_gamma = RandomGammaCorrection(gamma)
adjust_gamma_reverse = RandomGammaCorrection(1 / gamma)
color_jitter = transforms.ColorJitter(brightness=(0.8, 3), hue=0.0)
if self.transform_base is not None:
base_img = to_tensor(base_img)
base_img = adjust_gamma(base_img)
base_img = self.transform_base(base_img)
else:
base_img = to_tensor(base_img)
base_img = adjust_gamma(base_img)
sigma_chi = 0.01 * np.random.chisquare(df=1)
base_img = Normal(base_img, sigma_chi).sample()
gain = np.random.uniform(0.5, 1.2)
flare_DC_offset = np.random.uniform(-0.02, 0.02)
base_img = gain * base_img
base_img = torch.clamp(base_img, min=0, max=1)
choice_dataset = random.choices(
[i for i in range(len(self.flare_list))], self.data_ratio
)[0]
choice_index = random.randint(0, len(self.flare_list[choice_dataset]) - 1)
# load flare and light source image
if self.light_flag:
assert len(self.flare_list) == len(
self.light_list
), "Error, number of light source and flares dataset no match!"
for i in range(len(self.flare_list)):
assert len(self.flare_list[i]) == len(
self.light_list[i]
), f"Error, number of light source and flares no match in {i} dataset!"
flare_path = self.flare_list[choice_dataset][choice_index]
light_path = self.light_list[choice_dataset][choice_index]
light_img = Image.open(light_path).convert("RGB")
light_img = to_tensor(light_img)
light_img = adjust_gamma(light_img)
else:
flare_path = self.flare_list[choice_dataset][choice_index]
flare_img = Image.open(flare_path).convert("RGB")
if self.reflective_flag:
reflective_path_list = self.reflective_list[choice_dataset]
if len(reflective_path_list) != 0:
reflective_path = random.choice(reflective_path_list)
reflective_img = Image.open(reflective_path).convert("RGB")
else:
reflective_img = None
flare_img = to_tensor(flare_img)
flare_img = adjust_gamma(flare_img)
if self.reflective_flag and reflective_img is not None:
reflective_img = to_tensor(reflective_img)
reflective_img = adjust_gamma(reflective_img)
flare_img = torch.clamp(flare_img + reflective_img, min=0, max=1)
flare_img = remove_background(flare_img)
if self.transform_flare is not None:
if self.light_flag:
flare_merge = torch.cat((flare_img, light_img), dim=0)
flare_merge = self.transform_flare(flare_merge)
else:
flare_img = self.transform_flare(flare_img)
# change color
if self.light_flag:
# flare_merge=color_jitter(flare_merge)
flare_img, light_img = torch.split(flare_merge, 3, dim=0)
else:
flare_img = color_jitter(flare_img)
# flare blur
blur_transform = transforms.GaussianBlur(21, sigma=(0.1, 3.0))
flare_img = blur_transform(flare_img)
# flare_img=flare_img+flare_DC_offset
flare_img = torch.clamp(flare_img, min=0, max=1)
# merge image
merge_img = flare_img + base_img
merge_img = torch.clamp(merge_img, min=0, max=1)
if self.light_flag:
base_img = base_img + light_img
base_img = torch.clamp(base_img, min=0, max=1)
flare_img = flare_img - light_img
flare_img = torch.clamp(flare_img, min=0, max=1)
flare_mask = None
if self.mask_type == None:
return {
"gt": adjust_gamma_reverse(base_img),
"flare": adjust_gamma_reverse(flare_img),
"lq": adjust_gamma_reverse(merge_img),
"gamma": gamma,
}
elif self.mask_type == "luminance":
# calculate mask (the mask is 3 channel)
one = torch.ones_like(base_img)
zero = torch.zeros_like(base_img)
luminance = 0.3 * flare_img[0] + 0.59 * flare_img[1] + 0.11 * flare_img[2]
threshold_value = 0.99**gamma
flare_mask = torch.where(luminance > threshold_value, one, zero)
elif self.mask_type == "color":
one = torch.ones_like(base_img)
zero = torch.zeros_like(base_img)
threshold_value = 0.99**gamma
flare_mask = torch.where(merge_img > threshold_value, one, zero)
elif self.mask_type == "flare":
one = torch.ones_like(base_img)
zero = torch.zeros_like(base_img)
threshold_value = 0.7**gamma
flare_mask = torch.where(flare_img > threshold_value, one, zero)
elif self.mask_type == "light":
# Depreciated: we dont need light mask anymore
one = torch.ones_like(base_img)
zero = torch.zeros_like(base_img)
luminance = 0.3 * light_img[0] + 0.59 * light_img[1] + 0.11 * light_img[2]
threshold_value = 0.01
flare_mask = torch.where(luminance > threshold_value, one, zero)
light_source_cond = torch.zeros_like(flare_mask[0])
light_source_cond = (flare_mask[0] + flare_mask[1] + flare_mask[2]) > 0
light_source_cond = light_source_cond.float()
light_source_cond = torch.repeat_interleave(
light_source_cond[None, ...], 3, dim=0
)
# box = self.crop(light_source_cond[0])
box = self.lightsource_crop(light_source_cond[0])
# random int between -15 ~ 15
margin = random.randint(-15, 15)
if box[0] - margin >= 0:
box[0] -= margin
if box[1] + margin < self.img_size:
box[1] += margin
if box[2] - margin >= 0:
box[2] -= margin
if box[3] + margin < self.img_size:
box[3] += margin
top, bottom, left, right = box[2], box[3], box[0], box[1]
merge_img = adjust_gamma_reverse(merge_img)
cropped_mask = torch.ones((self.img_size, self.img_size))
cropped_mask[top : bottom + 1, left : right + 1] = False
cropped_mask = torch.repeat_interleave(cropped_mask[None, ...], 1, dim=0)
channel3_mask = cropped_mask.repeat(3, 1, 1)
masked_img = merge_img * (1 - channel3_mask)
masked_img[channel3_mask == 1] = 0.5
return {
# add
"pixel_values": self.normalize(merge_img),
"masks": cropped_mask,
"masked_images": self.normalize(masked_img),
"conditioning_pixel_values": light_source_cond,
}
def __len__(self):
return len(self.data_list)
def load_scattering_flare(self, flare_name, flare_path):
flare_list = []
[flare_list.extend(glob.glob(flare_path + "/*." + e)) for e in self.ext]
flare_list = sorted(flare_list)
self.flare_name_list.append(flare_name)
self.flare_dict[flare_name] = flare_list
self.flare_list.append(flare_list)
len_flare_list = len(self.flare_dict[flare_name])
if len_flare_list == 0:
print("ERROR: scattering flare images are not loaded properly")
else:
print(
"Scattering Flare Image:",
flare_name,
" is loaded successfully with examples",
str(len_flare_list),
)
# print("Now we have", len(self.flare_list), "scattering flare images")
def load_light_source(self, light_name, light_path):
# The number of the light source images should match the number of scattering flares
light_list = []
[light_list.extend(glob.glob(light_path + "/*." + e)) for e in self.ext]
light_list = sorted(light_list)
self.flare_name_list.append(light_name)
self.light_dict[light_name] = light_list
self.light_list.append(light_list)
len_light_list = len(self.light_dict[light_name])
if len_light_list == 0:
print("ERROR: Light Source images are not loaded properly")
else:
self.light_flag = True
print(
"Light Source Image:",
light_name,
" is loaded successfully with examples",
str(len_light_list),
)
# print("Now we have", len(self.light_list), "light source images")
def load_reflective_flare(self, reflective_name, reflective_path):
if reflective_path is None:
reflective_list = []
else:
reflective_list = []
[
reflective_list.extend(glob.glob(reflective_path + "/*." + e))
for e in self.ext
]
reflective_list = sorted(reflective_list)
self.reflective_name_list.append(reflective_name)
self.reflective_dict[reflective_name] = reflective_list
self.reflective_list.append(reflective_list)
len_reflective_list = len(self.reflective_dict[reflective_name])
if len_reflective_list == 0 and reflective_path is not None:
print("ERROR: reflective flare images are not loaded properly")
else:
self.reflective_flag = True
print(
"Reflective Flare Image:",
reflective_name,
" is loaded successfully with examples",
str(len_reflective_list),
)
# print("Now we have", len(self.reflective_list), "refelctive flare images")
class Flare7kpp_Pair_Loader(Flare_Image_Loader):
def __init__(self, config):
Flare_Image_Loader.__init__(
self,
config["image_path"],
config["transform_base"],
config["transform_flare"],
config["mask_type"],
)
scattering_dict = config["scattering_dict"]
reflective_dict = config["reflective_dict"]
light_dict = config["light_dict"]
# defualt not use light mask if opt['use_light_mask'] is not declared
if "data_ratio" not in config or len(config["data_ratio"]) == 0:
self.data_ratio = [1] * len(scattering_dict)
else:
self.data_ratio = config["data_ratio"]
if len(scattering_dict) != 0:
for key in scattering_dict.keys():
self.load_scattering_flare(key, scattering_dict[key])
if len(reflective_dict) != 0:
for key in reflective_dict.keys():
self.load_reflective_flare(key, reflective_dict[key])
if len(light_dict) != 0:
for key in light_dict.keys():
self.load_light_source(key, light_dict[key])
class Lightsource_Regress_Loader(Flare7kpp_Pair_Loader):
def __init__(self, config, num_lights=4):
Flare7kpp_Pair_Loader.__init__(self, config)
self.transform_flare = transforms.Compose(
[
transforms.RandomAffine(
degrees=(0, 360),
scale=(
config["transform_flare"]["scale_min"],
config["transform_flare"]["scale_max"],
),
shear=(
-config["transform_flare"]["shear"],
config["transform_flare"]["shear"],
),
),
# transforms.CenterCrop((self.img_size, self.img_size)),
]
)
self.mask_type = "light"
self.num_lights = num_lights
def __getitem__(self, index):
# load base image
img_path = self.data_list[index]
base_img = Image.open(img_path).convert("RGB")
gamma = np.random.uniform(1.8, 2.2)
to_tensor = transforms.ToTensor()
adjust_gamma = RandomGammaCorrection(gamma)
adjust_gamma_reverse = RandomGammaCorrection(1 / gamma)
color_jitter = transforms.ColorJitter(brightness=(0.8, 3), hue=0.0)
base_img = to_tensor(base_img)
base_img = adjust_gamma(base_img)
if self.transform_base is not None:
base_img = self.transform_base(base_img)
sigma_chi = 0.01 * np.random.chisquare(df=1)
base_img = Normal(base_img, sigma_chi).sample()
gain = np.random.uniform(0.5, 1.2)
base_img = gain * base_img
base_img = torch.clamp(base_img, min=0, max=1)
# init flare and light imgs
flare_imgs = []
light_imgs = []
position = [
[[-224, 0], [-224, 0]],
[[-224, 0], [0, 224]],
[[0, 224], [-224, 0]],
[[0, 224], [0, 224]],
]
axis = random.sample(range(4), 4)
axis[-1] = axis[0]
flare_nums = int(
random.random() * self.num_lights + 1
) # random number of flares from 1 to 4
for fn in range(flare_nums):
choice_dataset = random.choices(
[i for i in range(len(self.flare_list))], self.data_ratio
)[0]
choice_index = random.randint(0, len(self.flare_list[choice_dataset]) - 1)
flare_path = self.flare_list[choice_dataset][choice_index]
flare_img = Image.open(flare_path).convert("RGB")
flare_img = to_tensor(flare_img)
flare_img = adjust_gamma(flare_img)
flare_img = remove_background(flare_img)
if self.light_flag:
light_path = self.light_list[choice_dataset][choice_index]
light_img = Image.open(light_path).convert("RGB")
light_img = to_tensor(light_img)
light_img = adjust_gamma(light_img)
if self.transform_flare is not None:
if self.light_flag:
flare_merge = torch.cat((flare_img, light_img), dim=0)
if flare_nums == 1:
dx = random.randint(-224, 224)
dy = random.randint(-224, 224)
else:
dx = random.randint(
position[axis[fn]][0][0], position[axis[fn]][0][1]
)
dy = random.randint(
position[axis[fn]][1][0], position[axis[fn]][1][1]
)
if -160 < dx < 160 and -160 < dy < 160:
if random.random() < 0.5:
dx = 160 if dx > 0 else -160
else:
dy = 160 if dy > 0 else -160
flare_merge = self.transform_flare(flare_merge)
flare_merge = TF.affine(
flare_merge, angle=0, translate=(dx, dy), scale=1.0, shear=0
)
flare_merge = TF.center_crop(
flare_merge, (self.img_size, self.img_size)
)
else:
flare_img = self.transform_flare(flare_img)
# change color
if self.light_flag:
flare_img, light_img = torch.split(flare_merge, 3, dim=0)
else:
flare_img = color_jitter(flare_img)
flare_imgs.append(flare_img)
if self.light_flag:
light_img = torch.clamp(light_img, min=0, max=1)
light_imgs.append(light_img)
flare_img = torch.sum(torch.stack(flare_imgs), dim=0)
flare_img = torch.clamp(flare_img, min=0, max=1)
# flare blur
blur_transform = transforms.GaussianBlur(21, sigma=(0.1, 3.0))
flare_img = blur_transform(flare_img)
flare_img = torch.clamp(flare_img, min=0, max=1)
merge_img = torch.clamp(flare_img + base_img, min=0, max=1)
if self.light_flag:
light_img = torch.sum(torch.stack(light_imgs), dim=0)
light_img = torch.clamp(light_img, min=0, max=1)
base_img = torch.clamp(base_img + light_img, min=0, max=1)
flare_img = torch.clamp(flare_img - light_img, min=0, max=1)
flare_mask = None
if self.mask_type == None:
return {
"gt": adjust_gamma_reverse(base_img),
"flare": adjust_gamma_reverse(flare_img),
"lq": adjust_gamma_reverse(merge_img),
"gamma": gamma,
}
elif self.mask_type == "light":
one = torch.ones_like(base_img)
zero = torch.zeros_like(base_img)
threshold_value = 0.01
# flare_masks_list = []
XYRs = torch.zeros((self.num_lights, 4))
for i in range(flare_nums):
luminance = (
0.3 * light_imgs[i][0]
+ 0.59 * light_imgs[i][1]
+ 0.11 * light_imgs[i][2]
)
flare_mask = torch.where(luminance > threshold_value, one, zero)
light_source_cond = (flare_mask.sum(dim=0) > 0).float()
x, y, r = self.find_circle_properties(light_source_cond, i)
XYRs[i] = torch.tensor([x, y, r, 1.0])
XYRs[:, :3] = XYRs[:, :3] / self.img_size
luminance = 0.3 * light_img[0] + 0.59 * light_img[1] + 0.11 * light_img[2]
flare_mask = torch.where(luminance > threshold_value, one, zero)
light_source_cond = (flare_mask.sum(dim=0) > 0).float()
light_source_cond = torch.repeat_interleave(
light_source_cond[None, ...], 1, dim=0
)
# box = self.crop(light_source_cond[0])
box = self.lightsource_crop(light_source_cond[0])
# random int between 0 ~ 15
margin = random.randint(0, 15)
if box[0] - margin >= 0:
box[0] -= margin
if box[1] + margin < self.img_size:
box[1] += margin
if box[2] - margin >= 0:
box[2] -= margin
if box[3] + margin < self.img_size:
box[3] += margin
top, bottom, left, right = box[2], box[3], box[0], box[1]
merge_img = adjust_gamma_reverse(merge_img)
cropped_mask = torch.full(
(self.img_size, self.img_size), True, dtype=torch.bool
)
cropped_mask[top : bottom + 1, left : right + 1] = False
channel3_mask = cropped_mask.unsqueeze(0).expand(3, -1, -1)
masked_img = merge_img * (1 - channel3_mask.float())
masked_img[channel3_mask] = 0.5
return {
# add
"input": self.normalize(masked_img), # normalize to [-1, 1]
"light_masks": light_source_cond,
"xyrs": XYRs,
}
def find_circle_properties(self, mask, i, method="minEnclosingCircle"):
"""
Find the properties of the light source circle in the mask.
"""
_mask = (mask.numpy() * 255).astype(np.uint8)
_, binary_mask = cv2.threshold(_mask, 127, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(
binary_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
if len(contours) == 0:
return 0.0, 0.0, 0.0
largest_contour = max(contours, key=cv2.contourArea)
if method == "minEnclosingCircle":
(x, y), radius = cv2.minEnclosingCircle(largest_contour)
elif method == "area_based":
M = cv2.moments(largest_contour)
if M["m00"] == 0: # if the contour is too small
return 0.0, 0.0, 0.0
x = M["m10"] / M["m00"]
y = M["m01"] / M["m00"]
area = cv2.contourArea(largest_contour)
radius = np.sqrt(area / np.pi)
# # draw
# cv2.circle(_mask, (int(x), int(y)), int(radius), 128, 2)
# cv2.imwrite(f"mask_{i}.png", _mask)
return x, y, radius
class Lightsource_3Maps_Loader(Lightsource_Regress_Loader):
def __init__(self, config, num_lights=4):
Lightsource_Regress_Loader.__init__(self, config, num_lights=num_lights)
def build_gt_maps(self, coords, radii, H, W, kappa=0.4):
yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij")
prob_gt = torch.zeros((H, W))
rad_gt = torch.zeros((H, W))
eps = 1e-6
for x_i, y_i, r_i in zip(coords[:, 0], coords[:, 1], radii):
if r_i < 1.0:
continue
sigma = kappa * r_i
g = torch.exp(-((xx - x_i) ** 2 + (yy - y_i) ** 2) / (2 * sigma**2))
g_prime = torch.exp(
-((xx - x_i) ** 2 + (yy - y_i) ** 2) / (2 * (sigma / 1.414) ** 2)
)
prob_gt = torch.maximum(prob_gt, g)
rad_gt = torch.maximum(rad_gt, g_prime * r_i)
rad_gt = rad_gt / (prob_gt + eps)
return prob_gt, rad_gt
def __getitem__(self, index):
# load base image
img_path = self.data_list[index]
base_img = Image.open(img_path).convert("RGB")
gamma = np.random.uniform(1.8, 2.2)
to_tensor = transforms.ToTensor()
adjust_gamma = RandomGammaCorrection(gamma)
adjust_gamma_reverse = RandomGammaCorrection(1 / gamma)
color_jitter = transforms.ColorJitter(brightness=(0.8, 3), hue=0.0)
base_img = to_tensor(base_img)
base_img = adjust_gamma(base_img)
if self.transform_base is not None:
base_img = self.transform_base(base_img)
sigma_chi = 0.01 * np.random.chisquare(df=1)
base_img = Normal(base_img, sigma_chi).sample()
gain = np.random.uniform(0.5, 1.2)
base_img = gain * base_img
base_img = torch.clamp(base_img, min=0, max=1)
# init flare and light imgs
flare_imgs = []
light_imgs = []
position = [
[[-224, 0], [-224, 0]],
[[-224, 0], [0, 224]],
[[0, 224], [-224, 0]],
[[0, 224], [0, 224]],
]
axis = random.sample(range(4), 4)
axis[-1] = axis[0]
flare_nums = int(
random.random() * self.num_lights + 1
) # random number of flares from 1 to 4
for fn in range(flare_nums):
choice_dataset = random.choices(
[i for i in range(len(self.flare_list))], self.data_ratio
)[0]
choice_index = random.randint(0, len(self.flare_list[choice_dataset]) - 1)
flare_path = self.flare_list[choice_dataset][choice_index]
flare_img = Image.open(flare_path).convert("RGB")
flare_img = to_tensor(flare_img)
flare_img = adjust_gamma(flare_img)
flare_img = remove_background(flare_img)
if self.light_flag:
light_path = self.light_list[choice_dataset][choice_index]
light_img = Image.open(light_path).convert("RGB")
light_img = to_tensor(light_img)
light_img = adjust_gamma(light_img)
if self.transform_flare is not None:
if self.light_flag:
flare_merge = torch.cat((flare_img, light_img), dim=0)
if flare_nums == 1:
dx = random.randint(-224, 224)
dy = random.randint(-224, 224)
else:
dx = random.randint(
position[axis[fn]][0][0], position[axis[fn]][0][1]
)
dy = random.randint(
position[axis[fn]][1][0], position[axis[fn]][1][1]
)
if -160 < dx < 160 and -160 < dy < 160:
if random.random() < 0.5:
dx = 160 if dx > 0 else -160
else:
dy = 160 if dy > 0 else -160
flare_merge = self.transform_flare(flare_merge)
flare_merge = TF.affine(
flare_merge, angle=0, translate=(dx, dy), scale=1.0, shear=0
)
flare_merge = TF.center_crop(
flare_merge, (self.img_size, self.img_size)
)
else:
flare_img = self.transform_flare(flare_img)
# change color
if self.light_flag:
flare_img, light_img = torch.split(flare_merge, 3, dim=0)
else:
flare_img = color_jitter(flare_img)
flare_imgs.append(flare_img)
if self.light_flag:
light_img = torch.clamp(light_img, min=0, max=1)
light_imgs.append(light_img)
flare_img = torch.sum(torch.stack(flare_imgs), dim=0)
flare_img = torch.clamp(flare_img, min=0, max=1)
# flare blur
blur_transform = transforms.GaussianBlur(21, sigma=(0.1, 3.0))
flare_img = blur_transform(flare_img)
flare_img = torch.clamp(flare_img, min=0, max=1)
merge_img = torch.clamp(flare_img + base_img, min=0, max=1)
if self.light_flag:
light_img = torch.sum(torch.stack(light_imgs), dim=0)
light_img = torch.clamp(light_img, min=0, max=1)
base_img = torch.clamp(base_img + light_img, min=0, max=1)
flare_img = torch.clamp(flare_img - light_img, min=0, max=1)
flare_mask = None
if self.mask_type == None:
return {
"gt": adjust_gamma_reverse(base_img),
"flare": adjust_gamma_reverse(flare_img),
"lq": adjust_gamma_reverse(merge_img),
"gamma": gamma,
}
elif self.mask_type == "light":
one = torch.ones_like(base_img)
zero = torch.zeros_like(base_img)
threshold_value = 0.01
# flare_masks_list = []
XYRs = torch.zeros((self.num_lights, 4))
for i in range(flare_nums):
luminance = (
0.3 * light_imgs[i][0]
+ 0.59 * light_imgs[i][1]
+ 0.11 * light_imgs[i][2]
)
flare_mask = torch.where(luminance > threshold_value, one, zero)
light_source_cond = (flare_mask.sum(dim=0) > 0).float()
x, y, r = self.find_circle_properties(light_source_cond, i)
XYRs[i] = torch.tensor([x, y, r, 1.0])
gt_prob, gt_rad = self.build_gt_maps(
XYRs[:, :2], XYRs[:, 2], self.img_size, self.img_size
)
gt_prob = gt_prob.unsqueeze(0) # shape: (1, H, W)
gt_rad = gt_rad.unsqueeze(0)
gt_rad /= self.img_size
gt_maps = torch.cat((gt_prob, gt_rad), dim=0) # shape: (2, H, W)
XYRs[:, :3] = XYRs[:, :3] / self.img_size
luminance = 0.3 * light_img[0] + 0.59 * light_img[1] + 0.11 * light_img[2]
flare_mask = torch.where(luminance > threshold_value, one, zero)
light_source_cond = (flare_mask.sum(dim=0) > 0).float()
light_source_cond = torch.repeat_interleave(
light_source_cond[None, ...], 1, dim=0
)
# box = self.crop(light_source_cond[0])
box = self.lightsource_crop(light_source_cond[0])
# random int between 0 ~ 15
margin = random.randint(0, 15)
if box[0] - margin >= 0:
box[0] -= margin
if box[1] + margin < self.img_size:
box[1] += margin
if box[2] - margin >= 0:
box[2] -= margin
if box[3] + margin < self.img_size:
box[3] += margin
top, bottom, left, right = box[2], box[3], box[0], box[1]
merge_img = adjust_gamma_reverse(merge_img)
cropped_mask = torch.full(
(self.img_size, self.img_size), True, dtype=torch.bool
)
cropped_mask[top : bottom + 1, left : right + 1] = False
channel3_mask = cropped_mask.unsqueeze(0).expand(3, -1, -1)
masked_img = merge_img * (1 - channel3_mask.float())
masked_img[channel3_mask] = 0.5
return {
# add
"input": self.normalize(masked_img), # normalize to [-1, 1]
"light_masks": light_source_cond,
"xyrs": gt_maps,
}
class TestImageLoader(Dataset):
def __init__(
self,
dataroot_gt,
dataroot_input,
dataroot_mask,
margin=0,
img_size=512,
noise_matching=False,
):
super(TestImageLoader, self).__init__()
self.gt_folder = dataroot_gt
self.input_folder = dataroot_input
self.mask_folder = dataroot_mask
self.paths = glod_from_folder(
[self.input_folder, self.gt_folder, self.mask_folder],
["input", "gt", "mask"],
)
self.margin = margin
self.img_size = img_size
self.noise_matching = noise_matching
def __len__(self):
return len(self.paths["input"])
def __getitem__(self, index):
img_name = self.paths["input"][index].split("/")[-1]
num = img_name.split("_")[1].split(".")[0]
# preprocess light source mask
light_mask = np.array(Image.open(self.paths["mask"][index]))
tmp_light_mask = np.zeros_like(light_mask[:, :, 0])
tmp_light_mask[light_mask[:, :, 2] > 0] = 255
cond = (light_mask[:, :, 0] > 0) & (light_mask[:, :, 1] > 0)
tmp_light_mask[cond] = 0
light_mask = tmp_light_mask
# img for controlnet input
control_img = np.repeat(light_mask[:, :, None], 3, axis=2)
# crop region
box = self.lightsource_crop(light_mask)
if box[0] - self.margin >= 0:
box[0] -= self.margin
if box[1] + self.margin < self.img_size:
box[1] += self.margin
if box[2] - self.margin >= 0:
box[2] -= self.margin
if box[3] + self.margin < self.img_size:
box[3] += self.margin
# input image to be outpainted
input_img = np.array(Image.open(self.paths["input"][index]))
cropped_region = np.ones((self.img_size, self.img_size), dtype=np.uint8)
cropped_region[box[2] : box[3] + 1, box[0] : box[1] + 1] = 0
input_img[cropped_region == 1] = 128
# image for blip
blip_img = input_img[box[2] : box[3] + 1, box[0] : box[1] + 1, :]
# noise matching
input_img_matching = None
if self.noise_matching:
np_src_img = input_img / 255.0
np_mask_rgb = np.repeat(cropped_region[:, :, None], 3, axis=2).astype(
np.float32
)
matched_noise = self.get_matched_noise(np_src_img, np_mask_rgb)
input_img_matching = (matched_noise * 255).astype(np.uint8)
# mask image
mask_img = (cropped_region * 255).astype(np.uint8)
return {
"blip_img": blip_img,
"input_img": Image.fromarray(input_img),
"input_img_matching": (
Image.fromarray(input_img_matching)
if input_img_matching is not None
else Image.fromarray(input_img)
),
"mask_img": Image.fromarray(mask_img),
"control_img": Image.fromarray(control_img),
"box": box,
"output_name": "output_" + num + ".png",
}
def lightsource_crop(self, matrix):
"""Find the largest rectangle of 1s in a binary matrix."""
def largestRectangleArea(heights):
heights.append(0)
stack = [-1]
max_area = 0
max_rectangle = (0, 0, 0, 0) # (area, left, right, height)
for i in range(len(heights)):
while heights[i] < heights[stack[-1]]:
h = heights[stack.pop()]
w = i - stack[-1] - 1
area = h * w
if area > max_area:
max_area = area
max_rectangle = (area, stack[-1] + 1, i - 1, h)
stack.append(i)
heights.pop()
return max_rectangle
max_area = 0
max_rectangle = [0, 0, 0, 0] # (left, right, top, bottom)
heights = [0] * len(matrix[0])
for row in range(len(matrix)):
for i, val in enumerate(matrix[row]):
heights[i] = heights[i] + 1 if val == 0 else 0
area, left, right, height = largestRectangleArea(heights)
if area > max_area:
max_area = area
max_rectangle = [int(left), int(right), int(row - height + 1), int(row)]
return list(max_rectangle)
# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
def get_matched_noise(
self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05
):
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(data):
if data.ndim > 2: # has channels
out_fft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_fft[:, :, c] = np.fft.fft2(
np.fft.fftshift(c_data), norm="ortho"
)
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
else: # one channel
out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
return out_fft
def _ifft2(data):
if data.ndim > 2: # has channels
out_ifft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_ifft[:, :, c] = np.fft.ifft2(
np.fft.fftshift(c_data), norm="ortho"
)
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
else: # one channel
out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
return out_ifft
def _get_gaussian_window(width, height, std=3.14, mode=0):
window_scale_x = float(width / min(width, height))
window_scale_y = float(height / min(width, height))
window = np.zeros((width, height))
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
for y in range(height):
fy = (y / height * 2.0 - 1.0) * window_scale_y
if mode == 0:
window[:, y] = np.exp(-(x**2 + fy**2) * std)
else:
window[:, y] = (1 / ((x**2 + 1.0) * (fy**2 + 1.0))) ** (
std / 3.14
) # hey wait a minute that's not gaussian
return window
def _get_masked_window_rgb(np_mask_grey, hardness=1.0):
np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
if hardness != 1.0:
hardened = np_mask_grey[:] ** hardness
else:
hardened = np_mask_grey[:]
for c in range(3):
np_mask_rgb[:, :, c] = hardened[:]
return np_mask_rgb
width = _np_src_image.shape[0]
height = _np_src_image.shape[1]
num_channels = _np_src_image.shape[2]
_np_src_image[:] * (1.0 - np_mask_rgb)
np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
windowed_image = _np_src_image * (1.0 - _get_masked_window_rgb(np_mask_grey))
windowed_image /= np.max(windowed_image)
windowed_image += (
np.average(_np_src_image) * np_mask_rgb
) # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
src_fft = _fft2(windowed_image) # get feature statistics from masked src img
src_dist = np.absolute(src_fft)
src_phase = src_fft / src_dist
# create a generator with a static seed to make outpainting deterministic / only follow global seed
rng = np.random.default_rng(0)
noise_window = _get_gaussian_window(
width, height, mode=1
) # start with simple gaussian noise
noise_rgb = rng.random((width, height, num_channels))
noise_grey = np.sum(noise_rgb, axis=2) / 3.0
noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
for c in range(num_channels):
noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
noise_fft = _fft2(noise_rgb)
for c in range(num_channels):
noise_fft[:, :, c] *= noise_window
noise_rgb = np.real(_ifft2(noise_fft))
shaped_noise_fft = _fft2(noise_rgb)
shaped_noise_fft[:, :, :] = (
np.absolute(shaped_noise_fft[:, :, :]) ** 2
* (src_dist**noise_q)
* src_phase
) # perform the actual shaping
brightness_variation = 0.0 # color_variation # todo: temporarily tying brightness variation to color variation for now
contrast_adjusted_np_src = (
_np_src_image[:] * (brightness_variation + 1.0) - brightness_variation * 2.0
)
# scikit-image is used for histogram matching, very convenient!
shaped_noise = np.real(_ifft2(shaped_noise_fft))
shaped_noise -= np.min(shaped_noise)
shaped_noise /= np.max(shaped_noise)
shaped_noise[img_mask, :] = skimage.exposure.match_histograms(
shaped_noise[img_mask, :] ** 1.0,
contrast_adjusted_np_src[ref_mask, :],
channel_axis=1,
)
shaped_noise = (
_np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
)
matched_noise = shaped_noise[:]
return np.clip(matched_noise, 0.0, 1.0)
class CustomImageLoader(Dataset):
def __init__(
self, dataroot_input, left_outpaint, right_outpaint, up_outpaint, down_outpaint
):
self.dataroot_input = dataroot_input
self.left_outpaint = left_outpaint
self.right_outpaint = right_outpaint
self.up_outpaint = up_outpaint
self.down_outpaint = down_outpaint
self.H = 512 - (up_outpaint + down_outpaint)
self.W = 512 - (left_outpaint + right_outpaint)
self.img_size = 512
self.img_lists = [
os.path.join(dataroot_input, f)
for f in os.listdir(dataroot_input)
if f.endswith(".png") or f.endswith(".jpg")
]
def __len__(self):
return len(self.img_lists)
def __getitem__(self, index):
img_name = self.img_lists[index].split("/")[-1]
# crop region
box = [
self.left_outpaint,
511 - self.right_outpaint,
self.up_outpaint,
511 - self.down_outpaint,
] # [left, right, top, bottom]
# box = self.lightsource_crop(light_mask)
# if box[0] - self.margin >= 0:
# box[0] -= self.margin
# if box[1] + self.margin < self.img_size:
# box[1] += self.margin
# if box[2] - self.margin >= 0:
# box[2] -= self.margin
# if box[3] + self.margin < self.img_size:
# box[3] += self.margin
# input image to be outpainted
input_img = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
paste_img = np.array(
Image.open(self.img_lists[index]).resize((self.W, self.H), Image.LANCZOS)
)
input_img[box[2] : box[3] + 1, box[0] : box[1] + 1, :] = paste_img
cropped_region = np.ones((self.img_size, self.img_size), dtype=np.uint8)
cropped_region[box[2] : box[3] + 1, box[0] : box[1] + 1] = 0
input_img[cropped_region == 1] = 128
# image for blip
blip_img = np.array(Image.open(self.img_lists[index]))
# # noise matching
# input_img_matching = None
# if self.noise_matching:
# np_src_img = input_img / 255.0
# np_mask_rgb = np.repeat(cropped_region[:, :, None], 3, axis=2).astype(
# np.float32
# )
# matched_noise = self.get_matched_noise(np_src_img, np_mask_rgb)
# input_img_matching = (matched_noise * 255).astype(np.uint8)
# mask image
mask_img = (cropped_region * 255).astype(np.uint8)
return {
"blip_img": blip_img,
"input_img": Image.fromarray(input_img),
# "input_img": (
# Image.fromarray(input_img_matching)
# if input_img_matching is not None
# else Image.fromarray(input_img)
# ),
"mask_img": Image.fromarray(mask_img),
"box": box,
"output_name": img_name,
}
class HFCustomImageLoader(Dataset):
def __init__(
self, img_data, left_outpaint=64, right_outpaint=64, up_outpaint=64, down_outpaint=64
):
self.left_outpaint = left_outpaint
self.right_outpaint = right_outpaint
self.up_outpaint = up_outpaint
self.down_outpaint = down_outpaint
self.H = 512 - (up_outpaint + down_outpaint)
self.W = 512 - (left_outpaint + right_outpaint)
self.img_size = 512
self.img_lists = [img_data]
def __len__(self):
return len(self.img_lists)
def __getitem__(self, index):
# img_name = self.img_lists[index].split("/")[-1]
# crop region
box = [
self.left_outpaint,
511 - self.right_outpaint,
self.up_outpaint,
511 - self.down_outpaint,
] # [left, right, top, bottom]
# input image to be outpainted
input_img = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
paste_img = np.array(self.img_lists[index].resize((self.W, self.H), Image.LANCZOS))
input_img[box[2] : box[3] + 1, box[0] : box[1] + 1, :] = paste_img
cropped_region = np.ones((self.img_size, self.img_size), dtype=np.uint8)
cropped_region[box[2] : box[3] + 1, box[0] : box[1] + 1] = 0
input_img[cropped_region == 1] = 128
# image for blip
blip_img = np.array(self.img_lists[index])
# # noise matching
# input_img_matching = None
# if self.noise_matching:
# np_src_img = input_img / 255.0
# np_mask_rgb = np.repeat(cropped_region[:, :, None], 3, axis=2).astype(
# np.float32
# )
# matched_noise = self.get_matched_noise(np_src_img, np_mask_rgb)
# input_img_matching = (matched_noise * 255).astype(np.uint8)
# mask image
mask_img = (cropped_region * 255).astype(np.uint8)
return {
"blip_img": blip_img,
"input_img": Image.fromarray(input_img),
"mask_img": Image.fromarray(mask_img),
"box": box,
}
if __name__ == "__main__":
pass