Spaces:
Running
on
Zero
Running
on
Zero
fix pt2
Browse files
app.py
CHANGED
|
@@ -236,13 +236,29 @@ def create_tags(threshold, sorted_tag_score: dict):
|
|
| 236 |
return text_no_impl, filtered_tag_score
|
| 237 |
|
| 238 |
def clear_image():
|
| 239 |
-
return "", {}, None, {}, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
@spaces.GPU(duration=5)
|
| 242 |
def cam_inference(img, threshold, alpha, selected_tag: str):
|
| 243 |
-
|
| 244 |
-
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
if torch.cuda.is_available():
|
| 248 |
tensor = tensor.to(device, dtype=torch.float16)
|
|
@@ -438,6 +454,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 438 |
original_image_state = gr.State() # stash a copy of the input image
|
| 439 |
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
| 440 |
cam_state = gr.State()
|
|
|
|
| 441 |
with gr.Row():
|
| 442 |
with gr.Column():
|
| 443 |
image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
|
|
@@ -458,7 +475,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 458 |
image.clear(
|
| 459 |
fn=clear_image,
|
| 460 |
inputs=[],
|
| 461 |
-
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
|
| 462 |
)
|
| 463 |
|
| 464 |
threshold_slider.input(
|
|
@@ -469,8 +486,14 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 469 |
)
|
| 470 |
|
| 471 |
label_box.select(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
fn=cam_inference,
|
| 473 |
-
inputs=[original_image_state, cam_slider, alpha_slider,
|
| 474 |
outputs=[image, cam_state],
|
| 475 |
show_progress='minimal'
|
| 476 |
)
|
|
|
|
| 236 |
return text_no_impl, filtered_tag_score
|
| 237 |
|
| 238 |
def clear_image():
|
| 239 |
+
return "", {}, None, {}, None, None
|
| 240 |
+
|
| 241 |
+
def extract_selected_tag(evt: gr.SelectData):
|
| 242 |
+
# evt is a gr.SelectData; keep it out of GPU calls
|
| 243 |
+
try:
|
| 244 |
+
return evt.value
|
| 245 |
+
except Exception:
|
| 246 |
+
return None
|
| 247 |
|
| 248 |
@spaces.GPU(duration=5)
|
| 249 |
def cam_inference(img, threshold, alpha, selected_tag: str):
|
| 250 |
+
if img is None or not selected_tag:
|
| 251 |
+
return img, None
|
| 252 |
|
| 253 |
+
# Map to index
|
| 254 |
+
if selected_tag not in tags:
|
| 255 |
+
key = selected_tag.replace("_", " ")
|
| 256 |
+
if key not in tags:
|
| 257 |
+
return img, None
|
| 258 |
+
selected_tag = key
|
| 259 |
+
|
| 260 |
+
target_tag_index = tags[selected_tag]
|
| 261 |
+
tensor = transform(img).unsqueeze(0)
|
| 262 |
|
| 263 |
if torch.cuda.is_available():
|
| 264 |
tensor = tensor.to(device, dtype=torch.float16)
|
|
|
|
| 454 |
original_image_state = gr.State() # stash a copy of the input image
|
| 455 |
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
| 456 |
cam_state = gr.State()
|
| 457 |
+
selected_tag_state = gr.State(value=None)
|
| 458 |
with gr.Row():
|
| 459 |
with gr.Column():
|
| 460 |
image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
|
|
|
|
| 475 |
image.clear(
|
| 476 |
fn=clear_image,
|
| 477 |
inputs=[],
|
| 478 |
+
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state, selected_tag_state]
|
| 479 |
)
|
| 480 |
|
| 481 |
threshold_slider.input(
|
|
|
|
| 486 |
)
|
| 487 |
|
| 488 |
label_box.select(
|
| 489 |
+
fn=extract_selected_tag,
|
| 490 |
+
inputs=None,
|
| 491 |
+
outputs=selected_tag_state,
|
| 492 |
+
show_progress='hidden',
|
| 493 |
+
queue=False # This should be a very fast operation
|
| 494 |
+
).then(
|
| 495 |
fn=cam_inference,
|
| 496 |
+
inputs=[original_image_state, cam_slider, alpha_slider, selected_tag_state],
|
| 497 |
outputs=[image, cam_state],
|
| 498 |
show_progress='minimal'
|
| 499 |
)
|