|
|
import cv2
|
|
|
import gradio as gr
|
|
|
import numpy as np
|
|
|
import spaces
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
from transformers import AutoModel
|
|
|
|
|
|
|
|
|
def calculate_ctr(mask: np.ndarray) -> float:
|
|
|
|
|
|
lungs = np.zeros_like(mask)
|
|
|
lungs[mask == 1] = 1
|
|
|
lungs[mask == 2] = 1
|
|
|
heart = (mask == 3).astype("int")
|
|
|
y, x = np.stack(np.where(lungs == 1))
|
|
|
lung_min = x.min()
|
|
|
lung_max = x.max()
|
|
|
y, x = np.stack(np.where(heart == 1))
|
|
|
heart_min = x.min()
|
|
|
heart_max = x.max()
|
|
|
lung_range = lung_max - lung_min
|
|
|
heart_range = heart_max - heart_min
|
|
|
return heart_range / lung_range
|
|
|
|
|
|
|
|
|
def make_overlay(
|
|
|
img: np.ndarray, mask: np.ndarray, alpha: float = 0.7
|
|
|
) -> np.ndarray[np.uint8]:
|
|
|
overlay = alpha * img + (1 - alpha) * mask
|
|
|
return overlay.astype(np.uint8)
|
|
|
|
|
|
|
|
|
@spaces.GPU
|
|
|
def predict(Radiograph):
|
|
|
rg = cv2.cvtColor(Radiograph, cv2.COLOR_GRAY2RGB)
|
|
|
x = cxr_info_model.preprocess(Radiograph)
|
|
|
x = torch.from_numpy(x).float().to(device)
|
|
|
x = rearrange(x, "h w -> 1 1 h w")
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
info_out = cxr_info_model(x)
|
|
|
|
|
|
info_mask = info_out["mask"]
|
|
|
h, w = rg.shape[:2]
|
|
|
info_mask = F.interpolate(info_mask, size=(h, w), mode="bilinear")
|
|
|
info_mask = info_mask.argmax(1)[0]
|
|
|
info_mask_3ch = F.one_hot(info_mask, num_classes=4)[..., 1:]
|
|
|
info_mask_3ch = (info_mask_3ch.cpu().numpy() * 255).astype(np.uint8)
|
|
|
info_overlay = make_overlay(rg, info_mask_3ch[..., ::-1])
|
|
|
|
|
|
view = info_out["view"].argmax(1).item()
|
|
|
info_string = ""
|
|
|
if view in {0, 1}:
|
|
|
info_string += "This is a frontal chest radiograph "
|
|
|
if view == 0:
|
|
|
info_string += "(AP projection)."
|
|
|
elif view == 1:
|
|
|
info_string += "(PA projection)."
|
|
|
elif view == 2:
|
|
|
info_string += "This is a lateral chest radiograph."
|
|
|
|
|
|
age = info_out["age"].item()
|
|
|
info_string += f"\nThe patient's predicted age is {round(age)} years."
|
|
|
sex = info_out["female"].item()
|
|
|
if sex < 0.5:
|
|
|
sex = "male"
|
|
|
else:
|
|
|
sex = "female"
|
|
|
info_string += f"\nThe patient's predicted sex is {sex}."
|
|
|
|
|
|
if view in {0, 1}:
|
|
|
ctr = calculate_ctr(info_mask.cpu().numpy())
|
|
|
info_string += f"\nThe estimated cardiothoracic ratio (CTR) is {ctr:0.2f}."
|
|
|
if view == 0:
|
|
|
info_string += (
|
|
|
"\nNote that the cardiac silhuoette is magnified in the AP projection."
|
|
|
)
|
|
|
|
|
|
if view == 2:
|
|
|
info_string += (
|
|
|
"\nNOTE: The below outputs are NOT VALID for lateral radiographs."
|
|
|
)
|
|
|
|
|
|
x = pna_model.preprocess(Radiograph)
|
|
|
x = torch.from_numpy(x).float().to(device)
|
|
|
x = rearrange(x, "h w -> 1 1 h w")
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
pna_out = pna_model(x)
|
|
|
|
|
|
pna_mask = pna_out["mask"]
|
|
|
h, w = rg.shape[:2]
|
|
|
pna_mask = F.interpolate(pna_mask, size=(h, w), mode="bilinear")
|
|
|
pna_mask = (pna_mask.cpu().numpy()[0, 0] * 255).astype(np.uint8)
|
|
|
pna_mask = cv2.applyColorMap(pna_mask, cv2.COLORMAP_JET)
|
|
|
pna_overlay = make_overlay(rg, pna_mask[..., ::-1])
|
|
|
|
|
|
x = ptx_model.preprocess(Radiograph)
|
|
|
x = torch.from_numpy(x).float().to(device)
|
|
|
x = rearrange(x, "h w -> 1 1 h w")
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
ptx_out = ptx_model(x)
|
|
|
|
|
|
ptx_mask = ptx_out["mask"]
|
|
|
h, w = rg.shape[:2]
|
|
|
ptx_mask = F.interpolate(ptx_mask, size=(h, w), mode="bilinear")
|
|
|
ptx_mask = (ptx_mask.cpu().numpy()[0, 0] * 255).astype(np.uint8)
|
|
|
ptx_mask = cv2.applyColorMap(ptx_mask, cv2.COLORMAP_JET)
|
|
|
ptx_overlay = make_overlay(rg, ptx_mask[..., ::-1])
|
|
|
|
|
|
preds = {"Pneumonia": pna_out["cls"].item(), "Pneumothorax": ptx_out["cls"].item()}
|
|
|
return [info_string, preds, info_overlay, pna_overlay, ptx_overlay]
|
|
|
|
|
|
|
|
|
image = gr.Image(image_mode="L")
|
|
|
info_textbox = gr.Textbox(show_label=False)
|
|
|
labels = gr.Label(show_label=False, show_heading=False)
|
|
|
heatmap1 = gr.Image(image_mode="RGB", label="Heart & Lungs")
|
|
|
heatmap2 = gr.Image(image_mode="RGB", label="Pneumonia")
|
|
|
heatmap3 = gr.Image(image_mode="RGB", label="Pneumothorax")
|
|
|
|
|
|
with gr.Blocks() as demo:
|
|
|
gr.Markdown(
|
|
|
"""
|
|
|
# Deep Learning for Chest Radiographs
|
|
|
|
|
|
This demo uses 3 models for chest radiographs:
|
|
|
1) Heart and lungs segmentation, with age, view, and sex prediction <https://huggingface.co/ianpan/chest-x-ray-basic>
|
|
|
2) Pneumonia classification and segmentation <https://huggingface.co/ianpan/pneumonia-cxr>
|
|
|
3) Pneumothorax classification and segmentation <https://huggingface.co/ianpan/pneumothorax-cxr>
|
|
|
|
|
|
Note that the pneumonia and pneumothorax heatmaps produced by this model are based on pixel-level segmentation maps.
|
|
|
Thus, they are expected to be more accurate than non-explicit localization methods such as GradCAM.
|
|
|
|
|
|
The example radiograph is my own, from when I had pneumonia.
|
|
|
|
|
|
This model is for demonstration purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes
|
|
|
any and all responsibility regarding their own use of this model and its outputs. Do NOT upload any images containing protected
|
|
|
health information, as this demonstration is not compliant with patient privacy laws.
|
|
|
|
|
|
Created by: Ian Pan, <https://ianpan.me>
|
|
|
|
|
|
Last updated: December 27, 2024
|
|
|
"""
|
|
|
)
|
|
|
gr.Interface(
|
|
|
fn=predict,
|
|
|
inputs=image,
|
|
|
outputs=[info_textbox, labels, heatmap1, heatmap2, heatmap3],
|
|
|
examples=["examples/cxr.png"],
|
|
|
cache_examples=True,
|
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
print(f"Using device `{device}` ...")
|
|
|
|
|
|
cxr_info_model = (
|
|
|
AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True)
|
|
|
.eval()
|
|
|
.to(device)
|
|
|
)
|
|
|
pna_model = (
|
|
|
AutoModel.from_pretrained("ianpan/pneumonia-cxr", trust_remote_code=True)
|
|
|
.eval()
|
|
|
.to(device)
|
|
|
)
|
|
|
ptx_model = (
|
|
|
AutoModel.from_pretrained("ianpan/pneumothorax-cxr", trust_remote_code=True)
|
|
|
.eval()
|
|
|
.to(device)
|
|
|
)
|
|
|
|
|
|
demo.launch(share=True)
|
|
|
|