File size: 5,452 Bytes
969f59e
 
1c76709
 
969f59e
1c76709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
969f59e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from pathlib import Path

import numpy as np
from keras import ops
from PIL import Image
from skimage import filters, morphology
from zea.utils import translate


def L1(x):
    """L1 norm of a tensor.

    Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html
    """
    return ops.sum(ops.abs(x))


def smooth_L1(x, beta=0.4):
    """Smooth L1 loss function.

    Implementation of Smooth L1 loss. Large beta values make it similar to L1 loss,
    while small beta values make it similar to L2 loss.
    """
    abs_x = ops.abs(x)
    loss = ops.where(abs_x < beta, 0.5 * x**2 / beta, abs_x - 0.5 * beta)
    return ops.sum(loss)


def postprocess(data, normalization_range):
    """Postprocess data from model output to image."""
    data = ops.clip(data, *normalization_range)
    data = translate(data, normalization_range, (0, 255))
    data = ops.convert_to_numpy(data)
    data = np.squeeze(data, axis=-1)
    return np.clip(data, 0, 255).astype("uint8")


def preprocess(data, normalization_range):
    """Preprocess data for model input. Converts uint8 image(s) in [0, 255] to model input range."""
    data = ops.convert_to_tensor(data, dtype="float32")
    data = translate(data, (0, 255), normalization_range)
    data = ops.expand_dims(data, axis=-1)
    return data


def apply_bottom_preservation(
    output_images, input_images, preserve_bottom_percent=30.0, transition_width=10.0
):
    """Apply bottom preservation with smooth windowed transition.

    Args:
        output_images: Model output images, (batch, height, width, channels)
        input_images: Original input images, (batch, height, width, channels)
        preserve_bottom_percent: Percentage of bottom to preserve from input (default 30%)
        transition_width: Percentage of image height for smooth transition (default 10%)

    Returns:
        Blended images with preserved bottom portion
    """
    output_shape = ops.shape(output_images)

    batch_size, height, width, channels = output_shape

    preserve_height = int(height * preserve_bottom_percent / 100.0)
    transition_height = int(height * transition_width / 100.0)

    transition_start = height - preserve_height - transition_height
    preserve_start = height - preserve_height

    transition_start = max(0, transition_start)
    preserve_start = min(height, preserve_start)

    if transition_start >= preserve_start:
        transition_start = preserve_start
        transition_height = 0

    y_coords = ops.arange(height, dtype="float32")
    y_coords = ops.reshape(y_coords, (height, 1, 1))

    if transition_height > 0:
        # Smooth transition using cosine interpolation
        transition_region = ops.logical_and(
            y_coords >= transition_start, y_coords < preserve_start
        )

        transition_progress = (y_coords - transition_start) / transition_height
        transition_progress = ops.clip(transition_progress, 0.0, 1.0)

        # Use cosine for smooth transition (0.5 * (1 - cos(π * t)))
        cosine_weight = 0.5 * (1.0 - ops.cos(np.pi * transition_progress))

        blend_weight = ops.where(
            y_coords < transition_start,
            0.0,
            ops.where(
                transition_region,
                cosine_weight,
                1.0,
            ),
        )
    else:
        # No transition, just hard switch
        blend_weight = ops.where(y_coords >= preserve_start, 1.0, 0.0)

    blend_weight = ops.expand_dims(blend_weight, axis=0)

    blended_images = (1.0 - blend_weight) * output_images + blend_weight * input_images

    return blended_images


def extract_skeleton(images, input_range, sigma_pre=4, sigma_post=4, threshold=0.3):
    """Extract skeletons from the input images."""
    images_np = ops.convert_to_numpy(images)
    images_np = np.clip(images_np, input_range[0], input_range[1])
    images_np = translate(images_np, input_range, (0, 1))
    images_np = np.squeeze(images_np, axis=-1)

    skeleton_masks = []
    for img in images_np:
        img[img < threshold] = 0
        smoothed = filters.gaussian(img, sigma=sigma_pre)
        binary = smoothed > filters.threshold_otsu(smoothed)
        skeleton = morphology.skeletonize(binary)
        skeleton = morphology.dilation(skeleton, morphology.disk(2))
        skeleton = filters.gaussian(skeleton.astype(np.float32), sigma=sigma_post)
        skeleton_masks.append(skeleton)

    skeleton_masks = np.array(skeleton_masks)
    skeleton_masks = np.expand_dims(skeleton_masks, axis=-1)

    # normalize to [0, 1]
    min_val, max_val = np.min(skeleton_masks), np.max(skeleton_masks)
    skeleton_masks = (skeleton_masks - min_val) / (max_val - min_val + 1e-8)

    return ops.convert_to_tensor(skeleton_masks, dtype=images.dtype)


def load_image(filename, grayscale=True):
    """Load an image file and return a numpy array using PIL.

    Args:
        filename (str): The path to the image file.
        grayscale (bool, optional): Whether to convert the image to grayscale. Defaults to True.

    Returns:
        numpy.ndarray: A numpy array of the image.

    Raises:
        FileNotFoundError: If the file does not exist.
    """
    filename = Path(filename)
    if not filename.exists():
        raise FileNotFoundError(f"File {filename} does not exist")

    img = Image.open(filename)
    if grayscale:
        img = img.convert("L")
    else:
        img = img.convert("RGB")

    arr = np.array(img)
    return arr