Spaces:
Running
Running
| import numpy as np | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import matplotlib.pyplot as plt | |
| import SimpleITK as sitk # noqa: N813 | |
| import torch | |
| from monai.transforms import Compose, ScaleIntensityd, SpatialPadd | |
| from cinema import ConvUNetR | |
| from pathlib import Path | |
| import spaces | |
| # cache directories | |
| cache_dir = Path("/tmp/.cinema") | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| def inferece( | |
| images: torch.Tensor, | |
| view: str, | |
| transform: Compose, | |
| model: ConvUNetR, | |
| progress=gr.Progress(), | |
| ) -> np.ndarray: | |
| # set device and dtype | |
| dtype, device = torch.float32, torch.device("cpu") | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| device = torch.device("cuda") | |
| if torch.cuda.is_bf16_supported(): | |
| dtype = torch.bfloat16 | |
| # inference | |
| model.to(device) | |
| n_slices, n_frames = images.shape[-2:] | |
| labels_list = [] | |
| for t in range(0, n_frames): | |
| progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...") | |
| batch = transform({view: torch.from_numpy(images[None, ..., t])}) | |
| batch = { | |
| k: v[None, ...].to(device=device, dtype=torch.float32) | |
| for k, v in batch.items() | |
| } | |
| with ( | |
| torch.no_grad(), | |
| torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()), | |
| ): | |
| logits = model(batch)[view] | |
| labels_list.append(torch.argmax(logits, dim=1)[0, ..., :n_slices]) | |
| labels = torch.stack(labels_list, dim=-1).detach().cpu().numpy() | |
| return labels | |
| def run_inference(trained_dataset, seed, image_id, t_step, progress=gr.Progress()): | |
| # Fixed parameters | |
| view = "sax" | |
| split = "train" if image_id <= 100 else "test" | |
| trained_dataset = { | |
| "ACDC": "acdc", | |
| "M&MS": "mnms", | |
| "M&MS2": "mnms2", | |
| }[str(trained_dataset)] | |
| # Download and load model | |
| progress(0, desc="Downloading model and data...") | |
| image_path = hf_hub_download( | |
| repo_id="mathpluscode/ACDC", | |
| repo_type="dataset", | |
| filename=f"{split}/patient{image_id:03d}/patient{image_id:03d}_sax_t.nii.gz", | |
| cache_dir=cache_dir, | |
| ) | |
| model = ConvUNetR.from_finetuned( | |
| repo_id="mathpluscode/CineMA", | |
| model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors", | |
| config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml", | |
| cache_dir=cache_dir, | |
| ) | |
| # Load and process data | |
| transform = Compose( | |
| [ | |
| ScaleIntensityd(keys=view), | |
| SpatialPadd( | |
| keys=view, | |
| spatial_size=(192, 192, 16), | |
| method="end", | |
| lazy=True, | |
| allow_missing_keys=True, | |
| ), | |
| ] | |
| ) | |
| images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(image_path))) | |
| images = images[..., ::t_step] | |
| labels = inferece(images, view, transform, model, progress) | |
| progress(1, desc="Plotting results...") | |
| # Create segmentation visualization | |
| n_slices, n_frames = labels.shape[-2:] | |
| fig1, axs = plt.subplots(n_frames, n_slices, figsize=(n_slices, n_frames), dpi=300) | |
| for t in range(n_frames): | |
| for z in range(n_slices): | |
| axs[t, z].imshow(images[..., z, t], cmap="gray") | |
| axs[t, z].imshow( | |
| (labels[..., z, t, None] == 1) | |
| * np.array([108 / 255, 142 / 255, 191 / 255, 0.6]) | |
| ) | |
| axs[t, z].imshow( | |
| (labels[..., z, t, None] == 2) | |
| * np.array([214 / 255, 182 / 255, 86 / 255, 0.6]) | |
| ) | |
| axs[t, z].imshow( | |
| (labels[..., z, t, None] == 3) | |
| * np.array([130 / 255, 179 / 255, 102 / 255, 0.6]) | |
| ) | |
| axs[t, z].set_xticks([]) | |
| axs[t, z].set_yticks([]) | |
| if z == 0: | |
| axs[t, z].set_ylabel(f"t = {t * t_step}") | |
| fig1.suptitle(f"Subject {image_id} in {split} split") | |
| axs[0, n_slices // 2].set_title("SAX Slices") | |
| fig1.tight_layout() | |
| plt.subplots_adjust(wspace=0, hspace=0) | |
| # Create volume plot | |
| xs = np.arange(n_frames) * t_step | |
| rv_volumes = np.sum(labels == 1, axis=(0, 1, 2)) * 10 / 1000 | |
| myo_volumes = np.sum(labels == 2, axis=(0, 1, 2)) * 10 / 1000 | |
| lv_volumes = np.sum(labels == 3, axis=(0, 1, 2)) * 10 / 1000 | |
| lvef = (max(lv_volumes) - min(lv_volumes)) / max(lv_volumes) * 100 | |
| rvef = (max(rv_volumes) - min(rv_volumes)) / max(rv_volumes) * 100 | |
| fig2, ax = plt.subplots(figsize=(4, 4), dpi=120) | |
| ax.plot(xs, rv_volumes, color="#6C8EBF", label="RV") | |
| ax.plot(xs, myo_volumes, color="#D6B656", label="MYO") | |
| ax.plot(xs, lv_volumes, color="#82B366", label="LV") | |
| ax.set_xlabel("Frame") | |
| ax.set_ylabel("Volume (ml)") | |
| ax.set_title(f"LVEF = {lvef:.2f}%, RVEF = {rvef:.2f}%") | |
| ax.legend(loc="lower right") | |
| fig2.tight_layout() | |
| return fig1, fig2 | |
| # Create the Gradio interface | |
| theme = gr.themes.Ocean( | |
| primary_hue="red", | |
| secondary_hue="purple", | |
| ) | |
| with gr.Blocks( | |
| theme=theme, title="CineMA: A Foundation Model for Cine Cardiac MRI" | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # CineMA: A Foundation Model for Cine Cardiac MRI π₯π« | |
| Below is an example of ejection fraction prediction inference. For more examples, checkout our [GitHub](https://github.com/mathpluscode/CineMA). | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=0.4): | |
| gr.Markdown("## Description") | |
| gr.Markdown(""" | |
| Please adjust the settings on the right panels and click the button to run the inference. | |
| ### Data | |
| The available data is from ACDC. All images have been resampled to 1 mm Γ 1 mm Γ 10 mm and centre-cropped to 192 mm Γ 192 mm for each SAX slice. | |
| Image 1 - 100 are from the training set, and image 101 - 150 are from the test set. | |
| ### Model | |
| The available models are finetuned on different datasets ([ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/), [M&Ms](https://www.ub.edu/mnms/), and [M&Ms2](https://www.ub.edu/mnms-2/)). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2. The default model is the one finetuned on ACDC dataset with seed 0. | |
| ### Visualization | |
| The left panel shows the segmentation of ventricles and myocardium every n time steps across all SAX slices. | |
| The right panel plots the ventricle and mycoardium volumes across all inference time frames. | |
| """) | |
| with gr.Column(scale=0.3): | |
| gr.Markdown("## Data Settings") | |
| image_id = gr.Slider( | |
| minimum=1, | |
| maximum=150, | |
| step=1, | |
| label="Choose an ACDC image, ID is between 1 and 150", | |
| value=1, | |
| ) | |
| t_step = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| label="Choose the gap between time frames", | |
| value=2, | |
| ) | |
| with gr.Column(scale=0.3): | |
| gr.Markdown("## Model Setting") | |
| trained_dataset = gr.Dropdown( | |
| choices=["ACDC", "M&MS", "M&MS2"], | |
| label="Choose which dataset the segmentation model was finetuned on", | |
| value="ACDC", | |
| ) | |
| seed = gr.Slider( | |
| minimum=0, | |
| maximum=2, | |
| step=1, | |
| label="Choose which seed the finetuning used", | |
| value=0, | |
| ) | |
| run_button = gr.Button("Run segmentation inference", variant="primary") | |
| with gr.Row(): | |
| segmentation_plot = gr.Plot(label="Ventricle and Myocardium Segmentation") | |
| volume_plot = gr.Plot(label="Ejection Fraction Prediction") | |
| run_button.click( | |
| fn=run_inference, | |
| inputs=[trained_dataset, seed, image_id, t_step], | |
| outputs=[segmentation_plot, volume_plot], | |
| ) | |
| demo.launch() | |