Spaces:
Running
Running
| import os | |
| os.environ["GRADIO_TEMP_DIR"] = "./tmp" | |
| import sys | |
| import torch | |
| import torchvision | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from transformers import ( | |
| DFineForObjectDetection, | |
| RTDetrV2ForObjectDetection, | |
| RTDetrImageProcessor, | |
| ) | |
| # == Device configuration == | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # == Model configurations == | |
| MODELS = { | |
| "Docling Layout Egret XLarge": { | |
| "path": "ds4sd/docling-layout-egret-xlarge", | |
| "model_class": DFineForObjectDetection | |
| }, | |
| "Docling Layout Egret Large": { | |
| "path": "ds4sd/docling-layout-egret-large", | |
| "model_class": DFineForObjectDetection | |
| }, | |
| "Docling Layout Egret Medium": { | |
| "path": "ds4sd/docling-layout-egret-medium", | |
| "model_class": DFineForObjectDetection | |
| }, | |
| "Docling Layout Heron 101": { | |
| "path": "ds4sd/docling-layout-heron-101", | |
| "model_class": RTDetrV2ForObjectDetection | |
| }, | |
| "Docling Layout Heron": { | |
| "path": "ds4sd/docling-layout-heron", | |
| "model_class": RTDetrV2ForObjectDetection | |
| } | |
| } | |
| # == Class mappings == | |
| classes_map = { | |
| 0: "Caption", 1: "Footnote", 2: "Formula", 3: "List-item", | |
| 4: "Page-footer", 5: "Page-header", 6: "Picture", 7: "Section-header", | |
| 8: "Table", 9: "Text", 10: "Title", 11: "Document Index", | |
| 12: "Code", 13: "Checkbox-Selected", 14: "Checkbox-Unselected", | |
| 15: "Form", 16: "Key-Value Region", | |
| } | |
| # == Global model variables == | |
| current_model = None | |
| current_processor = None | |
| current_model_name = None | |
| cached_results = None # Para guardar los resultados y poder cambiar labels sin reprocesar | |
| def colormap(N=256, normalized=False): | |
| """Generate dynamic colormap.""" | |
| def bitget(byteval, idx): | |
| return ((byteval & (1 << idx)) != 0) | |
| cmap = np.zeros((N, 3), dtype=np.uint8) | |
| for i in range(N): | |
| r = g = b = 0 | |
| c = i | |
| for j in range(8): | |
| r = r | (bitget(c, 0) << (7 - j)) | |
| g = g | (bitget(c, 1) << (7 - j)) | |
| b = b | (bitget(c, 2) << (7 - j)) | |
| c = c >> 3 | |
| cmap[i] = np.array([r, g, b]) | |
| if normalized: | |
| cmap = cmap.astype(np.float32) / 255.0 | |
| return cmap | |
| def iomin(box1, box2): | |
| """Intersection over Minimum (IoMin).""" | |
| x1 = torch.max(box1[:, 0], box2[:, 0]) | |
| y1 = torch.max(box1[:, 1], box2[:, 1]) | |
| x2 = torch.min(box1[:, 2], box2[:, 2]) | |
| y2 = torch.min(box1[:, 3], box2[:, 3]) | |
| inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0) | |
| box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) | |
| box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) | |
| min_area = torch.min(box1_area, box2_area) | |
| return inter_area / min_area | |
| def nms_custom(boxes, scores, iou_threshold=0.5): | |
| """Custom NMS implementation using IoMin.""" | |
| keep = [] | |
| _, order = scores.sort(descending=True) | |
| while order.numel() > 0: | |
| i = order[0] | |
| keep.append(i.item()) | |
| if order.numel() == 1: | |
| break | |
| box_i = boxes[i].unsqueeze(0) | |
| rest = order[1:] | |
| ious = iomin(box_i, boxes[rest]) | |
| mask = (ious <= iou_threshold) | |
| order = order[1:][mask] | |
| return torch.tensor(keep, dtype=torch.long) | |
| def load_model_if_needed(model_name): | |
| """Load the selected model if not already loaded.""" | |
| global current_model, current_processor, current_model_name | |
| if current_model_name == model_name and current_model is not None: | |
| return True | |
| try: | |
| model_info = MODELS[model_name] | |
| model_path = model_info["path"] | |
| model_class = model_info["model_class"] | |
| print(f"Loading {model_name} from {model_path}") | |
| processor = RTDetrImageProcessor.from_pretrained(model_path) | |
| model = model_class.from_pretrained(model_path) | |
| model = model.to(device) | |
| model.eval() | |
| current_processor = processor | |
| current_model = model | |
| current_model_name = model_name | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return False | |
| def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3, show_labels=True): | |
| """Visualize bounding boxes with OpenCV.""" | |
| if isinstance(image_input, Image.Image): | |
| image = np.array(image_input) | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| elif isinstance(image_input, np.ndarray): | |
| if len(image_input.shape) == 3 and image_input.shape[2] == 3: | |
| image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR) | |
| else: | |
| image = image_input.copy() | |
| else: | |
| raise ValueError("Input must be PIL Image or numpy array") | |
| if len(bboxes) == 0: | |
| return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| overlay = image.copy() | |
| cmap = colormap(N=len(id_to_names), normalized=False) | |
| for i in range(len(bboxes)): | |
| try: | |
| bbox = bboxes[i] | |
| if torch.is_tensor(bbox): | |
| bbox = bbox.cpu().numpy() | |
| class_id = classes[i] | |
| if torch.is_tensor(class_id): | |
| class_id = class_id.item() | |
| score = scores[i] | |
| if torch.is_tensor(score): | |
| score = score.item() | |
| x_min, y_min, x_max, y_max = map(int, bbox) | |
| class_id = int(class_id) | |
| class_name = id_to_names.get(class_id, f"unknown_{class_id}") | |
| color = tuple(int(c) for c in cmap[class_id % len(cmap)]) | |
| # Draw filled rectangle on overlay | |
| cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1) | |
| # Draw border on main image | |
| cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3) | |
| # Add text label only if show_labels is True | |
| if show_labels: | |
| text = f"{class_name}: {score:.3f}" | |
| (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2) | |
| cv2.rectangle(image, (x_min, y_min - text_height - baseline - 4), | |
| (x_min + text_width + 8, y_min), color, -1) | |
| cv2.putText(image, text, (x_min + 4, y_min - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) | |
| except Exception as e: | |
| print(f"Skipping box {i} due to error: {e}") | |
| # Apply transparency | |
| cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) | |
| return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| def toggle_labels_visualization(show_labels, alpha): | |
| """Toggle labels without reprocessing the image.""" | |
| global cached_results | |
| if cached_results is None: | |
| return None, "β οΈ No cached results. Please analyze an image first." | |
| input_img, boxes, labels, scores = cached_results | |
| output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels) | |
| labels_status = "with labels" if show_labels else "without labels" | |
| info = f"β Visualization updated ({labels_status}) | {len(boxes)} detections" | |
| return output, info | |
| def process_image(input_img, model_name, conf_threshold, iou_threshold, nms_method, alpha, show_labels): | |
| """Process image with document layout detection.""" | |
| global cached_results | |
| if input_img is None: | |
| return None, "β Please upload an image first." | |
| # Load model if needed | |
| if not load_model_if_needed(model_name): | |
| return None, f"β Failed to load model {model_name}." | |
| try: | |
| # Prepare image | |
| if isinstance(input_img, np.ndarray): | |
| input_img = Image.fromarray(input_img) | |
| if input_img.mode != 'RGB': | |
| input_img = input_img.convert('RGB') | |
| # Process with model | |
| inputs = current_processor(images=[input_img], return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = current_model(**inputs) | |
| # Post-process results | |
| results = current_processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=torch.tensor([input_img.size[::-1]]), | |
| threshold=conf_threshold, | |
| ) | |
| if not results or len(results) == 0: | |
| cached_results = None | |
| return np.array(input_img), "βΉοΈ No detections found." | |
| result = results[0] | |
| boxes = result["boxes"] | |
| scores = result["scores"] | |
| labels = result["labels"] | |
| if len(boxes) == 0: | |
| cached_results = None | |
| return np.array(input_img), f"βΉοΈ No detections above threshold {conf_threshold:.2f}." | |
| # Apply NMS | |
| if iou_threshold < 1.0: | |
| if nms_method == "Custom IoMin": | |
| keep_indices = nms_custom(boxes=boxes, scores=scores, iou_threshold=iou_threshold) | |
| else: | |
| keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold) | |
| boxes = boxes[keep_indices] | |
| scores = scores[keep_indices] | |
| labels = labels[keep_indices] | |
| # Cache results for label toggling | |
| cached_results = (input_img, boxes, labels, scores) | |
| # Visualize results | |
| output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels) | |
| labels_status = "with labels" if show_labels else "without labels" | |
| info = f"β Found {len(boxes)} detections ({labels_status}) | Model: {model_name} | NMS: {nms_method} | Conf: {conf_threshold:.2f}" | |
| return output, info | |
| except Exception as e: | |
| print(f"[ERROR] process_image failed: {e}") | |
| cached_results = None | |
| error_msg = f"β Processing error: {str(e)}" | |
| if input_img is not None: | |
| return np.array(input_img), error_msg | |
| return np.zeros((512, 512, 3), dtype=np.uint8), error_msg | |
| if __name__ == "__main__": | |
| print(f"π Starting Document Layout Analysis App") | |
| print(f"π± Device: {device}") | |
| print(f"π€ Available models: {len(MODELS)}") | |
| # Custom CSS for clean layout | |
| custom_css = """ | |
| .gradio-container { | |
| max-width: 100% !important; | |
| padding: 15px !important; | |
| } | |
| .control-panel { | |
| background: #f8f9fa; | |
| border-radius: 12px; | |
| border: 1px solid #e9ecef; | |
| padding: 20px; | |
| margin-bottom: 15px; | |
| } | |
| .results-panel { | |
| background: #f8f9fa; | |
| border-radius: 12px; | |
| border: 1px solid #e9ecef; | |
| padding: 20px; | |
| min-height: 600px; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="π Document Layout Analysis", | |
| theme=gr.themes.Soft(), | |
| css=custom_css | |
| ) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div style='text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 12px; margin-bottom: 20px;'> | |
| <h1 style='margin: 0; font-size: 2.5em;'>π Document Layout Analysis</h1> | |
| <p style='margin: 8px 0 0 0; font-size: 1.1em; opacity: 0.9;'>Advanced document structure detection with Docling models </p> | |
| </div> | |
| """) | |
| # Main content in two columns | |
| with gr.Row(): | |
| # LEFT COLUMN - Controls (more compact) | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes=["control-panel"]): | |
| # 1. Image Upload (first) | |
| gr.HTML("<h3>π Upload Image</h3>") | |
| input_img = gr.Image( | |
| label="Document Image", | |
| type="pil", | |
| height=300, | |
| interactive=True | |
| ) | |
| # gr.HTML("<br><h3>π€ Model Selection</h3>") | |
| # 2. Model Selection (second, without buttons) | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value="Docling Layout Egret XLarge", | |
| label="AI Model", | |
| info="Model will be loaded automatically", | |
| interactive=True | |
| ) | |
| # gr.HTML("<br><h3>βοΈ Parameters</h3>") | |
| # 3. All parameters together (third) | |
| with gr.Row(): | |
| conf_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.6, step=0.05, | |
| label="Confidence", info="Detection threshold" | |
| ) | |
| iou_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.5, step=0.05, | |
| label="NMS IoU", info="Suppression threshold" | |
| ) | |
| with gr.Row(): | |
| nms_method = gr.Radio( | |
| choices=["Standard IoU", "Custom IoMin"], | |
| value="Standard IoU", | |
| label="NMS Method", scale=2 | |
| ) | |
| alpha_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.3, step=0.1, | |
| label="Transparency", scale=1 | |
| ) | |
| # gr.HTML("<br>") | |
| # 4. Analyze button (last) | |
| analyze_btn = gr.Button("π Analyze Document", variant="primary", size="lg") | |
| # RIGHT COLUMN - Results | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes=["results-panel"]): | |
| gr.HTML("<h3>π― Analysis Results</h3>") | |
| output_img = gr.Image( | |
| label="Detected Layout", | |
| type="numpy", | |
| height=450, | |
| interactive=False | |
| ) | |
| detection_info = gr.Textbox( | |
| label="Detection Summary", | |
| value="", | |
| interactive=False, | |
| lines=2, | |
| placeholder="Results will appear here..." | |
| ) | |
| # Labels toggle (independent control) | |
| # gr.HTML("<h4>π¨ Visualization</h4>") | |
| show_labels_checkbox = gr.Checkbox( | |
| value=True, | |
| label="Show Class Labels", | |
| info="Toggle labels without reprocessing", | |
| interactive=True | |
| ) | |
| # Event Handlers | |
| # Main analysis (full processing) | |
| analyze_btn.click( | |
| fn=process_image, | |
| inputs=[input_img, model_dropdown, conf_threshold, iou_threshold, nms_method, alpha_slider, show_labels_checkbox], | |
| outputs=[output_img, detection_info] | |
| ) | |
| # Independent label toggle (no reprocessing) | |
| show_labels_checkbox.change( | |
| fn=toggle_labels_visualization, | |
| inputs=[show_labels_checkbox, alpha_slider], | |
| outputs=[output_img, detection_info] | |
| ) | |
| # Also update visualization when transparency changes (if we have cached results) | |
| alpha_slider.change( | |
| fn=toggle_labels_visualization, | |
| inputs=[show_labels_checkbox, alpha_slider], | |
| outputs=[output_img, detection_info] | |
| ) | |
| # Launch application | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| debug=True, | |
| share=False, | |
| show_error=True, | |
| inbrowser=True | |
| ) |