Mungert commited on
Commit
7d85a18
·
verified ·
1 Parent(s): 27fd625

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -18
app.py CHANGED
@@ -34,7 +34,7 @@ def pick_dtype(device: str) -> torch.dtype:
34
  return torch.bfloat16 if major >= 8 else torch.float16 # Ampere+ -> bf16
35
  if device == "mps":
36
  return torch.float16
37
- return torch.float16 # CPU
38
 
39
  def move_to_device(batch, device: str):
40
  if isinstance(batch, dict):
@@ -82,6 +82,52 @@ def trim_generated(generated_ids, inputs):
82
  return [out_ids for out_ids in generated_ids]
83
  return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)]
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # --- Load model/processor ON CPU at import time (required for ZeroGPU) ---
86
  print(f"Loading model and processor for {MODEL_ID} on CPU startup (ZeroGPU safe)...")
87
  model = None
@@ -168,8 +214,7 @@ def run_inference_localization(
168
  return decoded_output[0] if decoded_output else ""
169
 
170
  # --- Gradio processing function (ZeroGPU-visible) ---
171
- # Decorate the function Gradio calls so Spaces detects a GPU entry point.
172
- @spaces.GPU(duration=120) # keep GPU attached briefly between calls (seconds)
173
  def predict_click_location(input_pil_image: Image.Image, instruction: str):
174
  if not model_loaded or not processor or not model:
175
  return f"Model not loaded. Error: {load_error_message}", None, "device: n/a | dtype: n/a"
@@ -229,21 +274,17 @@ def predict_click_location(input_pil_image: Image.Image, instruction: str):
229
 
230
  # 4) Parse coordinates and draw marker
231
  output_image_with_click = resized_image.copy().convert("RGB")
232
- match = re.search(r"Click\((\d+),\s*(\d+)\)", coordinates_str)
233
- if match:
234
- try:
235
- x = int(match.group(1))
236
- y = int(match.group(2))
237
- draw = ImageDraw.Draw(output_image_with_click)
238
- radius = max(5, min(resized_width // 100, resized_height // 100, 15))
239
- bbox = (x - radius, y - radius, x + radius, y + radius)
240
- draw.ellipse(bbox, outline="red", width=max(2, radius // 4))
241
- print(f"Predicted and drawn click at: ({x}, {y}) on resized image ({resized_width}x{resized_height})")
242
- except Exception as e:
243
- print(f"Error drawing on image: {e}")
244
- traceback.print_exc()
245
  else:
246
- print(f"Could not parse 'Click(x, y)' from model output: {coordinates_str}")
247
 
248
  return coordinates_str, output_image_with_click, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}"
249
 
@@ -293,7 +334,7 @@ else:
293
 
294
  with gr.Column(scale=1):
295
  output_coords_component = gr.Textbox(
296
- label="Predicted Coordinates (Format: Click(x, y))",
297
  interactive=False
298
  )
299
  output_image_component = gr.Image(
 
34
  return torch.bfloat16 if major >= 8 else torch.float16 # Ampere+ -> bf16
35
  if device == "mps":
36
  return torch.float16
37
+ return torch.float32 # CPU: FP32 is usually fastest & most stable
38
 
39
  def move_to_device(batch, device: str):
40
  if isinstance(batch, dict):
 
82
  return [out_ids for out_ids in generated_ids]
83
  return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)]
84
 
85
+ # --- Parsing helper: normalize various UI-TARS click formats to (x, y) ---
86
+ def parse_click_coordinates(text: str, img_w: int, img_h: int):
87
+ """
88
+ Returns (x, y) in image coordinates, clamped to bounds, or None.
89
+ Handles:
90
+ - Click(start_box='(x,y)') / Click(end_box='(x,y)')
91
+ - Click(box='(x1,y1,x2,y2)') -> center
92
+ - Click(x, y)
93
+ - Click({'x':..., 'y':...}) / Click({"x":...,"y":...})
94
+ Preference: start_box > end_box when both exist.
95
+ """
96
+ s = str(text)
97
+
98
+ # 1) start_box / end_box
99
+ pairs = re.findall(r"(start_box|end_box)\s*=\s*['\"]\(\s*(\d+)\s*,\s*(\d+)\s*\)['\"]", s)
100
+ if pairs:
101
+ start = next(((int(x), int(y)) for k, x, y in pairs if k == "start_box"), None)
102
+ if start:
103
+ x, y = start
104
+ return max(0, min(x, img_w - 1)), max(0, min(y, img_h - 1))
105
+ end = next(((int(x), int(y)) for k, x, y in pairs if k == "end_box"), None)
106
+ if end:
107
+ x, y = end
108
+ return max(0, min(x, img_w - 1)), max(0, min(y, img_h - 1))
109
+
110
+ # 2) box='(x1,y1,x2,y2)' -> center
111
+ m = re.search(r"box\s*=\s*['\"]\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)['\"]", s)
112
+ if m:
113
+ x1, y1, x2, y2 = map(int, m.groups())
114
+ cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
115
+ return max(0, min(cx, img_w - 1)), max(0, min(cy, img_h - 1))
116
+
117
+ # 3) Direct Click(x, y)
118
+ m = re.search(r"Click\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)", s)
119
+ if m:
120
+ x, y = int(m.group(1)), int(m.group(2))
121
+ return max(0, min(x, img_w - 1)), max(0, min(y, img_h - 1))
122
+
123
+ # 4) JSON-ish dicts
124
+ m = re.search(r"Click\s*\(\s*[{[][^)}]*['\"]?x['\"]?\s*:\s*(\d+)\s*,\s*['\"]?y['\"]?\s*:\s*(\d+)[^)}]*\)\s*", s)
125
+ if m:
126
+ x, y = int(m.group(1)), int(m.group(2))
127
+ return max(0, min(x, img_w - 1)), max(0, min(y, img_h - 1))
128
+
129
+ return None
130
+
131
  # --- Load model/processor ON CPU at import time (required for ZeroGPU) ---
132
  print(f"Loading model and processor for {MODEL_ID} on CPU startup (ZeroGPU safe)...")
133
  model = None
 
214
  return decoded_output[0] if decoded_output else ""
215
 
216
  # --- Gradio processing function (ZeroGPU-visible) ---
217
+ @spaces.GGPU(duration=120) # keep GPU attached briefly between calls (seconds)
 
218
  def predict_click_location(input_pil_image: Image.Image, instruction: str):
219
  if not model_loaded or not processor or not model:
220
  return f"Model not loaded. Error: {load_error_message}", None, "device: n/a | dtype: n/a"
 
274
 
275
  # 4) Parse coordinates and draw marker
276
  output_image_with_click = resized_image.copy().convert("RGB")
277
+ coords = parse_click_coordinates(coordinates_str, resized_width, resized_height)
278
+
279
+ if coords is not None:
280
+ x, y = coords
281
+ draw = ImageDraw.Draw(output_image_with_click)
282
+ radius = max(5, min(resized_width // 100, resized_height // 100, 15))
283
+ bbox = (x - radius, y - radius, x + radius, y + radius)
284
+ draw.ellipse(bbox, outline="red", width=max(2, radius // 4))
285
+ print(f"Predicted and drawn click at: ({x}, {y}) on resized image ({resized_width}x{resized_height})")
 
 
 
 
286
  else:
287
+ print(f"Could not parse a click from model output: {coordinates_str}")
288
 
289
  return coordinates_str, output_image_with_click, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}"
290
 
 
334
 
335
  with gr.Column(scale=1):
336
  output_coords_component = gr.Textbox(
337
+ label="Predicted Coordinates (Normalized)",
338
  interactive=False
339
  )
340
  output_image_component = gr.Image(