Spaces:
Sleeping
Sleeping
switch to click model
Browse files
app.py
CHANGED
|
@@ -6,10 +6,9 @@ import torch
|
|
| 6 |
import html
|
| 7 |
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
| 8 |
|
| 9 |
-
pretrained_repo_name = 'ivelin/donut-refexp-
|
| 10 |
pretrained_revision = 'main'
|
| 11 |
-
# revision
|
| 12 |
-
# revision: '41210d7c42a22e77711711ec45508a6b63ec380f' # : IoU=0.42
|
| 13 |
# use 'main' for latest revision
|
| 14 |
print(f"Loading model checkpoint: {pretrained_repo_name}")
|
| 15 |
|
|
@@ -31,7 +30,7 @@ def process_refexp(image: Image, prompt: str):
|
|
| 31 |
pixel_values = processor(image, return_tensors="pt").pixel_values
|
| 32 |
|
| 33 |
# prepare decoder inputs
|
| 34 |
-
task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><
|
| 35 |
prompt = task_prompt.replace("{user_input}", prompt)
|
| 36 |
decoder_input_ids = processor.tokenizer(
|
| 37 |
prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
|
@@ -61,37 +60,28 @@ def process_refexp(image: Image, prompt: str):
|
|
| 61 |
fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
|
| 62 |
seqjson = processor.token2json(sequence)
|
| 63 |
|
| 64 |
-
# safeguard in case predicted sequence does not include a
|
| 65 |
-
|
| 66 |
-
if
|
| 67 |
print(
|
| 68 |
-
f"
|
| 69 |
-
|
| 70 |
-
return
|
| 71 |
|
| 72 |
-
print(f"predicted
|
| 73 |
-
# safeguard in case text prediction is missing some
|
| 74 |
# or coordinates are not valid numeric values
|
| 75 |
try:
|
| 76 |
-
|
| 77 |
except ValueError:
|
| 78 |
-
|
| 79 |
try:
|
| 80 |
-
|
| 81 |
except ValueError:
|
| 82 |
-
|
| 83 |
-
try:
|
| 84 |
-
xmax = float(bbox.get("xmax", 1))
|
| 85 |
-
except ValueError:
|
| 86 |
-
xmax = 1
|
| 87 |
-
try:
|
| 88 |
-
ymax = float(bbox.get("ymax", 1))
|
| 89 |
-
except ValueError:
|
| 90 |
-
ymax = 1
|
| 91 |
# replace str with float coords
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
print(f"predicted bounding box with float coordinates: {bbox}")
|
| 95 |
|
| 96 |
print(f"image object: {image}")
|
| 97 |
print(f"image size: {image.size}")
|
|
@@ -99,26 +89,25 @@ def process_refexp(image: Image, prompt: str):
|
|
| 99 |
print(f"image width, height: {width, height}")
|
| 100 |
print(f"processed prompt: {prompt}")
|
| 101 |
|
| 102 |
-
# safeguard in case text prediction is missing some
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
xmax = math.floor(width*bbox["xmax"])
|
| 106 |
-
ymax = math.floor(height*bbox["ymax"])
|
| 107 |
|
| 108 |
print(
|
| 109 |
-
f"to image pixel values:
|
| 110 |
-
|
| 111 |
-
shape = [(xmin, ymin), (xmax, ymax)]
|
| 112 |
|
| 113 |
-
#
|
| 114 |
img1 = ImageDraw.Draw(image)
|
| 115 |
-
img1.rectangle(shape, outline="green", width=5)
|
| 116 |
-
img1.rectangle(shape, outline="white", width=2)
|
| 117 |
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
-
title = "Demo: Donut 🍩 for UI RefExp (by GuardianUI)"
|
| 122 |
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on [UIBert RefExp](https://huggingface.co/datasets/ivelin/ui_refexp_saved) Dataset (UI Referring Expression). To use it, simply upload your image and type a prompt and click 'submit', or click one of the examples to load them. See the model training <a href='https://colab.research.google.com/github/ivelin/donut_ui_refexp/blob/main/Fine_tune_Donut_on_UI_RefExp.ipynb' target='_parent'>Colab Notebook</a> for this space. Read more at the links below."
|
| 123 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
|
| 124 |
examples = [["example_1.jpg", "select the setting icon from top right corner"],
|
|
|
|
| 6 |
import html
|
| 7 |
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
| 8 |
|
| 9 |
+
pretrained_repo_name = 'ivelin/donut-refexp-click'
|
| 10 |
pretrained_revision = 'main'
|
| 11 |
+
# revision can be git commit hash, branch or tag
|
|
|
|
| 12 |
# use 'main' for latest revision
|
| 13 |
print(f"Loading model checkpoint: {pretrained_repo_name}")
|
| 14 |
|
|
|
|
| 30 |
pixel_values = processor(image, return_tensors="pt").pixel_values
|
| 31 |
|
| 32 |
# prepare decoder inputs
|
| 33 |
+
task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_center>"
|
| 34 |
prompt = task_prompt.replace("{user_input}", prompt)
|
| 35 |
decoder_input_ids = processor.tokenizer(
|
| 36 |
prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
|
|
|
| 60 |
fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
|
| 61 |
seqjson = processor.token2json(sequence)
|
| 62 |
|
| 63 |
+
# safeguard in case predicted sequence does not include a target_center token
|
| 64 |
+
center_point = seqjson.get('target_center')
|
| 65 |
+
if center_point is None:
|
| 66 |
print(
|
| 67 |
+
f"predicted sequence has no target_center, seq:{sequence}")
|
| 68 |
+
center_point = {"x": 0, "y": 0}
|
| 69 |
+
return center_point
|
| 70 |
|
| 71 |
+
print(f"predicted center_point with text coordinates: {center_point}")
|
| 72 |
+
# safeguard in case text prediction is missing some center point coordinates
|
| 73 |
# or coordinates are not valid numeric values
|
| 74 |
try:
|
| 75 |
+
x = float(center_point.get("x", 0))
|
| 76 |
except ValueError:
|
| 77 |
+
x = 0
|
| 78 |
try:
|
| 79 |
+
y = float(center_point.get("y", 0))
|
| 80 |
except ValueError:
|
| 81 |
+
y = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
# replace str with float coords
|
| 83 |
+
center_point = {"x": x, "y": y, "decoder output sequence": sequence}
|
| 84 |
+
print(f"predicted center_point with float coordinates: {center_point}")
|
|
|
|
| 85 |
|
| 86 |
print(f"image object: {image}")
|
| 87 |
print(f"image size: {image.size}")
|
|
|
|
| 89 |
print(f"image width, height: {width, height}")
|
| 90 |
print(f"processed prompt: {prompt}")
|
| 91 |
|
| 92 |
+
# safeguard in case text prediction is missing some center point coordinates
|
| 93 |
+
x = math.floor(width*center_point["x"])
|
| 94 |
+
y = math.floor(height*center_point["y"])
|
|
|
|
|
|
|
| 95 |
|
| 96 |
print(
|
| 97 |
+
f"to image pixel values: x, y: {x, y}")
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
# draw center point circle
|
| 100 |
img1 = ImageDraw.Draw(image)
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
r = 1
|
| 103 |
+
shape = [(x-r, y-r), (x+r, y+r)]
|
| 104 |
+
img1.ellipse(shape, outline="green", width=10)
|
| 105 |
+
img1.ellipse(shape, outline="white", width=5)
|
| 106 |
+
|
| 107 |
+
return image, center_point
|
| 108 |
|
| 109 |
|
| 110 |
+
title = "Demo: Donut 🍩 for UI RefExp - Center Point (by GuardianUI)"
|
| 111 |
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on [UIBert RefExp](https://huggingface.co/datasets/ivelin/ui_refexp_saved) Dataset (UI Referring Expression). To use it, simply upload your image and type a prompt and click 'submit', or click one of the examples to load them. See the model training <a href='https://colab.research.google.com/github/ivelin/donut_ui_refexp/blob/main/Fine_tune_Donut_on_UI_RefExp.ipynb' target='_parent'>Colab Notebook</a> for this space. Read more at the links below."
|
| 112 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
|
| 113 |
examples = [["example_1.jpg", "select the setting icon from top right corner"],
|