Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import numpy as np | |
| import skimage | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms.functional as TF | |
| from PIL import Image | |
| from skimage.draw import disk | |
| from skimage import morphology | |
| from collections import OrderedDict | |
| def load_mfdnet_checkpoint(model, weights): | |
| checkpoint = torch.load(weights, map_location=lambda storage, loc: storage.cuda(0)) | |
| new_state_dict = OrderedDict() | |
| for key, value in checkpoint["state_dict"].items(): | |
| if key.startswith("module"): | |
| name = key[7:] | |
| else: | |
| name = key | |
| new_state_dict[name] = value | |
| model.load_state_dict(new_state_dict) | |
| def adjust_gamma(image: torch.Tensor, gamma, device): | |
| # image is in shape of [B,C,H,W] and gamma is in shape [B] | |
| gamma = gamma.float().to(device) | |
| gamma_tensor = torch.ones_like(image) | |
| gamma_tensor = gamma.view(-1, 1, 1, 1) * gamma_tensor | |
| image = torch.pow(image, gamma_tensor) | |
| out = torch.clamp(image, 0.0, 1.0) | |
| return out | |
| def adjust_gamma_reverse(image: torch.Tensor, gamma, device): | |
| # gamma=torch.Tensor([gamma]).cuda() | |
| gamma = 1 / gamma.float().to(device) | |
| gamma_tensor = torch.ones_like(image) | |
| gamma_tensor = gamma.view(-1, 1, 1, 1) * gamma_tensor | |
| image = torch.pow(image, gamma_tensor) | |
| out = torch.clamp(image, 0.0, 1.0) | |
| return out | |
| def predict_flare_from_6_channel(input_tensor, gamma, device="cpu"): | |
| # the input is a tensor in [B,C,H,W], the C here is 6 | |
| deflare_img = input_tensor[:, :3, :, :] | |
| flare_img_predicted = input_tensor[:, 3:, :, :] | |
| merge_img_predicted_linear = adjust_gamma(deflare_img, gamma, device) + adjust_gamma( | |
| flare_img_predicted, gamma, device | |
| ) | |
| merge_img_predicted = adjust_gamma_reverse( | |
| torch.clamp(merge_img_predicted_linear, 1e-7, 1.0), gamma, device | |
| ) | |
| return deflare_img, flare_img_predicted, merge_img_predicted | |
| def predict_flare_from_3_channel( | |
| input_tensor, flare_mask, base_img, flare_img, merge_img, gamma | |
| ): | |
| # the input is a tensor in [B,C,H,W], the C here is 3 | |
| input_tensor_linear = adjust_gamma(input_tensor, gamma) | |
| merge_tensor_linear = adjust_gamma(merge_img, gamma) | |
| flare_img_predicted = adjust_gamma_reverse( | |
| torch.clamp(merge_tensor_linear - input_tensor_linear, 1e-7, 1.0), gamma | |
| ) | |
| masked_deflare_img = input_tensor * (1 - flare_mask) + base_img * flare_mask | |
| masked_flare_img_predicted = ( | |
| flare_img_predicted * (1 - flare_mask) + flare_img * flare_mask | |
| ) | |
| return masked_deflare_img, masked_flare_img_predicted | |
| def get_highlight_mask(image, threshold=0.99, luminance_mode=False): | |
| """Get the area close to the exposure | |
| Args: | |
| image: the image tensor in [B,C,H,W]. For inference, B is set as 1. | |
| threshold: the threshold of luminance/greyscale of exposure region | |
| luminance_mode: use luminance or greyscale | |
| Return: | |
| Binary image in [B,H,W] | |
| """ | |
| if luminance_mode: | |
| # 3 channels in RGB | |
| luminance = ( | |
| 0.2126 * image[:, 0, :, :] | |
| + 0.7152 * image[:, 1, :, :] | |
| + 0.0722 * image[:, 2, :, :] | |
| ) | |
| binary_mask = luminance > threshold | |
| else: | |
| binary_mask = image.mean(dim=1, keepdim=True) > threshold | |
| binary_mask = binary_mask.to(image.dtype) | |
| return binary_mask | |
| def refine_mask(mask, morph_size=0.01): | |
| """Refines a mask by applying mophological operations. | |
| Args: | |
| mask: A float array of shape [H, W] | |
| morph_size: Size of the morphological kernel relative to the long side of | |
| the image. | |
| Returns: | |
| Refined mask of shape [H, W]. | |
| """ | |
| mask_size = max(np.shape(mask)) | |
| kernel_radius = 0.5 * morph_size * mask_size | |
| kernel = morphology.disk(np.ceil(kernel_radius)) | |
| opened = morphology.binary_opening(mask, kernel) | |
| return opened | |
| def _create_disk_kernel(kernel_size): | |
| _EPS = 1e-7 | |
| x = np.arange(kernel_size) - (kernel_size - 1) / 2 | |
| xx, yy = np.meshgrid(x, x) | |
| rr = np.sqrt(xx**2 + yy**2) | |
| kernel = np.float32(rr <= np.max(x)) + _EPS | |
| kernel = kernel / np.sum(kernel) | |
| return kernel | |
| def blend_light_source(input_scene, pred_scene, threshold=0.99, luminance_mode=False): | |
| binary_mask = ( | |
| get_highlight_mask( | |
| input_scene, threshold=threshold, luminance_mode=luminance_mode | |
| ) | |
| > 0.5 | |
| ).to("cpu", torch.bool) | |
| binary_mask = binary_mask.squeeze() # (h, w) | |
| binary_mask = binary_mask.numpy() | |
| binary_mask = refine_mask(binary_mask) | |
| labeled = skimage.measure.label(binary_mask) | |
| properties = skimage.measure.regionprops(labeled) | |
| max_diameter = 0 | |
| for p in properties: | |
| # The diameter of a circle with the same area as the region. | |
| max_diameter = max(max_diameter, p["equivalent_diameter"]) | |
| mask = np.float32(binary_mask) | |
| kernel_size = round(1.5 * max_diameter) # default is 1.5 | |
| if kernel_size > 0: | |
| kernel = _create_disk_kernel(kernel_size) | |
| mask = cv2.filter2D(mask, -1, kernel) | |
| mask = np.clip(mask * 3.0, 0.0, 1.0) | |
| mask_rgb = np.stack([mask] * 3, axis=0) | |
| mask_rgb = torch.from_numpy(mask_rgb).to(input_scene.device, torch.float32) | |
| blend = input_scene * mask_rgb + pred_scene * (1 - mask_rgb) | |
| else: | |
| blend = pred_scene | |
| return blend | |
| def blend_with_alpha(result, input_img, box, blur_size=31): | |
| """ | |
| Apply alpha blending to paste the specified box region from input_img onto the result image | |
| to reduce boundary artifacts and make the blending more natural. | |
| Args: | |
| result (np.array): inpainting generated image | |
| input_img (np.array): original image | |
| box (tuple): (x_min, x_max, y_min, y_max) representing the paste-back region from the original image | |
| blur_size (int): blur range for the mask, larger values create smoother transitions (recommended 15~50) | |
| Returns: | |
| np.array: image after alpha blending | |
| """ | |
| x_min, x_max, y_min, y_max = box | |
| # alpha mask | |
| mask = np.zeros_like(result, dtype=np.float32) | |
| mask[y_min : y_max + 1, x_min : x_max + 1] = 1.0 | |
| # gaussian blur | |
| mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0) | |
| # alpha blending | |
| blended = (mask * input_img + (1 - mask) * result).astype(np.uint8) | |
| return blended | |
| def IoU(pred, target): | |
| assert pred.shape == target.shape, "Prediction and target must have the same shape." | |
| intersection = np.logical_and(pred, target).sum() | |
| union = np.logical_or(pred, target).sum() | |
| if union == 0: | |
| return 1.0 if intersection == 0 else 0.0 | |
| return intersection / union | |
| def mean_IoU(y_true, y_pred, num_classes): | |
| """ | |
| Calculate the mean Intersection over Union (mIoU) score. | |
| Args: | |
| y_true (np.ndarray): Ground truth labels (integer class values). | |
| y_pred (np.ndarray): Predicted labels (integer class values). | |
| num_classes (int): Number of classes. | |
| Returns: | |
| float: The mean IoU score across all classes. | |
| """ | |
| iou_scores = [] | |
| for cls in range(num_classes): | |
| # Create binary masks for the current class | |
| true_mask = y_true == cls | |
| pred_mask = y_pred == cls | |
| # Calculate intersection and union | |
| intersection = np.logical_and(true_mask, pred_mask) | |
| union = np.logical_or(true_mask, pred_mask) | |
| # Compute IoU for the current class | |
| if np.sum(union) == 0: | |
| # Handle edge case: no samples for this class | |
| iou_scores.append(np.nan) | |
| else: | |
| iou_scores.append(np.sum(intersection) / np.sum(union)) | |
| # Calculate mean IoU, ignoring NaN values (classes without samples) | |
| mean_iou = np.nanmean(iou_scores) | |
| return mean_iou | |
| def RGB2YCbCr(img): | |
| img = img * 255.0 | |
| r, g, b = torch.split(img, 1, dim=0) | |
| y = torch.zeros_like(r) | |
| cb = torch.zeros_like(r) | |
| cr = torch.zeros_like(r) | |
| y = 0.257 * r + 0.504 * g + 0.098 * b + 16 | |
| y = y / 255.0 | |
| cb = -0.148 * r - 0.291 * g + 0.439 * b + 128 | |
| cb = cb / 255.0 | |
| cr = 0.439 * r - 0.368 * g - 0.071 * b + 128 | |
| cr = cr / 255.0 | |
| img = torch.cat([y, y, y], dim=0) | |
| return img | |
| def extract_peaks(prob_map, thr=0.5, pool=7): | |
| """ | |
| prob_map: (H, W) after sigmoid | |
| return: tensor of peak coordinates [K, 2] (x, y) | |
| """ | |
| # binary mask | |
| pos = prob_map > thr | |
| # non‑maximum suppression | |
| nms = F.max_pool2d( | |
| prob_map.unsqueeze(0).unsqueeze(0), | |
| kernel_size=pool, | |
| stride=1, | |
| padding=pool // 2, | |
| ) | |
| peaks = (prob_map == nms.squeeze()) & pos | |
| ys, xs = torch.nonzero(peaks, as_tuple=True) | |
| return torch.stack([xs, ys], dim=1) # (K, 2) | |
| def pick_radius(radius_map, centers, ksize=3): | |
| """ | |
| radius_map: (H, W) ∈ [0, 1] | |
| centers: (K, 2) x,y | |
| return: (K,) radii in pixel | |
| """ | |
| # H, W = radius_map.shape | |
| pad = ksize // 2 | |
| padded = F.pad( | |
| radius_map.unsqueeze(0).unsqueeze(0), (pad, pad, pad, pad), mode="reflect" | |
| ) | |
| radii = [] | |
| for x, y in centers: | |
| patch = padded[..., y : y + ksize, x : x + ksize] | |
| radii.append(patch.mean()) # 3×3 mean | |
| return torch.stack(radii) | |
| def draw_mask(centers, radii, H, W): | |
| """ | |
| centers: (K, 2) (x, y) | |
| radii: (K,) | |
| return: (H, W) uint8 mask | |
| """ | |
| radii *= 256 | |
| mask = np.zeros((H, W), dtype=np.float32) | |
| for (x, y), r in zip(centers, radii): | |
| rr, cc = disk((y.item(), x.item()), r.item(), shape=mask.shape) | |
| mask[rr, cc] = 1 | |
| return mask | |