enpaiva's picture
Update app.py
a092a2b verified
raw
history blame
16.2 kB
import os
os.environ["GRADIO_TEMP_DIR"] = "./tmp"
import sys
import torch
import torchvision
import gradio as gr
import numpy as np
import cv2
from PIL import Image
from transformers import (
DFineForObjectDetection,
RTDetrV2ForObjectDetection,
RTDetrImageProcessor,
)
# == Device configuration ==
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# == Model configurations ==
MODELS = {
"Docling Layout Egret XLarge": {
"path": "ds4sd/docling-layout-egret-xlarge",
"model_class": DFineForObjectDetection
},
"Docling Layout Egret Large": {
"path": "ds4sd/docling-layout-egret-large",
"model_class": DFineForObjectDetection
},
"Docling Layout Egret Medium": {
"path": "ds4sd/docling-layout-egret-medium",
"model_class": DFineForObjectDetection
},
"Docling Layout Heron 101": {
"path": "ds4sd/docling-layout-heron-101",
"model_class": RTDetrV2ForObjectDetection
},
"Docling Layout Heron": {
"path": "ds4sd/docling-layout-heron",
"model_class": RTDetrV2ForObjectDetection
}
}
# == Class mappings ==
classes_map = {
0: "Caption", 1: "Footnote", 2: "Formula", 3: "List-item",
4: "Page-footer", 5: "Page-header", 6: "Picture", 7: "Section-header",
8: "Table", 9: "Text", 10: "Title", 11: "Document Index",
12: "Code", 13: "Checkbox-Selected", 14: "Checkbox-Unselected",
15: "Form", 16: "Key-Value Region",
}
# == Global model variables ==
current_model = None
current_processor = None
current_model_name = None
cached_results = None # Para guardar los resultados y poder cambiar labels sin reprocesar
def colormap(N=256, normalized=False):
"""Generate dynamic colormap."""
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)
cmap = np.zeros((N, 3), dtype=np.uint8)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << (7 - j))
g = g | (bitget(c, 1) << (7 - j))
b = b | (bitget(c, 2) << (7 - j))
c = c >> 3
cmap[i] = np.array([r, g, b])
if normalized:
cmap = cmap.astype(np.float32) / 255.0
return cmap
def iomin(box1, box2):
"""Intersection over Minimum (IoMin)."""
x1 = torch.max(box1[:, 0], box2[:, 0])
y1 = torch.max(box1[:, 1], box2[:, 1])
x2 = torch.min(box1[:, 2], box2[:, 2])
y2 = torch.min(box1[:, 3], box2[:, 3])
inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
min_area = torch.min(box1_area, box2_area)
return inter_area / min_area
def nms_custom(boxes, scores, iou_threshold=0.5):
"""Custom NMS implementation using IoMin."""
keep = []
_, order = scores.sort(descending=True)
while order.numel() > 0:
i = order[0]
keep.append(i.item())
if order.numel() == 1:
break
box_i = boxes[i].unsqueeze(0)
rest = order[1:]
ious = iomin(box_i, boxes[rest])
mask = (ious <= iou_threshold)
order = order[1:][mask]
return torch.tensor(keep, dtype=torch.long)
def load_model_if_needed(model_name):
"""Load the selected model if not already loaded."""
global current_model, current_processor, current_model_name
if current_model_name == model_name and current_model is not None:
return True
try:
model_info = MODELS[model_name]
model_path = model_info["path"]
model_class = model_info["model_class"]
print(f"Loading {model_name} from {model_path}")
processor = RTDetrImageProcessor.from_pretrained(model_path)
model = model_class.from_pretrained(model_path)
model = model.to(device)
model.eval()
current_processor = processor
current_model = model
current_model_name = model_name
return True
except Exception as e:
print(f"Error loading model: {e}")
return False
def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3, show_labels=True):
"""Visualize bounding boxes with OpenCV."""
if isinstance(image_input, Image.Image):
image = np.array(image_input)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
elif isinstance(image_input, np.ndarray):
if len(image_input.shape) == 3 and image_input.shape[2] == 3:
image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
else:
image = image_input.copy()
else:
raise ValueError("Input must be PIL Image or numpy array")
if len(bboxes) == 0:
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
overlay = image.copy()
cmap = colormap(N=len(id_to_names), normalized=False)
for i in range(len(bboxes)):
try:
bbox = bboxes[i]
if torch.is_tensor(bbox):
bbox = bbox.cpu().numpy()
class_id = classes[i]
if torch.is_tensor(class_id):
class_id = class_id.item()
score = scores[i]
if torch.is_tensor(score):
score = score.item()
x_min, y_min, x_max, y_max = map(int, bbox)
class_id = int(class_id)
class_name = id_to_names.get(class_id, f"unknown_{class_id}")
color = tuple(int(c) for c in cmap[class_id % len(cmap)])
# Draw filled rectangle on overlay
cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1)
# Draw border on main image
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3)
# Add text label only if show_labels is True
if show_labels:
text = f"{class_name}: {score:.3f}"
(text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)
cv2.rectangle(image, (x_min, y_min - text_height - baseline - 4),
(x_min + text_width + 8, y_min), color, -1)
cv2.putText(image, text, (x_min + 4, y_min - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
except Exception as e:
print(f"Skipping box {i} due to error: {e}")
# Apply transparency
cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
def toggle_labels_visualization(show_labels, alpha):
"""Toggle labels without reprocessing the image."""
global cached_results
if cached_results is None:
return None, "⚠️ No cached results. Please analyze an image first."
input_img, boxes, labels, scores = cached_results
output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels)
labels_status = "with labels" if show_labels else "without labels"
info = f"βœ… Visualization updated ({labels_status}) | {len(boxes)} detections"
return output, info
def process_image(input_img, model_name, conf_threshold, iou_threshold, nms_method, alpha, show_labels):
"""Process image with document layout detection."""
global cached_results
if input_img is None:
return None, "❌ Please upload an image first."
# Load model if needed
if not load_model_if_needed(model_name):
return None, f"❌ Failed to load model {model_name}."
try:
# Prepare image
if isinstance(input_img, np.ndarray):
input_img = Image.fromarray(input_img)
if input_img.mode != 'RGB':
input_img = input_img.convert('RGB')
# Process with model
inputs = current_processor(images=[input_img], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = current_model(**inputs)
# Post-process results
results = current_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([input_img.size[::-1]]),
threshold=conf_threshold,
)
if not results or len(results) == 0:
cached_results = None
return np.array(input_img), "ℹ️ No detections found."
result = results[0]
boxes = result["boxes"]
scores = result["scores"]
labels = result["labels"]
if len(boxes) == 0:
cached_results = None
return np.array(input_img), f"ℹ️ No detections above threshold {conf_threshold:.2f}."
# Apply NMS
if iou_threshold < 1.0:
if nms_method == "Custom IoMin":
keep_indices = nms_custom(boxes=boxes, scores=scores, iou_threshold=iou_threshold)
else:
keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold)
boxes = boxes[keep_indices]
scores = scores[keep_indices]
labels = labels[keep_indices]
# Cache results for label toggling
cached_results = (input_img, boxes, labels, scores)
# Visualize results
output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels)
labels_status = "with labels" if show_labels else "without labels"
info = f"βœ… Found {len(boxes)} detections ({labels_status}) | Model: {model_name} | NMS: {nms_method} | Conf: {conf_threshold:.2f}"
return output, info
except Exception as e:
print(f"[ERROR] process_image failed: {e}")
cached_results = None
error_msg = f"❌ Processing error: {str(e)}"
if input_img is not None:
return np.array(input_img), error_msg
return np.zeros((512, 512, 3), dtype=np.uint8), error_msg
if __name__ == "__main__":
print(f"πŸš€ Starting Document Layout Analysis App")
print(f"πŸ“± Device: {device}")
print(f"πŸ€– Available models: {len(MODELS)}")
# Custom CSS for clean layout
custom_css = """
.gradio-container {
max-width: 100% !important;
padding: 15px !important;
}
.control-panel {
background: #f8f9fa;
border-radius: 12px;
border: 1px solid #e9ecef;
padding: 20px;
margin-bottom: 15px;
}
.results-panel {
background: #f8f9fa;
border-radius: 12px;
border: 1px solid #e9ecef;
padding: 20px;
min-height: 600px;
}
"""
# Create Gradio interface
with gr.Blocks(
title="πŸ“„ Document Layout Analysis",
theme=gr.themes.Soft(),
css=custom_css
) as demo:
# Header
gr.HTML("""
<div style='text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 12px; margin-bottom: 20px;'>
<h1 style='margin: 0; font-size: 2.5em;'>πŸ” Document Layout Analysis</h1>
<p style='margin: 8px 0 0 0; font-size: 1.1em; opacity: 0.9;'>Advanced document structure detection with Docling models </p>
</div>
""")
# Main content in two columns
with gr.Row():
# LEFT COLUMN - Controls (more compact)
with gr.Column(scale=1):
with gr.Group(elem_classes=["control-panel"]):
# 1. Image Upload (first)
gr.HTML("<h3>πŸ“„ Upload Image</h3>")
input_img = gr.Image(
label="Document Image",
type="pil",
height=300,
interactive=True
)
# gr.HTML("<br><h3>πŸ€– Model Selection</h3>")
# 2. Model Selection (second, without buttons)
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="Docling Layout Egret XLarge",
label="AI Model",
info="Model will be loaded automatically",
interactive=True
)
# gr.HTML("<br><h3>βš™οΈ Parameters</h3>")
# 3. All parameters together (third)
with gr.Row():
conf_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.6, step=0.05,
label="Confidence", info="Detection threshold"
)
iou_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.05,
label="NMS IoU", info="Suppression threshold"
)
with gr.Row():
nms_method = gr.Radio(
choices=["Standard IoU", "Custom IoMin"],
value="Standard IoU",
label="NMS Method", scale=2
)
alpha_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.3, step=0.1,
label="Transparency", scale=1
)
# gr.HTML("<br>")
# 4. Analyze button (last)
analyze_btn = gr.Button("πŸ” Analyze Document", variant="primary", size="lg")
# RIGHT COLUMN - Results
with gr.Column(scale=1):
with gr.Group(elem_classes=["results-panel"]):
gr.HTML("<h3>🎯 Analysis Results</h3>")
output_img = gr.Image(
label="Detected Layout",
type="numpy",
height=450,
interactive=False
)
detection_info = gr.Textbox(
label="Detection Summary",
value="",
interactive=False,
lines=2,
placeholder="Results will appear here..."
)
# Labels toggle (independent control)
# gr.HTML("<h4>🎨 Visualization</h4>")
show_labels_checkbox = gr.Checkbox(
value=True,
label="Show Class Labels",
info="Toggle labels without reprocessing",
interactive=True
)
# Event Handlers
# Main analysis (full processing)
analyze_btn.click(
fn=process_image,
inputs=[input_img, model_dropdown, conf_threshold, iou_threshold, nms_method, alpha_slider, show_labels_checkbox],
outputs=[output_img, detection_info]
)
# Independent label toggle (no reprocessing)
show_labels_checkbox.change(
fn=toggle_labels_visualization,
inputs=[show_labels_checkbox, alpha_slider],
outputs=[output_img, detection_info]
)
# Also update visualization when transparency changes (if we have cached results)
alpha_slider.change(
fn=toggle_labels_visualization,
inputs=[show_labels_checkbox, alpha_slider],
outputs=[output_img, detection_info]
)
# Launch application
demo.launch(
server_name="0.0.0.0",
server_port=7860,
debug=True,
share=False,
show_error=True,
inbrowser=True
)