File size: 9,617 Bytes
a856109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ac6cf3
a856109
7ac6cf3
a856109
 
 
 
 
 
 
7ac6cf3
a856109
7ac6cf3
a856109
 
 
 
 
 
 
7ac6cf3
a856109
 
 
 
 
7ac6cf3
 
a856109
 
7ac6cf3
a856109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import os
import cv2
import numpy as np
import skimage
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from PIL import Image
from skimage.draw import disk
from skimage import morphology
from collections import OrderedDict


def load_mfdnet_checkpoint(model, weights):
    checkpoint = torch.load(weights, map_location=lambda storage, loc: storage.cuda(0))
    new_state_dict = OrderedDict()
    for key, value in checkpoint["state_dict"].items():
        if key.startswith("module"):
            name = key[7:]
        else:
            name = key
        new_state_dict[name] = value
    model.load_state_dict(new_state_dict)


def adjust_gamma(image: torch.Tensor, gamma, device):
    # image is in shape of [B,C,H,W] and gamma is in shape [B]
    gamma = gamma.float().to(device)
    gamma_tensor = torch.ones_like(image)
    gamma_tensor = gamma.view(-1, 1, 1, 1) * gamma_tensor
    image = torch.pow(image, gamma_tensor)
    out = torch.clamp(image, 0.0, 1.0)
    return out


def adjust_gamma_reverse(image: torch.Tensor, gamma, device):
    # gamma=torch.Tensor([gamma]).cuda()
    gamma = 1 / gamma.float().to(device)
    gamma_tensor = torch.ones_like(image)
    gamma_tensor = gamma.view(-1, 1, 1, 1) * gamma_tensor
    image = torch.pow(image, gamma_tensor)
    out = torch.clamp(image, 0.0, 1.0)
    return out


def predict_flare_from_6_channel(input_tensor, gamma, device="cpu"):
    # the input is a tensor in [B,C,H,W], the C here is 6

    deflare_img = input_tensor[:, :3, :, :]
    flare_img_predicted = input_tensor[:, 3:, :, :]

    merge_img_predicted_linear = adjust_gamma(deflare_img, gamma, device) + adjust_gamma(
        flare_img_predicted, gamma, device
    )
    merge_img_predicted = adjust_gamma_reverse(
        torch.clamp(merge_img_predicted_linear, 1e-7, 1.0), gamma, device
    )
    return deflare_img, flare_img_predicted, merge_img_predicted


def predict_flare_from_3_channel(
    input_tensor, flare_mask, base_img, flare_img, merge_img, gamma
):
    # the input is a tensor in [B,C,H,W], the C here is 3

    input_tensor_linear = adjust_gamma(input_tensor, gamma)
    merge_tensor_linear = adjust_gamma(merge_img, gamma)
    flare_img_predicted = adjust_gamma_reverse(
        torch.clamp(merge_tensor_linear - input_tensor_linear, 1e-7, 1.0), gamma
    )

    masked_deflare_img = input_tensor * (1 - flare_mask) + base_img * flare_mask
    masked_flare_img_predicted = (
        flare_img_predicted * (1 - flare_mask) + flare_img * flare_mask
    )

    return masked_deflare_img, masked_flare_img_predicted


def get_highlight_mask(image, threshold=0.99, luminance_mode=False):
    """Get the area close to the exposure
    Args:
        image: the image tensor in [B,C,H,W]. For inference, B is set as 1.
        threshold: the threshold of luminance/greyscale of exposure region
        luminance_mode: use luminance or greyscale
    Return:
        Binary image in [B,H,W]
    """
    if luminance_mode:
        # 3 channels in RGB
        luminance = (
            0.2126 * image[:, 0, :, :]
            + 0.7152 * image[:, 1, :, :]
            + 0.0722 * image[:, 2, :, :]
        )
        binary_mask = luminance > threshold
    else:
        binary_mask = image.mean(dim=1, keepdim=True) > threshold
    binary_mask = binary_mask.to(image.dtype)
    return binary_mask


def refine_mask(mask, morph_size=0.01):
    """Refines a mask by applying mophological operations.
    Args:
      mask: A float array of shape [H, W]
      morph_size: Size of the morphological kernel relative to the long side of
        the image.

    Returns:
      Refined mask of shape [H, W].
    """
    mask_size = max(np.shape(mask))
    kernel_radius = 0.5 * morph_size * mask_size
    kernel = morphology.disk(np.ceil(kernel_radius))
    opened = morphology.binary_opening(mask, kernel)
    return opened


def _create_disk_kernel(kernel_size):
    _EPS = 1e-7
    x = np.arange(kernel_size) - (kernel_size - 1) / 2
    xx, yy = np.meshgrid(x, x)
    rr = np.sqrt(xx**2 + yy**2)
    kernel = np.float32(rr <= np.max(x)) + _EPS
    kernel = kernel / np.sum(kernel)
    return kernel


def blend_light_source(input_scene, pred_scene, threshold=0.99, luminance_mode=False):
    binary_mask = (
        get_highlight_mask(
            input_scene, threshold=threshold, luminance_mode=luminance_mode
        )
        > 0.5
    ).to("cpu", torch.bool)
    binary_mask = binary_mask.squeeze()  # (h, w)
    binary_mask = binary_mask.numpy()
    binary_mask = refine_mask(binary_mask)

    labeled = skimage.measure.label(binary_mask)
    properties = skimage.measure.regionprops(labeled)
    max_diameter = 0
    for p in properties:
        # The diameter of a circle with the same area as the region.
        max_diameter = max(max_diameter, p["equivalent_diameter"])

    mask = np.float32(binary_mask)
    kernel_size = round(1.5 * max_diameter)  # default is 1.5
    if kernel_size > 0:
        kernel = _create_disk_kernel(kernel_size)
        mask = cv2.filter2D(mask, -1, kernel)
        mask = np.clip(mask * 3.0, 0.0, 1.0)
        mask_rgb = np.stack([mask] * 3, axis=0)

        mask_rgb = torch.from_numpy(mask_rgb).to(input_scene.device, torch.float32)
        blend = input_scene * mask_rgb + pred_scene * (1 - mask_rgb)
    else:
        blend = pred_scene
    return blend


def blend_with_alpha(result, input_img, box, blur_size=31):
    """
    Apply alpha blending to paste the specified box region from input_img onto the result image
    to reduce boundary artifacts and make the blending more natural.

    Args:
        result (np.array): inpainting generated image
        input_img (np.array): original image
        box (tuple): (x_min, x_max, y_min, y_max) representing the paste-back region from the original image
        blur_size (int): blur range for the mask, larger values create smoother transitions (recommended 15~50)

    Returns:
        np.array: image after alpha blending
    """

    x_min, x_max, y_min, y_max = box

    # alpha mask
    mask = np.zeros_like(result, dtype=np.float32)
    mask[y_min : y_max + 1, x_min : x_max + 1] = 1.0

    # gaussian blur
    mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0)

    # alpha blending
    blended = (mask * input_img + (1 - mask) * result).astype(np.uint8)

    return blended


def IoU(pred, target):
    assert pred.shape == target.shape, "Prediction and target must have the same shape."

    intersection = np.logical_and(pred, target).sum()
    union = np.logical_or(pred, target).sum()

    if union == 0:
        return 1.0 if intersection == 0 else 0.0

    return intersection / union


def mean_IoU(y_true, y_pred, num_classes):
    """
    Calculate the mean Intersection over Union (mIoU) score.

    Args:
        y_true (np.ndarray): Ground truth labels (integer class values).
        y_pred (np.ndarray): Predicted labels (integer class values).
        num_classes (int): Number of classes.

    Returns:
        float: The mean IoU score across all classes.
    """
    iou_scores = []

    for cls in range(num_classes):
        # Create binary masks for the current class
        true_mask = y_true == cls
        pred_mask = y_pred == cls

        # Calculate intersection and union
        intersection = np.logical_and(true_mask, pred_mask)
        union = np.logical_or(true_mask, pred_mask)

        # Compute IoU for the current class
        if np.sum(union) == 0:
            # Handle edge case: no samples for this class
            iou_scores.append(np.nan)
        else:
            iou_scores.append(np.sum(intersection) / np.sum(union))

    # Calculate mean IoU, ignoring NaN values (classes without samples)
    mean_iou = np.nanmean(iou_scores)
    return mean_iou


def RGB2YCbCr(img):
    img = img * 255.0
    r, g, b = torch.split(img, 1, dim=0)
    y = torch.zeros_like(r)
    cb = torch.zeros_like(r)
    cr = torch.zeros_like(r)

    y = 0.257 * r + 0.504 * g + 0.098 * b + 16
    y = y / 255.0

    cb = -0.148 * r - 0.291 * g + 0.439 * b + 128
    cb = cb / 255.0

    cr = 0.439 * r - 0.368 * g - 0.071 * b + 128
    cr = cr / 255.0

    img = torch.cat([y, y, y], dim=0)
    return img


def extract_peaks(prob_map, thr=0.5, pool=7):
    """
    prob_map: (H, W) after sigmoid
    return: tensor of peak coordinates  [K, 2]  (x, y)
    """
    # binary mask
    pos = prob_map > thr

    # non‑maximum suppression
    nms = F.max_pool2d(
        prob_map.unsqueeze(0).unsqueeze(0),
        kernel_size=pool,
        stride=1,
        padding=pool // 2,
    )
    peaks = (prob_map == nms.squeeze()) & pos
    ys, xs = torch.nonzero(peaks, as_tuple=True)
    return torch.stack([xs, ys], dim=1)  # (K, 2)


def pick_radius(radius_map, centers, ksize=3):
    """
    radius_map: (H, W) ∈ [0, 1]
    centers: (K, 2)  x,y
    return: (K,) radii in pixel
    """
    # H, W = radius_map.shape
    pad = ksize // 2
    padded = F.pad(
        radius_map.unsqueeze(0).unsqueeze(0), (pad, pad, pad, pad), mode="reflect"
    )

    radii = []
    for x, y in centers:
        patch = padded[..., y : y + ksize, x : x + ksize]
        radii.append(patch.mean())  # 3×3 mean
    return torch.stack(radii)


def draw_mask(centers, radii, H, W):
    """
    centers: (K, 2)  (x, y)
    radii:   (K,)
    return:  (H, W) uint8 mask
    """
    radii *= 256
    mask = np.zeros((H, W), dtype=np.float32)
    for (x, y), r in zip(centers, radii):
        rr, cc = disk((y.item(), x.item()), r.item(), shape=mask.shape)
        mask[rr, cc] = 1
    return mask