isLinXu commited on
Commit
00fb916
Β·
1 Parent(s): ca1720d
Files changed (2) hide show
  1. app.py +98 -0
  2. requirements.txt +20 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torchvision import models, transforms
7
+
8
+ import warnings
9
+
10
+ warnings.filterwarnings("ignore")
11
+
12
+ # εŠ θ½½ζ¨‘εž‹
13
+ models_dict = {
14
+ 'DeepLabv3': models.segmentation.deeplabv3_resnet50(pretrained=True).eval(),
15
+ 'DeepLabv3+': models.segmentation.deeplabv3_resnet101(pretrained=True).eval(),
16
+ 'FCN-ResNet50': models.segmentation.fcn_resnet50(pretrained=True).eval(),
17
+ 'FCN-ResNet101': models.segmentation.fcn_resnet101(pretrained=True).eval(),
18
+ 'LRR': models.segmentation.lraspp_mobilenet_v3_large(pretrained=True).eval(),
19
+ }
20
+
21
+ # 图像钄倄理
22
+ image_transforms = transforms.Compose([
23
+ transforms.Resize(256),
24
+ transforms.CenterCrop(224),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(
27
+ mean=[0.485, 0.456, 0.406],
28
+ std=[0.229, 0.224, 0.225]
29
+ )
30
+ ])
31
+
32
+ def download_test_img():
33
+ # Images
34
+ torch.hub.download_url_to_file(
35
+ 'https://user-images.githubusercontent.com/59380685/266264420-21575a83-4057-41cf-8a4a-b3ea6f332d79.jpg',
36
+ 'bus.jpg')
37
+ torch.hub.download_url_to_file(
38
+ 'https://user-images.githubusercontent.com/59380685/266264536-82afdf58-6b9a-4568-b9df-551ee72cb6d9.jpg',
39
+ 'dogs.jpg')
40
+ torch.hub.download_url_to_file(
41
+ 'https://user-images.githubusercontent.com/59380685/266264600-9d0c26ca-8ba6-45f2-b53b-4dc98460c43e.jpg',
42
+ 'zidane.jpg')
43
+
44
+ def predict_segmentation(image, model_name):
45
+
46
+ # 图像钄倄理
47
+ image_tensor = image_transforms(image).unsqueeze(0)
48
+
49
+ # ζ¨‘εž‹ζŽ¨η†
50
+ with torch.no_grad():
51
+ output = models_dict[model_name](image_tensor)['out'][0]
52
+ output_predictions = output.argmax(0)
53
+ segmentation = F.interpolate(
54
+ output.float().unsqueeze(0),
55
+ size=image.size[::-1],
56
+ mode='bicubic',
57
+ align_corners=False
58
+ )[0].argmax(0).numpy()
59
+
60
+ # εˆ†ε‰²ε›Ύ
61
+ segmentation_image = np.uint8(segmentation)
62
+ segmentation_image = cv2.applyColorMap(segmentation_image, cv2.COLORMAP_JET)
63
+
64
+ # θžεˆε›Ύ
65
+ blend_image = cv2.addWeighted(np.array(image), 0.5, segmentation_image, 0.5, 0)
66
+ blend_image = cv2.cvtColor(blend_image, cv2.COLOR_BGR2RGB)
67
+
68
+ return segmentation_image, blend_image
69
+
70
+
71
+ import gradio as gr
72
+
73
+ examples = [
74
+ ['bus.jpg', 'DeepLabv3'],
75
+ ['dogs.jpg', 'DeepLabv3'],
76
+ ['zidane.jpg', 'DeepLabv3']
77
+ ]
78
+ download_test_img()
79
+ model_list = ['DeepLabv3', 'DeepLabv3+', 'FCN-ResNet50', 'FCN-ResNet101', 'LRR']
80
+ inputs = [
81
+ gr.inputs.Image(type='pil', label='εŽŸε§‹ε›Ύεƒ'),
82
+ gr.inputs.Dropdown(model_list, label='ι€‰ζ‹©ζ¨‘εž‹', default='DeepLabv3')
83
+ ]
84
+ outputs = [
85
+ gr.outputs.Image(type='pil',label='εˆ†ε‰²ε›Ύ'),
86
+ gr.outputs.Image(type='pil',label='θžεˆε›Ύ')
87
+ ]
88
+ interface = gr.Interface(
89
+ predict_segmentation,
90
+ inputs,
91
+ outputs,
92
+ examples=examples,
93
+ capture_session=True,
94
+ title='torchvision-segmentation-webui',
95
+ description='torchvision segmentation webui on gradio'
96
+ )
97
+
98
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wget~=3.2
2
+ opencv-python~=4.6.0.66
3
+ numpy~=1.23.0
4
+ torch~=1.13.1
5
+ torchvision~=0.14.1
6
+ pillow~=9.4.0
7
+ gradio~=3.42.0
8
+ ultralytics~=8.0.169
9
+ pyyaml~=6.0
10
+ wandb~=0.13.11
11
+ tqdm~=4.65.0
12
+ matplotlib~=3.7.1
13
+ pandas~=2.0.0
14
+ seaborn~=0.12.2
15
+ requests~=2.31.0
16
+ psutil~=5.9.4
17
+ thop~=0.1.1-2209072238
18
+ timm~=0.9.2
19
+ super-gradients~=3.2.0
20
+ openmim