Spaces:
Build error
Build error
| #!/usr/bin/env python | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import pathlib | |
| import subprocess | |
| import tarfile | |
| if os.getenv('SYSTEM') == 'spaces': | |
| import mim | |
| mim.uninstall('mmcv-full', confirm_yes=True) | |
| mim.install('mmcv-full==1.4.3', is_yes=True) | |
| mim.uninstall('mmdet', confirm_yes=True) | |
| mim.install('mmdet==2.20.0', is_yes=True) | |
| subprocess.call('pip uninstall -y opencv-python'.split()) | |
| subprocess.call('pip uninstall -y opencv-python-headless'.split()) | |
| subprocess.call('pip install opencv-python-headless==4.5.5.64'.split()) | |
| subprocess.call('conda install -c conda-forge pycocotools'.split()) | |
| subprocess.call('pip install detectron2==0.5 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.7/index.html'.split()) | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| from mmdet.apis import init_detector, inference_detector | |
| from utils import show_result | |
| import mmcv | |
| from mmcv import Config | |
| import os.path as osp | |
| DESCRIPTION = '''# OpenPSG | |
| This is an official demo for [OpenPSG](https://github.com/Jingkang50/OpenPSG). | |
| <img id="overview" alt="overview" src="https://camo.githubusercontent.com/880346b66831a8212074787ba9a2301b4d700bd8f765ca11e4845ac0ab34c230/68747470733a2f2f6c6976652e737461746963666c69636b722e636f6d2f36353533352f35323139333837393637375f373531613465306237395f6b2e6a7067" /> | |
| ''' | |
| FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=c-liangyu.openpsg" alt="visitor badge" />' | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--device', type=str, default='cpu') | |
| parser.add_argument('--theme', type=str) | |
| parser.add_argument('--share', action='store_true') | |
| parser.add_argument('--port', type=int) | |
| parser.add_argument('--disable-queue', | |
| dest='enable_queue', | |
| action='store_false') | |
| return parser.parse_args() | |
| def update_input_image(image: np.ndarray) -> dict: | |
| if image is None: | |
| return gr.Image.update(value=None) | |
| scale = 1500 / max(image.shape[:2]) | |
| if scale < 1: | |
| image = cv2.resize(image, None, fx=scale, fy=scale) | |
| return gr.Image.update(value=image) | |
| def set_example_image(example: list) -> dict: | |
| return gr.Image.update(value=example[0]) | |
| def infer(model, input_image, num_rel): | |
| result = inference_detector(model, input_image) | |
| return show_result(input_image, | |
| result, | |
| is_one_stage=True, | |
| num_rel=num_rel, | |
| show=True | |
| ) | |
| def main(): | |
| args = parse_args() | |
| model_ckt ='OpenPSG/checkpoints/epoch_60.pth' | |
| cfg = Config.fromfile('OpenPSG/configs/psgtr/psgtr_r50_psg_inference.py') | |
| model = init_detector(cfg, model_ckt, device=args.device) | |
| with gr.Blocks(theme=args.theme, css='style.css') as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| input_image = gr.Image(label='Input Image', type='numpy') | |
| with gr.Group(): | |
| with gr.Row(): | |
| num_rel = gr.Slider( | |
| 5, | |
| 100, | |
| step=5, | |
| value=20, | |
| label='Number of Relations') | |
| with gr.Row(): | |
| run_button = gr.Button(value='Run') | |
| # prediction_results = gr.Variable() | |
| with gr.Column(): | |
| with gr.Row(): | |
| # visualization = gr.Image(label='Result', type='numpy') | |
| result = gr.Gallery(label='Result', type='numpy') | |
| with gr.Row(): | |
| paths = sorted(pathlib.Path('images').rglob('*.jpg')) | |
| example_images = gr.Dataset(components=[input_image], | |
| samples=[[path.as_posix()] | |
| for path in paths]) | |
| gr.Markdown(FOOTER) | |
| input_image.change(fn=update_input_image, | |
| inputs=input_image, | |
| outputs=input_image) | |
| run_button.click(fn=infer, | |
| inputs=[ | |
| model, input_image | |
| ], | |
| outputs=result) | |
| example_images.click(fn=set_example_image, | |
| inputs=example_images, | |
| outputs=input_image) | |
| demo.launch( | |
| enable_queue=args.enable_queue, | |
| server_port=args.port, | |
| share=args.share, | |
| ) | |
| if __name__ == '__main__': | |
| main() | |