isLinXu commited on
Commit
32a5465
·
1 Parent(s): d42c518
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.system("pip install 'mmcv-full>=1.3.17,<=1.7.0'")
3
+ os.system("pip install 'mmcv-full>=1.3.17,<=1.7.0'")
4
+ os.system("pip install mmdet==2.25.1")
5
+ os.system("git clone https://github.com/open-mmlab/mmtracking.git")
6
+ os.system("pip install -r mmtracking/requirements.txt")
7
+ os.system("pip install -v -e mmtracking/")
8
+ os.system("pip install 'mmtrack'")
9
+ import os
10
+ import os.path as osp
11
+ import gradio as gr
12
+ import tempfile
13
+ from argparse import ArgumentParser
14
+
15
+ import mmcv
16
+
17
+ from mmtrack.apis import inference_mot, init_model
18
+
19
+ def parse_args():
20
+ parser = ArgumentParser()
21
+ parser.add_argument('--config', help='config file')
22
+ parser.add_argument('--input', help='input video file or folder')
23
+ parser.add_argument(
24
+ '--output', help='output video file (mp4 format) or folder')
25
+ parser.add_argument('--checkpoint', help='checkpoint file')
26
+ parser.add_argument(
27
+ '--score-thr',
28
+ type=float,
29
+ default=0.0,
30
+ help='The threshold of score to filter bboxes.')
31
+ parser.add_argument(
32
+ '--device', default='cuda:0', help='device used for inference')
33
+ parser.add_argument(
34
+ '--show',
35
+ action='store_true',
36
+ help='whether show the results on the fly')
37
+ parser.add_argument(
38
+ '--backend',
39
+ choices=['cv2', 'plt'],
40
+ default='cv2',
41
+ help='the backend to visualize the results')
42
+ parser.add_argument('--fps', help='FPS of the output video')
43
+ args = parser.parse_args()
44
+ return args
45
+ def track_mot(input, config, output, device, score_thr):
46
+ args = parse_args()
47
+ args.input = input
48
+ args.config = config
49
+ args.output = output
50
+ args.device = device
51
+ args.score_thr = score_thr
52
+ args.show = False
53
+ args.backend = 'cv2'
54
+
55
+ # assert args.output or args.show
56
+ # load images
57
+ if osp.isdir(args.input):
58
+ imgs = sorted(
59
+ filter(lambda x: x.endswith(('.jpg', '.png', '.jpeg')),
60
+ os.listdir(args.input)),
61
+ key=lambda x: int(x.split('.')[0]))
62
+ IN_VIDEO = False
63
+ else:
64
+ imgs = mmcv.VideoReader(args.input)
65
+ IN_VIDEO = True
66
+ # define output
67
+ if args.output is not None:
68
+ if args.output.endswith('.mp4'):
69
+ OUT_VIDEO = True
70
+ out_dir = tempfile.TemporaryDirectory()
71
+ out_path = out_dir.name
72
+ _out = args.output.rsplit(os.sep, 1)
73
+ if len(_out) > 1:
74
+ os.makedirs(_out[0], exist_ok=True)
75
+ else:
76
+ OUT_VIDEO = False
77
+ out_path = args.output
78
+ os.makedirs(out_path, exist_ok=True)
79
+ #
80
+ fps = args.fps
81
+ if args.show or OUT_VIDEO:
82
+ if fps is None and IN_VIDEO:
83
+ fps = imgs.fps
84
+ if not fps:
85
+ raise ValueError('Please set the FPS for the output video.')
86
+ fps = int(fps)
87
+ #
88
+ # build the model from a config file and a checkpoint file
89
+ model = init_model(args.config, args.checkpoint, device=args.device)
90
+
91
+ prog_bar = mmcv.ProgressBar(len(imgs))
92
+ # test and show/save the images
93
+ for i, img in enumerate(imgs):
94
+ if isinstance(img, str):
95
+ img = osp.join(args.input, img)
96
+ result = inference_mot(model, img, frame_id=i)
97
+ if args.output is not None:
98
+ if IN_VIDEO or OUT_VIDEO:
99
+ out_file = osp.join(out_path, f'{i:06d}.jpg')
100
+ else:
101
+ out_file = osp.join(out_path, img.rsplit(os.sep, 1)[-1])
102
+ else:
103
+ out_file = None
104
+ model.show_result(
105
+ img,
106
+ result,
107
+ score_thr=args.score_thr,
108
+ show=args.show,
109
+ wait_time=int(1000. / fps) if fps else 0,
110
+ out_file=out_file,
111
+ backend=args.backend)
112
+ prog_bar.update()
113
+
114
+ if args.output and OUT_VIDEO:
115
+ print(f'making the output video at {args.output} with a FPS of {fps}')
116
+ mmcv.frames2video(out_path, args.output, fps=fps, fourcc='mp4v')
117
+ out_dir.cleanup()
118
+ # print("output:", out_dir)
119
+
120
+ # return output
121
+ # print("output:", out_dir)
122
+ save_dir = 'mot.mp4'
123
+ return save_dir
124
+
125
+ if __name__ == '__main__':
126
+ # main()
127
+ input_video = gr.Video(type="mp4", label="Input Video")
128
+ config = gr.inputs.Textbox(default="configs/mot/deepsort/sort_faster-rcnn_fpn_4e_mot17-private.py")
129
+ output = gr.inputs.Textbox(default="mot.mp4", label="Output Video")
130
+ device = gr.inputs.Radio(choices=["cpu", "cuda"], label="Device used for inference", default="cpu")
131
+ score_thr = gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.3, label="The threshold of score to filter bboxes.")
132
+ output_video = gr.Video(type="mp4", label="Output Image")
133
+
134
+ title = "MMTracking web demo"
135
+ description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmtracking/master/resources/mmtrack-logo.png' width='450''/><div>" \
136
+ "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmtracking'>MMTracking</a> MMTracking是一款基于PyTorch的视频目标感知开源工具箱,是OpenMMLab项目的一部分。" \
137
+ "OpenMMLab Video Perception Toolbox. It supports Video Object Detection (VID), Multiple Object Tracking (MOT), Single Object Tracking (SOT), Video Instance Segmentation (VIS) with a unified framework..</p>"
138
+ article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmtracking'>MMTracking</a></p>" \
139
+ "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmtracking'>gradio build by gatilin</a></a></p>"
140
+
141
+ # Create Gradio interface
142
+ iface = gr.Interface(
143
+ fn=track_mot,
144
+ inputs=[
145
+ input_video, config, output, device, score_thr
146
+ ],
147
+ # outputs="playable_video",
148
+ outputs=output_video,
149
+ title=title, description=description, article=article,
150
+ )
151
+
152
+ # Launch Gradio interface
153
+ iface.launch()