John Ho commited on
Commit
a083e2f
Β·
1 Parent(s): a2b5338

added point and object_detection features

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import spaces, torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from PIL import Image
 
5
 
6
 
7
  @spaces.GPU
@@ -15,17 +16,23 @@ def load_model():
15
 
16
 
17
  @spaces.GPU
18
- def point(im: Image.Image, object_name: str):
 
 
19
  model = load_model()
20
- return model.detect(im, object_name)["objects"]
 
 
 
21
 
22
 
23
  demo = gr.Interface(
24
- fn=point,
25
  inputs=[
26
  gr.Image(label="Input Image", type="pil"),
27
  gr.Textbox(label="Object to Detect"),
 
28
  ],
29
  outputs=gr.Textbox(label="Output Text"),
30
  )
31
- demo.launch()
 
2
  import spaces, torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from PIL import Image
5
+ from typing import Literal
6
 
7
 
8
  @spaces.GPU
 
16
 
17
 
18
  @spaces.GPU
19
+ def detect(
20
+ im: Image.Image, object_name: str, mode: Literal["point", "object_detection"]
21
+ ):
22
  model = load_model()
23
+ if mode == "point":
24
+ return model.point(im, object_name)["points"]
25
+ elif mode == "object_detection":
26
+ return model.detect(im, object_name)["objects"]
27
 
28
 
29
  demo = gr.Interface(
30
+ fn=detect,
31
  inputs=[
32
  gr.Image(label="Input Image", type="pil"),
33
  gr.Textbox(label="Object to Detect"),
34
+ gr.Dropdown(label="Mode", choices=["point", "object_detection"]),
35
  ],
36
  outputs=gr.Textbox(label="Output Text"),
37
  )
38
+ demo.launch(mcp_server=True)