isLinXu commited on
Commit
538b38c
·
1 Parent(s): 1b4b3e6

update files

Browse files
Files changed (2) hide show
  1. app.py +130 -0
  2. requirements.txt +20 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import mmcv
4
+ from mmengine import Config
5
+
6
+ os.system("pip install 'mmengine>=0.6.0'")
7
+ os.system("pip install 'mmcv>=2.0.0rc4,<2.1.0'")
8
+ os.system("pip install mmsegmentation")
9
+
10
+ import gradio as gr
11
+ import fnmatch
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ from mmseg.apis import init_model, inference_model, show_result_pyplot
16
+ from mmseg.apis import MMSegInferencer
17
+
18
+ import PIL
19
+ from mim import download
20
+ import warnings
21
+
22
+ warnings.filterwarnings("ignore")
23
+
24
+ mmseg_models_list = MMSegInferencer.list_models('mmseg')
25
+
26
+ path = "./checkpoint"
27
+ if not os.path.exists(path):
28
+ os.makedirs(path)
29
+
30
+
31
+ def clear_folder(folder_path):
32
+ import shutil
33
+ for filename in os.listdir(folder_path):
34
+ file_path = os.path.join(folder_path, filename)
35
+ try:
36
+ if os.path.isfile(file_path) or os.path.islink(file_path):
37
+ os.unlink(file_path)
38
+ elif os.path.isdir(file_path):
39
+ shutil.rmtree(file_path)
40
+ except Exception as e:
41
+ print(f"Failed to delete {file_path}. Reason: {e}")
42
+ print(f"Clear {folder_path} successfully.")
43
+
44
+
45
+ def save_image(img, img_path):
46
+ # Convert PIL image to OpenCV image
47
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
48
+ # Save OpenCV image
49
+ cv2.imwrite(img_path, img)
50
+
51
+
52
+ def download_test_image():
53
+ # Images
54
+ torch.hub.download_url_to_file(
55
+ 'https://user-images.githubusercontent.com/59380685/266264420-21575a83-4057-41cf-8a4a-b3ea6f332d79.jpg',
56
+ 'bus.jpg')
57
+ torch.hub.download_url_to_file(
58
+ 'https://user-images.githubusercontent.com/59380685/266264536-82afdf58-6b9a-4568-b9df-551ee72cb6d9.jpg',
59
+ 'dogs.jpg')
60
+ torch.hub.download_url_to_file(
61
+ 'https://user-images.githubusercontent.com/59380685/266264600-9d0c26ca-8ba6-45f2-b53b-4dc98460c43e.jpg',
62
+ 'zidane.jpg')
63
+
64
+
65
+ def download_cfg_checkpoint_model_name(model_name):
66
+ clear_folder("./checkpoint")
67
+ download(package='mmsegmentation',
68
+ configs=[model_name],
69
+ dest_root='./checkpoint')
70
+
71
+
72
+
73
+ # 定义推理函数
74
+ def predict(img, model_name):
75
+ # 保存输入图片
76
+ img_path = 'input_image.png'
77
+ save_image(img, img_path)
78
+ download_cfg_checkpoint_model_name(model_name)
79
+
80
+ config_path = [f for f in os.listdir(path) if fnmatch.fnmatch(f, "*.py")][0]
81
+ config_path = path + "/" + config_path
82
+
83
+ checkpoint_path = [f for f in os.listdir(path) if fnmatch.fnmatch(f, "*.pth")][0]
84
+ checkpoint_path = path + "/" + checkpoint_path
85
+
86
+ # 从配置文件和权重文件构建模型
87
+ device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
88
+
89
+ if device == 'cpu':
90
+ config_path = Config.fromfile(config_path)
91
+ # Remove pretrain model download for testing
92
+ config_path.model.pretrained = None
93
+ # Replace SyncBN with BN to inference on CPU
94
+ norm_cfg = dict(type='BN', requires_grad=True)
95
+ config_path.model.backbone.norm_cfg = norm_cfg
96
+ config_path.model.decode_head.norm_cfg = norm_cfg
97
+ config_path.model.auxiliary_head.norm_cfg = norm_cfg
98
+
99
+ model = init_model(config_path, checkpoint_path, device=device)
100
+
101
+ # 推理给定图像
102
+ result = inference_model(model, img_path)
103
+
104
+ # 保存可视化结果
105
+ vis_image = show_result_pyplot(model, img_path, result, show=False)
106
+ vis_image_path = 'output_image.png'
107
+ cv2.imwrite(vis_image_path, vis_image)
108
+ output_img = PIL.Image.open(vis_image_path)
109
+ # 返回输出图片
110
+ return output_img
111
+
112
+ download_test_image()
113
+ # 定义输入和输出界面
114
+ inputs_img = gr.inputs.Image(type='pil', label="Input Image")
115
+ model_list = gr.inputs.Dropdown(choices=[m for m in mmseg_models_list], label='Model')
116
+ outputs_img = gr.outputs.Image(type='pil', label="Output Image")
117
+
118
+ # 启动界面
119
+ title = "MMSegmentation segmentation web demo"
120
+ description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmsegmentation/main/resources/mmseg-logo.png' width='450''/><div>" \
121
+ "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmsegmentation'>MMSegmentation</a> MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 OpenMMLab 项目的一部分。" \
122
+ "OpenMMLab Semantic Segmentation Toolbox and Benchmark..</p>"
123
+ article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmsegmentation'>MMSegmentation</a></p>" \
124
+ "<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
125
+ examples = [["bus.jpg", "deeplabv3_r101-d8_4xb2-40k_cityscapes-512x1024"],
126
+ ["dogs.jpg", "pspnet_r50-d8_4xb2-40k_cityscapes-512x1024"],
127
+ ["zidane.jpg", "fcn_r101-d8_4xb4-80k_ade20k-512x512"]
128
+ ]
129
+ gr.Interface(fn=predict, inputs=[inputs_img, model_list], outputs=outputs_img, examples=examples,
130
+ title=title, description=description, article=article).launch()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wget~=3.2
2
+ opencv-python~=4.6.0.66
3
+ numpy~=1.23.0
4
+ torch~=1.13.1
5
+ torchvision~=0.14.1
6
+ pillow~=9.4.0
7
+ gradio~=3.42.0
8
+ ultralytics~=8.0.169
9
+ pyyaml~=6.0
10
+ wandb~=0.13.11
11
+ tqdm~=4.65.0
12
+ matplotlib~=3.7.1
13
+ pandas~=2.0.0
14
+ seaborn~=0.12.2
15
+ requests~=2.31.0
16
+ psutil~=5.9.4
17
+ thop~=0.1.1-2209072238
18
+ timm~=0.9.2
19
+ super-gradients~=3.2.0
20
+ openmim