Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import monai | |
| import torch | |
| from monai.networks.nets import UNet | |
| from PIL import Image | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| import numpy as np | |
| import shutil | |
| import openslide | |
| model = UNet( | |
| spatial_dims=2, | |
| in_channels=3, | |
| out_channels=1, | |
| channels=[16, 32, 64, 128, 256, 512], | |
| strides=(2, 2, 2, 2, 2), | |
| num_res_units=4, | |
| dropout=0.15, | |
| ) | |
| model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu'))) | |
| model.eval() | |
| def process_image(image): | |
| image = image / 255.0 | |
| image = image.astype(np.float32) | |
| inference_transforms = A.Compose([ | |
| A.Resize(height=512, width=512), | |
| ToTensorV2(), | |
| ]) | |
| image = inference_transforms(image=image)["image"] | |
| image = image.unsqueeze(0) | |
| with torch.no_grad(): | |
| mask_pred = torch.sigmoid(model(image)) | |
| return mask_pred[0, 0, :, :].numpy() | |
| interface_image = gr.Interface( | |
| fn=process_image, | |
| title="Histapathology segmentation", | |
| inputs=[ | |
| gr.Image( | |
| label="Input image", | |
| image_mode="RGB", | |
| height=400, | |
| type="numpy", | |
| width=400, | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Image( | |
| label="Model Prediction", | |
| image_mode="L", | |
| height=400, | |
| width=400, | |
| ) | |
| ], | |
| # examples=[ | |
| # os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"), | |
| # os.path.join(os.path.dirname(__file__), "images/lion.jpg"), | |
| # os.path.join(os.path.dirname(__file__), "images/logo.png"), | |
| # os.path.join(os.path.dirname(__file__), "images/tower.jpg"), | |
| # ], | |
| ) | |
| def process_slide(slide_path): | |
| if not slide_path.endswith("zip"): | |
| slide = openslide.OpenSlide(os.path.join(path, image_path)) | |
| else: # mrxs slide files | |
| shutil.unpack_archive(slide_path, "cache_mrxs") | |
| files = os.listdir("cache_mrxs") | |
| slide_name = [file for file in files if file.endswith("mrxs")][0] | |
| slide = openslide.OpenSlide(os.path.join("cache_mrxs", slide_name)) | |
| slide.get_thumbnail((512, 512)) | |
| return slide | |
| interface_slide = gr.Interface( | |
| fn=process_slide, | |
| inputs=[ | |
| gr.File( | |
| label="Input slide file (input zip for `.mrxs` files)", | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Image( | |
| label="Model Prediction", | |
| image_mode="RGB", | |
| height=400, | |
| width=400, | |
| ) | |
| ], | |
| ) | |
| demo = gr.TabbedInterface([interface_image, interface_slide], ["Image-to-Mask", "Slide-to-Mask"]) | |
| demo.launch() | |