dots-ocr-idcard / app.py
tommulder's picture
Prepare for Hugging Face Spaces deployment
e300623
raw
history blame
8.38 kB
"""HF Dots.OCR Text Extraction Endpoint
This FastAPI application provides a Hugging Face Space endpoint for Dots.OCR
text extraction with ROI support and standardized field extraction schema.
"""
import logging
import time
import uuid
import json
import re
from typing import List, Optional, Dict, Any
from contextlib import asynccontextmanager
import cv2
import numpy as np
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
import torch
from PIL import Image
import io
import base64
# Dots.OCR imports
try:
from dots_ocr import DotsOCR
DOTS_OCR_AVAILABLE = True
except ImportError:
DOTS_OCR_AVAILABLE = False
logging.warning("Dots.OCR not available - using mock implementation")
# Import local field extraction utilities
from field_extraction import FieldExtractor
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model instance
dots_ocr_model = None
class BoundingBox(BaseModel):
"""Normalized bounding box coordinates."""
x1: float = Field(..., ge=0.0, le=1.0, description="Top-left x coordinate")
y1: float = Field(..., ge=0.0, le=1.0, description="Top-left y coordinate")
x2: float = Field(..., ge=0.0, le=1.0, description="Bottom-right x coordinate")
y2: float = Field(..., ge=0.0, le=1.0, description="Bottom-right y coordinate")
class ExtractedField(BaseModel):
"""Individual extracted field with confidence and source."""
field_name: str = Field(..., description="Standardized field name")
value: Optional[str] = Field(None, description="Extracted field value")
confidence: float = Field(..., ge=0.0, le=1.0, description="Extraction confidence")
source: str = Field(..., description="Extraction source (e.g., 'ocr')")
class ExtractedFields(BaseModel):
"""All extracted fields from identity document."""
document_number: Optional[ExtractedField] = None
document_type: Optional[ExtractedField] = None
issuing_country: Optional[ExtractedField] = None
issuing_authority: Optional[ExtractedField] = None
surname: Optional[ExtractedField] = None
given_names: Optional[ExtractedField] = None
nationality: Optional[ExtractedField] = None
date_of_birth: Optional[ExtractedField] = None
gender: Optional[ExtractedField] = None
place_of_birth: Optional[ExtractedField] = None
date_of_issue: Optional[ExtractedField] = None
date_of_expiry: Optional[ExtractedField] = None
personal_number: Optional[ExtractedField] = None
optional_data_1: Optional[ExtractedField] = None
optional_data_2: Optional[ExtractedField] = None
class MRZData(BaseModel):
"""Machine Readable Zone data."""
document_type: Optional[str] = Field(None, description="MRZ document type (TD1|TD2|TD3)")
issuing_country: Optional[str] = Field(None, description="Issuing country code")
surname: Optional[str] = Field(None, description="Surname from MRZ")
given_names: Optional[str] = Field(None, description="Given names from MRZ")
document_number: Optional[str] = Field(None, description="Document number from MRZ")
nationality: Optional[str] = Field(None, description="Nationality code from MRZ")
date_of_birth: Optional[str] = Field(None, description="Date of birth from MRZ")
gender: Optional[str] = Field(None, description="Gender from MRZ")
date_of_expiry: Optional[str] = Field(None, description="Date of expiry from MRZ")
personal_number: Optional[str] = Field(None, description="Personal number from MRZ")
raw_mrz: Optional[str] = Field(None, description="Raw MRZ text")
confidence: float = Field(0.0, ge=0.0, le=1.0, description="MRZ extraction confidence")
class OCRDetection(BaseModel):
"""Single OCR detection result."""
mrz_data: Optional[MRZData] = Field(None, description="MRZ data if detected")
extracted_fields: ExtractedFields = Field(..., description="Extracted field data")
class OCRResponse(BaseModel):
"""OCR API response."""
request_id: str = Field(..., description="Unique request identifier")
media_type: str = Field(..., description="Media type processed")
processing_time: float = Field(..., description="Processing time in seconds")
detections: List[OCRDetection] = Field(..., description="List of OCR detections")
# FieldExtractor is now imported from the shared module
def crop_image_by_roi(image: np.ndarray, roi: BoundingBox) -> np.ndarray:
"""Crop image using ROI coordinates."""
h, w = image.shape[:2]
x1 = int(roi.x1 * w)
y1 = int(roi.y1 * h)
x2 = int(roi.x2 * w)
y2 = int(roi.y2 * h)
# Ensure coordinates are within image bounds
x1 = max(0, min(x1, w))
y1 = max(0, min(y1, h))
x2 = max(x1, min(x2, w))
y2 = max(y1, min(y2, h))
return image[y1:y2, x1:x2]
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager for model loading."""
global dots_ocr_model
logger.info("Loading Dots.OCR model...")
try:
if DOTS_OCR_AVAILABLE:
# Load Dots.OCR model
dots_ocr_model = DotsOCR()
logger.info("Dots.OCR model loaded successfully")
else:
logger.warning("Dots.OCR not available - using mock implementation")
dots_ocr_model = "mock"
except Exception as e:
logger.error(f"Failed to load Dots.OCR model: {e}")
# Don't raise - allow mock mode for development
dots_ocr_model = "mock"
yield
logger.info("Shutting down Dots.OCR endpoint...")
app = FastAPI(
title="KYB Dots.OCR Text Extraction",
description="Dots.OCR for identity document text extraction with ROI support",
version="1.0.0",
lifespan=lifespan
)
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy", "version": "1.0.0"}
@app.post("/v1/id/ocr", response_model=OCRResponse)
async def extract_text(
file: UploadFile = File(..., description="Image file to process"),
roi: Optional[str] = Form(None, description="ROI coordinates as JSON string")
):
"""Extract text from identity document image."""
if dots_ocr_model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
start_time = time.time()
request_id = str(uuid.uuid4())
try:
# Read and validate image
image_data = await file.read()
image = Image.open(io.BytesIO(image_data))
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Parse ROI if provided
roi_bbox = None
if roi:
try:
roi_data = json.loads(roi)
roi_bbox = BoundingBox(**roi_data)
# Crop image to ROI
image_cv = crop_image_by_roi(image_cv, roi_bbox)
except Exception as e:
logger.warning(f"Invalid ROI provided: {e}")
# Run OCR
if DOTS_OCR_AVAILABLE and dots_ocr_model != "mock":
# Use real Dots.OCR model
ocr_results = dots_ocr_model(image_cv)
ocr_text = " ".join([result.text for result in ocr_results])
else:
# Mock implementation for development
ocr_text = "MOCK OCR TEXT - Document Number: NLD123456789 Surname: MULDER Given Names: THOMAS"
logger.info("Using mock OCR implementation")
# Extract structured fields
extracted_fields = FieldExtractor.extract_fields(ocr_text)
# Extract MRZ data
mrz_data = FieldExtractor.extract_mrz(ocr_text)
# Create detection
detection = OCRDetection(
mrz_data=mrz_data,
extracted_fields=extracted_fields
)
processing_time = time.time() - start_time
return OCRResponse(
request_id=request_id,
media_type="image",
processing_time=processing_time,
detections=[detection]
)
except Exception as e:
logger.error(f"OCR extraction failed: {e}")
raise HTTPException(status_code=500, detail=f"OCR extraction failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)