Spaces:
Runtime error
Runtime error
AAAAAAAyq
commited on
Commit
·
4b45202
1
Parent(s):
e03ed2b
Better points mode & Fix the Contours button bug
Browse files- app_gradio.py +4 -4
- utils/tools.py +4 -3
app_gradio.py
CHANGED
|
@@ -221,7 +221,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
|
|
| 221 |
input_size_slider.render()
|
| 222 |
|
| 223 |
with gr.Row():
|
| 224 |
-
|
| 225 |
|
| 226 |
with gr.Column():
|
| 227 |
segment_btn_e = gr.Button("Segment Everything", variant='primary')
|
|
@@ -298,7 +298,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
|
|
| 298 |
info='Our model was trained on a size of 1024')
|
| 299 |
with gr.Row():
|
| 300 |
with gr.Column():
|
| 301 |
-
|
| 302 |
text_box = gr.Textbox(label="text prompt", value="a black dog")
|
| 303 |
|
| 304 |
with gr.Column():
|
|
@@ -334,7 +334,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
|
|
| 334 |
iou_threshold,
|
| 335 |
conf_threshold,
|
| 336 |
mor_check,
|
| 337 |
-
|
| 338 |
retina_check,
|
| 339 |
],
|
| 340 |
outputs=segm_img_e)
|
|
@@ -350,7 +350,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
|
|
| 350 |
iou_threshold,
|
| 351 |
conf_threshold,
|
| 352 |
mor_check,
|
| 353 |
-
|
| 354 |
retina_check,
|
| 355 |
text_box,
|
| 356 |
],
|
|
|
|
| 221 |
input_size_slider.render()
|
| 222 |
|
| 223 |
with gr.Row():
|
| 224 |
+
contour_check_e = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
|
| 225 |
|
| 226 |
with gr.Column():
|
| 227 |
segment_btn_e = gr.Button("Segment Everything", variant='primary')
|
|
|
|
| 298 |
info='Our model was trained on a size of 1024')
|
| 299 |
with gr.Row():
|
| 300 |
with gr.Column():
|
| 301 |
+
contour_check_t = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
|
| 302 |
text_box = gr.Textbox(label="text prompt", value="a black dog")
|
| 303 |
|
| 304 |
with gr.Column():
|
|
|
|
| 334 |
iou_threshold,
|
| 335 |
conf_threshold,
|
| 336 |
mor_check,
|
| 337 |
+
contour_check_e,
|
| 338 |
retina_check,
|
| 339 |
],
|
| 340 |
outputs=segm_img_e)
|
|
|
|
| 350 |
iou_threshold,
|
| 351 |
conf_threshold,
|
| 352 |
mor_check,
|
| 353 |
+
contour_check_t,
|
| 354 |
retina_check,
|
| 355 |
text_box,
|
| 356 |
],
|
utils/tools.py
CHANGED
|
@@ -400,16 +400,17 @@ def point_prompt(masks, points, point_label, target_height, target_width): # nu
|
|
| 400 |
for point in points
|
| 401 |
]
|
| 402 |
onemask = np.zeros((h, w))
|
|
|
|
| 403 |
for i, annotation in enumerate(masks):
|
| 404 |
if type(annotation) == dict:
|
| 405 |
-
mask = annotation[
|
| 406 |
else:
|
| 407 |
mask = annotation
|
| 408 |
for i, point in enumerate(points):
|
| 409 |
if mask[point[1], point[0]] == 1 and point_label[i] == 1:
|
| 410 |
-
onemask
|
| 411 |
if mask[point[1], point[0]] == 1 and point_label[i] == 0:
|
| 412 |
-
onemask
|
| 413 |
onemask = onemask >= 1
|
| 414 |
return onemask, 0
|
| 415 |
|
|
|
|
| 400 |
for point in points
|
| 401 |
]
|
| 402 |
onemask = np.zeros((h, w))
|
| 403 |
+
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
|
| 404 |
for i, annotation in enumerate(masks):
|
| 405 |
if type(annotation) == dict:
|
| 406 |
+
mask = annotation['segmentation']
|
| 407 |
else:
|
| 408 |
mask = annotation
|
| 409 |
for i, point in enumerate(points):
|
| 410 |
if mask[point[1], point[0]] == 1 and point_label[i] == 1:
|
| 411 |
+
onemask[mask] = 1
|
| 412 |
if mask[point[1], point[0]] == 1 and point_label[i] == 0:
|
| 413 |
+
onemask[mask] = 0
|
| 414 |
onemask = onemask >= 1
|
| 415 |
return onemask, 0
|
| 416 |
|