Spaces:
Paused
Paused
File size: 8,377 Bytes
0a7e5ec e300623 0a7e5ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
"""HF Dots.OCR Text Extraction Endpoint
This FastAPI application provides a Hugging Face Space endpoint for Dots.OCR
text extraction with ROI support and standardized field extraction schema.
"""
import logging
import time
import uuid
import json
import re
from typing import List, Optional, Dict, Any
from contextlib import asynccontextmanager
import cv2
import numpy as np
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
import torch
from PIL import Image
import io
import base64
# Dots.OCR imports
try:
from dots_ocr import DotsOCR
DOTS_OCR_AVAILABLE = True
except ImportError:
DOTS_OCR_AVAILABLE = False
logging.warning("Dots.OCR not available - using mock implementation")
# Import local field extraction utilities
from field_extraction import FieldExtractor
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model instance
dots_ocr_model = None
class BoundingBox(BaseModel):
"""Normalized bounding box coordinates."""
x1: float = Field(..., ge=0.0, le=1.0, description="Top-left x coordinate")
y1: float = Field(..., ge=0.0, le=1.0, description="Top-left y coordinate")
x2: float = Field(..., ge=0.0, le=1.0, description="Bottom-right x coordinate")
y2: float = Field(..., ge=0.0, le=1.0, description="Bottom-right y coordinate")
class ExtractedField(BaseModel):
"""Individual extracted field with confidence and source."""
field_name: str = Field(..., description="Standardized field name")
value: Optional[str] = Field(None, description="Extracted field value")
confidence: float = Field(..., ge=0.0, le=1.0, description="Extraction confidence")
source: str = Field(..., description="Extraction source (e.g., 'ocr')")
class ExtractedFields(BaseModel):
"""All extracted fields from identity document."""
document_number: Optional[ExtractedField] = None
document_type: Optional[ExtractedField] = None
issuing_country: Optional[ExtractedField] = None
issuing_authority: Optional[ExtractedField] = None
surname: Optional[ExtractedField] = None
given_names: Optional[ExtractedField] = None
nationality: Optional[ExtractedField] = None
date_of_birth: Optional[ExtractedField] = None
gender: Optional[ExtractedField] = None
place_of_birth: Optional[ExtractedField] = None
date_of_issue: Optional[ExtractedField] = None
date_of_expiry: Optional[ExtractedField] = None
personal_number: Optional[ExtractedField] = None
optional_data_1: Optional[ExtractedField] = None
optional_data_2: Optional[ExtractedField] = None
class MRZData(BaseModel):
"""Machine Readable Zone data."""
document_type: Optional[str] = Field(None, description="MRZ document type (TD1|TD2|TD3)")
issuing_country: Optional[str] = Field(None, description="Issuing country code")
surname: Optional[str] = Field(None, description="Surname from MRZ")
given_names: Optional[str] = Field(None, description="Given names from MRZ")
document_number: Optional[str] = Field(None, description="Document number from MRZ")
nationality: Optional[str] = Field(None, description="Nationality code from MRZ")
date_of_birth: Optional[str] = Field(None, description="Date of birth from MRZ")
gender: Optional[str] = Field(None, description="Gender from MRZ")
date_of_expiry: Optional[str] = Field(None, description="Date of expiry from MRZ")
personal_number: Optional[str] = Field(None, description="Personal number from MRZ")
raw_mrz: Optional[str] = Field(None, description="Raw MRZ text")
confidence: float = Field(0.0, ge=0.0, le=1.0, description="MRZ extraction confidence")
class OCRDetection(BaseModel):
"""Single OCR detection result."""
mrz_data: Optional[MRZData] = Field(None, description="MRZ data if detected")
extracted_fields: ExtractedFields = Field(..., description="Extracted field data")
class OCRResponse(BaseModel):
"""OCR API response."""
request_id: str = Field(..., description="Unique request identifier")
media_type: str = Field(..., description="Media type processed")
processing_time: float = Field(..., description="Processing time in seconds")
detections: List[OCRDetection] = Field(..., description="List of OCR detections")
# FieldExtractor is now imported from the shared module
def crop_image_by_roi(image: np.ndarray, roi: BoundingBox) -> np.ndarray:
"""Crop image using ROI coordinates."""
h, w = image.shape[:2]
x1 = int(roi.x1 * w)
y1 = int(roi.y1 * h)
x2 = int(roi.x2 * w)
y2 = int(roi.y2 * h)
# Ensure coordinates are within image bounds
x1 = max(0, min(x1, w))
y1 = max(0, min(y1, h))
x2 = max(x1, min(x2, w))
y2 = max(y1, min(y2, h))
return image[y1:y2, x1:x2]
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager for model loading."""
global dots_ocr_model
logger.info("Loading Dots.OCR model...")
try:
if DOTS_OCR_AVAILABLE:
# Load Dots.OCR model
dots_ocr_model = DotsOCR()
logger.info("Dots.OCR model loaded successfully")
else:
logger.warning("Dots.OCR not available - using mock implementation")
dots_ocr_model = "mock"
except Exception as e:
logger.error(f"Failed to load Dots.OCR model: {e}")
# Don't raise - allow mock mode for development
dots_ocr_model = "mock"
yield
logger.info("Shutting down Dots.OCR endpoint...")
app = FastAPI(
title="KYB Dots.OCR Text Extraction",
description="Dots.OCR for identity document text extraction with ROI support",
version="1.0.0",
lifespan=lifespan
)
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy", "version": "1.0.0"}
@app.post("/v1/id/ocr", response_model=OCRResponse)
async def extract_text(
file: UploadFile = File(..., description="Image file to process"),
roi: Optional[str] = Form(None, description="ROI coordinates as JSON string")
):
"""Extract text from identity document image."""
if dots_ocr_model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
start_time = time.time()
request_id = str(uuid.uuid4())
try:
# Read and validate image
image_data = await file.read()
image = Image.open(io.BytesIO(image_data))
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Parse ROI if provided
roi_bbox = None
if roi:
try:
roi_data = json.loads(roi)
roi_bbox = BoundingBox(**roi_data)
# Crop image to ROI
image_cv = crop_image_by_roi(image_cv, roi_bbox)
except Exception as e:
logger.warning(f"Invalid ROI provided: {e}")
# Run OCR
if DOTS_OCR_AVAILABLE and dots_ocr_model != "mock":
# Use real Dots.OCR model
ocr_results = dots_ocr_model(image_cv)
ocr_text = " ".join([result.text for result in ocr_results])
else:
# Mock implementation for development
ocr_text = "MOCK OCR TEXT - Document Number: NLD123456789 Surname: MULDER Given Names: THOMAS"
logger.info("Using mock OCR implementation")
# Extract structured fields
extracted_fields = FieldExtractor.extract_fields(ocr_text)
# Extract MRZ data
mrz_data = FieldExtractor.extract_mrz(ocr_text)
# Create detection
detection = OCRDetection(
mrz_data=mrz_data,
extracted_fields=extracted_fields
)
processing_time = time.time() - start_time
return OCRResponse(
request_id=request_id,
media_type="image",
processing_time=processing_time,
detections=[detection]
)
except Exception as e:
logger.error(f"OCR extraction failed: {e}")
raise HTTPException(status_code=500, detail=f"OCR extraction failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|