File size: 5,917 Bytes
32a5465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
# os.system("pip install 'mmcv-full>=1.3.17,<=1.7.0'")
os.system("pip install 'mmcv-full>=1.3.17,<=1.7.0'")
os.system("pip install mmdet==2.25.1")
os.system("git clone https://github.com/open-mmlab/mmtracking.git")
os.system("pip install -r mmtracking/requirements.txt")
os.system("pip install -v -e mmtracking/")
os.system("pip install 'mmtrack'")
import os
import os.path as osp
import gradio as gr
import tempfile
from argparse import ArgumentParser

import mmcv

from mmtrack.apis import inference_mot, init_model

def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--config', help='config file')
    parser.add_argument('--input', help='input video file or folder')
    parser.add_argument(
        '--output', help='output video file (mp4 format) or folder')
    parser.add_argument('--checkpoint', help='checkpoint file')
    parser.add_argument(
        '--score-thr',
        type=float,
        default=0.0,
        help='The threshold of score to filter bboxes.')
    parser.add_argument(
        '--device', default='cuda:0', help='device used for inference')
    parser.add_argument(
        '--show',
        action='store_true',
        help='whether show the results on the fly')
    parser.add_argument(
        '--backend',
        choices=['cv2', 'plt'],
        default='cv2',
        help='the backend to visualize the results')
    parser.add_argument('--fps', help='FPS of the output video')
    args = parser.parse_args()
    return args
def track_mot(input, config, output, device, score_thr):
    args = parse_args()
    args.input = input
    args.config = config
    args.output = output
    args.device = device
    args.score_thr = score_thr
    args.show = False
    args.backend = 'cv2'

    # assert args.output or args.show
    # load images
    if osp.isdir(args.input):
        imgs = sorted(
            filter(lambda x: x.endswith(('.jpg', '.png', '.jpeg')),
                   os.listdir(args.input)),
            key=lambda x: int(x.split('.')[0]))
        IN_VIDEO = False
    else:
        imgs = mmcv.VideoReader(args.input)
        IN_VIDEO = True
    # define output
    if args.output is not None:
        if args.output.endswith('.mp4'):
            OUT_VIDEO = True
            out_dir = tempfile.TemporaryDirectory()
            out_path = out_dir.name
            _out = args.output.rsplit(os.sep, 1)
            if len(_out) > 1:
                os.makedirs(_out[0], exist_ok=True)
        else:
            OUT_VIDEO = False
            out_path = args.output
            os.makedirs(out_path, exist_ok=True)
    #
    fps = args.fps
    if args.show or OUT_VIDEO:
        if fps is None and IN_VIDEO:
            fps = imgs.fps
        if not fps:
            raise ValueError('Please set the FPS for the output video.')
        fps = int(fps)
    #
    # build the model from a config file and a checkpoint file
    model = init_model(args.config, args.checkpoint, device=args.device)

    prog_bar = mmcv.ProgressBar(len(imgs))
    # test and show/save the images
    for i, img in enumerate(imgs):
        if isinstance(img, str):
            img = osp.join(args.input, img)
        result = inference_mot(model, img, frame_id=i)
        if args.output is not None:
            if IN_VIDEO or OUT_VIDEO:
                out_file = osp.join(out_path, f'{i:06d}.jpg')
            else:
                out_file = osp.join(out_path, img.rsplit(os.sep, 1)[-1])
        else:
            out_file = None
        model.show_result(
            img,
            result,
            score_thr=args.score_thr,
            show=args.show,
            wait_time=int(1000. / fps) if fps else 0,
            out_file=out_file,
            backend=args.backend)
        prog_bar.update()

    if args.output and OUT_VIDEO:
        print(f'making the output video at {args.output} with a FPS of {fps}')
        mmcv.frames2video(out_path, args.output, fps=fps, fourcc='mp4v')
        out_dir.cleanup()
    # print("output:", out_dir)

    # return output
    # print("output:", out_dir)
    save_dir = 'mot.mp4'
    return save_dir

if __name__ == '__main__':
    # main()
    input_video = gr.Video(type="mp4", label="Input Video")
    config = gr.inputs.Textbox(default="configs/mot/deepsort/sort_faster-rcnn_fpn_4e_mot17-private.py")
    output = gr.inputs.Textbox(default="mot.mp4", label="Output Video")
    device = gr.inputs.Radio(choices=["cpu", "cuda"], label="Device used for inference", default="cpu")
    score_thr = gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.3, label="The threshold of score to filter bboxes.")
    output_video = gr.Video(type="mp4", label="Output Image")

    title = "MMTracking web demo"
    description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmtracking/master/resources/mmtrack-logo.png' width='450''/><div>" \
                  "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmtracking'>MMTracking</a> MMTracking是一款基于PyTorch的视频目标感知开源工具箱,是OpenMMLab项目的一部分。" \
                  "OpenMMLab Video Perception Toolbox. It supports Video Object Detection (VID), Multiple Object Tracking (MOT), Single Object Tracking (SOT), Video Instance Segmentation (VIS) with a unified framework..</p>"
    article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmtracking'>MMTracking</a></p>" \
              "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmtracking'>gradio build by gatilin</a></a></p>"

    # Create Gradio interface
    iface = gr.Interface(
        fn=track_mot,
        inputs=[
            input_video, config, output, device, score_thr
        ],
        # outputs="playable_video",
        outputs=output_video,
        title=title, description=description, article=article,
    )

    # Launch Gradio interface
    iface.launch()