|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
from keras import ops |
|
|
from PIL import Image |
|
|
from skimage import filters, morphology |
|
|
from zea.utils import translate |
|
|
|
|
|
|
|
|
def L1(x): |
|
|
"""L1 norm of a tensor. |
|
|
|
|
|
Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html |
|
|
""" |
|
|
return ops.sum(ops.abs(x)) |
|
|
|
|
|
|
|
|
def smooth_L1(x, beta=0.4): |
|
|
"""Smooth L1 loss function. |
|
|
|
|
|
Implementation of Smooth L1 loss. Large beta values make it similar to L1 loss, |
|
|
while small beta values make it similar to L2 loss. |
|
|
""" |
|
|
abs_x = ops.abs(x) |
|
|
loss = ops.where(abs_x < beta, 0.5 * x**2 / beta, abs_x - 0.5 * beta) |
|
|
return ops.sum(loss) |
|
|
|
|
|
|
|
|
def postprocess(data, normalization_range): |
|
|
"""Postprocess data from model output to image.""" |
|
|
data = ops.clip(data, *normalization_range) |
|
|
data = translate(data, normalization_range, (0, 255)) |
|
|
data = ops.convert_to_numpy(data) |
|
|
data = np.squeeze(data, axis=-1) |
|
|
return np.clip(data, 0, 255).astype("uint8") |
|
|
|
|
|
|
|
|
def preprocess(data, normalization_range): |
|
|
"""Preprocess data for model input. Converts uint8 image(s) in [0, 255] to model input range.""" |
|
|
data = ops.convert_to_tensor(data, dtype="float32") |
|
|
data = translate(data, (0, 255), normalization_range) |
|
|
data = ops.expand_dims(data, axis=-1) |
|
|
return data |
|
|
|
|
|
|
|
|
def apply_bottom_preservation( |
|
|
output_images, input_images, preserve_bottom_percent=30.0, transition_width=10.0 |
|
|
): |
|
|
"""Apply bottom preservation with smooth windowed transition. |
|
|
|
|
|
Args: |
|
|
output_images: Model output images, (batch, height, width, channels) |
|
|
input_images: Original input images, (batch, height, width, channels) |
|
|
preserve_bottom_percent: Percentage of bottom to preserve from input (default 30%) |
|
|
transition_width: Percentage of image height for smooth transition (default 10%) |
|
|
|
|
|
Returns: |
|
|
Blended images with preserved bottom portion |
|
|
""" |
|
|
output_shape = ops.shape(output_images) |
|
|
|
|
|
batch_size, height, width, channels = output_shape |
|
|
|
|
|
preserve_height = int(height * preserve_bottom_percent / 100.0) |
|
|
transition_height = int(height * transition_width / 100.0) |
|
|
|
|
|
transition_start = height - preserve_height - transition_height |
|
|
preserve_start = height - preserve_height |
|
|
|
|
|
transition_start = max(0, transition_start) |
|
|
preserve_start = min(height, preserve_start) |
|
|
|
|
|
if transition_start >= preserve_start: |
|
|
transition_start = preserve_start |
|
|
transition_height = 0 |
|
|
|
|
|
y_coords = ops.arange(height, dtype="float32") |
|
|
y_coords = ops.reshape(y_coords, (height, 1, 1)) |
|
|
|
|
|
if transition_height > 0: |
|
|
|
|
|
transition_region = ops.logical_and( |
|
|
y_coords >= transition_start, y_coords < preserve_start |
|
|
) |
|
|
|
|
|
transition_progress = (y_coords - transition_start) / transition_height |
|
|
transition_progress = ops.clip(transition_progress, 0.0, 1.0) |
|
|
|
|
|
|
|
|
cosine_weight = 0.5 * (1.0 - ops.cos(np.pi * transition_progress)) |
|
|
|
|
|
blend_weight = ops.where( |
|
|
y_coords < transition_start, |
|
|
0.0, |
|
|
ops.where( |
|
|
transition_region, |
|
|
cosine_weight, |
|
|
1.0, |
|
|
), |
|
|
) |
|
|
else: |
|
|
|
|
|
blend_weight = ops.where(y_coords >= preserve_start, 1.0, 0.0) |
|
|
|
|
|
blend_weight = ops.expand_dims(blend_weight, axis=0) |
|
|
|
|
|
blended_images = (1.0 - blend_weight) * output_images + blend_weight * input_images |
|
|
|
|
|
return blended_images |
|
|
|
|
|
|
|
|
def extract_skeleton(images, input_range, sigma_pre=4, sigma_post=4, threshold=0.3): |
|
|
"""Extract skeletons from the input images.""" |
|
|
images_np = ops.convert_to_numpy(images) |
|
|
images_np = np.clip(images_np, input_range[0], input_range[1]) |
|
|
images_np = translate(images_np, input_range, (0, 1)) |
|
|
images_np = np.squeeze(images_np, axis=-1) |
|
|
|
|
|
skeleton_masks = [] |
|
|
for img in images_np: |
|
|
img[img < threshold] = 0 |
|
|
smoothed = filters.gaussian(img, sigma=sigma_pre) |
|
|
binary = smoothed > filters.threshold_otsu(smoothed) |
|
|
skeleton = morphology.skeletonize(binary) |
|
|
skeleton = morphology.dilation(skeleton, morphology.disk(2)) |
|
|
skeleton = filters.gaussian(skeleton.astype(np.float32), sigma=sigma_post) |
|
|
skeleton_masks.append(skeleton) |
|
|
|
|
|
skeleton_masks = np.array(skeleton_masks) |
|
|
skeleton_masks = np.expand_dims(skeleton_masks, axis=-1) |
|
|
|
|
|
|
|
|
min_val, max_val = np.min(skeleton_masks), np.max(skeleton_masks) |
|
|
skeleton_masks = (skeleton_masks - min_val) / (max_val - min_val + 1e-8) |
|
|
|
|
|
return ops.convert_to_tensor(skeleton_masks, dtype=images.dtype) |
|
|
|
|
|
|
|
|
def load_image(filename, grayscale=True): |
|
|
"""Load an image file and return a numpy array using PIL. |
|
|
|
|
|
Args: |
|
|
filename (str): The path to the image file. |
|
|
grayscale (bool, optional): Whether to convert the image to grayscale. Defaults to True. |
|
|
|
|
|
Returns: |
|
|
numpy.ndarray: A numpy array of the image. |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: If the file does not exist. |
|
|
""" |
|
|
filename = Path(filename) |
|
|
if not filename.exists(): |
|
|
raise FileNotFoundError(f"File {filename} does not exist") |
|
|
|
|
|
img = Image.open(filename) |
|
|
if grayscale: |
|
|
img = img.convert("L") |
|
|
else: |
|
|
img = img.convert("RGB") |
|
|
|
|
|
arr = np.array(img) |
|
|
return arr |
|
|
|