Spaces:
Runtime error
Runtime error
| import os | |
| from pathlib import Path | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from ultralyticsplus import YOLO, render_result, postprocess_classify_output | |
| from utils import load_models_from_txt_files, get_dataset_id_from_model_id, get_task_from_readme | |
| EXAMPLE_IMAGE_DIR = 'example_images' | |
| DEFAULT_DET_MODEL_ID = 'keremberke/yolov8m-valorant-detection' | |
| DEFAULT_DET_DATASET_ID = 'keremberke/valorant-object-detection' | |
| DEFAULT_SEG_MODEL_ID = 'keremberke/yolov8s-building-segmentation' | |
| DEFAULT_SEG_DATASET_ID = 'keremberke/satellite-building-segmentation' | |
| DEFAULT_CLS_MODEL_ID = 'keremberke/yolov8m-chest-xray-classification' | |
| DEFAULT_CLS_DATASET_ID = 'keremberke/chest-xray-classification' | |
| # load model ids and default models | |
| det_model_ids, seg_model_ids, cls_model_ids = load_models_from_txt_files() | |
| task_to_model_ids = {'detect': det_model_ids, 'segment': seg_model_ids, 'classify': cls_model_ids} | |
| det_model = YOLO(DEFAULT_DET_MODEL_ID) | |
| det_model_id = DEFAULT_DET_MODEL_ID | |
| seg_model = YOLO(DEFAULT_SEG_MODEL_ID) | |
| seg_model_id = DEFAULT_SEG_MODEL_ID | |
| cls_model = YOLO(DEFAULT_CLS_MODEL_ID) | |
| cls_model_id = DEFAULT_CLS_MODEL_ID | |
| def get_examples(task): | |
| examples = [] | |
| Path(EXAMPLE_IMAGE_DIR).mkdir(parents=True, exist_ok=True) | |
| image_ind = 0 | |
| for model_id in task_to_model_ids[task]: | |
| dataset_id = get_dataset_id_from_model_id(model_id) | |
| ds = load_dataset(dataset_id, name="mini")["validation"] | |
| for ind in range(min(2, len(ds))): | |
| jpeg_image_file = ds[ind]["image"] | |
| image_file_path = str(Path(EXAMPLE_IMAGE_DIR) / f"{task}_example_{image_ind}.jpg") | |
| jpeg_image_file.save(image_file_path, format='JPEG', quality=100) | |
| image_path = os.path.abspath(image_file_path) | |
| examples.append([image_path, model_id, 0.25]) | |
| image_ind += 1 | |
| return examples | |
| # load default examples using default datasets | |
| det_examples = get_examples('detect') | |
| seg_examples = get_examples('segment') | |
| cls_examples = get_examples('classify') | |
| def predict(image, model_id, threshold): | |
| """Perform inference on image.""" | |
| # set task | |
| if model_id in det_model_ids: | |
| task = 'detect' | |
| elif model_id in seg_model_ids: | |
| task = 'segment' | |
| elif model_id in cls_model_ids: | |
| task = 'classify' | |
| else: | |
| raise ValueError(f"Invalid model_id: {model_id}") | |
| # set model | |
| if task == 'detect': | |
| global det_model | |
| global det_model_id | |
| if model_id != det_model_id: | |
| det_model = YOLO(model_id) | |
| det_model_id = model_id | |
| model = det_model | |
| elif task == 'segment': | |
| global seg_model | |
| global seg_model_id | |
| if model_id != seg_model_id: | |
| seg_model = YOLO(model_id) | |
| seg_model_id = model_id | |
| model = seg_model | |
| elif task == 'classify': | |
| global cls_model | |
| global cls_model_id | |
| if model_id != cls_model_id: | |
| cls_model = YOLO(model_id) | |
| cls_model_id = model_id | |
| model = cls_model | |
| else: | |
| raise ValueError(f"Invalid task: {task}") | |
| # set model parameters | |
| model.overrides['conf'] = threshold | |
| # perform inference | |
| results = model.predict(image) | |
| print(model_id) | |
| print(task) | |
| if task in ['detect', 'segment']: | |
| # draw predictions | |
| output = render_result(model=model, image=image, result=results[0]) | |
| elif task == 'classify': | |
| # postprocess classification output | |
| output = postprocess_classify_output(model, result=results[0]) | |
| else: | |
| raise ValueError(f"Invalid task: {task}") | |
| return output | |
| with gr.Blocks() as demo: | |
| gr.Markdown("""# <p align='center'><a href="https://github.com/keremberke/awesome-yolov8-models" target='_blank'><img width='500px' src='https://user-images.githubusercontent.com/34196005/215836968-fb54e066-a524-4caf-b469-92bbaa96f921.gif' /></a></p> | |
| <p style='text-align: center'> | |
| <br> <a href='https://yolov8.xyz' target='_blank'>project website</a> | <a href='https://github.com/keremberke/awesome-yolov8-models' target='_blank'>project github</a> | |
| </p> | |
| <p style='text-align: center'> | |
| Follow me for more! | |
| <br> <a href='https://twitter.com/_keremberke' target='_blank'>twitter</a> | <a href='https://github.com/keremberke' target='_blank'>github</a> | <a href='https://www.linkedin.com/in/kerem-berke-bba6a5204/' target='_blank'>linkedin</a> | |
| </p> | |
| """) | |
| with gr.Tab("Detection"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| detect_input = gr.Image() | |
| detect_model_id = gr.Dropdown(choices=det_model_ids, label="Model:", value=DEFAULT_DET_MODEL_ID, interactive=True) | |
| detect_threshold = gr.Slider(maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True) | |
| detect_button = gr.Button("Detect!") | |
| with gr.Column(): | |
| detect_output = gr.Image(label="Predictions:", interactive=False) | |
| with gr.Row(): | |
| half_ind = int(len(det_examples) / 2) | |
| with gr.Column(): | |
| gr.Examples( | |
| det_examples[half_ind:], | |
| inputs=[detect_input, detect_model_id, detect_threshold], | |
| outputs=detect_output, | |
| fn=predict, | |
| cache_examples=False, | |
| run_on_click=False, | |
| ) | |
| with gr.Column(): | |
| gr.Examples( | |
| det_examples[:half_ind], | |
| inputs=[detect_input, detect_model_id, detect_threshold], | |
| outputs=detect_output, | |
| fn=predict, | |
| cache_examples=False, | |
| run_on_click=False, | |
| ) | |
| with gr.Tab("Segmentation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| segment_input = gr.Image() | |
| segment_model_id = gr.Dropdown(choices=seg_model_ids, label="Model:", value=DEFAULT_SEG_MODEL_ID, interactive=True) | |
| segment_threshold = gr.Slider(maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True) | |
| segment_button = gr.Button("Segment!") | |
| with gr.Column(): | |
| segment_output = gr.Image(label="Predictions:", interactive=False) | |
| with gr.Row(): | |
| half_ind = int(len(seg_examples) / 2) | |
| with gr.Column(): | |
| gr.Examples( | |
| seg_examples[half_ind:], | |
| inputs=[segment_input, segment_model_id, segment_threshold], | |
| outputs=segment_output, | |
| fn=predict, | |
| cache_examples=False, | |
| run_on_click=False, | |
| ) | |
| with gr.Column(): | |
| gr.Examples( | |
| seg_examples[:half_ind], | |
| inputs=[segment_input, segment_model_id, segment_threshold], | |
| outputs=segment_output, | |
| fn=predict, | |
| cache_examples=False, | |
| run_on_click=False, | |
| ) | |
| with gr.Tab("Classification"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| classify_input = gr.Image() | |
| classify_model_id = gr.Dropdown(choices=cls_model_ids, label="Model:", value=DEFAULT_CLS_MODEL_ID, interactive=True) | |
| classify_threshold = gr.Slider(maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True) | |
| classify_button = gr.Button("Classify!") | |
| with gr.Column(): | |
| classify_output = gr.Label( | |
| label="Predictions:", show_label=True, num_top_classes=5 | |
| ) | |
| with gr.Row(): | |
| half_ind = int(len(cls_examples) / 2) | |
| with gr.Column(): | |
| gr.Examples( | |
| cls_examples[half_ind:], | |
| inputs=[classify_input, classify_model_id, classify_threshold], | |
| outputs=classify_output, | |
| fn=predict, | |
| cache_examples=False, | |
| run_on_click=False, | |
| ) | |
| with gr.Column(): | |
| gr.Examples( | |
| cls_examples[:half_ind], | |
| inputs=[classify_input, classify_model_id, classify_threshold], | |
| outputs=classify_output, | |
| fn=predict, | |
| cache_examples=False, | |
| run_on_click=False, | |
| ) | |
| detect_button.click( | |
| predict, inputs=[detect_input, detect_model_id, detect_threshold], outputs=detect_output, api_name="detect" | |
| ) | |
| segment_button.click( | |
| predict, inputs=[segment_input, segment_model_id, segment_threshold], outputs=segment_output, api_name="segment" | |
| ) | |
| classify_button.click( | |
| predict, inputs=[classify_input, classify_model_id, classify_threshold], outputs=classify_output, api_name="classify" | |
| ) | |
| demo.launch(enable_queue=True) |