enpaiva's picture
Update app.py
a4cb188 verified
raw
history blame
16.7 kB
import os
os.environ["GRADIO_TEMP_DIR"] = "./tmp"
import sys
import torch
import torchvision
import gradio as gr
import numpy as np
import cv2
from PIL import Image
from transformers import (
DFineForObjectDetection,
RTDetrV2ForObjectDetection,
RTDetrImageProcessor,
)
# == Device configuration ==
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# == Model configurations ==
MODELS = {
"Egret XLarge": {
"path": "ds4sd/docling-layout-egret-xlarge",
"model_class": DFineForObjectDetection
},
"Egret Large": {
"path": "ds4sd/docling-layout-egret-large",
"model_class": DFineForObjectDetection
},
"Egret Medium": {
"path": "ds4sd/docling-layout-egret-medium",
"model_class": DFineForObjectDetection
},
"Heron 101": {
"path": "ds4sd/docling-layout-heron-101",
"model_class": RTDetrV2ForObjectDetection
},
"Heron": {
"path": "ds4sd/docling-layout-heron",
"model_class": RTDetrV2ForObjectDetection
}
}
# == Class mappings ==
classes_map = {
0: "Caption", 1: "Footnote", 2: "Formula", 3: "List-item",
4: "Page-footer", 5: "Page-header", 6: "Picture", 7: "Section-header",
8: "Table", 9: "Text", 10: "Title", 11: "Document Index",
12: "Code", 13: "Checkbox-Selected", 14: "Checkbox-Unselected",
15: "Form", 16: "Key-Value Region",
}
# == Global model variables ==
current_model = None
current_processor = None
current_model_name = None
def colormap(N=256, normalized=False):
"""Generate dynamic colormap."""
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)
cmap = np.zeros((N, 3), dtype=np.uint8)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << (7 - j))
g = g | (bitget(c, 1) << (7 - j))
b = b | (bitget(c, 2) << (7 - j))
c = c >> 3
cmap[i] = np.array([r, g, b])
if normalized:
cmap = cmap.astype(np.float32) / 255.0
return cmap
def iomin(box1, box2):
"""Intersection over Minimum (IoMin)."""
x1 = torch.max(box1[:, 0], box2[:, 0])
y1 = torch.max(box1[:, 1], box2[:, 1])
x2 = torch.min(box1[:, 2], box2[:, 2])
y2 = torch.min(box1[:, 3], box2[:, 3])
inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
min_area = torch.min(box1_area, box2_area)
return inter_area / min_area
def nms_custom(boxes, scores, iou_threshold=0.5):
"""Custom NMS implementation using IoMin."""
keep = []
_, order = scores.sort(descending=True)
while order.numel() > 0:
i = order[0]
keep.append(i.item())
if order.numel() == 1:
break
box_i = boxes[i].unsqueeze(0)
rest = order[1:]
ious = iomin(box_i, boxes[rest])
mask = (ious <= iou_threshold)
order = order[1:][mask]
return torch.tensor(keep, dtype=torch.long)
def load_model(model_name):
"""Load the selected model."""
global current_model, current_processor, current_model_name
if current_model_name == model_name:
return f"βœ… Model {model_name} is already loaded!"
try:
model_info = MODELS[model_name]
model_path = model_info["path"]
model_class = model_info["model_class"]
print(f"Loading {model_name} from {model_path}")
processor = RTDetrImageProcessor.from_pretrained(model_path)
model = model_class.from_pretrained(model_path)
model = model.to(device)
model.eval()
current_processor = processor
current_model = model
current_model_name = model_name
return f"βœ… Successfully loaded {model_name}!"
except Exception as e:
print(f"Error loading model: {e}")
return f"❌ Error loading {model_name}: {str(e)}"
def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3, show_labels=True):
"""Visualize bounding boxes with OpenCV."""
if isinstance(image_input, Image.Image):
image = np.array(image_input)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
elif isinstance(image_input, np.ndarray):
if len(image_input.shape) == 3 and image_input.shape[2] == 3:
image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
else:
image = image_input.copy()
else:
raise ValueError("Input must be PIL Image or numpy array")
if len(bboxes) == 0:
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
overlay = image.copy()
cmap = colormap(N=len(id_to_names), normalized=False)
for i in range(len(bboxes)):
try:
bbox = bboxes[i]
if torch.is_tensor(bbox):
bbox = bbox.cpu().numpy()
class_id = classes[i]
if torch.is_tensor(class_id):
class_id = class_id.item()
score = scores[i]
if torch.is_tensor(score):
score = score.item()
x_min, y_min, x_max, y_max = map(int, bbox)
class_id = int(class_id)
class_name = id_to_names.get(class_id, f"unknown_{class_id}")
color = tuple(int(c) for c in cmap[class_id % len(cmap)])
# Draw filled rectangle on overlay
cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1)
# Draw border on main image
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3)
# Add text label only if show_labels is True
if show_labels:
text = f"{class_name}: {score:.3f}"
(text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)
cv2.rectangle(image, (x_min, y_min - text_height - baseline - 4),
(x_min + text_width + 8, y_min), color, -1)
cv2.putText(image, text, (x_min + 4, y_min - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
except Exception as e:
print(f"Skipping box {i} due to error: {e}")
# Apply transparency
cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
def process_image(input_img, conf_threshold, iou_threshold, nms_method, alpha, show_labels):
"""Process image with document layout detection."""
if input_img is None:
return None, "❌ Please upload an image first."
if current_model is None or current_processor is None:
return None, "❌ Please load a model first."
try:
# Prepare image
if isinstance(input_img, np.ndarray):
input_img = Image.fromarray(input_img)
if input_img.mode != 'RGB':
input_img = input_img.convert('RGB')
# Process with model
inputs = current_processor(images=[input_img], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = current_model(**inputs)
# Post-process results
results = current_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([input_img.size[::-1]]),
threshold=conf_threshold,
)
if not results or len(results) == 0:
return np.array(input_img), "ℹ️ No detections found."
result = results[0]
boxes = result["boxes"]
scores = result["scores"]
labels = result["labels"]
if len(boxes) == 0:
return np.array(input_img), f"ℹ️ No detections above threshold {conf_threshold:.2f}."
# Apply NMS
if iou_threshold < 1.0:
if nms_method == "Custom IoMin":
keep_indices = nms_custom(boxes=boxes, scores=scores, iou_threshold=iou_threshold)
else:
# Use torchvision NMS with correct format
keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold)
boxes = boxes[keep_indices]
scores = scores[keep_indices]
labels = labels[keep_indices]
# Visualize results
output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels)
labels_status = "with labels" if show_labels else "without labels"
info = f"βœ… Found {len(boxes)} detections ({labels_status}) | NMS: {nms_method} | Threshold: {conf_threshold:.2f}"
return output, info
except Exception as e:
print(f"[ERROR] process_image failed: {e}")
error_msg = f"❌ Processing error: {str(e)}"
if input_img is not None:
return np.array(input_img), error_msg
return np.zeros((512, 512, 3), dtype=np.uint8), error_msg
def reset_interface():
"""Reset all interface components."""
return gr.update(value=None), gr.update(value=None), gr.update(value="")
if __name__ == "__main__":
print(f"πŸš€ Starting Document Layout Analysis App")
print(f"πŸ“± Device: {device}")
print(f"πŸ€– Available models: {len(MODELS)}")
# Custom CSS for full-width layout
custom_css = """
.gradio-container {
max-width: 100% !important;
padding: 20px !important;
}
.main-container {
width: 100% !important;
max-width: none !important;
}
.panel-left, .panel-right {
min-height: 600px;
padding: 20px;
background: #f8f9fa;
border-radius: 12px;
border: 1px solid #e9ecef;
}
.control-section {
margin-bottom: 20px;
padding: 15px;
background: white;
border-radius: 8px;
border: 1px solid #dee2e6;
}
.status-good { color: #28a745; font-weight: bold; }
.status-error { color: #dc3545; font-weight: bold; }
.status-info { color: #17a2b8; font-weight: bold; }
.toggle-labels {
background: linear-gradient(45deg, #667eea, #764ba2) !important;
border: none !important;
color: white !important;
font-weight: bold !important;
}
"""
# Create Gradio interface
with gr.Blocks(
title="πŸ“„ Document Layout Analysis - Full Width",
theme=gr.themes.Soft(),
css=custom_css
) as demo:
# Header
gr.HTML("""
<div style='text-align: center; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 15px; margin-bottom: 30px;'>
<h1 style='margin: 0; font-size: 3em; text-shadow: 2px 2px 4px rgba(0,0,0,0.3);'>πŸ” Document Layout Analysis</h1>
<p style='margin: 10px 0 0 0; font-size: 1.3em; opacity: 0.9;'>Advanced document structure detection with multiple AI models</p>
</div>
""")
# Main content in two columns
with gr.Row():
# LEFT COLUMN - Controls and Input
with gr.Column(scale=1, elem_classes=["panel-left"]):
# Model Section
with gr.Group(elem_classes=["control-section"]):
gr.HTML("<h3>πŸ€– Model Configuration</h3>")
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="Egret XLarge",
label="Select Model",
info="Choose the AI model for document analysis",
interactive=True
)
with gr.Row():
load_btn = gr.Button("πŸ“₯ Load Model", variant="primary", scale=1)
clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary", scale=1)
model_status = gr.Textbox(
label="Model Status",
value="πŸ”„ No model loaded. Please select and load a model.",
interactive=False,
lines=2
)
# Image Upload Section
with gr.Group(elem_classes=["control-section"]):
gr.HTML("<h3>πŸ“„ Image Input</h3>")
input_img = gr.Image(
label="Upload Document Image",
type="pil",
height=400,
interactive=True
)
detect_btn = gr.Button("πŸ” Analyze Document", variant="primary", size="lg")
# Parameters Section
with gr.Group(elem_classes=["control-section"]):
gr.HTML("<h3>βš™οΈ Detection Parameters</h3>")
conf_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.6,
step=0.05,
label="Confidence Threshold",
info="Minimum confidence for detections"
)
iou_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.05,
label="NMS IoU Threshold",
info="Non-maximum suppression threshold"
)
nms_method = gr.Radio(
choices=["Custom IoMin", "Standard IoU"],
value="Custom IoMin",
label="NMS Algorithm",
info="Choose suppression method"
)
alpha_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.3,
step=0.1,
label="Overlay Transparency",
info="Transparency of detection overlays"
)
# RIGHT COLUMN - Results and Output
with gr.Column(scale=1, elem_classes=["panel-right"]):
# Results Section
with gr.Group(elem_classes=["control-section"]):
gr.HTML("<h3>🎯 Detection Results</h3>")
output_img = gr.Image(
label="Analyzed Document",
type="numpy",
height=500,
interactive=False
)
detection_info = gr.Textbox(
label="Analysis Summary",
value="",
interactive=False,
lines=3,
placeholder="Detection results will appear here..."
)
# Visualization Options Section
with gr.Group(elem_classes=["control-section"]):
gr.HTML("<h3>🎨 Visualization Options</h3>")
show_labels_checkbox = gr.Checkbox(
value=True,
label="Show Class Labels",
info="Display class names and confidence scores on detections",
interactive=True
)
# Event Handlers
load_btn.click(
fn=load_model,
inputs=[model_dropdown],
outputs=[model_status]
)
clear_btn.click(
fn=reset_interface,
outputs=[input_img, output_img, detection_info]
)
detect_btn.click(
fn=process_image,
inputs=[input_img, conf_threshold, iou_threshold, nms_method, alpha_slider, show_labels_checkbox],
outputs=[output_img, detection_info]
)
# Launch application
demo.launch(
server_name="0.0.0.0",
server_port=7860,
debug=True,
share=False,
show_error=True,
inbrowser=True
)