Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import cv2 | |
| from AV.models.network import PGNet | |
| from AV.Tools.AVclassifiation import AVclassifiation | |
| from AV.Tools.utils_test import paint_border_overlap, extract_ordered_overlap_big, Normalize, sigmoid, recompone_overlap, \ | |
| kill_border | |
| from AV.config import config_test_general as cfg | |
| import torch.autograd as autograd | |
| import numpy as np | |
| import os | |
| from datetime import datetime | |
| from huggingface_hub import hf_hub_download | |
| hf_token = os.environ.get("HF_token") | |
| def creatMask(Image, threshold=5): | |
| ##This program try to creat the mask for the filed-of-view | |
| ##Input original image (RGB or green channel), threshold (user set parameter, default 10) | |
| ##Output: the filed-of-view mask | |
| if len(Image.shape) == 3: ##RGB image | |
| gray = cv2.cvtColor(Image, cv2.COLOR_BGR2GRAY) | |
| Mask0 = gray >= threshold | |
| else: # for green channel image | |
| Mask0 = Image >= threshold | |
| # ######get the largest blob, this takes 0.18s | |
| cvVersion = int(cv2.__version__.split('.')[0]) | |
| Mask0 = np.uint8(Mask0) | |
| contours, hierarchy = cv2.findContours(Mask0, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
| areas = [cv2.contourArea(c) for c in contours] | |
| max_index = np.argmax(areas) | |
| Mask = np.zeros(Image.shape[:2], dtype=np.uint8) | |
| cv2.drawContours(Mask, contours, max_index, 1, -1) | |
| ResultImg = Image.copy() | |
| if len(Image.shape) == 3: | |
| ResultImg[Mask == 0] = (255, 255, 255) | |
| else: | |
| ResultImg[Mask == 0] = 255 | |
| Mask[Mask > 0] = 255 | |
| kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) | |
| Mask = cv2.morphologyEx(Mask, cv2.MORPH_OPEN, kernel, iterations=3) | |
| return ResultImg, Mask | |
| def shift_rgb(img, *args): | |
| result_img = np.empty_like(img) | |
| shifts = args | |
| max_value = 255 | |
| # print(shifts) | |
| for i, shift in enumerate(shifts): | |
| lut = np.arange(0, max_value + 1).astype("float32") | |
| lut += shift | |
| lut = np.clip(lut, 0, max_value).astype(img.dtype) | |
| if len(img.shape) == 2: | |
| print(f'=========grey image=======') | |
| result_img = cv2.LUT(img, lut) | |
| else: | |
| result_img[..., i] = cv2.LUT(img[..., i], lut) | |
| return result_img | |
| def CAM(x, img_path, rate=0.8, ind=0): | |
| """ | |
| :param dataset_path: 计算整个训练数据集的平均RGB通道值 | |
| :param image: array, 单张图片的array 形式 | |
| :return: array形式的cam后的结果 | |
| """ | |
| # 每次使用新数据集时都需要重新计算前面的RBG平均值 | |
| # RGB-->Rshift-->CLAHE | |
| x = np.uint8(x) | |
| _, Mask0 = creatMask(x, threshold=10) | |
| Mask = np.zeros((x.shape[0], x.shape[1]), np.float32) | |
| Mask[Mask0 > 0] = 1 | |
| resize = False | |
| R_mea_num, G_mea_num, B_mea_num = [], [], [] | |
| dataset_path = img_path | |
| image = np.array(Image.open(dataset_path)) | |
| R_mea_num.append(np.mean(image[:, :, 0])) | |
| G_mea_num.append(np.mean(image[:, :, 1])) | |
| B_mea_num.append(np.mean(image[:, :, 2])) | |
| mea2stand = int((np.mean(R_mea_num) - np.mean(x[:, :, 0])) * rate) | |
| mea2standg = int((np.mean(G_mea_num) - np.mean(x[:, :, 1])) * rate) | |
| mea2standb = int((np.mean(B_mea_num) - np.mean(x[:, :, 2])) * rate) | |
| y = shift_rgb(x, mea2stand, mea2standg, mea2standb) | |
| y[Mask == 0, :] = 0 | |
| return y | |
| def modelEvalution_out_big(net, use_cuda=False, dataset='', is_kill_border=True, input_ch=3, | |
| config=None, output_dir='', evaluate_metrics=False): | |
| # path for images to save | |
| n_classes = 3 | |
| Net = PGNet(use_global_semantic=config.use_global_semantic, input_ch=input_ch, | |
| num_classes=n_classes, use_cuda=use_cuda, pretrained=False, centerness=config.use_centerness, | |
| centerness_map_size=config.centerness_map_size) | |
| msg = Net.load_state_dict(net, strict=False) | |
| if use_cuda: | |
| Net.cuda() | |
| Net.eval() | |
| image_basename = dataset | |
| # if not os.path.exists(output_dir): | |
| # os.makedirs(output_dir) | |
| step = 1 | |
| # every step of between star and end for loop until len(image_basename) | |
| # for start_end in start_end_list: | |
| image0 = cv2.imread(image_basename) | |
| test_image_height = image0.shape[0] | |
| test_image_width = image0.shape[1] | |
| if config.use_resize: | |
| if min(test_image_height, test_image_width) <= 256: | |
| scaling = 512 / min(test_image_height, test_image_width) | |
| new_width = int(test_image_width * scaling) | |
| new_height = int(test_image_height * scaling) | |
| test_image_width, test_image_height = new_width, new_height | |
| # 大尺寸处理:确保最长边≤1536 | |
| elif max(test_image_height, test_image_width) >= 2048: | |
| scaling = 2048 / max(test_image_height, test_image_width) | |
| new_width = int(test_image_width * scaling) | |
| new_height = int(test_image_height * scaling) | |
| test_image_width, test_image_height = new_width, new_height | |
| ArteryPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) | |
| VeinPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) | |
| VesselPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) | |
| ProMap = np.zeros((1, 3, test_image_height, test_image_width), np.float32) | |
| MaskAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) | |
| ArteryPred, VeinPred, VesselPred, Mask, LabelArtery, LabelVein, LabelVessel = GetResult_out_big(Net, 0, | |
| use_cuda=use_cuda, | |
| dataset=image_basename, | |
| is_kill_border=is_kill_border, | |
| config=config, | |
| resize_w_h=( | |
| test_image_width, | |
| test_image_height) | |
| ) | |
| ArteryPredAll[0 % step, :, :, :] = ArteryPred | |
| VeinPredAll[0 % step, :, :, :] = VeinPred | |
| VesselPredAll[0 % step, :, :, :] = VesselPred | |
| MaskAll[0 % step, :, :, :] = Mask | |
| image_color = AVclassifiation(output_dir, ArteryPredAll, VeinPredAll, VesselPredAll, 1, image_basename) | |
| return image_color | |
| def GetResult_out_big(Net, k, use_cuda=False, dataset='', is_kill_border=False, config=None, | |
| resize_w_h=None): | |
| ImgName = dataset | |
| Img0 = cv2.imread(ImgName) | |
| _, Mask0 = creatMask(Img0, threshold=-1) | |
| Mask = np.zeros((Img0.shape[0], Img0.shape[1]), np.float32) | |
| Mask[Mask0 > 0] = 1 | |
| if config.use_resize: | |
| Img0 = cv2.resize(Img0, resize_w_h) | |
| Mask = cv2.resize(Mask, resize_w_h, interpolation=cv2.INTER_NEAREST) | |
| Img = Img0 | |
| height, width = Img.shape[:2] | |
| n_classes = 3 | |
| patch_height = config.patch_size | |
| patch_width = config.patch_size | |
| stride_height = config.stride_height | |
| stride_width = config.stride_width | |
| Img = cv2.cvtColor(Img, cv2.COLOR_BGR2RGB) | |
| if cfg.dataset == 'all': | |
| # # # 将图像转换为 LAB 颜色空间 | |
| lab = cv2.cvtColor(Img, cv2.COLOR_RGB2LAB) | |
| # 拆分 LAB 通道 | |
| l, a, b = cv2.split(lab) | |
| # 创建 CLAHE 对象并应用到 L 通道 | |
| clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(8, 8)) | |
| l_clahe = clahe.apply(l) | |
| # 将 CLAHE 处理后的 L 通道与原始的 A 和 B 通道合并 | |
| lab_clahe = cv2.merge((l_clahe, a, b)) | |
| # 将图像转换回 BGR 颜色空间 | |
| Img = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB) | |
| if cfg.use_CAM: | |
| Img = CAM(Img, dataset) | |
| Img = np.float32(Img / 255.) | |
| Img_enlarged = paint_border_overlap(Img, patch_height, patch_width, stride_height, stride_width) | |
| patch_size = config.patch_size | |
| batch_size = 2 | |
| patches_imgs, global_images = extract_ordered_overlap_big(Img_enlarged, patch_height, patch_width, | |
| stride_height, | |
| stride_width) | |
| patches_imgs = np.transpose(patches_imgs, (0, 3, 1, 2)) | |
| patches_imgs = Normalize(patches_imgs) | |
| global_images = np.transpose(global_images, (0, 3, 1, 2)) | |
| global_images = Normalize(global_images) | |
| patchNum = patches_imgs.shape[0] | |
| max_iter = int(np.ceil(patchNum / float(batch_size))) | |
| pred_patches = np.zeros((patchNum, n_classes, patch_size, patch_size), np.float32) | |
| for i in range(max_iter): | |
| begin_index = i * batch_size | |
| end_index = (i + 1) * batch_size | |
| patches_temp1 = patches_imgs[begin_index:end_index, :, :, :] | |
| patches_input_temp1 = torch.FloatTensor(patches_temp1) | |
| global_input_temp1 = patches_input_temp1 | |
| if config.use_global_semantic: | |
| global_temp1 = global_images[begin_index:end_index, :, :, :] | |
| global_input_temp1 = torch.FloatTensor(global_temp1) | |
| if use_cuda: | |
| patches_input_temp1 = autograd.Variable(patches_input_temp1.cuda()) | |
| if config.use_global_semantic: | |
| global_input_temp1 = autograd.Variable(global_input_temp1.cuda()) | |
| else: | |
| patches_input_temp1 = autograd.Variable(patches_input_temp1) | |
| if config.use_global_semantic: | |
| global_input_temp1 = autograd.Variable(global_input_temp1) | |
| output_temp, _1, = Net(patches_input_temp1, global_input_temp1) | |
| pred_patches_temp = np.float32(output_temp.data.cpu().numpy()) | |
| pred_patches_temp_sigmoid = sigmoid(pred_patches_temp) | |
| pred_patches[begin_index:end_index, :, :, :] = pred_patches_temp_sigmoid[:, :, :patch_size, :patch_size] | |
| del patches_input_temp1 | |
| del pred_patches_temp | |
| del patches_temp1 | |
| del output_temp | |
| del pred_patches_temp_sigmoid | |
| new_height, new_width = Img_enlarged.shape[0], Img_enlarged.shape[1] | |
| pred_img = recompone_overlap(pred_patches, new_height, new_width, stride_height, stride_width) # predictions | |
| pred_img = pred_img[:, 0:height, 0:width] | |
| if is_kill_border: | |
| pred_img = kill_border(pred_img, Mask) | |
| ArteryPred = np.float32(pred_img[0, :, :]) | |
| VeinPred = np.float32(pred_img[2, :, :]) | |
| VesselPred = np.float32(pred_img[1, :, :]) | |
| ArteryPred = ArteryPred[np.newaxis, :, :] | |
| VeinPred = VeinPred[np.newaxis, :, :] | |
| VesselPred = VesselPred[np.newaxis, :, :] | |
| Mask = Mask[np.newaxis, :, :] | |
| return ArteryPred, VeinPred, VesselPred, Mask, ArteryPred, VeinPred, VesselPred, | |
| def out_test(cfg,model_path='', output_dir='', evaluate_metrics=False, img_name='out_test'): | |
| device = torch.device("cuda" if cfg.use_cuda else "cpu") | |
| model_path = model_path | |
| net = torch.load(model_path, map_location=device) | |
| image_color = modelEvalution_out_big(net, | |
| use_cuda=cfg.use_cuda, | |
| dataset=img_name, | |
| input_ch=cfg.input_nc, | |
| config=cfg, | |
| output_dir=output_dir, evaluate_metrics=evaluate_metrics) | |
| return image_color | |
| def segment_by_out_test(image,model_name): | |
| print("✅ 传到后端的模型名:", model_name) | |
| model_path = hf_hub_download( | |
| repo_id="weidai00/RIP-AV-sulab", # 模型库的名字 | |
| filename=f"G_{model_name}.pkl", # 文件名 | |
| repo_type="model", # 模型库必须写 repo_type | |
| token=hf_token | |
| ) | |
| cfg.set_dataset(model_name) | |
| if image is None: | |
| raise gr.Error("请上传一张图像(upload a fundus image)。") | |
| os.makedirs("./examples", exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| temp_path = f"./examples/tmp_upload_{timestamp}.png" | |
| image.save(temp_path) | |
| image_color = out_test(cfg,model_path=model_path, output_dir='', evaluate_metrics=False, img_name=temp_path) | |
| return Image.fromarray(image_color) | |
| def gradio_interface(): | |
| model_info_md = """ | |
| ### 📘 模型说明 | |
| | 模型(model name) | 数据集(dataset) | patch size |running time | | |
| |------|--------|------------|--------| | |
| | DRIVE | 小分辨率血管图像 | 256 |30s以内| | |
| | HRF | 高分辨率图像(健康、青光眼等)| 256 | 2min以内| | |
| | LES | 视盘中心图像适配 | 256 |2min以内| | |
| | UKBB | UKBB图像 | 256 |2min以内 | | |
| | 通用模型(512) | 超清图像,适配性强 | 512 |2min以内| | |
| """ | |
| model_choices = [ | |
| ("1: DRIVE专用模型", "DRIVE"), | |
| ("2: HRF专用模型", "hrf"), | |
| ("3: LES专用模型","LES"), | |
| ("4: UKBB专用模型", "ukbb"), | |
| ("5: 通用模型(general)", "all"), | |
| ] | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 👁️ 眼底图像动静脉血管分割(Retinal image artery and vein segmentation)") | |
| gr.Markdown("上传眼底图像,选择一个模型开始处理,结果将自动生成。(Upload the retinal image, select a model to start processing, and the results will be generated automatically.)") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="📤 上传图像(upload)",height=300) | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_select = gr.Radio( | |
| choices=model_choices, | |
| label="🎯 选择模型", | |
| value="DRIVE", | |
| interactive = True | |
| ) | |
| submit_btn = gr.Button("🚀 开始分割(RUN)") | |
| with gr.Column(): | |
| output_image = gr.Image(label="🖼️ 分割结果(Result)") | |
| gr.Markdown("### 📁 示例图像examples(点击自动加载)") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/DRIVE.tif", "DRIVE"], | |
| ["examples/LES.png", "LES"], | |
| ["examples/hrf.png", "hrf"], | |
| ["examples/ukbb.png", "ukbb"], | |
| ["examples/all.jpg", "all"] | |
| ], | |
| inputs=[image_input, model_select], | |
| label="示例图像", | |
| examples_per_page=5 | |
| ) | |
| with gr.Accordion("📖 模型说明-Description(点击展开)", open=False): | |
| gr.Markdown(model_info_md) | |
| # 功能连接 | |
| submit_btn.click( | |
| fn=segment_by_out_test, | |
| inputs=[image_input, model_select], | |
| outputs=[output_image] | |
| ) | |
| gr.Markdown("📚 **专用模型引用cite**: RIP-AV: Joint Representative Instance Pre-training with Context Aware Network for Retinal Artery/Vein Segmentation") | |
| gr.Markdown("📚 **通用模型引用cite**: An Efficient and Interpretable Foundation Model for Retinal Image Analysis in Disease Diagnosis.") | |
| demo.queue() | |
| demo.launch() | |
| if __name__ == '__main__': | |
| # cfg.set_dataset('all') | |
| # image_color = out_test(cfg = cfg, evaluate_metrics=False, img_name=r'.\AV\data\AV-DRIVE\test\images\01_test.tif') | |
| # Image.fromarray(image_color).save('image_color.png') | |
| #print(cfg.patch_size) | |
| gradio_interface() | |