tristan-deep's picture
replaced load_image
969f59e
import warnings
from glob import glob
from pathlib import Path
import numpy as np
import torch
import tyro
from PIL import Image
from scipy.ndimage import binary_erosion, distance_transform_edt
from scipy.stats import ks_2samp
from zea import log
import fid_score
from plots import plot_metrics
from utils import load_image
def calculate_fid_score(denoised_image_dirs, ground_truth_dir):
if isinstance(denoised_image_dirs, (str, Path)):
denoised_image_dirs = [denoised_image_dirs]
elif not isinstance(denoised_image_dirs, list):
raise ValueError("Input must be a path or list of paths")
clean_images_folder = glob(str(ground_truth_dir) + "/*.png")
print(f"Looking for clean images in: {ground_truth_dir}")
print(f"Found {len(clean_images_folder)} clean images")
# Determine optimal batch size based on number of images
num_denoised = len(denoised_image_dirs)
num_clean = len(clean_images_folder)
optimal_batch_size = min(8, num_denoised, num_clean)
print(f"Using batch size: {optimal_batch_size}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="os.fork.*JAX is multithreaded")
fid_value = fid_score.calculate_fid_with_cached_ground_truth(
denoised_image_dirs,
clean_images_folder,
batch_size=optimal_batch_size,
device="cuda" if torch.cuda.is_available() else "cpu",
num_workers=2 if torch.cuda.is_available() else 0,
dims=2048,
)
return fid_value
def gcnr(img1, img2):
"""Generalized Contrast-to-Noise Ratio"""
_, bins = np.histogram(np.concatenate((img1, img2)), bins=256)
f, _ = np.histogram(img1, bins=bins, density=True)
g, _ = np.histogram(img2, bins=bins, density=True)
f /= f.sum()
g /= g.sum()
return 1 - np.sum(np.minimum(f, g))
def cnr(img1, img2):
"""Contrast-to-Noise Ratio"""
return (img1.mean() - img2.mean()) / np.sqrt(img1.var() + img2.var())
def calculate_cnr_gcnr(result_dehazed_cardiac_ultrasound, mask_path):
"""
Evaluate gCNR and CNR metrics for denoised images using paired masks.
Saves detailed and summary statistics to Excel.
"""
results = []
mask = np.array(Image.open(mask_path).convert("L"))
roi1_pixels = result_dehazed_cardiac_ultrasound[mask == 255] # Foreground ROI
roi2_pixels = result_dehazed_cardiac_ultrasound[mask == 128] # Background/Noise ROI
gcnr_val = gcnr(roi1_pixels, roi2_pixels)
cnr_val = cnr(roi1_pixels, roi2_pixels)
results.append([cnr_val, gcnr_val])
return results
def calculate_ks_statistics(
result_hazy_cardiac_ultrasound, result_dehazed_cardiac_ultrasound, mask_path
):
mask = np.array(Image.open(mask_path).convert("L"))
roi1_original = result_hazy_cardiac_ultrasound[mask == 255] # region A
roi1_denoised = result_dehazed_cardiac_ultrasound[mask == 255]
roi2_original = result_hazy_cardiac_ultrasound[mask == 128] # region B
roi2_denoised = result_dehazed_cardiac_ultrasound[mask == 128]
roi1_ks_stat, roi1_ks_p_value = (None, None)
roi2_ks_stat, roi2_ks_p_value = (None, None)
if roi1_original.size > 0 and roi1_denoised.size > 0:
roi1_ks_stat, roi1_ks_p_value = ks_2samp(roi1_original, roi1_denoised)
if roi2_original.size > 0 and roi2_denoised.size > 0:
roi2_ks_stat, roi2_ks_p_value = ks_2samp(roi2_original, roi2_denoised)
return roi1_ks_stat, roi1_ks_p_value, roi2_ks_stat, roi2_ks_p_value
def calculate_dice_asd(image_path, label_path, checkpoint_path, image_size=224):
try:
from test import inference # Our Segmentation Method
except ImportError:
raise ImportError(
"Segmentation method not available, skipping Dice/ASD calculation"
)
pred_img = inference(image_path, checkpoint_path, image_size)
pred = np.array(pred_img) > 127
label = Image.open(label_path).convert("L")
label = label.resize((image_size, image_size), Image.NEAREST)
label = np.array(label) > 127
# calculate Dice
intersection = np.logical_and(pred, label).sum()
dice = 2 * intersection / (pred.sum() + label.sum() + 1e-8)
# calculate ASD
if pred.sum() == 0 or label.sum() == 0:
asd = np.nan
else:
pred_dt = distance_transform_edt(~pred)
label_dt = distance_transform_edt(~label)
surface_pred = pred ^ binary_erosion(pred)
surface_label = label ^ binary_erosion(label)
d1 = pred_dt[surface_label].mean()
d2 = label_dt[surface_pred].mean()
asd = (d1 + d2) / 2
return dice, asd
def calculate_final_score(aggregates):
try:
# (FID + CNR + gCNR):(KS^A + KS^B):(Dice + ASD)= 5:3:2
group1_score = 0 # FID + CNR + gCNR
if aggregates.get("fid") is not None:
fid_min = 60.0
fid_max = 150.0
fid_score = (fid_max - aggregates["fid"]) / (fid_max - fid_min)
fid_score = max(0, min(1, fid_score))
group1_score += fid_score * 100 * 0.33
if aggregates.get("cnr_mean") is not None:
cnr_min = 1.0
cnr_max = 1.5
cnr_score = (aggregates["cnr_mean"] - cnr_min) / (cnr_max - cnr_min)
cnr_score = max(0, min(1, cnr_score))
group1_score += cnr_score * 100 * 0.33
if aggregates.get("gcnr_mean") is not None:
gcnr_min = 0.5
gcnr_max = 0.8
gcnr_score = (aggregates["gcnr_mean"] - gcnr_min) / (gcnr_max - gcnr_min)
gcnr_score = max(0, min(1, gcnr_score))
group1_score += gcnr_score * 100 * 0.34
group2_score = 0 # KS^A + KS^B
if aggregates.get("ks_roi1_ksstatistic_mean") is not None:
ks1_min = 0.1
ks1_max = 0.3
ks1_score = (ks1_max - aggregates["ks_roi1_ksstatistic_mean"]) / (
ks1_max - ks1_min
)
ks1_score = max(0, min(1, ks1_score))
group2_score += ks1_score * 100 * 0.5
if aggregates.get("ks_roi2_ksstatistic_mean") is not None:
ks2_min = 0.0
ks2_max = 0.5
ks2_score = (aggregates["ks_roi2_ksstatistic_mean"] - ks2_min) / (
ks2_max - ks2_min
)
ks2_score = max(0, min(1, ks2_score))
group2_score += ks2_score * 100 * 0.5
group3_score = 0 # Dice + ASD
if aggregates.get("dice_mean") is not None:
dice_min = 0.85
dice_max = 0.95
dice_score = (aggregates["dice_mean"] - dice_min) / (dice_max - dice_min)
dice_score = max(0, min(1, dice_score))
group3_score += dice_score * 100 * 0.5
if aggregates.get("asd_mean") is not None:
asd_min = 0.7
asd_max = 2.0
asd_score = (asd_max - aggregates["asd_mean"]) / (asd_max - asd_min)
asd_score = max(0, min(1, asd_score))
group3_score += asd_score * 100 * 0.5
# Final score calculation
final_score = (group1_score * 5 + group2_score * 3 + group3_score * 2) / 10
return final_score
except Exception as e:
print(f"Error calculating final score: {str(e)}")
return 0
def evaluate(folder: str, noisy_folder: str, roi_folder: str, reference_folder: str):
"""Evaluate the dehazing algorithm.
Args:
folder (str): Path to the folder containing the dehazed images.
Used for evaluating all metrics.
noisy_folder (str): Path to the folder containing the noisy images.
Only used for KS statistics.
roi_folder (str): Path to the folder containing the ROI images.
Used for contrast and KS statistic metrics.
reference_folder (str): Path to the folder containing the reference images.
Used only for FID calculation.
"""
folder = Path(folder)
noisy_folder = Path(noisy_folder)
roi_folder = Path(roi_folder)
reference_folder = Path(reference_folder)
folder_files = set(f.name for f in folder.glob("*.png"))
noisy_files = set(f.name for f in noisy_folder.glob("*.png"))
roi_files = set(f.name for f in roi_folder.glob("*.png"))
print(f"Found {len(folder_files)} .png files in output folder: {folder}")
print(f"Found {len(noisy_files)} .png files in noisy folder: {noisy_folder}")
print(f"Found {len(roi_files)} .png files in ROI folder: {roi_folder}")
# Find intersection of filenames
common_files = sorted(folder_files & roi_files & noisy_files)
print(f"Found {len(common_files)} matching images in noisy/dehazed/roi folders")
assert len(common_files) > 0, (
"No matching .png files in all folders. Cannot proceed."
)
metrics = {"CNR": [], "gCNR": [], "KS_A": [], "KS_B": []}
limits = {
"CNR": [1.0, 1.5],
"gCNR": [0.5, 0.8],
"KS_A": [0.1, 0.3],
"KS_B": [0.0, 0.5],
}
for name in common_files:
dehazed_path = folder / name
noisy_path = noisy_folder / name
roi_path = roi_folder / name
try:
img_dehazed = np.array(load_image(str(dehazed_path)))
img_noisy = np.array(load_image(str(noisy_path)))
except Exception as e:
print(f"Error loading image {name}: {e}")
continue
# CNR/gCNR
cnr_gcnr = calculate_cnr_gcnr(img_dehazed, str(roi_path))
metrics["CNR"].append(cnr_gcnr[0][0])
metrics["gCNR"].append(cnr_gcnr[0][1])
# KS statistics
ks_a, _, ks_b, _ = calculate_ks_statistics(
img_noisy, img_dehazed, str(roi_path)
)
metrics["KS_A"].append(ks_a)
metrics["KS_B"].append(ks_b)
# Compute statistics
stats = {
k: (np.mean(v), np.std(v), np.min(v), np.max(v)) for k, v in metrics.items()
}
print("Contrast statistics:")
for k, (mean, std, minv, maxv) in stats.items():
print(f"{k}: mean={mean:.3f}, std={std:.3f}, min={minv:.3f}, max={maxv:.3f}")
fig = plot_metrics(metrics, limits, "contrast_metrics.png")
path = Path("contrast_metrics.png")
save_kwargs = {"bbox_inches": "tight", "dpi": 300}
fig.savefig(path, **save_kwargs)
fig.savefig(path.with_suffix(".pdf"), **save_kwargs)
log.success(f"Metrics plot saved to {log.yellow(path)}")
# Compute FID
fid_image_paths = [str(folder / name) for name in common_files]
fid_score = calculate_fid_score(fid_image_paths, str(reference_folder))
print(f"FID between {folder} and {reference_folder}: {fid_score:.3f}")
# Create aggregates dictionary for final score calculation
aggregates = {
"fid": float(fid_score),
"cnr_mean": float(np.mean(metrics["CNR"])),
"cnr_std": float(np.std(metrics["CNR"])),
"gcnr_mean": float(np.mean(metrics["gCNR"])),
"gcnr_std": float(np.std(metrics["gCNR"])),
"ks_roi1_ksstatistic_mean": float(np.mean(metrics["KS_A"])),
"ks_roi1_ksstatistic_std": float(np.std(metrics["KS_A"])),
"ks_roi2_ksstatistic_mean": float(np.mean(metrics["KS_B"])),
"ks_roi2_ksstatistic_std": float(np.std(metrics["KS_B"])),
}
# Calculate final score
final_score = calculate_final_score(aggregates)
aggregates["final_score"] = float(final_score)
return aggregates
if __name__ == "__main__":
tyro.cli(evaluate)