Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,7 +34,7 @@ def pick_dtype(device: str) -> torch.dtype:
|
|
| 34 |
return torch.bfloat16 if major >= 8 else torch.float16 # Ampere+ -> bf16
|
| 35 |
if device == "mps":
|
| 36 |
return torch.float16
|
| 37 |
-
return torch.
|
| 38 |
|
| 39 |
def move_to_device(batch, device: str):
|
| 40 |
if isinstance(batch, dict):
|
|
@@ -82,6 +82,52 @@ def trim_generated(generated_ids, inputs):
|
|
| 82 |
return [out_ids for out_ids in generated_ids]
|
| 83 |
return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)]
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# --- Load model/processor ON CPU at import time (required for ZeroGPU) ---
|
| 86 |
print(f"Loading model and processor for {MODEL_ID} on CPU startup (ZeroGPU safe)...")
|
| 87 |
model = None
|
|
@@ -168,8 +214,7 @@ def run_inference_localization(
|
|
| 168 |
return decoded_output[0] if decoded_output else ""
|
| 169 |
|
| 170 |
# --- Gradio processing function (ZeroGPU-visible) ---
|
| 171 |
-
#
|
| 172 |
-
@spaces.GPU(duration=120) # keep GPU attached briefly between calls (seconds)
|
| 173 |
def predict_click_location(input_pil_image: Image.Image, instruction: str):
|
| 174 |
if not model_loaded or not processor or not model:
|
| 175 |
return f"Model not loaded. Error: {load_error_message}", None, "device: n/a | dtype: n/a"
|
|
@@ -229,21 +274,17 @@ def predict_click_location(input_pil_image: Image.Image, instruction: str):
|
|
| 229 |
|
| 230 |
# 4) Parse coordinates and draw marker
|
| 231 |
output_image_with_click = resized_image.copy().convert("RGB")
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
print(f"Predicted and drawn click at: ({x}, {y}) on resized image ({resized_width}x{resized_height})")
|
| 242 |
-
except Exception as e:
|
| 243 |
-
print(f"Error drawing on image: {e}")
|
| 244 |
-
traceback.print_exc()
|
| 245 |
else:
|
| 246 |
-
print(f"Could not parse
|
| 247 |
|
| 248 |
return coordinates_str, output_image_with_click, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}"
|
| 249 |
|
|
@@ -293,7 +334,7 @@ else:
|
|
| 293 |
|
| 294 |
with gr.Column(scale=1):
|
| 295 |
output_coords_component = gr.Textbox(
|
| 296 |
-
label="Predicted Coordinates (
|
| 297 |
interactive=False
|
| 298 |
)
|
| 299 |
output_image_component = gr.Image(
|
|
|
|
| 34 |
return torch.bfloat16 if major >= 8 else torch.float16 # Ampere+ -> bf16
|
| 35 |
if device == "mps":
|
| 36 |
return torch.float16
|
| 37 |
+
return torch.float32 # CPU: FP32 is usually fastest & most stable
|
| 38 |
|
| 39 |
def move_to_device(batch, device: str):
|
| 40 |
if isinstance(batch, dict):
|
|
|
|
| 82 |
return [out_ids for out_ids in generated_ids]
|
| 83 |
return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)]
|
| 84 |
|
| 85 |
+
# --- Parsing helper: normalize various UI-TARS click formats to (x, y) ---
|
| 86 |
+
def parse_click_coordinates(text: str, img_w: int, img_h: int):
|
| 87 |
+
"""
|
| 88 |
+
Returns (x, y) in image coordinates, clamped to bounds, or None.
|
| 89 |
+
Handles:
|
| 90 |
+
- Click(start_box='(x,y)') / Click(end_box='(x,y)')
|
| 91 |
+
- Click(box='(x1,y1,x2,y2)') -> center
|
| 92 |
+
- Click(x, y)
|
| 93 |
+
- Click({'x':..., 'y':...}) / Click({"x":...,"y":...})
|
| 94 |
+
Preference: start_box > end_box when both exist.
|
| 95 |
+
"""
|
| 96 |
+
s = str(text)
|
| 97 |
+
|
| 98 |
+
# 1) start_box / end_box
|
| 99 |
+
pairs = re.findall(r"(start_box|end_box)\s*=\s*['\"]\(\s*(\d+)\s*,\s*(\d+)\s*\)['\"]", s)
|
| 100 |
+
if pairs:
|
| 101 |
+
start = next(((int(x), int(y)) for k, x, y in pairs if k == "start_box"), None)
|
| 102 |
+
if start:
|
| 103 |
+
x, y = start
|
| 104 |
+
return max(0, min(x, img_w - 1)), max(0, min(y, img_h - 1))
|
| 105 |
+
end = next(((int(x), int(y)) for k, x, y in pairs if k == "end_box"), None)
|
| 106 |
+
if end:
|
| 107 |
+
x, y = end
|
| 108 |
+
return max(0, min(x, img_w - 1)), max(0, min(y, img_h - 1))
|
| 109 |
+
|
| 110 |
+
# 2) box='(x1,y1,x2,y2)' -> center
|
| 111 |
+
m = re.search(r"box\s*=\s*['\"]\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)['\"]", s)
|
| 112 |
+
if m:
|
| 113 |
+
x1, y1, x2, y2 = map(int, m.groups())
|
| 114 |
+
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
| 115 |
+
return max(0, min(cx, img_w - 1)), max(0, min(cy, img_h - 1))
|
| 116 |
+
|
| 117 |
+
# 3) Direct Click(x, y)
|
| 118 |
+
m = re.search(r"Click\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)", s)
|
| 119 |
+
if m:
|
| 120 |
+
x, y = int(m.group(1)), int(m.group(2))
|
| 121 |
+
return max(0, min(x, img_w - 1)), max(0, min(y, img_h - 1))
|
| 122 |
+
|
| 123 |
+
# 4) JSON-ish dicts
|
| 124 |
+
m = re.search(r"Click\s*\(\s*[{[][^)}]*['\"]?x['\"]?\s*:\s*(\d+)\s*,\s*['\"]?y['\"]?\s*:\s*(\d+)[^)}]*\)\s*", s)
|
| 125 |
+
if m:
|
| 126 |
+
x, y = int(m.group(1)), int(m.group(2))
|
| 127 |
+
return max(0, min(x, img_w - 1)), max(0, min(y, img_h - 1))
|
| 128 |
+
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
# --- Load model/processor ON CPU at import time (required for ZeroGPU) ---
|
| 132 |
print(f"Loading model and processor for {MODEL_ID} on CPU startup (ZeroGPU safe)...")
|
| 133 |
model = None
|
|
|
|
| 214 |
return decoded_output[0] if decoded_output else ""
|
| 215 |
|
| 216 |
# --- Gradio processing function (ZeroGPU-visible) ---
|
| 217 |
+
@spaces.GGPU(duration=120) # keep GPU attached briefly between calls (seconds)
|
|
|
|
| 218 |
def predict_click_location(input_pil_image: Image.Image, instruction: str):
|
| 219 |
if not model_loaded or not processor or not model:
|
| 220 |
return f"Model not loaded. Error: {load_error_message}", None, "device: n/a | dtype: n/a"
|
|
|
|
| 274 |
|
| 275 |
# 4) Parse coordinates and draw marker
|
| 276 |
output_image_with_click = resized_image.copy().convert("RGB")
|
| 277 |
+
coords = parse_click_coordinates(coordinates_str, resized_width, resized_height)
|
| 278 |
+
|
| 279 |
+
if coords is not None:
|
| 280 |
+
x, y = coords
|
| 281 |
+
draw = ImageDraw.Draw(output_image_with_click)
|
| 282 |
+
radius = max(5, min(resized_width // 100, resized_height // 100, 15))
|
| 283 |
+
bbox = (x - radius, y - radius, x + radius, y + radius)
|
| 284 |
+
draw.ellipse(bbox, outline="red", width=max(2, radius // 4))
|
| 285 |
+
print(f"Predicted and drawn click at: ({x}, {y}) on resized image ({resized_width}x{resized_height})")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
else:
|
| 287 |
+
print(f"Could not parse a click from model output: {coordinates_str}")
|
| 288 |
|
| 289 |
return coordinates_str, output_image_with_click, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}"
|
| 290 |
|
|
|
|
| 334 |
|
| 335 |
with gr.Column(scale=1):
|
| 336 |
output_coords_component = gr.Textbox(
|
| 337 |
+
label="Predicted Coordinates (Normalized)",
|
| 338 |
interactive=False
|
| 339 |
)
|
| 340 |
output_image_component = gr.Image(
|