File size: 8,377 Bytes
0a7e5ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e300623
 
0a7e5ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""HF Dots.OCR Text Extraction Endpoint

This FastAPI application provides a Hugging Face Space endpoint for Dots.OCR
text extraction with ROI support and standardized field extraction schema.
"""

import logging
import time
import uuid
import json
import re
from typing import List, Optional, Dict, Any
from contextlib import asynccontextmanager

import cv2
import numpy as np
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
import torch
from PIL import Image
import io
import base64

# Dots.OCR imports
try:
    from dots_ocr import DotsOCR
    DOTS_OCR_AVAILABLE = True
except ImportError:
    DOTS_OCR_AVAILABLE = False
    logging.warning("Dots.OCR not available - using mock implementation")

# Import local field extraction utilities
from field_extraction import FieldExtractor

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global model instance
dots_ocr_model = None


class BoundingBox(BaseModel):
    """Normalized bounding box coordinates."""
    x1: float = Field(..., ge=0.0, le=1.0, description="Top-left x coordinate")
    y1: float = Field(..., ge=0.0, le=1.0, description="Top-left y coordinate")
    x2: float = Field(..., ge=0.0, le=1.0, description="Bottom-right x coordinate")
    y2: float = Field(..., ge=0.0, le=1.0, description="Bottom-right y coordinate")


class ExtractedField(BaseModel):
    """Individual extracted field with confidence and source."""
    field_name: str = Field(..., description="Standardized field name")
    value: Optional[str] = Field(None, description="Extracted field value")
    confidence: float = Field(..., ge=0.0, le=1.0, description="Extraction confidence")
    source: str = Field(..., description="Extraction source (e.g., 'ocr')")


class ExtractedFields(BaseModel):
    """All extracted fields from identity document."""
    document_number: Optional[ExtractedField] = None
    document_type: Optional[ExtractedField] = None
    issuing_country: Optional[ExtractedField] = None
    issuing_authority: Optional[ExtractedField] = None
    surname: Optional[ExtractedField] = None
    given_names: Optional[ExtractedField] = None
    nationality: Optional[ExtractedField] = None
    date_of_birth: Optional[ExtractedField] = None
    gender: Optional[ExtractedField] = None
    place_of_birth: Optional[ExtractedField] = None
    date_of_issue: Optional[ExtractedField] = None
    date_of_expiry: Optional[ExtractedField] = None
    personal_number: Optional[ExtractedField] = None
    optional_data_1: Optional[ExtractedField] = None
    optional_data_2: Optional[ExtractedField] = None


class MRZData(BaseModel):
    """Machine Readable Zone data."""
    document_type: Optional[str] = Field(None, description="MRZ document type (TD1|TD2|TD3)")
    issuing_country: Optional[str] = Field(None, description="Issuing country code")
    surname: Optional[str] = Field(None, description="Surname from MRZ")
    given_names: Optional[str] = Field(None, description="Given names from MRZ")
    document_number: Optional[str] = Field(None, description="Document number from MRZ")
    nationality: Optional[str] = Field(None, description="Nationality code from MRZ")
    date_of_birth: Optional[str] = Field(None, description="Date of birth from MRZ")
    gender: Optional[str] = Field(None, description="Gender from MRZ")
    date_of_expiry: Optional[str] = Field(None, description="Date of expiry from MRZ")
    personal_number: Optional[str] = Field(None, description="Personal number from MRZ")
    raw_mrz: Optional[str] = Field(None, description="Raw MRZ text")
    confidence: float = Field(0.0, ge=0.0, le=1.0, description="MRZ extraction confidence")


class OCRDetection(BaseModel):
    """Single OCR detection result."""
    mrz_data: Optional[MRZData] = Field(None, description="MRZ data if detected")
    extracted_fields: ExtractedFields = Field(..., description="Extracted field data")


class OCRResponse(BaseModel):
    """OCR API response."""
    request_id: str = Field(..., description="Unique request identifier")
    media_type: str = Field(..., description="Media type processed")
    processing_time: float = Field(..., description="Processing time in seconds")
    detections: List[OCRDetection] = Field(..., description="List of OCR detections")


# FieldExtractor is now imported from the shared module


def crop_image_by_roi(image: np.ndarray, roi: BoundingBox) -> np.ndarray:
    """Crop image using ROI coordinates."""
    h, w = image.shape[:2]
    x1 = int(roi.x1 * w)
    y1 = int(roi.y1 * h)
    x2 = int(roi.x2 * w)
    y2 = int(roi.y2 * h)
    
    # Ensure coordinates are within image bounds
    x1 = max(0, min(x1, w))
    y1 = max(0, min(y1, h))
    x2 = max(x1, min(x2, w))
    y2 = max(y1, min(y2, h))
    
    return image[y1:y2, x1:x2]


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Application lifespan manager for model loading."""
    global dots_ocr_model
    
    logger.info("Loading Dots.OCR model...")
    try:
        if DOTS_OCR_AVAILABLE:
            # Load Dots.OCR model
            dots_ocr_model = DotsOCR()
            logger.info("Dots.OCR model loaded successfully")
        else:
            logger.warning("Dots.OCR not available - using mock implementation")
            dots_ocr_model = "mock"
    except Exception as e:
        logger.error(f"Failed to load Dots.OCR model: {e}")
        # Don't raise - allow mock mode for development
        dots_ocr_model = "mock"
    
    yield
    
    logger.info("Shutting down Dots.OCR endpoint...")


app = FastAPI(
    title="KYB Dots.OCR Text Extraction",
    description="Dots.OCR for identity document text extraction with ROI support",
    version="1.0.0",
    lifespan=lifespan
)


@app.get("/health")
async def health_check():
    """Health check endpoint."""
    return {"status": "healthy", "version": "1.0.0"}


@app.post("/v1/id/ocr", response_model=OCRResponse)
async def extract_text(
    file: UploadFile = File(..., description="Image file to process"),
    roi: Optional[str] = Form(None, description="ROI coordinates as JSON string")
):
    """Extract text from identity document image."""
    if dots_ocr_model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    start_time = time.time()
    request_id = str(uuid.uuid4())
    
    try:
        # Read and validate image
        image_data = await file.read()
        image = Image.open(io.BytesIO(image_data))
        image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        
        # Parse ROI if provided
        roi_bbox = None
        if roi:
            try:
                roi_data = json.loads(roi)
                roi_bbox = BoundingBox(**roi_data)
                # Crop image to ROI
                image_cv = crop_image_by_roi(image_cv, roi_bbox)
            except Exception as e:
                logger.warning(f"Invalid ROI provided: {e}")
        
        # Run OCR
        if DOTS_OCR_AVAILABLE and dots_ocr_model != "mock":
            # Use real Dots.OCR model
            ocr_results = dots_ocr_model(image_cv)
            ocr_text = " ".join([result.text for result in ocr_results])
        else:
            # Mock implementation for development
            ocr_text = "MOCK OCR TEXT - Document Number: NLD123456789 Surname: MULDER Given Names: THOMAS"
            logger.info("Using mock OCR implementation")
        
        # Extract structured fields
        extracted_fields = FieldExtractor.extract_fields(ocr_text)
        
        # Extract MRZ data
        mrz_data = FieldExtractor.extract_mrz(ocr_text)
        
        # Create detection
        detection = OCRDetection(
            mrz_data=mrz_data,
            extracted_fields=extracted_fields
        )
        
        processing_time = time.time() - start_time
        
        return OCRResponse(
            request_id=request_id,
            media_type="image",
            processing_time=processing_time,
            detections=[detection]
        )
        
    except Exception as e:
        logger.error(f"OCR extraction failed: {e}")
        raise HTTPException(status_code=500, detail=f"OCR extraction failed: {str(e)}")


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)