Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from grounded_sam.inference import grounded_segmentation | |
| from grounded_sam.plot import plot_detections, plot_detections_plotly | |
| def app_fn( | |
| image: gr.Image, | |
| labels: str, | |
| threshold: float, | |
| bounding_box_selection: bool | |
| ) -> str: | |
| labels = labels.split("\n") | |
| labels = [label if label.endswith(".") else label + "." for label in labels] | |
| image_array, detections = grounded_segmentation(image, labels, threshold, True) | |
| fig_detection = plot_detections_plotly(image_array, detections, bounding_box_selection) | |
| return fig_detection | |
| if __name__=="__main__": | |
| title = "Grounding SAM - Text-to-Segmentation Model" | |
| with gr.Blocks(title=title) as demo: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown( | |
| """ | |
| Grounded SAM is a text-to-segmentation model that generates segmentation masks from natural language descriptions. | |
| This demo uses Grounding DINO in tandem with SAM to generate segmentation masks from text. | |
| The workflow is as follows: | |
| 1. Select text labels to generate bounding boxes with Grounding DINO. | |
| 2. Prompt the SAM model to generate segmentation masks from the bounding boxes. | |
| 3. Refine the masks if needed. | |
| 4. Visualize the segmentation masks. | |
| ### Notes | |
| - To pass multiple labels, separate them by a new line. | |
| - The model may take a few seconds to generate the segmentation masks as we need to run through two models. | |
| - The refinement is done by default by converting the mask to a polygon and back to a mask with openCV. | |
| - I use in here a concise implementation, but you can find the full code at [GitHub](https://github.com/EduardoPach/grounded-sam) | |
| """ | |
| ) | |
| with gr.Row(): | |
| threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.05, label="Box Threshold") | |
| labels = gr.Textbox(lines=2, max_lines=5, label="Labels") | |
| bounding_box_selection = gr.Checkbox(label="Allow Box Selection") | |
| btn = gr.Button() | |
| with gr.Row(): | |
| img = gr.Image(type="pil") | |
| fig = gr.Plot(label="Segmentation Mask") | |
| btn.click(fn=app_fn, inputs=[img, labels, threshold, bounding_box_selection], outputs=[fig]) | |
| gr.Examples( | |
| [ | |
| ["input_image.jpeg", "a person.\na mountain.", 0.3, False], | |
| ], | |
| inputs = [img, labels, threshold, bounding_box_selection], | |
| outputs = [fig], | |
| fn=app_fn, | |
| cache_examples="lazy", | |
| label='Try this example input!' | |
| ) | |
| demo.launch() |