File size: 5,452 Bytes
969f59e 1c76709 969f59e 1c76709 969f59e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from pathlib import Path
import numpy as np
from keras import ops
from PIL import Image
from skimage import filters, morphology
from zea.utils import translate
def L1(x):
"""L1 norm of a tensor.
Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html
"""
return ops.sum(ops.abs(x))
def smooth_L1(x, beta=0.4):
"""Smooth L1 loss function.
Implementation of Smooth L1 loss. Large beta values make it similar to L1 loss,
while small beta values make it similar to L2 loss.
"""
abs_x = ops.abs(x)
loss = ops.where(abs_x < beta, 0.5 * x**2 / beta, abs_x - 0.5 * beta)
return ops.sum(loss)
def postprocess(data, normalization_range):
"""Postprocess data from model output to image."""
data = ops.clip(data, *normalization_range)
data = translate(data, normalization_range, (0, 255))
data = ops.convert_to_numpy(data)
data = np.squeeze(data, axis=-1)
return np.clip(data, 0, 255).astype("uint8")
def preprocess(data, normalization_range):
"""Preprocess data for model input. Converts uint8 image(s) in [0, 255] to model input range."""
data = ops.convert_to_tensor(data, dtype="float32")
data = translate(data, (0, 255), normalization_range)
data = ops.expand_dims(data, axis=-1)
return data
def apply_bottom_preservation(
output_images, input_images, preserve_bottom_percent=30.0, transition_width=10.0
):
"""Apply bottom preservation with smooth windowed transition.
Args:
output_images: Model output images, (batch, height, width, channels)
input_images: Original input images, (batch, height, width, channels)
preserve_bottom_percent: Percentage of bottom to preserve from input (default 30%)
transition_width: Percentage of image height for smooth transition (default 10%)
Returns:
Blended images with preserved bottom portion
"""
output_shape = ops.shape(output_images)
batch_size, height, width, channels = output_shape
preserve_height = int(height * preserve_bottom_percent / 100.0)
transition_height = int(height * transition_width / 100.0)
transition_start = height - preserve_height - transition_height
preserve_start = height - preserve_height
transition_start = max(0, transition_start)
preserve_start = min(height, preserve_start)
if transition_start >= preserve_start:
transition_start = preserve_start
transition_height = 0
y_coords = ops.arange(height, dtype="float32")
y_coords = ops.reshape(y_coords, (height, 1, 1))
if transition_height > 0:
# Smooth transition using cosine interpolation
transition_region = ops.logical_and(
y_coords >= transition_start, y_coords < preserve_start
)
transition_progress = (y_coords - transition_start) / transition_height
transition_progress = ops.clip(transition_progress, 0.0, 1.0)
# Use cosine for smooth transition (0.5 * (1 - cos(π * t)))
cosine_weight = 0.5 * (1.0 - ops.cos(np.pi * transition_progress))
blend_weight = ops.where(
y_coords < transition_start,
0.0,
ops.where(
transition_region,
cosine_weight,
1.0,
),
)
else:
# No transition, just hard switch
blend_weight = ops.where(y_coords >= preserve_start, 1.0, 0.0)
blend_weight = ops.expand_dims(blend_weight, axis=0)
blended_images = (1.0 - blend_weight) * output_images + blend_weight * input_images
return blended_images
def extract_skeleton(images, input_range, sigma_pre=4, sigma_post=4, threshold=0.3):
"""Extract skeletons from the input images."""
images_np = ops.convert_to_numpy(images)
images_np = np.clip(images_np, input_range[0], input_range[1])
images_np = translate(images_np, input_range, (0, 1))
images_np = np.squeeze(images_np, axis=-1)
skeleton_masks = []
for img in images_np:
img[img < threshold] = 0
smoothed = filters.gaussian(img, sigma=sigma_pre)
binary = smoothed > filters.threshold_otsu(smoothed)
skeleton = morphology.skeletonize(binary)
skeleton = morphology.dilation(skeleton, morphology.disk(2))
skeleton = filters.gaussian(skeleton.astype(np.float32), sigma=sigma_post)
skeleton_masks.append(skeleton)
skeleton_masks = np.array(skeleton_masks)
skeleton_masks = np.expand_dims(skeleton_masks, axis=-1)
# normalize to [0, 1]
min_val, max_val = np.min(skeleton_masks), np.max(skeleton_masks)
skeleton_masks = (skeleton_masks - min_val) / (max_val - min_val + 1e-8)
return ops.convert_to_tensor(skeleton_masks, dtype=images.dtype)
def load_image(filename, grayscale=True):
"""Load an image file and return a numpy array using PIL.
Args:
filename (str): The path to the image file.
grayscale (bool, optional): Whether to convert the image to grayscale. Defaults to True.
Returns:
numpy.ndarray: A numpy array of the image.
Raises:
FileNotFoundError: If the file does not exist.
"""
filename = Path(filename)
if not filename.exists():
raise FileNotFoundError(f"File {filename} does not exist")
img = Image.open(filename)
if grayscale:
img = img.convert("L")
else:
img = img.convert("RGB")
arr = np.array(img)
return arr
|