import gradio as gr import torch import json from PIL import Image, ImageDraw import numpy as np from transformers import ( LayoutLMv3FeatureExtractor, LayoutLMv3Tokenizer, LayoutLMv3ForTokenClassification, LayoutLMv3Config ) import pytesseract from datasets import load_dataset import os # Set up device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Constants NUM_LABELS = 5 # 0: regular text, 1: title, 2: H1, 3: H2, 4: H3 def create_student_model(num_labels=5): """Create a distilled version of LayoutLMv3""" student_config = LayoutLMv3Config( hidden_size=384, # vs 768 original num_attention_heads=6, # vs 12 original intermediate_size=1536, # vs 3072 original num_hidden_layers=8, # vs 12 original num_labels=num_labels ) model = LayoutLMv3ForTokenClassification(student_config) return model def load_model(): """Load the model and components""" print("Creating model components...") # Create feature extractor feature_extractor = LayoutLMv3FeatureExtractor( do_resize=True, size=224, apply_ocr=False, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5] ) # Create tokenizer tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") # Create student model model = create_student_model(num_labels=NUM_LABELS) model.to(device) # For demo purposes, we'll use random weights # In production, you would load your trained weights here print("Model components created successfully!") return model, feature_extractor, tokenizer def perform_ocr(image): """Extract text and bounding boxes from image using OCR""" try: # Convert PIL image to numpy array img_array = np.array(image) # Get OCR data ocr_data = pytesseract.image_to_data(img_array, output_type=pytesseract.Output.DICT) words = [] boxes = [] confidences = ocr_data['conf'] for i in range(len(ocr_data['text'])): if int(confidences[i]) > 30: # Filter low confidence word = ocr_data['text'][i].strip() if word: # Only add non-empty words x, y, w, h = (ocr_data['left'][i], ocr_data['top'][i], ocr_data['width'][i], ocr_data['height'][i]) # Normalize coordinates img_width, img_height = image.size normalized_box = [ x / img_width, y / img_height, (x + w) / img_width, (y + h) / img_height ] words.append(word) boxes.append(normalized_box) return words, boxes except Exception as e: print(f"OCR failed: {e}") return ["sample", "text"], [[0, 0, 0.5, 0.1], [0.5, 0, 1.0, 0.1]] def extract_headings_from_image(image, model, feature_extractor, tokenizer): """Extract headings from uploaded image using the model""" try: # Perform OCR to get words and boxes words, boxes = perform_ocr(image) if not words: return {"ERROR": ["No text found in image"]} # Prepare inputs for the model # Process image pixel_values = feature_extractor(image, return_tensors="pt")["pixel_values"] pixel_values = pixel_values.to(device) # Process text and boxes (limit to first 512 tokens) max_words = min(len(words), 500) # Leave room for special tokens words = words[:max_words] boxes = boxes[:max_words] # Convert boxes to the format expected by LayoutLMv3 (0-1000 scale) scaled_boxes = [] for box in boxes: scaled_box = [ int(box[0] * 1000), int(box[1] * 1000), int(box[2] * 1000), int(box[3] * 1000) ] scaled_boxes.append(scaled_box) # Tokenize encoding = tokenizer( words, boxes=scaled_boxes, max_length=512, padding="max_length", truncation=True, return_tensors="pt" ) # Move to device input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) bbox = encoding["bbox"].to(device) # Run inference with torch.no_grad(): outputs = model( input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, pixel_values=pixel_values ) # Get predictions predictions = torch.argmax(outputs.logits, dim=-1).cpu().numpy()[0] # Map predictions back to words word_ids = encoding.word_ids(batch_index=0) # Extract headings by label headings = {"TITLE": [], "H1": [], "H2": [], "H3": []} label_map = {0: "TEXT", 1: "TITLE", 2: "H1", 3: "H2", 4: "H3"} current_heading = {"text": "", "level": None} for i, (word_id, pred) in enumerate(zip(word_ids, predictions)): if word_id is not None and word_id < len(words): predicted_label = label_map.get(pred, "TEXT") if predicted_label != "TEXT": if current_heading["level"] == predicted_label: # Continue building current heading current_heading["text"] += " " + words[word_id] else: # Save previous heading if it exists if current_heading["text"] and current_heading["level"]: headings[current_heading["level"]].append(current_heading["text"].strip()) # Start new heading current_heading = {"text": words[word_id], "level": predicted_label} else: # Save current heading when we hit regular text if current_heading["text"] and current_heading["level"]: headings[current_heading["level"]].append(current_heading["text"].strip()) current_heading = {"text": "", "level": None} # Save final heading if current_heading["text"] and current_heading["level"]: headings[current_heading["level"]].append(current_heading["text"].strip()) # Remove empty lists and return headings = {k: v for k, v in headings.items() if v} if not headings: return {"INFO": ["No headings detected - this might be a model training issue"]} return headings except Exception as e: return {"ERROR": [f"Processing failed: {str(e)}"]} # Load model (this will happen when the Space starts) print("Loading model...") model, feature_extractor, tokenizer = load_model() print("Model loaded successfully!") def process_document(image): """Main function to process uploaded document""" if image is None: return "Please upload an image" print("Processing uploaded image...") # Extract headings headings = extract_headings_from_image(image, model, feature_extractor, tokenizer) # Format output result = "## Extracted Document Structure:\n\n" if "ERROR" in headings: result += f"❌ **Error:** {headings['ERROR'][0]}\n" return result if "INFO" in headings: result += f"ℹ️ **Info:** {headings['INFO'][0]}\n" return result # Display found headings for level, texts in headings.items(): result += f"**{level}:**\n" for text in texts: if level == "TITLE": result += f"# {text}\n" elif level == "H1": result += f"## {text}\n" elif level == "H2": result += f"### {text}\n" elif level == "H3": result += f"#### {text}\n" result += "\n" if not any(headings.values()): result += "⚠️ No headings were detected in this image.\n\n" result += "**Possible reasons:**\n" result += "- The model needs training on actual data\n" result += "- The image quality is too low\n" result += "- The document doesn't contain clear headings\n" return result # Create Gradio interface demo = gr.Interface( fn=process_document, inputs=gr.Image(type="pil", label="Upload Document Image"), outputs=gr.Markdown(label="Extracted Headings"), title="📄 PDF Heading Extractor", description=""" Upload an image of a document to extract its heading hierarchy. **Note:** This is a demo version using an untrained model. The actual model would need to be trained on DocLayNet data for accurate results. """, examples=None, allow_flagging="never" ) if __name__ == "__main__": demo.launch()