enpaiva commited on
Commit
8304cf1
Β·
verified Β·
1 Parent(s): afc4872

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -215
app.py CHANGED
@@ -5,22 +5,39 @@ import sys
5
  import torch
6
  import gradio as gr
7
  import numpy as np
8
- from PIL import Image, ImageDraw, ImageFont
 
9
  from transformers import (
10
  DFineForObjectDetection,
 
11
  RTDetrImageProcessor,
12
  )
13
 
14
  # == select device ==
15
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
 
17
- # Available models
18
  MODELS = {
19
- "Egret XLarge": "ds4sd/docling-layout-egret-xlarge",
20
- "Egret Large": "ds4sd/docling-layout-egret-large",
21
- "Egret Medium": "ds4sd/docling-layout-egret-medium",
22
- "Heron 101": "ds4sd/docling-layout-heron-101",
23
- "Heron": "ds4sd/docling-layout-heron"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
 
26
  # Classes mapping for the docling model
@@ -44,34 +61,40 @@ classes_map = {
44
  16: "Key-Value Region",
45
  }
46
 
47
- # Color mapping for visualization
48
- colors = [
49
- "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FECA57",
50
- "#FF9FF3", "#54A0FF", "#5F27CD", "#00D2D3", "#FF9F43",
51
- "#10AC84", "#EE5A24", "#0ABDE3", "#006BA6", "#F79F1F",
52
- "#A3CB38", "#FDA7DF"
53
- ]
54
-
55
  # Global variables for model
56
  current_model = None
57
  current_processor = None
58
  current_model_name = None
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def iomin(box1, box2):
61
- """
62
- Intersection over Minimum (IoMin)
63
- box1: Tensor[1, 4]
64
- box2: Tensor[N, 4]
65
- Returns: Tensor[N]
66
- """
67
- # Intersection
68
  x1 = torch.max(box1[:, 0], box2[:, 0])
69
  y1 = torch.max(box1[:, 1], box2[:, 1])
70
  x2 = torch.min(box1[:, 2], box2[:, 2])
71
  y2 = torch.min(box1[:, 3], box2[:, 3])
72
  inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
73
 
74
- # Areas
75
  box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
76
  box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
77
  min_area = torch.min(box1_area, box2_area)
@@ -79,9 +102,7 @@ def iomin(box1, box2):
79
  return inter_area / min_area
80
 
81
  def nms(boxes, scores, iou_threshold=0.5):
82
- """
83
- Custom NMS implementation using IoMin
84
- """
85
  keep = []
86
  _, order = scores.sort(descending=True)
87
 
@@ -92,7 +113,7 @@ def nms(boxes, scores, iou_threshold=0.5):
92
  if order.numel() == 1:
93
  break
94
 
95
- box_i = boxes[i].unsqueeze(0) # [1, 4]
96
  rest = order[1:]
97
  ious = iomin(box_i, boxes[rest])
98
 
@@ -102,9 +123,7 @@ def nms(boxes, scores, iou_threshold=0.5):
102
  return torch.tensor(keep, dtype=torch.long)
103
 
104
  def load_model(model_name):
105
- """
106
- Load the selected model
107
- """
108
  global current_model, current_processor, current_model_name
109
 
110
  if current_model_name == model_name:
@@ -112,10 +131,12 @@ def load_model(model_name):
112
 
113
  try:
114
  print(f"Loading model: {model_name}")
115
- model_path = MODELS[model_name]
 
 
116
 
117
  processor = RTDetrImageProcessor.from_pretrained(model_path)
118
- model = DFineForObjectDetection.from_pretrained(model_path)
119
  model = model.to(device)
120
  model.eval()
121
 
@@ -128,64 +149,61 @@ def load_model(model_name):
128
  except Exception as e:
129
  return f"❌ Error loading {model_name}: {str(e)}"
130
 
131
- def visualize_bbox(image, boxes, labels, scores, classes_map, colors):
132
- """
133
- Visualize bounding boxes on image
134
- """
135
- if isinstance(image, np.ndarray):
136
- image = Image.fromarray(image)
137
- elif not isinstance(image, Image.Image):
138
- raise ValueError("Input image must be PIL Image or numpy array")
139
-
140
- # Create a copy to draw on
141
- draw_image = image.copy()
142
- draw = ImageDraw.Draw(draw_image)
143
-
144
- # Try to use a font, fallback to default if not available
145
- try:
146
- font = ImageFont.truetype("arial.ttf", 20)
147
- except:
 
 
 
148
  try:
149
- font = ImageFont.load_default()
150
- except:
151
- font = None
152
-
153
- for box, label_id, score in zip(boxes, labels, scores):
154
- # Convert tensor to int if needed
155
- if torch.is_tensor(label_id):
156
- label_id = label_id.item()
157
- if torch.is_tensor(score):
158
- score = score.item()
159
 
160
- label = classes_map.get(int(label_id), f"Class_{label_id}")
161
- color = colors[int(label_id) % len(colors)]
162
-
163
- # Convert box coordinates to integers
164
- x1, y1, x2, y2 = [int(coord) for coord in box]
165
-
166
- # Draw rectangle
167
- draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
168
-
169
- # Draw label background
170
- text = f"{label}: {score:.2f}"
171
- if font:
172
- bbox = draw.textbbox((x1, y1), text, font=font)
173
- text_width = bbox[2] - bbox[0]
174
- text_height = bbox[3] - bbox[1]
175
- else:
176
- # Estimate text size if no font available
177
- text_width = len(text) * 10
178
- text_height = 20
179
 
180
- draw.rectangle([x1, y1-text_height-4, x1+text_width+4, y1], fill=color)
181
- draw.text((x1+2, y1-text_height-2), text, fill="white", font=font)
182
-
183
- return np.array(draw_image)
 
 
 
184
 
185
- def recognize_image(input_img, conf_threshold, iou_threshold, nms_method):
186
- """
187
- Process image with docling layout model
188
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  if input_img is None:
190
  return None, "Please upload an image first."
191
 
@@ -193,23 +211,18 @@ def recognize_image(input_img, conf_threshold, iou_threshold, nms_method):
193
  return None, "Please load a model first."
194
 
195
  try:
196
- # Ensure image is PIL Image
197
  if isinstance(input_img, np.ndarray):
198
  input_img = Image.fromarray(input_img)
199
 
200
- # Convert to RGB if needed
201
  if input_img.mode != 'RGB':
202
  input_img = input_img.convert('RGB')
203
 
204
- # Process image
205
  inputs = current_processor(images=[input_img], return_tensors="pt")
206
  inputs = {k: v.to(device) for k, v in inputs.items()}
207
 
208
- # Run inference
209
  with torch.no_grad():
210
  outputs = current_model(**inputs)
211
 
212
- # Post-process results
213
  results = current_processor.post_process_object_detection(
214
  outputs,
215
  target_sizes=torch.tensor([input_img.size[::-1]]),
@@ -220,8 +233,6 @@ def recognize_image(input_img, conf_threshold, iou_threshold, nms_method):
220
  return np.array(input_img), "No detections found."
221
 
222
  result = results[0]
223
-
224
- # Get results
225
  boxes = result["boxes"]
226
  scores = result["scores"]
227
  labels = result["labels"]
@@ -229,50 +240,28 @@ def recognize_image(input_img, conf_threshold, iou_threshold, nms_method):
229
  if len(boxes) == 0:
230
  return np.array(input_img), "No detections above confidence threshold."
231
 
232
- # Apply NMS if requested
233
  if iou_threshold < 1.0:
234
  if nms_method == "Custom IoMin":
235
- # Use custom NMS with IoMin
236
- keep_indices = nms(
237
- boxes=boxes,
238
- scores=scores,
239
- iou_threshold=iou_threshold
240
- )
241
  else:
242
- # Use standard torchvision NMS
243
- keep_indices = torch.ops.torchvision.nms(
244
- boxes=boxes,
245
- scores=scores,
246
- iou_threshold=iou_threshold
247
- )
248
 
249
  boxes = boxes[keep_indices]
250
  scores = scores[keep_indices]
251
  labels = labels[keep_indices]
252
 
253
- # Handle single detection case
254
  if len(boxes.shape) == 1:
255
  boxes = boxes.unsqueeze(0)
256
  scores = scores.unsqueeze(0)
257
  labels = labels.unsqueeze(0)
258
 
259
- # Visualize results
260
- output = visualize_bbox(
261
- input_img,
262
- boxes,
263
- labels,
264
- scores,
265
- classes_map,
266
- colors
267
- )
268
-
269
  detection_info = f"Found {len(boxes)} detections after NMS ({nms_method})"
270
  return output, detection_info
271
 
272
  except Exception as e:
273
  print(f"[ERROR] recognize_image failed: {e}")
274
  error_msg = f"Error during processing: {str(e)}"
275
- # Return original image on error
276
  if input_img is not None:
277
  return np.array(input_img), error_msg
278
  return np.zeros((512, 512, 3), dtype=np.uint8), error_msg
@@ -283,123 +272,82 @@ def gradio_reset():
283
  if __name__ == "__main__":
284
  print(f"Using device: {device}")
285
 
286
- # Create header HTML
287
- header_html = """
288
- <div style="text-align: center; margin-bottom: 20px;">
289
- <h1>πŸ” Document Layout Analysis</h1>
290
- <p>Using Docling Layout Models for document structure detection</p>
291
- <p>Select a model, upload an image and adjust the parameters to detect document elements</p>
292
- </div>
 
 
 
293
  """
294
 
295
- with gr.Blocks(title="Document Layout Analysis", theme=gr.themes.Soft()) as demo:
296
- gr.HTML(header_html)
 
 
 
 
 
 
297
 
298
  with gr.Row():
299
- with gr.Column():
 
300
  # Model selection
301
  model_dropdown = gr.Dropdown(
302
  choices=list(MODELS.keys()),
303
  value="Egret XLarge",
304
- label="πŸ€– Select Model",
305
- info="Choose which Docling model to use"
306
  )
307
 
308
- load_btn = gr.Button("πŸ“₯ Load Model", variant="secondary")
309
- model_status = gr.Textbox(
310
- label="Model Status",
311
- interactive=False,
312
- value="No model loaded"
313
- )
314
 
315
- input_img = gr.Image(
316
- label="πŸ“„ Upload Document Image",
317
- interactive=True,
318
- type="pil"
319
- )
320
 
321
  with gr.Row():
322
- clear = gr.Button("πŸ—‘οΈ Clear")
323
- predict = gr.Button("πŸ” Detect Layout", interactive=True, variant="primary")
324
-
325
- with gr.Row():
326
- conf_threshold = gr.Slider(
327
- label="Confidence Threshold",
328
- minimum=0.0,
329
- maximum=1.0,
330
- step=0.05,
331
- value=0.6,
332
- info="Minimum confidence score for detections"
333
- )
334
-
335
- with gr.Row():
336
- iou_threshold = gr.Slider(
337
- label="NMS IoU Threshold",
338
- minimum=0.0,
339
- maximum=1.0,
340
- step=0.05,
341
- value=0.5,
342
- info="IoU threshold for Non-Maximum Suppression"
343
- )
344
 
345
- nms_method = gr.Radio(
346
- choices=["Custom IoMin", "Standard IoU"],
347
- value="Custom IoMin",
348
- label="NMS Method",
349
- info="Choose NMS algorithm"
350
- )
351
 
352
- # Legend
353
- with gr.Accordion("πŸ“‹ Detected Classes", open=False):
354
- legend_html = "<div style='display: grid; grid-template-columns: repeat(2, 1fr); gap: 10px;'>"
355
- for class_id, class_name in classes_map.items():
356
- color = colors[class_id % len(colors)]
357
- legend_html += f"""
358
- <div style='display: flex; align-items: center; padding: 5px;'>
359
- <div style='width: 20px; height: 20px; background-color: {color}; margin-right: 10px; border: 1px solid #ccc;'></div>
360
- <span>{class_name}</span>
361
- </div>
362
- """
363
- legend_html += "</div>"
364
- gr.HTML(legend_html)
365
-
366
- with gr.Column():
367
  gr.HTML("<h3>🎯 Detection Results</h3>")
368
- output_img = gr.Image(
369
- label="Detected Layout",
370
- interactive=False,
371
- type="numpy"
372
- )
373
-
374
- detection_info = gr.Textbox(
375
- label="Detection Info",
376
- interactive=False,
377
- value=""
378
- )
379
 
380
- # Event handlers
381
- load_btn.click(
382
- load_model,
383
- inputs=[model_dropdown],
384
- outputs=[model_status]
385
- )
386
-
387
- clear.click(
388
- gradio_reset,
389
- inputs=None,
390
- outputs=[input_img, output_img, detection_info]
391
- )
 
 
 
392
 
 
 
 
393
  predict.click(
394
  recognize_image,
395
- inputs=[input_img, conf_threshold, iou_threshold, nms_method],
396
  outputs=[output_img, detection_info]
397
  )
398
 
399
- # Launch the demo
400
- demo.launch(
401
- server_name="0.0.0.0",
402
- server_port=7860,
403
- debug=True,
404
- share=False
405
- )
 
5
  import torch
6
  import gradio as gr
7
  import numpy as np
8
+ import cv2
9
+ from PIL import Image
10
  from transformers import (
11
  DFineForObjectDetection,
12
+ RTDetrV2ForObjectDetection,
13
  RTDetrImageProcessor,
14
  )
15
 
16
  # == select device ==
17
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
 
19
+ # Available models with their corresponding model classes
20
  MODELS = {
21
+ "Egret XLarge": {
22
+ "path": "ds4sd/docling-layout-egret-xlarge",
23
+ "model_class": DFineForObjectDetection
24
+ },
25
+ "Egret Large": {
26
+ "path": "ds4sd/docling-layout-egret-large",
27
+ "model_class": DFineForObjectDetection
28
+ },
29
+ "Egret Medium": {
30
+ "path": "ds4sd/docling-layout-egret-medium",
31
+ "model_class": DFineForObjectDetection
32
+ },
33
+ "Heron 101": {
34
+ "path": "ds4sd/docling-layout-heron-101",
35
+ "model_class": RTDetrV2ForObjectDetection
36
+ },
37
+ "Heron": {
38
+ "path": "ds4sd/docling-layout-heron",
39
+ "model_class": RTDetrV2ForObjectDetection
40
+ }
41
  }
42
 
43
  # Classes mapping for the docling model
 
61
  16: "Key-Value Region",
62
  }
63
 
 
 
 
 
 
 
 
 
64
  # Global variables for model
65
  current_model = None
66
  current_processor = None
67
  current_model_name = None
68
 
69
+ def colormap(N=256, normalized=False):
70
+ """Generate the color map."""
71
+ def bitget(byteval, idx):
72
+ return ((byteval & (1 << idx)) != 0)
73
+
74
+ cmap = np.zeros((N, 3), dtype=np.uint8)
75
+ for i in range(N):
76
+ r = g = b = 0
77
+ c = i
78
+ for j in range(8):
79
+ r = r | (bitget(c, 0) << (7 - j))
80
+ g = g | (bitget(c, 1) << (7 - j))
81
+ b = b | (bitget(c, 2) << (7 - j))
82
+ c = c >> 3
83
+ cmap[i] = np.array([r, g, b])
84
+
85
+ if normalized:
86
+ cmap = cmap.astype(np.float32) / 255.0
87
+
88
+ return cmap
89
+
90
  def iomin(box1, box2):
91
+ """Intersection over Minimum (IoMin)"""
 
 
 
 
 
 
92
  x1 = torch.max(box1[:, 0], box2[:, 0])
93
  y1 = torch.max(box1[:, 1], box2[:, 1])
94
  x2 = torch.min(box1[:, 2], box2[:, 2])
95
  y2 = torch.min(box1[:, 3], box2[:, 3])
96
  inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
97
 
 
98
  box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
99
  box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
100
  min_area = torch.min(box1_area, box2_area)
 
102
  return inter_area / min_area
103
 
104
  def nms(boxes, scores, iou_threshold=0.5):
105
+ """Custom NMS implementation using IoMin"""
 
 
106
  keep = []
107
  _, order = scores.sort(descending=True)
108
 
 
113
  if order.numel() == 1:
114
  break
115
 
116
+ box_i = boxes[i].unsqueeze(0)
117
  rest = order[1:]
118
  ious = iomin(box_i, boxes[rest])
119
 
 
123
  return torch.tensor(keep, dtype=torch.long)
124
 
125
  def load_model(model_name):
126
+ """Load the selected model"""
 
 
127
  global current_model, current_processor, current_model_name
128
 
129
  if current_model_name == model_name:
 
131
 
132
  try:
133
  print(f"Loading model: {model_name}")
134
+ model_info = MODELS[model_name]
135
+ model_path = model_info["path"]
136
+ model_class = model_info["model_class"]
137
 
138
  processor = RTDetrImageProcessor.from_pretrained(model_path)
139
+ model = model_class.from_pretrained(model_path)
140
  model = model.to(device)
141
  model.eval()
142
 
 
149
  except Exception as e:
150
  return f"❌ Error loading {model_name}: {str(e)}"
151
 
152
+ def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3):
153
+ """Visualize bounding boxes with transparent overlays using OpenCV"""
154
+ if isinstance(image_input, Image.Image):
155
+ image = np.array(image_input)
156
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
157
+ elif isinstance(image_input, np.ndarray):
158
+ if len(image_input.shape) == 3 and image_input.shape[2] == 3:
159
+ image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
160
+ else:
161
+ image = image_input.copy()
162
+ else:
163
+ raise ValueError("Input must be PIL Image or numpy array")
164
+
165
+ overlay = image.copy()
166
+ cmap = colormap(N=len(id_to_names), normalized=False)
167
+
168
+ if len(bboxes) == 0:
169
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
170
+
171
+ for i in range(len(bboxes)):
172
  try:
173
+ bbox = bboxes[i]
174
+ if torch.is_tensor(bbox):
175
+ bbox = bbox.cpu().numpy()
 
 
 
 
 
 
 
176
 
177
+ class_id = classes[i]
178
+ if torch.is_tensor(class_id):
179
+ class_id = class_id.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ score = scores[i]
182
+ if torch.is_tensor(score):
183
+ score = score.item()
184
+
185
+ x_min, y_min, x_max, y_max = map(int, bbox)
186
+ class_id = int(class_id)
187
+ class_name = id_to_names.get(class_id, f"unknown_{class_id}")
188
 
189
+ text = f"{class_name}:{score:.3f}"
190
+ color = tuple(int(c) for c in cmap[class_id % len(cmap)])
191
+
192
+ cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1)
193
+ cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 2)
194
+
195
+ (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
196
+ cv2.rectangle(image, (x_min, y_min - text_height - baseline), (x_min + text_width, y_min), color, -1)
197
+ cv2.putText(image, text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
198
+
199
+ except Exception as e:
200
+ print(f"Skipping box {i} due to error: {e}")
201
+
202
+ cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
203
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
204
+
205
+ def recognize_image(input_img, conf_threshold, iou_threshold, nms_method, alpha):
206
+ """Process image with docling layout model"""
207
  if input_img is None:
208
  return None, "Please upload an image first."
209
 
 
211
  return None, "Please load a model first."
212
 
213
  try:
 
214
  if isinstance(input_img, np.ndarray):
215
  input_img = Image.fromarray(input_img)
216
 
 
217
  if input_img.mode != 'RGB':
218
  input_img = input_img.convert('RGB')
219
 
 
220
  inputs = current_processor(images=[input_img], return_tensors="pt")
221
  inputs = {k: v.to(device) for k, v in inputs.items()}
222
 
 
223
  with torch.no_grad():
224
  outputs = current_model(**inputs)
225
 
 
226
  results = current_processor.post_process_object_detection(
227
  outputs,
228
  target_sizes=torch.tensor([input_img.size[::-1]]),
 
233
  return np.array(input_img), "No detections found."
234
 
235
  result = results[0]
 
 
236
  boxes = result["boxes"]
237
  scores = result["scores"]
238
  labels = result["labels"]
 
240
  if len(boxes) == 0:
241
  return np.array(input_img), "No detections above confidence threshold."
242
 
 
243
  if iou_threshold < 1.0:
244
  if nms_method == "Custom IoMin":
245
+ keep_indices = nms(boxes=boxes, scores=scores, iou_threshold=iou_threshold)
 
 
 
 
 
246
  else:
247
+ keep_indices = torch.ops.torchvision.nms(boxes=boxes, scores=scores, iou_threshold=iou_threshold)
 
 
 
 
 
248
 
249
  boxes = boxes[keep_indices]
250
  scores = scores[keep_indices]
251
  labels = labels[keep_indices]
252
 
 
253
  if len(boxes.shape) == 1:
254
  boxes = boxes.unsqueeze(0)
255
  scores = scores.unsqueeze(0)
256
  labels = labels.unsqueeze(0)
257
 
258
+ output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha)
 
 
 
 
 
 
 
 
 
259
  detection_info = f"Found {len(boxes)} detections after NMS ({nms_method})"
260
  return output, detection_info
261
 
262
  except Exception as e:
263
  print(f"[ERROR] recognize_image failed: {e}")
264
  error_msg = f"Error during processing: {str(e)}"
 
265
  if input_img is not None:
266
  return np.array(input_img), error_msg
267
  return np.zeros((512, 512, 3), dtype=np.uint8), error_msg
 
272
  if __name__ == "__main__":
273
  print(f"Using device: {device}")
274
 
275
+ # Custom CSS for better scrolling and layout
276
+ custom_css = """
277
+ .gradio-container {
278
+ max-width: 1200px !important;
279
+ margin: auto !important;
280
+ }
281
+ .main-content {
282
+ overflow-y: auto !important;
283
+ max-height: 100vh !important;
284
+ }
285
  """
286
 
287
+ with gr.Blocks(title="Document Layout Analysis", theme=gr.themes.Soft(), css=custom_css) as demo:
288
+ # Header
289
+ gr.HTML("""
290
+ <div style="text-align: center; margin-bottom: 20px;">
291
+ <h1>πŸ” Document Layout Analysis</h1>
292
+ <p>Using Docling Layout Models for document structure detection</p>
293
+ </div>
294
+ """)
295
 
296
  with gr.Row():
297
+ # Left Column - Controls
298
+ with gr.Column(scale=1):
299
  # Model selection
300
  model_dropdown = gr.Dropdown(
301
  choices=list(MODELS.keys()),
302
  value="Egret XLarge",
303
+ label="πŸ€– Select Model"
 
304
  )
305
 
306
+ load_btn = gr.Button("πŸ“₯ Load Model", variant="secondary", size="sm")
307
+ model_status = gr.Textbox(label="Model Status", interactive=False, value="No model loaded", max_lines=2)
 
 
 
 
308
 
309
+ input_img = gr.Image(label="πŸ“„ Upload Image", type="pil", height=300)
 
 
 
 
310
 
311
  with gr.Row():
312
+ clear = gr.Button("πŸ—‘οΈ Clear", size="sm")
313
+ predict = gr.Button("πŸ” Detect", variant="primary", size="sm")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
+ # Parameters
316
+ conf_threshold = gr.Slider(0.0, 1.0, value=0.6, step=0.05, label="Confidence Threshold")
317
+ iou_threshold = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="NMS IoU Threshold")
318
+ nms_method = gr.Radio(["Custom IoMin", "Standard IoU"], value="Custom IoMin", label="NMS Method")
319
+ alpha_slider = gr.Slider(0.0, 1.0, value=0.3, step=0.1, label="Overlay Transparency")
 
320
 
321
+ # Right Column - Results
322
+ with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  gr.HTML("<h3>🎯 Detection Results</h3>")
324
+ output_img = gr.Image(label="Detected Layout", interactive=False, type="numpy", height=400)
325
+ detection_info = gr.Textbox(label="Detection Info", interactive=False, max_lines=2)
 
 
 
 
 
 
 
 
 
326
 
327
+ # Legend at the bottom
328
+ with gr.Accordion("πŸ“‹ Detected Classes", open=False):
329
+ cmap = colormap(N=len(classes_map), normalized=False)
330
+ legend_items = []
331
+ for class_id, class_name in classes_map.items():
332
+ color_rgb = cmap[class_id % len(cmap)]
333
+ color_hex = f"#{color_rgb[0]:02x}{color_rgb[1]:02x}{color_rgb[2]:02x}"
334
+ legend_items.append(f'<span style="display:inline-block;width:15px;height:15px;background-color:{color_hex};margin-right:5px;border:1px solid #ccc;"></span>{class_name}')
335
+
336
+ legend_html = f"""
337
+ <div style='display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; font-size: 14px;'>
338
+ {''.join([f'<div>{item}</div>' for item in legend_items])}
339
+ </div>
340
+ """
341
+ gr.HTML(legend_html)
342
 
343
+ # Event handlers
344
+ load_btn.click(load_model, inputs=[model_dropdown], outputs=[model_status])
345
+ clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img, detection_info])
346
  predict.click(
347
  recognize_image,
348
+ inputs=[input_img, conf_threshold, iou_threshold, nms_method, alpha_slider],
349
  outputs=[output_img, detection_info]
350
  )
351
 
352
+ # Launch
353
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True, share=False)