Spaces:
Paused
Paused
| """HF Dots.OCR Text Extraction Endpoint | |
| This FastAPI application provides a Hugging Face Space endpoint for Dots.OCR | |
| text extraction with ROI support and standardized field extraction schema. | |
| """ | |
| import logging | |
| import time | |
| import uuid | |
| import json | |
| import re | |
| from typing import List, Optional, Dict, Any | |
| from contextlib import asynccontextmanager | |
| import cv2 | |
| import numpy as np | |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| import torch | |
| from PIL import Image | |
| import io | |
| import base64 | |
| # Dots.OCR imports | |
| try: | |
| from dots_ocr import DotsOCR | |
| DOTS_OCR_AVAILABLE = True | |
| except ImportError: | |
| DOTS_OCR_AVAILABLE = False | |
| logging.warning("Dots.OCR not available - using mock implementation") | |
| # Import local field extraction utilities | |
| from field_extraction import FieldExtractor | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global model instance | |
| dots_ocr_model = None | |
| class BoundingBox(BaseModel): | |
| """Normalized bounding box coordinates.""" | |
| x1: float = Field(..., ge=0.0, le=1.0, description="Top-left x coordinate") | |
| y1: float = Field(..., ge=0.0, le=1.0, description="Top-left y coordinate") | |
| x2: float = Field(..., ge=0.0, le=1.0, description="Bottom-right x coordinate") | |
| y2: float = Field(..., ge=0.0, le=1.0, description="Bottom-right y coordinate") | |
| class ExtractedField(BaseModel): | |
| """Individual extracted field with confidence and source.""" | |
| field_name: str = Field(..., description="Standardized field name") | |
| value: Optional[str] = Field(None, description="Extracted field value") | |
| confidence: float = Field(..., ge=0.0, le=1.0, description="Extraction confidence") | |
| source: str = Field(..., description="Extraction source (e.g., 'ocr')") | |
| class ExtractedFields(BaseModel): | |
| """All extracted fields from identity document.""" | |
| document_number: Optional[ExtractedField] = None | |
| document_type: Optional[ExtractedField] = None | |
| issuing_country: Optional[ExtractedField] = None | |
| issuing_authority: Optional[ExtractedField] = None | |
| surname: Optional[ExtractedField] = None | |
| given_names: Optional[ExtractedField] = None | |
| nationality: Optional[ExtractedField] = None | |
| date_of_birth: Optional[ExtractedField] = None | |
| gender: Optional[ExtractedField] = None | |
| place_of_birth: Optional[ExtractedField] = None | |
| date_of_issue: Optional[ExtractedField] = None | |
| date_of_expiry: Optional[ExtractedField] = None | |
| personal_number: Optional[ExtractedField] = None | |
| optional_data_1: Optional[ExtractedField] = None | |
| optional_data_2: Optional[ExtractedField] = None | |
| class MRZData(BaseModel): | |
| """Machine Readable Zone data.""" | |
| document_type: Optional[str] = Field(None, description="MRZ document type (TD1|TD2|TD3)") | |
| issuing_country: Optional[str] = Field(None, description="Issuing country code") | |
| surname: Optional[str] = Field(None, description="Surname from MRZ") | |
| given_names: Optional[str] = Field(None, description="Given names from MRZ") | |
| document_number: Optional[str] = Field(None, description="Document number from MRZ") | |
| nationality: Optional[str] = Field(None, description="Nationality code from MRZ") | |
| date_of_birth: Optional[str] = Field(None, description="Date of birth from MRZ") | |
| gender: Optional[str] = Field(None, description="Gender from MRZ") | |
| date_of_expiry: Optional[str] = Field(None, description="Date of expiry from MRZ") | |
| personal_number: Optional[str] = Field(None, description="Personal number from MRZ") | |
| raw_mrz: Optional[str] = Field(None, description="Raw MRZ text") | |
| confidence: float = Field(0.0, ge=0.0, le=1.0, description="MRZ extraction confidence") | |
| class OCRDetection(BaseModel): | |
| """Single OCR detection result.""" | |
| mrz_data: Optional[MRZData] = Field(None, description="MRZ data if detected") | |
| extracted_fields: ExtractedFields = Field(..., description="Extracted field data") | |
| class OCRResponse(BaseModel): | |
| """OCR API response.""" | |
| request_id: str = Field(..., description="Unique request identifier") | |
| media_type: str = Field(..., description="Media type processed") | |
| processing_time: float = Field(..., description="Processing time in seconds") | |
| detections: List[OCRDetection] = Field(..., description="List of OCR detections") | |
| # FieldExtractor is now imported from the shared module | |
| def crop_image_by_roi(image: np.ndarray, roi: BoundingBox) -> np.ndarray: | |
| """Crop image using ROI coordinates.""" | |
| h, w = image.shape[:2] | |
| x1 = int(roi.x1 * w) | |
| y1 = int(roi.y1 * h) | |
| x2 = int(roi.x2 * w) | |
| y2 = int(roi.y2 * h) | |
| # Ensure coordinates are within image bounds | |
| x1 = max(0, min(x1, w)) | |
| y1 = max(0, min(y1, h)) | |
| x2 = max(x1, min(x2, w)) | |
| y2 = max(y1, min(y2, h)) | |
| return image[y1:y2, x1:x2] | |
| async def lifespan(app: FastAPI): | |
| """Application lifespan manager for model loading.""" | |
| global dots_ocr_model | |
| logger.info("Loading Dots.OCR model...") | |
| try: | |
| if DOTS_OCR_AVAILABLE: | |
| # Load Dots.OCR model | |
| dots_ocr_model = DotsOCR() | |
| logger.info("Dots.OCR model loaded successfully") | |
| else: | |
| logger.warning("Dots.OCR not available - using mock implementation") | |
| dots_ocr_model = "mock" | |
| except Exception as e: | |
| logger.error(f"Failed to load Dots.OCR model: {e}") | |
| # Don't raise - allow mock mode for development | |
| dots_ocr_model = "mock" | |
| yield | |
| logger.info("Shutting down Dots.OCR endpoint...") | |
| app = FastAPI( | |
| title="KYB Dots.OCR Text Extraction", | |
| description="Dots.OCR for identity document text extraction with ROI support", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return {"status": "healthy", "version": "1.0.0"} | |
| async def extract_text( | |
| file: UploadFile = File(..., description="Image file to process"), | |
| roi: Optional[str] = Form(None, description="ROI coordinates as JSON string") | |
| ): | |
| """Extract text from identity document image.""" | |
| if dots_ocr_model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| start_time = time.time() | |
| request_id = str(uuid.uuid4()) | |
| try: | |
| # Read and validate image | |
| image_data = await file.read() | |
| image = Image.open(io.BytesIO(image_data)) | |
| image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| # Parse ROI if provided | |
| roi_bbox = None | |
| if roi: | |
| try: | |
| roi_data = json.loads(roi) | |
| roi_bbox = BoundingBox(**roi_data) | |
| # Crop image to ROI | |
| image_cv = crop_image_by_roi(image_cv, roi_bbox) | |
| except Exception as e: | |
| logger.warning(f"Invalid ROI provided: {e}") | |
| # Run OCR | |
| if DOTS_OCR_AVAILABLE and dots_ocr_model != "mock": | |
| # Use real Dots.OCR model | |
| ocr_results = dots_ocr_model(image_cv) | |
| ocr_text = " ".join([result.text for result in ocr_results]) | |
| else: | |
| # Mock implementation for development | |
| ocr_text = "MOCK OCR TEXT - Document Number: NLD123456789 Surname: MULDER Given Names: THOMAS" | |
| logger.info("Using mock OCR implementation") | |
| # Extract structured fields | |
| extracted_fields = FieldExtractor.extract_fields(ocr_text) | |
| # Extract MRZ data | |
| mrz_data = FieldExtractor.extract_mrz(ocr_text) | |
| # Create detection | |
| detection = OCRDetection( | |
| mrz_data=mrz_data, | |
| extracted_fields=extracted_fields | |
| ) | |
| processing_time = time.time() - start_time | |
| return OCRResponse( | |
| request_id=request_id, | |
| media_type="image", | |
| processing_time=processing_time, | |
| detections=[detection] | |
| ) | |
| except Exception as e: | |
| logger.error(f"OCR extraction failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"OCR extraction failed: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |