Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import json | |
| from PIL import Image, ImageDraw | |
| import numpy as np | |
| from transformers import ( | |
| LayoutLMv3FeatureExtractor, | |
| LayoutLMv3Tokenizer, | |
| LayoutLMv3ForTokenClassification, | |
| LayoutLMv3Config | |
| ) | |
| import pytesseract | |
| from datasets import load_dataset | |
| import os | |
| # Set up device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Constants | |
| NUM_LABELS = 5 # 0: regular text, 1: title, 2: H1, 3: H2, 4: H3 | |
| def create_student_model(num_labels=5): | |
| """Create a distilled version of LayoutLMv3""" | |
| student_config = LayoutLMv3Config( | |
| hidden_size=384, # vs 768 original | |
| num_attention_heads=6, # vs 12 original | |
| intermediate_size=1536, # vs 3072 original | |
| num_hidden_layers=8, # vs 12 original | |
| num_labels=num_labels | |
| ) | |
| model = LayoutLMv3ForTokenClassification(student_config) | |
| return model | |
| def load_model(): | |
| """Load the model and components""" | |
| print("Creating model components...") | |
| # Create feature extractor | |
| feature_extractor = LayoutLMv3FeatureExtractor( | |
| do_resize=True, | |
| size=224, | |
| apply_ocr=False, | |
| image_mean=[0.5, 0.5, 0.5], | |
| image_std=[0.5, 0.5, 0.5] | |
| ) | |
| # Create tokenizer | |
| tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") | |
| # Create student model | |
| model = create_student_model(num_labels=NUM_LABELS) | |
| model.to(device) | |
| # For demo purposes, we'll use random weights | |
| # In production, you would load your trained weights here | |
| print("Model components created successfully!") | |
| return model, feature_extractor, tokenizer | |
| def perform_ocr(image): | |
| """Extract text and bounding boxes from image using OCR""" | |
| try: | |
| # Convert PIL image to numpy array | |
| img_array = np.array(image) | |
| # Get OCR data | |
| ocr_data = pytesseract.image_to_data(img_array, output_type=pytesseract.Output.DICT) | |
| words = [] | |
| boxes = [] | |
| confidences = ocr_data['conf'] | |
| for i in range(len(ocr_data['text'])): | |
| if int(confidences[i]) > 30: # Filter low confidence | |
| word = ocr_data['text'][i].strip() | |
| if word: # Only add non-empty words | |
| x, y, w, h = (ocr_data['left'][i], ocr_data['top'][i], | |
| ocr_data['width'][i], ocr_data['height'][i]) | |
| # Normalize coordinates | |
| img_width, img_height = image.size | |
| normalized_box = [ | |
| x / img_width, | |
| y / img_height, | |
| (x + w) / img_width, | |
| (y + h) / img_height | |
| ] | |
| words.append(word) | |
| boxes.append(normalized_box) | |
| return words, boxes | |
| except Exception as e: | |
| print(f"OCR failed: {e}") | |
| return ["sample", "text"], [[0, 0, 0.5, 0.1], [0.5, 0, 1.0, 0.1]] | |
| def extract_headings_from_image(image, model, feature_extractor, tokenizer): | |
| """Extract headings from uploaded image using the model""" | |
| try: | |
| # Perform OCR to get words and boxes | |
| words, boxes = perform_ocr(image) | |
| if not words: | |
| return {"ERROR": ["No text found in image"]} | |
| # Prepare inputs for the model | |
| # Process image | |
| pixel_values = feature_extractor(image, return_tensors="pt")["pixel_values"] | |
| pixel_values = pixel_values.to(device) | |
| # Process text and boxes (limit to first 512 tokens) | |
| max_words = min(len(words), 500) # Leave room for special tokens | |
| words = words[:max_words] | |
| boxes = boxes[:max_words] | |
| # Convert boxes to the format expected by LayoutLMv3 (0-1000 scale) | |
| scaled_boxes = [] | |
| for box in boxes: | |
| scaled_box = [ | |
| int(box[0] * 1000), | |
| int(box[1] * 1000), | |
| int(box[2] * 1000), | |
| int(box[3] * 1000) | |
| ] | |
| scaled_boxes.append(scaled_box) | |
| # Tokenize | |
| encoding = tokenizer( | |
| words, | |
| boxes=scaled_boxes, | |
| max_length=512, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| # Move to device | |
| input_ids = encoding["input_ids"].to(device) | |
| attention_mask = encoding["attention_mask"].to(device) | |
| bbox = encoding["bbox"].to(device) | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| bbox=bbox, | |
| pixel_values=pixel_values | |
| ) | |
| # Get predictions | |
| predictions = torch.argmax(outputs.logits, dim=-1).cpu().numpy()[0] | |
| # Map predictions back to words | |
| word_ids = encoding.word_ids(batch_index=0) | |
| # Extract headings by label | |
| headings = {"TITLE": [], "H1": [], "H2": [], "H3": []} | |
| label_map = {0: "TEXT", 1: "TITLE", 2: "H1", 3: "H2", 4: "H3"} | |
| current_heading = {"text": "", "level": None} | |
| for i, (word_id, pred) in enumerate(zip(word_ids, predictions)): | |
| if word_id is not None and word_id < len(words): | |
| predicted_label = label_map.get(pred, "TEXT") | |
| if predicted_label != "TEXT": | |
| if current_heading["level"] == predicted_label: | |
| # Continue building current heading | |
| current_heading["text"] += " " + words[word_id] | |
| else: | |
| # Save previous heading if it exists | |
| if current_heading["text"] and current_heading["level"]: | |
| headings[current_heading["level"]].append(current_heading["text"].strip()) | |
| # Start new heading | |
| current_heading = {"text": words[word_id], "level": predicted_label} | |
| else: | |
| # Save current heading when we hit regular text | |
| if current_heading["text"] and current_heading["level"]: | |
| headings[current_heading["level"]].append(current_heading["text"].strip()) | |
| current_heading = {"text": "", "level": None} | |
| # Save final heading | |
| if current_heading["text"] and current_heading["level"]: | |
| headings[current_heading["level"]].append(current_heading["text"].strip()) | |
| # Remove empty lists and return | |
| headings = {k: v for k, v in headings.items() if v} | |
| if not headings: | |
| return {"INFO": ["No headings detected - this might be a model training issue"]} | |
| return headings | |
| except Exception as e: | |
| return {"ERROR": [f"Processing failed: {str(e)}"]} | |
| # Load model (this will happen when the Space starts) | |
| print("Loading model...") | |
| model, feature_extractor, tokenizer = load_model() | |
| print("Model loaded successfully!") | |
| def process_document(image): | |
| """Main function to process uploaded document""" | |
| if image is None: | |
| return "Please upload an image" | |
| print("Processing uploaded image...") | |
| # Extract headings | |
| headings = extract_headings_from_image(image, model, feature_extractor, tokenizer) | |
| # Format output | |
| result = "## Extracted Document Structure:\n\n" | |
| if "ERROR" in headings: | |
| result += f"❌ **Error:** {headings['ERROR'][0]}\n" | |
| return result | |
| if "INFO" in headings: | |
| result += f"ℹ️ **Info:** {headings['INFO'][0]}\n" | |
| return result | |
| # Display found headings | |
| for level, texts in headings.items(): | |
| result += f"**{level}:**\n" | |
| for text in texts: | |
| if level == "TITLE": | |
| result += f"# {text}\n" | |
| elif level == "H1": | |
| result += f"## {text}\n" | |
| elif level == "H2": | |
| result += f"### {text}\n" | |
| elif level == "H3": | |
| result += f"#### {text}\n" | |
| result += "\n" | |
| if not any(headings.values()): | |
| result += "⚠️ No headings were detected in this image.\n\n" | |
| result += "**Possible reasons:**\n" | |
| result += "- The model needs training on actual data\n" | |
| result += "- The image quality is too low\n" | |
| result += "- The document doesn't contain clear headings\n" | |
| return result | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=process_document, | |
| inputs=gr.Image(type="pil", label="Upload Document Image"), | |
| outputs=gr.Markdown(label="Extracted Headings"), | |
| title="📄 PDF Heading Extractor", | |
| description=""" | |
| Upload an image of a document to extract its heading hierarchy. | |
| **Note:** This is a demo version using an untrained model. | |
| The actual model would need to be trained on DocLayNet data for accurate results. | |
| """, | |
| examples=None, | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |