import os os.environ["GRADIO_TEMP_DIR"] = "./tmp" import sys import torch import torchvision import gradio as gr import numpy as np import cv2 from PIL import Image from transformers import ( DFineForObjectDetection, RTDetrV2ForObjectDetection, RTDetrImageProcessor, ) # == Device configuration == device = 'cuda' if torch.cuda.is_available() else 'cpu' # == Model configurations == MODELS = { "Docling Layout Egret XLarge": { "path": "ds4sd/docling-layout-egret-xlarge", "model_class": DFineForObjectDetection }, "Docling Layout Egret Large": { "path": "ds4sd/docling-layout-egret-large", "model_class": DFineForObjectDetection }, "Docling Layout Egret Medium": { "path": "ds4sd/docling-layout-egret-medium", "model_class": DFineForObjectDetection }, "Docling Layout Heron 101": { "path": "ds4sd/docling-layout-heron-101", "model_class": RTDetrV2ForObjectDetection }, "Docling Layout Heron": { "path": "ds4sd/docling-layout-heron", "model_class": RTDetrV2ForObjectDetection } } # == Class mappings == classes_map = { 0: "Caption", 1: "Footnote", 2: "Formula", 3: "List-item", 4: "Page-footer", 5: "Page-header", 6: "Picture", 7: "Section-header", 8: "Table", 9: "Text", 10: "Title", 11: "Document Index", 12: "Code", 13: "Checkbox-Selected", 14: "Checkbox-Unselected", 15: "Form", 16: "Key-Value Region", } # == Global model variables == current_model = None current_processor = None current_model_name = None cached_results = None # Para guardar los resultados y poder cambiar labels sin reprocesar def colormap(N=256, normalized=False): """Generate dynamic colormap.""" def bitget(byteval, idx): return ((byteval & (1 << idx)) != 0) cmap = np.zeros((N, 3), dtype=np.uint8) for i in range(N): r = g = b = 0 c = i for j in range(8): r = r | (bitget(c, 0) << (7 - j)) g = g | (bitget(c, 1) << (7 - j)) b = b | (bitget(c, 2) << (7 - j)) c = c >> 3 cmap[i] = np.array([r, g, b]) if normalized: cmap = cmap.astype(np.float32) / 255.0 return cmap def iomin(box1, box2): """Intersection over Minimum (IoMin).""" x1 = torch.max(box1[:, 0], box2[:, 0]) y1 = torch.max(box1[:, 1], box2[:, 1]) x2 = torch.min(box1[:, 2], box2[:, 2]) y2 = torch.min(box1[:, 3], box2[:, 3]) inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0) box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) min_area = torch.min(box1_area, box2_area) return inter_area / min_area def nms_custom(boxes, scores, iou_threshold=0.5): """Custom NMS implementation using IoMin.""" keep = [] _, order = scores.sort(descending=True) while order.numel() > 0: i = order[0] keep.append(i.item()) if order.numel() == 1: break box_i = boxes[i].unsqueeze(0) rest = order[1:] ious = iomin(box_i, boxes[rest]) mask = (ious <= iou_threshold) order = order[1:][mask] return torch.tensor(keep, dtype=torch.long) def load_model_if_needed(model_name): """Load the selected model if not already loaded.""" global current_model, current_processor, current_model_name if current_model_name == model_name and current_model is not None: return True try: model_info = MODELS[model_name] model_path = model_info["path"] model_class = model_info["model_class"] print(f"Loading {model_name} from {model_path}") processor = RTDetrImageProcessor.from_pretrained(model_path) model = model_class.from_pretrained(model_path) model = model.to(device) model.eval() current_processor = processor current_model = model current_model_name = model_name return True except Exception as e: print(f"Error loading model: {e}") return False def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3, show_labels=True): """Visualize bounding boxes with OpenCV.""" if isinstance(image_input, Image.Image): image = np.array(image_input) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) elif isinstance(image_input, np.ndarray): if len(image_input.shape) == 3 and image_input.shape[2] == 3: image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR) else: image = image_input.copy() else: raise ValueError("Input must be PIL Image or numpy array") if len(bboxes) == 0: return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) overlay = image.copy() cmap = colormap(N=len(id_to_names), normalized=False) for i in range(len(bboxes)): try: bbox = bboxes[i] if torch.is_tensor(bbox): bbox = bbox.cpu().numpy() class_id = classes[i] if torch.is_tensor(class_id): class_id = class_id.item() score = scores[i] if torch.is_tensor(score): score = score.item() x_min, y_min, x_max, y_max = map(int, bbox) class_id = int(class_id) class_name = id_to_names.get(class_id, f"unknown_{class_id}") color = tuple(int(c) for c in cmap[class_id % len(cmap)]) # Draw filled rectangle on overlay cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1) # Draw border on main image cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3) # Add text label only if show_labels is True if show_labels: text = f"{class_name}: {score:.3f}" (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2) cv2.rectangle(image, (x_min, y_min - text_height - baseline - 4), (x_min + text_width + 8, y_min), color, -1) cv2.putText(image, text, (x_min + 4, y_min - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) except Exception as e: print(f"Skipping box {i} due to error: {e}") # Apply transparency cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) def toggle_labels_visualization(show_labels, alpha): """Toggle labels without reprocessing the image.""" global cached_results if cached_results is None: return None, "⚠️ No cached results. Please analyze an image first." input_img, boxes, labels, scores = cached_results output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels) labels_status = "with labels" if show_labels else "without labels" info = f"✅ Visualization updated ({labels_status}) | {len(boxes)} detections" return output, info def process_image(input_img, model_name, conf_threshold, iou_threshold, nms_method, alpha, show_labels): """Process image with document layout detection.""" global cached_results if input_img is None: return None, "❌ Please upload an image first." # Load model if needed if not load_model_if_needed(model_name): return None, f"❌ Failed to load model {model_name}." try: # Prepare image if isinstance(input_img, np.ndarray): input_img = Image.fromarray(input_img) if input_img.mode != 'RGB': input_img = input_img.convert('RGB') # Process with model inputs = current_processor(images=[input_img], return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = current_model(**inputs) # Post-process results results = current_processor.post_process_object_detection( outputs, target_sizes=torch.tensor([input_img.size[::-1]]), threshold=conf_threshold, ) if not results or len(results) == 0: cached_results = None return np.array(input_img), "ℹ️ No detections found." result = results[0] boxes = result["boxes"] scores = result["scores"] labels = result["labels"] if len(boxes) == 0: cached_results = None return np.array(input_img), f"ℹ️ No detections above threshold {conf_threshold:.2f}." # Apply NMS if iou_threshold < 1.0: if nms_method == "Custom IoMin": keep_indices = nms_custom(boxes=boxes, scores=scores, iou_threshold=iou_threshold) else: keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold) boxes = boxes[keep_indices] scores = scores[keep_indices] labels = labels[keep_indices] # Cache results for label toggling cached_results = (input_img, boxes, labels, scores) # Visualize results output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels) labels_status = "with labels" if show_labels else "without labels" info = f"✅ Found {len(boxes)} detections ({labels_status}) | Model: {model_name} | NMS: {nms_method} | Conf: {conf_threshold:.2f}" return output, info except Exception as e: print(f"[ERROR] process_image failed: {e}") cached_results = None error_msg = f"❌ Processing error: {str(e)}" if input_img is not None: return np.array(input_img), error_msg return np.zeros((512, 512, 3), dtype=np.uint8), error_msg if __name__ == "__main__": print(f"🚀 Starting Document Layout Analysis App") print(f"📱 Device: {device}") print(f"🤖 Available models: {len(MODELS)}") # Custom CSS for clean layout custom_css = """ .gradio-container { max-width: 100% !important; padding: 15px !important; } .control-panel { background: #f8f9fa; border-radius: 12px; border: 1px solid #e9ecef; padding: 20px; margin-bottom: 15px; } .results-panel { background: #f8f9fa; border-radius: 12px; border: 1px solid #e9ecef; padding: 20px; min-height: 600px; } """ # Create Gradio interface with gr.Blocks( title="📄 Document Layout Analysis", theme=gr.themes.Soft(), css=custom_css ) as demo: # Header gr.HTML("""
Advanced document structure detection with Docling models