enpaiva commited on
Commit
fc2ea23
Β·
verified Β·
1 Parent(s): 62ca522

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -156
app.py CHANGED
@@ -54,6 +54,7 @@ classes_map = {
54
  current_model = None
55
  current_processor = None
56
  current_model_name = None
 
57
 
58
  def colormap(N=256, normalized=False):
59
  """Generate dynamic colormap."""
@@ -110,12 +111,12 @@ def nms_custom(boxes, scores, iou_threshold=0.5):
110
 
111
  return torch.tensor(keep, dtype=torch.long)
112
 
113
- def load_model(model_name):
114
- """Load the selected model automatically."""
115
  global current_model, current_processor, current_model_name
116
 
117
- if current_model_name == model_name:
118
- return current_model, current_processor
119
 
120
  try:
121
  model_info = MODELS[model_name]
@@ -133,11 +134,11 @@ def load_model(model_name):
133
  current_model = model
134
  current_model_name = model_name
135
 
136
- return model, processor
137
 
138
  except Exception as e:
139
  print(f"Error loading model: {e}")
140
- return None, None
141
 
142
  def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3, show_labels=True):
143
  """Visualize bounding boxes with OpenCV."""
@@ -199,15 +200,32 @@ def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3,
199
 
200
  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  def process_image(input_img, model_name, conf_threshold, iou_threshold, nms_method, alpha, show_labels):
203
  """Process image with document layout detection."""
 
 
204
  if input_img is None:
205
  return None, "❌ Please upload an image first."
206
 
207
  # Load model if needed
208
- model, processor = load_model(model_name)
209
- if model is None or processor is None:
210
- return None, f"❌ Error loading model {model_name}."
211
 
212
  try:
213
  # Prepare image
@@ -218,20 +236,21 @@ def process_image(input_img, model_name, conf_threshold, iou_threshold, nms_meth
218
  input_img = input_img.convert('RGB')
219
 
220
  # Process with model
221
- inputs = processor(images=[input_img], return_tensors="pt")
222
  inputs = {k: v.to(device) for k, v in inputs.items()}
223
 
224
  with torch.no_grad():
225
- outputs = model(**inputs)
226
 
227
  # Post-process results
228
- results = processor.post_process_object_detection(
229
  outputs,
230
  target_sizes=torch.tensor([input_img.size[::-1]]),
231
  threshold=conf_threshold,
232
  )
233
 
234
  if not results or len(results) == 0:
 
235
  return np.array(input_img), "ℹ️ No detections found."
236
 
237
  result = results[0]
@@ -240,6 +259,7 @@ def process_image(input_img, model_name, conf_threshold, iou_threshold, nms_meth
240
  labels = result["labels"]
241
 
242
  if len(boxes) == 0:
 
243
  return np.array(input_img), f"ℹ️ No detections above threshold {conf_threshold:.2f}."
244
 
245
  # Apply NMS
@@ -247,23 +267,26 @@ def process_image(input_img, model_name, conf_threshold, iou_threshold, nms_meth
247
  if nms_method == "Custom IoMin":
248
  keep_indices = nms_custom(boxes=boxes, scores=scores, iou_threshold=iou_threshold)
249
  else:
250
- # Use torchvision NMS with correct format
251
  keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold)
252
 
253
  boxes = boxes[keep_indices]
254
  scores = scores[keep_indices]
255
  labels = labels[keep_indices]
256
 
 
 
 
257
  # Visualize results
258
  output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels)
259
 
260
  labels_status = "with labels" if show_labels else "without labels"
261
- info = f"βœ… Found {len(boxes)} detections ({labels_status}) | Model: {model_name} | Confidence: {conf_threshold:.2f}"
262
 
263
  return output, info
264
 
265
  except Exception as e:
266
  print(f"[ERROR] process_image failed: {e}")
 
267
  error_msg = f"❌ Processing error: {str(e)}"
268
  if input_img is not None:
269
  return np.array(input_img), error_msg
@@ -274,43 +297,27 @@ if __name__ == "__main__":
274
  print(f"πŸ“± Device: {device}")
275
  print(f"πŸ€– Available models: {len(MODELS)}")
276
 
277
- # Custom CSS for compact layout
278
  custom_css = """
279
  .gradio-container {
280
- max-width: 1400px !important;
281
- margin: 0 auto !important;
282
- padding: 20px !important;
283
  }
284
 
285
- .controls-container {
286
  background: #f8f9fa;
287
  border-radius: 12px;
288
- border: 1px solid #dee2e6;
289
  padding: 20px;
290
- margin-bottom: 20px;
291
  }
292
 
293
- .results-container {
294
- background: #ffffff;
295
  border-radius: 12px;
296
- border: 1px solid #dee2e6;
297
  padding: 20px;
298
- }
299
-
300
- .section-divider {
301
- border-top: 2px solid #e9ecef;
302
- margin: 20px 0;
303
- padding-top: 20px;
304
- }
305
-
306
- .analyze-btn {
307
- background: linear-gradient(45deg, #667eea, #764ba2) !important;
308
- border: none !important;
309
- color: white !important;
310
- font-weight: bold !important;
311
- font-size: 18px !important;
312
- padding: 15px 30px !important;
313
- border-radius: 10px !important;
314
  }
315
  """
316
 
@@ -323,133 +330,113 @@ if __name__ == "__main__":
323
 
324
  # Header
325
  gr.HTML("""
326
- <div style='text-align: center; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 15px; margin-bottom: 30px;'>
327
- <h1 style='margin: 0; font-size: 2.5em; text-shadow: 2px 2px 4px rgba(0,0,0,0.3);'>πŸ” Document Layout Analysis</h1>
328
- <p style='margin: 10px 0 0 0; font-size: 1.2em; opacity: 0.9;'>Compact interface for advanced document structure detection</p>
329
  </div>
330
  """)
331
-
332
- # LEFT COLUMN - Controls and Input
333
- with gr.Column(scale=1):
334
 
335
- # Controls Section
336
- with gr.Group(elem_classes=["controls-container"]):
337
- # 1. Image Upload (First)
338
- gr.HTML("<h3 style='margin-top: 0;'>πŸ“„ Upload Document</h3>")
339
- input_img = gr.Image(
340
- label="Document Image",
341
- type="pil",
342
- height=300,
343
- interactive=True
344
- )
345
-
346
- # Divider
347
- gr.HTML("<div class='section-divider'></div>")
348
-
349
- # 2. Model Selection (Second)
350
- gr.HTML("<h3>πŸ€– Model Selection</h3>")
351
- model_dropdown = gr.Dropdown(
352
- choices=list(MODELS.keys()),
353
- value="Egret XLarge",
354
- label="AI Model",
355
- info="Model will load automatically when analyzing",
356
- interactive=True
357
- )
358
-
359
- # Divider
360
- gr.HTML("<div class='section-divider'></div>")
361
-
362
- # 3. Detection Parameters (Third)
363
- gr.HTML("<h3>βš™οΈ Detection Settings</h3>")
364
-
365
- with gr.Row():
366
- conf_threshold = gr.Slider(
367
- minimum=0.0,
368
- maximum=1.0,
369
- value=0.6,
370
- step=0.05,
371
- label="Confidence Threshold",
372
- info="Minimum confidence for detections"
373
  )
374
 
375
- iou_threshold = gr.Slider(
376
- minimum=0.0,
377
- maximum=1.0,
378
- value=0.5,
379
- step=0.05,
380
- label="NMS IoU Threshold",
381
- info="Non-maximum suppression threshold"
 
382
  )
383
-
384
- with gr.Row():
385
- nms_method = gr.Radio(
386
- choices=["Custom IoMin", "Standard IoU"],
387
- value="Custom IoMin",
388
- label="NMS Algorithm",
389
- info="Choose suppression method"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  )
391
 
392
- alpha_slider = gr.Slider(
393
- minimum=0.0,
394
- maximum=1.0,
395
- value=0.3,
396
- step=0.1,
397
- label="Overlay Transparency",
398
- info="Transparency of detection overlays"
 
 
 
 
 
 
 
 
399
  )
400
-
401
- show_labels_checkbox = gr.Checkbox(
402
- value=True,
403
- label="Show Class Labels and Confidence Scores",
404
- info="Display detection labels on the output image",
405
- interactive=True
406
- )
407
-
408
- # Divider
409
- gr.HTML("<div class='section-divider'></div>")
410
-
411
- # 4. Analyze Button (Last)
412
- detect_btn = gr.Button(
413
- "πŸ” Analyze Document",
414
- variant="primary",
415
- size="lg",
416
- elem_classes=["analyze-btn"]
417
- )
418
-
419
- # RIGHT COLUMN - Results and Output
420
- with gr.Column(scale=1):
421
 
422
- # Results Section
423
- with gr.Group(elem_classes=["results-container"]):
424
- gr.HTML("<h3 style='margin-top: 0;'>🎯 Analysis Results</h3>")
425
-
426
- output_img = gr.Image(
427
- label="Analyzed Document",
428
- type="numpy",
429
- height=600,
430
- interactive=False
431
- )
432
-
433
- detection_info = gr.Textbox(
434
- label="Detection Summary",
435
- value="Ready for analysis. Upload an image and click 'Analyze Document'.",
436
- interactive=False,
437
- lines=2,
438
- show_copy_button=True
439
- )
440
-
441
- # Event Handler
442
- detect_btn.click(
443
  fn=process_image,
444
- inputs=[
445
- input_img,
446
- model_dropdown,
447
- conf_threshold,
448
- iou_threshold,
449
- nms_method,
450
- alpha_slider,
451
- show_labels_checkbox
452
- ],
 
 
 
 
 
 
453
  outputs=[output_img, detection_info]
454
  )
455
 
 
54
  current_model = None
55
  current_processor = None
56
  current_model_name = None
57
+ cached_results = None # Para guardar los resultados y poder cambiar labels sin reprocesar
58
 
59
  def colormap(N=256, normalized=False):
60
  """Generate dynamic colormap."""
 
111
 
112
  return torch.tensor(keep, dtype=torch.long)
113
 
114
+ def load_model_if_needed(model_name):
115
+ """Load the selected model if not already loaded."""
116
  global current_model, current_processor, current_model_name
117
 
118
+ if current_model_name == model_name and current_model is not None:
119
+ return True
120
 
121
  try:
122
  model_info = MODELS[model_name]
 
134
  current_model = model
135
  current_model_name = model_name
136
 
137
+ return True
138
 
139
  except Exception as e:
140
  print(f"Error loading model: {e}")
141
+ return False
142
 
143
  def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3, show_labels=True):
144
  """Visualize bounding boxes with OpenCV."""
 
200
 
201
  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
202
 
203
+ def toggle_labels_visualization(show_labels, alpha):
204
+ """Toggle labels without reprocessing the image."""
205
+ global cached_results
206
+
207
+ if cached_results is None:
208
+ return None, "⚠️ No cached results. Please analyze an image first."
209
+
210
+ input_img, boxes, labels, scores = cached_results
211
+
212
+ output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels)
213
+
214
+ labels_status = "with labels" if show_labels else "without labels"
215
+ info = f"βœ… Visualization updated ({labels_status}) | {len(boxes)} detections"
216
+
217
+ return output, info
218
+
219
  def process_image(input_img, model_name, conf_threshold, iou_threshold, nms_method, alpha, show_labels):
220
  """Process image with document layout detection."""
221
+ global cached_results
222
+
223
  if input_img is None:
224
  return None, "❌ Please upload an image first."
225
 
226
  # Load model if needed
227
+ if not load_model_if_needed(model_name):
228
+ return None, f"❌ Failed to load model {model_name}."
 
229
 
230
  try:
231
  # Prepare image
 
236
  input_img = input_img.convert('RGB')
237
 
238
  # Process with model
239
+ inputs = current_processor(images=[input_img], return_tensors="pt")
240
  inputs = {k: v.to(device) for k, v in inputs.items()}
241
 
242
  with torch.no_grad():
243
+ outputs = current_model(**inputs)
244
 
245
  # Post-process results
246
+ results = current_processor.post_process_object_detection(
247
  outputs,
248
  target_sizes=torch.tensor([input_img.size[::-1]]),
249
  threshold=conf_threshold,
250
  )
251
 
252
  if not results or len(results) == 0:
253
+ cached_results = None
254
  return np.array(input_img), "ℹ️ No detections found."
255
 
256
  result = results[0]
 
259
  labels = result["labels"]
260
 
261
  if len(boxes) == 0:
262
+ cached_results = None
263
  return np.array(input_img), f"ℹ️ No detections above threshold {conf_threshold:.2f}."
264
 
265
  # Apply NMS
 
267
  if nms_method == "Custom IoMin":
268
  keep_indices = nms_custom(boxes=boxes, scores=scores, iou_threshold=iou_threshold)
269
  else:
 
270
  keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold)
271
 
272
  boxes = boxes[keep_indices]
273
  scores = scores[keep_indices]
274
  labels = labels[keep_indices]
275
 
276
+ # Cache results for label toggling
277
+ cached_results = (input_img, boxes, labels, scores)
278
+
279
  # Visualize results
280
  output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels)
281
 
282
  labels_status = "with labels" if show_labels else "without labels"
283
+ info = f"βœ… Found {len(boxes)} detections ({labels_status}) | Model: {model_name} | NMS: {nms_method} | Conf: {conf_threshold:.2f}"
284
 
285
  return output, info
286
 
287
  except Exception as e:
288
  print(f"[ERROR] process_image failed: {e}")
289
+ cached_results = None
290
  error_msg = f"❌ Processing error: {str(e)}"
291
  if input_img is not None:
292
  return np.array(input_img), error_msg
 
297
  print(f"πŸ“± Device: {device}")
298
  print(f"πŸ€– Available models: {len(MODELS)}")
299
 
300
+ # Custom CSS for clean layout
301
  custom_css = """
302
  .gradio-container {
303
+ max-width: 100% !important;
304
+ padding: 15px !important;
 
305
  }
306
 
307
+ .control-panel {
308
  background: #f8f9fa;
309
  border-radius: 12px;
310
+ border: 1px solid #e9ecef;
311
  padding: 20px;
312
+ margin-bottom: 15px;
313
  }
314
 
315
+ .results-panel {
316
+ background: #f8f9fa;
317
  border-radius: 12px;
318
+ border: 1px solid #e9ecef;
319
  padding: 20px;
320
+ min-height: 600px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  }
322
  """
323
 
 
330
 
331
  # Header
332
  gr.HTML("""
333
+ <div style='text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 12px; margin-bottom: 20px;'>
334
+ <h1 style='margin: 0; font-size: 2.5em;'>πŸ” Document Layout Analysis</h1>
335
+ <p style='margin: 8px 0 0 0; font-size: 1.1em; opacity: 0.9;'>Advanced document structure detection</p>
336
  </div>
337
  """)
 
 
 
338
 
339
+ # Main content in two columns
340
+ with gr.Row():
341
+ # LEFT COLUMN - Controls (more compact)
342
+ with gr.Column(scale=1):
343
+ with gr.Group(elem_classes=["control-panel"]):
344
+
345
+ # 1. Image Upload (first)
346
+ gr.HTML("<h3>πŸ“„ Upload Image</h3>")
347
+ input_img = gr.Image(
348
+ label="Document Image",
349
+ type="pil",
350
+ height=300,
351
+ interactive=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  )
353
 
354
+ gr.HTML("<br><h3>πŸ€– Model Selection</h3>")
355
+ # 2. Model Selection (second, without buttons)
356
+ model_dropdown = gr.Dropdown(
357
+ choices=list(MODELS.keys()),
358
+ value="Egret XLarge",
359
+ label="AI Model",
360
+ info="Model will be loaded automatically",
361
+ interactive=True
362
  )
363
+
364
+ gr.HTML("<br><h3>βš™οΈ Parameters</h3>")
365
+ # 3. All parameters together (third)
366
+ with gr.Row():
367
+ conf_threshold = gr.Slider(
368
+ minimum=0.0, maximum=1.0, value=0.6, step=0.05,
369
+ label="Confidence", info="Detection threshold"
370
+ )
371
+ iou_threshold = gr.Slider(
372
+ minimum=0.0, maximum=1.0, value=0.5, step=0.05,
373
+ label="NMS IoU", info="Suppression threshold"
374
+ )
375
+
376
+ with gr.Row():
377
+ nms_method = gr.Radio(
378
+ choices=["Custom IoMin", "Standard IoU"],
379
+ value="Custom IoMin",
380
+ label="NMS Method", scale=2
381
+ )
382
+ alpha_slider = gr.Slider(
383
+ minimum=0.0, maximum=1.0, value=0.3, step=0.1,
384
+ label="Transparency", scale=1
385
+ )
386
+
387
+ gr.HTML("<br>")
388
+ # 4. Analyze button (last)
389
+ analyze_btn = gr.Button("πŸ” Analyze Document", variant="primary", size="lg")
390
+
391
+ # RIGHT COLUMN - Results
392
+ with gr.Column(scale=1):
393
+ with gr.Group(elem_classes=["results-panel"]):
394
+ gr.HTML("<h3>🎯 Analysis Results</h3>")
395
+
396
+ output_img = gr.Image(
397
+ label="Detected Layout",
398
+ type="numpy",
399
+ height=450,
400
+ interactive=False
401
  )
402
 
403
+ detection_info = gr.Textbox(
404
+ label="Detection Summary",
405
+ value="",
406
+ interactive=False,
407
+ lines=2,
408
+ placeholder="Results will appear here..."
409
+ )
410
+
411
+ # Labels toggle (independent control)
412
+ gr.HTML("<h4>🎨 Visualization</h4>")
413
+ show_labels_checkbox = gr.Checkbox(
414
+ value=True,
415
+ label="Show Class Labels",
416
+ info="Toggle labels without reprocessing",
417
+ interactive=True
418
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
+ # Event Handlers
421
+
422
+ # Main analysis (full processing)
423
+ analyze_btn.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  fn=process_image,
425
+ inputs=[input_img, model_dropdown, conf_threshold, iou_threshold, nms_method, alpha_slider, show_labels_checkbox],
426
+ outputs=[output_img, detection_info]
427
+ )
428
+
429
+ # Independent label toggle (no reprocessing)
430
+ show_labels_checkbox.change(
431
+ fn=toggle_labels_visualization,
432
+ inputs=[show_labels_checkbox, alpha_slider],
433
+ outputs=[output_img, detection_info]
434
+ )
435
+
436
+ # Also update visualization when transparency changes (if we have cached results)
437
+ alpha_slider.change(
438
+ fn=toggle_labels_visualization,
439
+ inputs=[show_labels_checkbox, alpha_slider],
440
  outputs=[output_img, detection_info]
441
  )
442