enpaiva commited on
Commit
2020af8
Β·
verified Β·
1 Parent(s): 3fddaf0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +405 -0
app.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["GRADIO_TEMP_DIR"] = "./tmp"
3
+
4
+ import sys
5
+ import torch
6
+ import gradio as gr
7
+ import numpy as np
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ from transformers import (
10
+ DFineForObjectDetection,
11
+ RTDetrImageProcessor,
12
+ )
13
+
14
+ # == select device ==
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+
17
+ # Available models
18
+ MODELS = {
19
+ "Egret XLarge": "ds4sd/docling-layout-egret-xlarge",
20
+ "Egret Large": "ds4sd/docling-layout-egret-large",
21
+ "Egret Medium": "ds4sd/docling-layout-egret-medium",
22
+ "Heron 101": "ds4sd/docling-layout-heron-101",
23
+ "Heron": "ds4sd/docling-layout-heron"
24
+ }
25
+
26
+ # Classes mapping for the docling model
27
+ classes_map = {
28
+ 0: "Caption",
29
+ 1: "Footnote",
30
+ 2: "Formula",
31
+ 3: "List-item",
32
+ 4: "Page-footer",
33
+ 5: "Page-header",
34
+ 6: "Picture",
35
+ 7: "Section-header",
36
+ 8: "Table",
37
+ 9: "Text",
38
+ 10: "Title",
39
+ 11: "Document Index",
40
+ 12: "Code",
41
+ 13: "Checkbox-Selected",
42
+ 14: "Checkbox-Unselected",
43
+ 15: "Form",
44
+ 16: "Key-Value Region",
45
+ }
46
+
47
+ # Color mapping for visualization
48
+ colors = [
49
+ "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FECA57",
50
+ "#FF9FF3", "#54A0FF", "#5F27CD", "#00D2D3", "#FF9F43",
51
+ "#10AC84", "#EE5A24", "#0ABDE3", "#006BA6", "#F79F1F",
52
+ "#A3CB38", "#FDA7DF"
53
+ ]
54
+
55
+ # Global variables for model
56
+ current_model = None
57
+ current_processor = None
58
+ current_model_name = None
59
+
60
+ def iomin(box1, box2):
61
+ """
62
+ Intersection over Minimum (IoMin)
63
+ box1: Tensor[1, 4]
64
+ box2: Tensor[N, 4]
65
+ Returns: Tensor[N]
66
+ """
67
+ # Intersection
68
+ x1 = torch.max(box1[:, 0], box2[:, 0])
69
+ y1 = torch.max(box1[:, 1], box2[:, 1])
70
+ x2 = torch.min(box1[:, 2], box2[:, 2])
71
+ y2 = torch.min(box1[:, 3], box2[:, 3])
72
+ inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
73
+
74
+ # Areas
75
+ box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
76
+ box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
77
+ min_area = torch.min(box1_area, box2_area)
78
+
79
+ return inter_area / min_area
80
+
81
+ def nms(boxes, scores, iou_threshold=0.5):
82
+ """
83
+ Custom NMS implementation using IoMin
84
+ """
85
+ keep = []
86
+ _, order = scores.sort(descending=True)
87
+
88
+ while order.numel() > 0:
89
+ i = order[0]
90
+ keep.append(i.item())
91
+
92
+ if order.numel() == 1:
93
+ break
94
+
95
+ box_i = boxes[i].unsqueeze(0) # [1, 4]
96
+ rest = order[1:]
97
+ ious = iomin(box_i, boxes[rest])
98
+
99
+ mask = (ious <= iou_threshold)
100
+ order = order[1:][mask]
101
+
102
+ return torch.tensor(keep, dtype=torch.long)
103
+
104
+ def load_model(model_name):
105
+ """
106
+ Load the selected model
107
+ """
108
+ global current_model, current_processor, current_model_name
109
+
110
+ if current_model_name == model_name:
111
+ return f"βœ… Model {model_name} is already loaded!"
112
+
113
+ try:
114
+ print(f"Loading model: {model_name}")
115
+ model_path = MODELS[model_name]
116
+
117
+ processor = RTDetrImageProcessor.from_pretrained(model_path)
118
+ model = DFineForObjectDetection.from_pretrained(model_path)
119
+ model = model.to(device)
120
+ model.eval()
121
+
122
+ current_processor = processor
123
+ current_model = model
124
+ current_model_name = model_name
125
+
126
+ return f"βœ… Successfully loaded {model_name}!"
127
+
128
+ except Exception as e:
129
+ return f"❌ Error loading {model_name}: {str(e)}"
130
+
131
+ def visualize_bbox(image, boxes, labels, scores, classes_map, colors):
132
+ """
133
+ Visualize bounding boxes on image
134
+ """
135
+ if isinstance(image, np.ndarray):
136
+ image = Image.fromarray(image)
137
+ elif not isinstance(image, Image.Image):
138
+ raise ValueError("Input image must be PIL Image or numpy array")
139
+
140
+ # Create a copy to draw on
141
+ draw_image = image.copy()
142
+ draw = ImageDraw.Draw(draw_image)
143
+
144
+ # Try to use a font, fallback to default if not available
145
+ try:
146
+ font = ImageFont.truetype("arial.ttf", 20)
147
+ except:
148
+ try:
149
+ font = ImageFont.load_default()
150
+ except:
151
+ font = None
152
+
153
+ for box, label_id, score in zip(boxes, labels, scores):
154
+ # Convert tensor to int if needed
155
+ if torch.is_tensor(label_id):
156
+ label_id = label_id.item()
157
+ if torch.is_tensor(score):
158
+ score = score.item()
159
+
160
+ label = classes_map.get(int(label_id), f"Class_{label_id}")
161
+ color = colors[int(label_id) % len(colors)]
162
+
163
+ # Convert box coordinates to integers
164
+ x1, y1, x2, y2 = [int(coord) for coord in box]
165
+
166
+ # Draw rectangle
167
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
168
+
169
+ # Draw label background
170
+ text = f"{label}: {score:.2f}"
171
+ if font:
172
+ bbox = draw.textbbox((x1, y1), text, font=font)
173
+ text_width = bbox[2] - bbox[0]
174
+ text_height = bbox[3] - bbox[1]
175
+ else:
176
+ # Estimate text size if no font available
177
+ text_width = len(text) * 10
178
+ text_height = 20
179
+
180
+ draw.rectangle([x1, y1-text_height-4, x1+text_width+4, y1], fill=color)
181
+ draw.text((x1+2, y1-text_height-2), text, fill="white", font=font)
182
+
183
+ return np.array(draw_image)
184
+
185
+ def recognize_image(input_img, conf_threshold, iou_threshold, nms_method):
186
+ """
187
+ Process image with docling layout model
188
+ """
189
+ if input_img is None:
190
+ return None, "Please upload an image first."
191
+
192
+ if current_model is None or current_processor is None:
193
+ return None, "Please load a model first."
194
+
195
+ try:
196
+ # Ensure image is PIL Image
197
+ if isinstance(input_img, np.ndarray):
198
+ input_img = Image.fromarray(input_img)
199
+
200
+ # Convert to RGB if needed
201
+ if input_img.mode != 'RGB':
202
+ input_img = input_img.convert('RGB')
203
+
204
+ # Process image
205
+ inputs = current_processor(images=[input_img], return_tensors="pt")
206
+ inputs = {k: v.to(device) for k, v in inputs.items()}
207
+
208
+ # Run inference
209
+ with torch.no_grad():
210
+ outputs = current_model(**inputs)
211
+
212
+ # Post-process results
213
+ results = current_processor.post_process_object_detection(
214
+ outputs,
215
+ target_sizes=torch.tensor([input_img.size[::-1]]),
216
+ threshold=conf_threshold,
217
+ )
218
+
219
+ if not results or len(results) == 0:
220
+ return np.array(input_img), "No detections found."
221
+
222
+ result = results[0]
223
+
224
+ # Get results
225
+ boxes = result["boxes"]
226
+ scores = result["scores"]
227
+ labels = result["labels"]
228
+
229
+ if len(boxes) == 0:
230
+ return np.array(input_img), "No detections above confidence threshold."
231
+
232
+ # Apply NMS if requested
233
+ if iou_threshold < 1.0:
234
+ if nms_method == "Custom IoMin":
235
+ # Use custom NMS with IoMin
236
+ keep_indices = nms(
237
+ boxes=boxes,
238
+ scores=scores,
239
+ iou_threshold=iou_threshold
240
+ )
241
+ else:
242
+ # Use standard torchvision NMS
243
+ keep_indices = torch.ops.torchvision.nms(
244
+ boxes=boxes,
245
+ scores=scores,
246
+ iou_threshold=iou_threshold
247
+ )
248
+
249
+ boxes = boxes[keep_indices]
250
+ scores = scores[keep_indices]
251
+ labels = labels[keep_indices]
252
+
253
+ # Handle single detection case
254
+ if len(boxes.shape) == 1:
255
+ boxes = boxes.unsqueeze(0)
256
+ scores = scores.unsqueeze(0)
257
+ labels = labels.unsqueeze(0)
258
+
259
+ # Visualize results
260
+ output = visualize_bbox(
261
+ input_img,
262
+ boxes,
263
+ labels,
264
+ scores,
265
+ classes_map,
266
+ colors
267
+ )
268
+
269
+ detection_info = f"Found {len(boxes)} detections after NMS ({nms_method})"
270
+ return output, detection_info
271
+
272
+ except Exception as e:
273
+ print(f"[ERROR] recognize_image failed: {e}")
274
+ error_msg = f"Error during processing: {str(e)}"
275
+ # Return original image on error
276
+ if input_img is not None:
277
+ return np.array(input_img), error_msg
278
+ return np.zeros((512, 512, 3), dtype=np.uint8), error_msg
279
+
280
+ def gradio_reset():
281
+ return gr.update(value=None), gr.update(value=None), gr.update(value="")
282
+
283
+ if __name__ == "__main__":
284
+ print(f"Using device: {device}")
285
+
286
+ # Create header HTML
287
+ header_html = """
288
+ <div style="text-align: center; margin-bottom: 20px;">
289
+ <h1>πŸ” Document Layout Analysis</h1>
290
+ <p>Using Docling Layout Models for document structure detection</p>
291
+ <p>Select a model, upload an image and adjust the parameters to detect document elements</p>
292
+ </div>
293
+ """
294
+
295
+ with gr.Blocks(title="Document Layout Analysis", theme=gr.themes.Soft()) as demo:
296
+ gr.HTML(header_html)
297
+
298
+ with gr.Row():
299
+ with gr.Column():
300
+ # Model selection
301
+ model_dropdown = gr.Dropdown(
302
+ choices=list(MODELS.keys()),
303
+ value="Egret XLarge",
304
+ label="πŸ€– Select Model",
305
+ info="Choose which Docling model to use"
306
+ )
307
+
308
+ load_btn = gr.Button("πŸ“₯ Load Model", variant="secondary")
309
+ model_status = gr.Textbox(
310
+ label="Model Status",
311
+ interactive=False,
312
+ value="No model loaded"
313
+ )
314
+
315
+ input_img = gr.Image(
316
+ label="πŸ“„ Upload Document Image",
317
+ interactive=True,
318
+ type="pil"
319
+ )
320
+
321
+ with gr.Row():
322
+ clear = gr.Button("πŸ—‘οΈ Clear")
323
+ predict = gr.Button("πŸ” Detect Layout", interactive=True, variant="primary")
324
+
325
+ with gr.Row():
326
+ conf_threshold = gr.Slider(
327
+ label="Confidence Threshold",
328
+ minimum=0.0,
329
+ maximum=1.0,
330
+ step=0.05,
331
+ value=0.6,
332
+ info="Minimum confidence score for detections"
333
+ )
334
+
335
+ with gr.Row():
336
+ iou_threshold = gr.Slider(
337
+ label="NMS IoU Threshold",
338
+ minimum=0.0,
339
+ maximum=1.0,
340
+ step=0.05,
341
+ value=0.5,
342
+ info="IoU threshold for Non-Maximum Suppression"
343
+ )
344
+
345
+ nms_method = gr.Radio(
346
+ choices=["Custom IoMin", "Standard IoU"],
347
+ value="Custom IoMin",
348
+ label="NMS Method",
349
+ info="Choose NMS algorithm"
350
+ )
351
+
352
+ # Legend
353
+ with gr.Accordion("πŸ“‹ Detected Classes", open=False):
354
+ legend_html = "<div style='display: grid; grid-template-columns: repeat(2, 1fr); gap: 10px;'>"
355
+ for class_id, class_name in classes_map.items():
356
+ color = colors[class_id % len(colors)]
357
+ legend_html += f"""
358
+ <div style='display: flex; align-items: center; padding: 5px;'>
359
+ <div style='width: 20px; height: 20px; background-color: {color}; margin-right: 10px; border: 1px solid #ccc;'></div>
360
+ <span>{class_name}</span>
361
+ </div>
362
+ """
363
+ legend_html += "</div>"
364
+ gr.HTML(legend_html)
365
+
366
+ with gr.Column():
367
+ gr.HTML("<h3>🎯 Detection Results</h3>")
368
+ output_img = gr.Image(
369
+ label="Detected Layout",
370
+ interactive=False,
371
+ type="numpy"
372
+ )
373
+
374
+ detection_info = gr.Textbox(
375
+ label="Detection Info",
376
+ interactive=False,
377
+ value=""
378
+ )
379
+
380
+ # Event handlers
381
+ load_btn.click(
382
+ load_model,
383
+ inputs=[model_dropdown],
384
+ outputs=[model_status]
385
+ )
386
+
387
+ clear.click(
388
+ gradio_reset,
389
+ inputs=None,
390
+ outputs=[input_img, output_img, detection_info]
391
+ )
392
+
393
+ predict.click(
394
+ recognize_image,
395
+ inputs=[input_img, conf_threshold, iou_threshold, nms_method],
396
+ outputs=[output_img, detection_info]
397
+ )
398
+
399
+ # Launch the demo
400
+ demo.launch(
401
+ server_name="0.0.0.0",
402
+ server_port=7860,
403
+ debug=True,
404
+ share=False
405
+ )