Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copyright (c) Facebook, Inc. and its affiliates. | |
| This source code is licensed under the MIT license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import argparse | |
| import pathlib | |
| from argparse import ArgumentParser | |
| from typing import Optional | |
| import h5py | |
| import numpy as np | |
| from runstats import Statistics | |
| from skimage.metrics import peak_signal_noise_ratio, structural_similarity | |
| from fastmri import transforms | |
| def mse(gt: np.ndarray, pred: np.ndarray) -> np.ndarray: | |
| """Compute Mean Squared Error (MSE)""" | |
| return np.mean((gt - pred) ** 2) | |
| def nmse(gt: np.ndarray, pred: np.ndarray) -> np.ndarray: | |
| """Compute Normalized Mean Squared Error (NMSE)""" | |
| return np.array(np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2) | |
| def psnr( | |
| gt: np.ndarray, pred: np.ndarray, maxval: Optional[float] = None | |
| ) -> np.ndarray: | |
| """Compute Peak Signal to Noise Ratio metric (PSNR)""" | |
| if maxval is None: | |
| maxval = gt.max() | |
| return peak_signal_noise_ratio(gt, pred, data_range=maxval) | |
| def ssim( | |
| gt: np.ndarray, pred: np.ndarray, maxval: Optional[float] = None | |
| ) -> np.ndarray: | |
| """Compute Structural Similarity Index Metric (SSIM)""" | |
| if not gt.ndim == 3: | |
| raise ValueError("Unexpected number of dimensions in ground truth.") | |
| if not gt.ndim == pred.ndim: | |
| raise ValueError("Ground truth dimensions does not match pred.") | |
| maxval = gt.max() if maxval is None else maxval | |
| ssim = np.array([0]) | |
| for slice_num in range(gt.shape[0]): | |
| ssim = ssim + structural_similarity( | |
| gt[slice_num], pred[slice_num], data_range=maxval | |
| ) | |
| return ssim / gt.shape[0] | |
| METRIC_FUNCS = dict( | |
| MSE=mse, | |
| NMSE=nmse, | |
| PSNR=psnr, | |
| SSIM=ssim, | |
| ) | |
| class Metrics: | |
| """ | |
| Maintains running statistics for a given collection of metrics. | |
| """ | |
| def __init__(self, metric_funcs): | |
| """ | |
| Parameters | |
| ---------- | |
| metric_funcs : dict | |
| A dictionary where the keys are metric names (as strings) and the values | |
| are Python functions for evaluating the corresponding metrics. | |
| """ | |
| self.metrics = {metric: Statistics() for metric in metric_funcs} | |
| def push(self, target, recons): | |
| for metric, func in METRIC_FUNCS.items(): | |
| self.metrics[metric].push(func(target, recons)) | |
| def means(self): | |
| return {metric: stat.mean() for metric, stat in self.metrics.items()} | |
| def stddevs(self): | |
| return {metric: stat.stddev() for metric, stat in self.metrics.items()} | |
| def __repr__(self): | |
| means = self.means() | |
| stddevs = self.stddevs() | |
| metric_names = sorted(list(means)) | |
| return " ".join( | |
| f"{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}" | |
| for name in metric_names | |
| ) | |
| def evaluate(args, recons_key): | |
| metrics = Metrics(METRIC_FUNCS) | |
| for tgt_file in args.target_path.iterdir(): | |
| with h5py.File(tgt_file, "r") as target, h5py.File( | |
| args.predictions_path / tgt_file.name, "r" | |
| ) as recons: | |
| if args.acquisition and args.acquisition != target.attrs["acquisition"]: | |
| continue | |
| if args.acceleration and target.attrs["acceleration"] != args.acceleration: | |
| continue | |
| target = target[recons_key][()] | |
| recons = recons["reconstruction"][()] | |
| target = transforms.center_crop( | |
| target, (target.shape[-1], target.shape[-1]) | |
| ) | |
| recons = transforms.center_crop( | |
| recons, (target.shape[-1], target.shape[-1]) | |
| ) | |
| metrics.push(target, recons) | |
| return metrics | |
| if __name__ == "__main__": | |
| parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
| parser.add_argument( | |
| "--target-path", | |
| type=pathlib.Path, | |
| required=True, | |
| help="Path to the ground truth data", | |
| ) | |
| parser.add_argument( | |
| "--predictions-path", | |
| type=pathlib.Path, | |
| required=True, | |
| help="Path to reconstructions", | |
| ) | |
| parser.add_argument( | |
| "--challenge", | |
| choices=["singlecoil", "multicoil"], | |
| required=True, | |
| help="Which challenge", | |
| ) | |
| parser.add_argument("--acceleration", type=int, default=None) | |
| parser.add_argument( | |
| "--acquisition", | |
| choices=[ | |
| "CORPD_FBK", | |
| "CORPDFS_FBK", | |
| "AXT1", | |
| "AXT1PRE", | |
| "AXT1POST", | |
| "AXT2", | |
| "AXFLAIR", | |
| ], | |
| default=None, | |
| help=( | |
| "If set, only volumes of the specified acquisition type are used " | |
| "for evaluation. By default, all volumes are included." | |
| ), | |
| ) | |
| args = parser.parse_args() | |
| recons_key = ( | |
| "reconstruction_rss" if args.challenge == "multicoil" else "reconstruction_esc" | |
| ) | |
| metrics = evaluate(args, recons_key) | |
| print(metrics) | |