dots-ocr-idcard / scripts /test_api_endpoint.py
tommulder's picture
feat(api): fast FastAPI app + model loader refactor; add mock mode for tests\n\n- Add pyproject + setuptools config and console entrypoint\n- Implement enhanced field extraction + MRZ heuristics\n- Add response builder with compatibility for legacy MRZ fields\n- New preprocessing pipeline for PDFs/images\n- HF Spaces GPU: cache ENV, optional flash-attn, configurable base image\n- Add Make targets for Spaces GPU and local CPU\n- Add httpx for TestClient; tests pass in mock mode\n- Remove embedded model files and legacy app/modules
211e423
raw
history blame
13.8 kB
#!/usr/bin/env python3
"""API Endpoint Test Script for Dots.OCR
This script tests the deployed Dots.OCR API endpoint using real ID card images.
It can be used to validate the complete pipeline in a production environment.
"""
import os
import sys
import json
import time
import requests
import logging
from pathlib import Path
from typing import Dict, Any, Optional, List
import argparse
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class DotsOCRAPITester:
"""Test client for the Dots.OCR API endpoint."""
def __init__(self, base_url: str, timeout: int = 30):
"""Initialize the API tester.
Args:
base_url: Base URL of the deployed API (e.g., "http://localhost:7860")
timeout: Request timeout in seconds
"""
self.base_url = base_url.rstrip('/')
self.timeout = timeout
self.session = requests.Session()
# Set common headers
self.session.headers.update({
'User-Agent': 'DotsOCR-APITester/1.0'
})
def health_check(self) -> Dict[str, Any]:
"""Check API health status.
Returns:
Health check response
"""
try:
response = self.session.get(
f"{self.base_url}/health",
timeout=self.timeout
)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Health check failed: {e}")
return {"error": str(e)}
def test_ocr_endpoint(
self,
image_path: str,
roi: Optional[Dict[str, float]] = None,
expected_fields: Optional[List[str]] = None
) -> Dict[str, Any]:
"""Test the OCR endpoint with an image file.
Args:
image_path: Path to the image file
roi: Optional ROI coordinates as {x1, y1, x2, y2}
expected_fields: List of expected field names to validate
Returns:
Test results dictionary
"""
logger.info(f"Testing OCR endpoint with {image_path}")
# Prepare files and data
files = {'file': open(image_path, 'rb')}
data = {}
if roi:
data['roi'] = json.dumps(roi)
logger.info(f"Using ROI: {roi}")
try:
# Make request
start_time = time.time()
response = self.session.post(
f"{self.base_url}/v1/id/ocr",
files=files,
data=data,
timeout=self.timeout
)
request_time = time.time() - start_time
# Close file
files['file'].close()
# Check response
response.raise_for_status()
result = response.json()
# Validate response structure
validation_result = self._validate_response(result)
# Check expected fields
field_validation = self._validate_expected_fields(result, expected_fields)
return {
"success": True,
"request_time": request_time,
"response": result,
"validation": validation_result,
"field_validation": field_validation,
"status_code": response.status_code
}
except requests.exceptions.RequestException as e:
logger.error(f"Request failed: {e}")
return {
"success": False,
"error": str(e),
"status_code": getattr(e.response, 'status_code', None)
}
except Exception as e:
logger.error(f"Unexpected error: {e}")
return {
"success": False,
"error": str(e)
}
finally:
# Ensure file is closed
if 'file' in locals():
files['file'].close()
def _validate_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""Validate the API response structure.
Args:
response: API response dictionary
Returns:
Validation results
"""
validation = {
"valid": True,
"errors": [],
"warnings": []
}
# Required fields
required_fields = ['request_id', 'media_type', 'processing_time', 'detections']
for field in required_fields:
if field not in response:
validation["errors"].append(f"Missing required field: {field}")
validation["valid"] = False
# Validate detections
if 'detections' in response:
if not isinstance(response['detections'], list):
validation["errors"].append("detections must be a list")
validation["valid"] = False
else:
for i, detection in enumerate(response['detections']):
if not isinstance(detection, dict):
validation["errors"].append(f"detection {i} must be a dictionary")
validation["valid"] = False
else:
# Check for extracted_fields
if 'extracted_fields' not in detection:
validation["warnings"].append(f"detection {i} missing extracted_fields")
if 'mrz_data' not in detection:
validation["warnings"].append(f"detection {i} missing mrz_data")
# Validate processing time
if 'processing_time' in response:
if not isinstance(response['processing_time'], (int, float)):
validation["errors"].append("processing_time must be a number")
validation["valid"] = False
elif response['processing_time'] < 0:
validation["warnings"].append("processing_time is negative")
return validation
def _validate_expected_fields(
self,
response: Dict[str, Any],
expected_fields: Optional[List[str]]
) -> Dict[str, Any]:
"""Validate that expected fields are present in the response.
Args:
response: API response dictionary
expected_fields: List of expected field names
Returns:
Field validation results
"""
if not expected_fields:
return {"valid": True, "found_fields": [], "missing_fields": []}
found_fields = []
missing_fields = []
# Check all detections for fields
for i, detection in enumerate(response.get('detections', [])):
extracted_fields = detection.get('extracted_fields', {})
for field_name in expected_fields:
if field_name in extracted_fields and extracted_fields[field_name] is not None:
found_fields.append(f"{field_name} (detection {i})")
else:
missing_fields.append(f"{field_name} (detection {i})")
return {
"valid": len(missing_fields) == 0,
"found_fields": found_fields,
"missing_fields": missing_fields
}
def test_multiple_images(
self,
image_paths: List[str],
roi: Optional[Dict[str, float]] = None
) -> Dict[str, Any]:
"""Test multiple images and return aggregated results.
Args:
image_paths: List of image file paths
roi: Optional ROI coordinates
Returns:
Aggregated test results
"""
logger.info(f"Testing {len(image_paths)} images")
results = []
successful_tests = 0
total_processing_time = 0
for image_path in image_paths:
if not os.path.exists(image_path):
logger.warning(f"Image not found: {image_path}")
results.append({
"image": image_path,
"success": False,
"error": "File not found"
})
continue
result = self.test_ocr_endpoint(image_path, roi)
results.append({
"image": image_path,
**result
})
if result.get("success", False):
successful_tests += 1
total_processing_time += result.get("request_time", 0)
return {
"total_images": len(image_paths),
"successful_tests": successful_tests,
"failed_tests": len(image_paths) - successful_tests,
"success_rate": successful_tests / len(image_paths) if image_paths else 0,
"average_processing_time": total_processing_time / successful_tests if successful_tests > 0 else 0,
"results": results
}
def main():
"""Main test function."""
parser = argparse.ArgumentParser(description="Test Dots.OCR API endpoint")
parser.add_argument(
"--url",
default="http://localhost:7860",
help="API base URL (default: http://localhost:7860)"
)
parser.add_argument(
"--timeout",
type=int,
default=30,
help="Request timeout in seconds (default: 30)"
)
parser.add_argument(
"--roi",
type=str,
help="ROI coordinates as JSON string (e.g., '{\"x1\": 0.1, \"y1\": 0.1, \"x2\": 0.9, \"y2\": 0.9}')"
)
parser.add_argument(
"--expected-fields",
nargs="+",
help="Expected field names to validate (e.g., document_number surname given_names)"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose logging"
)
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
# Parse ROI if provided
roi = None
if args.roi:
try:
roi = json.loads(args.roi)
except json.JSONDecodeError as e:
logger.error(f"Invalid ROI JSON: {e}")
sys.exit(1)
# Initialize tester
tester = DotsOCRAPITester(args.url, args.timeout)
# Health check
logger.info("πŸ” Checking API health...")
health = tester.health_check()
if "error" in health:
logger.error(f"❌ API health check failed: {health['error']}")
sys.exit(1)
logger.info(f"βœ… API is healthy: {health}")
# Test images
test_images = [
"tom_id_card_front.jpg",
"tom_id_card_back.jpg"
]
# Check if test images exist
existing_images = []
for image in test_images:
image_path = Path(__file__).parent / image
if image_path.exists():
existing_images.append(str(image_path))
else:
logger.warning(f"Test image not found: {image_path}")
if not existing_images:
logger.error("❌ No test images found")
sys.exit(1)
# Expected fields for validation
expected_fields = args.expected_fields or [
"document_number",
"surname",
"given_names",
"nationality",
"date_of_birth",
"gender"
]
# Run tests
logger.info(f"πŸš€ Starting API tests with {len(existing_images)} images...")
if len(existing_images) == 1:
# Single image test
result = tester.test_ocr_endpoint(existing_images[0], roi, expected_fields)
if result["success"]:
logger.info("βœ… Single image test passed")
logger.info(f"⏱️ Processing time: {result['request_time']:.2f}s")
logger.info(f"πŸ“„ Detections: {len(result['response']['detections'])}")
# Print field validation results
field_validation = result.get("field_validation", {})
if field_validation.get("found_fields"):
logger.info(f"βœ… Found fields: {', '.join(field_validation['found_fields'])}")
if field_validation.get("missing_fields"):
logger.warning(f"⚠️ Missing fields: {', '.join(field_validation['missing_fields'])}")
else:
logger.error(f"❌ Single image test failed: {result.get('error', 'Unknown error')}")
sys.exit(1)
else:
# Multiple images test
results = tester.test_multiple_images(existing_images, roi)
logger.info(f"πŸ“Š Test Results:")
logger.info(f" Total images: {results['total_images']}")
logger.info(f" Successful: {results['successful_tests']}")
logger.info(f" Failed: {results['failed_tests']}")
logger.info(f" Success rate: {results['success_rate']:.1%}")
logger.info(f" Average processing time: {results['average_processing_time']:.2f}s")
# Print detailed results
for result in results["results"]:
image_name = Path(result["image"]).name
if result["success"]:
logger.info(f" βœ… {image_name}: {result['request_time']:.2f}s")
else:
logger.error(f" ❌ {image_name}: {result.get('error', 'Unknown error')}")
if results["failed_tests"] > 0:
sys.exit(1)
logger.info("πŸŽ‰ All tests completed successfully!")
if __name__ == "__main__":
main()