Spaces:
Running
on
Zero
Running
on
Zero
gr.selectdata can't be pickled fix
Browse files
app.py
CHANGED
|
@@ -239,15 +239,15 @@ def clear_image():
|
|
| 239 |
return "", {}, None, {}, None
|
| 240 |
|
| 241 |
@spaces.GPU(duration=5)
|
| 242 |
-
def cam_inference(img, threshold, alpha,
|
| 243 |
-
target_tag_index = tags[
|
| 244 |
tensor = transform(img).unsqueeze(0)
|
| 245 |
|
|
|
|
| 246 |
if torch.cuda.is_available():
|
| 247 |
tensor = tensor.to(device, dtype=torch.float16)
|
| 248 |
else:
|
| 249 |
tensor = tensor.to(device)
|
| 250 |
-
|
| 251 |
tensor.requires_grad_()
|
| 252 |
|
| 253 |
gradients = {}
|
|
@@ -263,7 +263,7 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
|
| 263 |
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
| 264 |
|
| 265 |
probits = model(tensor)[0]
|
| 266 |
-
|
| 267 |
model.zero_grad()
|
| 268 |
probits[target_tag_index].backward(retain_graph=True)
|
| 269 |
|
|
@@ -470,7 +470,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 470 |
|
| 471 |
label_box.select(
|
| 472 |
fn=cam_inference,
|
| 473 |
-
inputs=[original_image_state, cam_slider, alpha_slider],
|
| 474 |
outputs=[image, cam_state],
|
| 475 |
show_progress='minimal'
|
| 476 |
)
|
|
|
|
| 239 |
return "", {}, None, {}, None
|
| 240 |
|
| 241 |
@spaces.GPU(duration=5)
|
| 242 |
+
def cam_inference(img, threshold, alpha, selected_tag: str):
|
| 243 |
+
target_tag_index = tags[selected_tag]
|
| 244 |
tensor = transform(img).unsqueeze(0)
|
| 245 |
|
| 246 |
+
|
| 247 |
if torch.cuda.is_available():
|
| 248 |
tensor = tensor.to(device, dtype=torch.float16)
|
| 249 |
else:
|
| 250 |
tensor = tensor.to(device)
|
|
|
|
| 251 |
tensor.requires_grad_()
|
| 252 |
|
| 253 |
gradients = {}
|
|
|
|
| 263 |
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
| 264 |
|
| 265 |
probits = model(tensor)[0]
|
| 266 |
+
|
| 267 |
model.zero_grad()
|
| 268 |
probits[target_tag_index].backward(retain_graph=True)
|
| 269 |
|
|
|
|
| 470 |
|
| 471 |
label_box.select(
|
| 472 |
fn=cam_inference,
|
| 473 |
+
inputs=[original_image_state, cam_slider, alpha_slider, label_box],
|
| 474 |
outputs=[image, cam_state],
|
| 475 |
show_progress='minimal'
|
| 476 |
)
|