Spaces:
Paused
Paused
feat(api): fast FastAPI app + model loader refactor; add mock mode for tests\n\n- Add pyproject + setuptools config and console entrypoint\n- Implement enhanced field extraction + MRZ heuristics\n- Add response builder with compatibility for legacy MRZ fields\n- New preprocessing pipeline for PDFs/images\n- HF Spaces GPU: cache ENV, optional flash-attn, configurable base image\n- Add Make targets for Spaces GPU and local CPU\n- Add httpx for TestClient; tests pass in mock mode\n- Remove embedded model files and legacy app/modules
211e423
| #!/usr/bin/env python3 | |
| """API Endpoint Test Script for Dots.OCR | |
| This script tests the deployed Dots.OCR API endpoint using real ID card images. | |
| It can be used to validate the complete pipeline in a production environment. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import requests | |
| import logging | |
| from pathlib import Path | |
| from typing import Dict, Any, Optional, List | |
| import argparse | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class DotsOCRAPITester: | |
| """Test client for the Dots.OCR API endpoint.""" | |
| def __init__(self, base_url: str, timeout: int = 30): | |
| """Initialize the API tester. | |
| Args: | |
| base_url: Base URL of the deployed API (e.g., "http://localhost:7860") | |
| timeout: Request timeout in seconds | |
| """ | |
| self.base_url = base_url.rstrip('/') | |
| self.timeout = timeout | |
| self.session = requests.Session() | |
| # Set common headers | |
| self.session.headers.update({ | |
| 'User-Agent': 'DotsOCR-APITester/1.0' | |
| }) | |
| def health_check(self) -> Dict[str, Any]: | |
| """Check API health status. | |
| Returns: | |
| Health check response | |
| """ | |
| try: | |
| response = self.session.get( | |
| f"{self.base_url}/health", | |
| timeout=self.timeout | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| except Exception as e: | |
| logger.error(f"Health check failed: {e}") | |
| return {"error": str(e)} | |
| def test_ocr_endpoint( | |
| self, | |
| image_path: str, | |
| roi: Optional[Dict[str, float]] = None, | |
| expected_fields: Optional[List[str]] = None | |
| ) -> Dict[str, Any]: | |
| """Test the OCR endpoint with an image file. | |
| Args: | |
| image_path: Path to the image file | |
| roi: Optional ROI coordinates as {x1, y1, x2, y2} | |
| expected_fields: List of expected field names to validate | |
| Returns: | |
| Test results dictionary | |
| """ | |
| logger.info(f"Testing OCR endpoint with {image_path}") | |
| # Prepare files and data | |
| files = {'file': open(image_path, 'rb')} | |
| data = {} | |
| if roi: | |
| data['roi'] = json.dumps(roi) | |
| logger.info(f"Using ROI: {roi}") | |
| try: | |
| # Make request | |
| start_time = time.time() | |
| response = self.session.post( | |
| f"{self.base_url}/v1/id/ocr", | |
| files=files, | |
| data=data, | |
| timeout=self.timeout | |
| ) | |
| request_time = time.time() - start_time | |
| # Close file | |
| files['file'].close() | |
| # Check response | |
| response.raise_for_status() | |
| result = response.json() | |
| # Validate response structure | |
| validation_result = self._validate_response(result) | |
| # Check expected fields | |
| field_validation = self._validate_expected_fields(result, expected_fields) | |
| return { | |
| "success": True, | |
| "request_time": request_time, | |
| "response": result, | |
| "validation": validation_result, | |
| "field_validation": field_validation, | |
| "status_code": response.status_code | |
| } | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Request failed: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "status_code": getattr(e.response, 'status_code', None) | |
| } | |
| except Exception as e: | |
| logger.error(f"Unexpected error: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| finally: | |
| # Ensure file is closed | |
| if 'file' in locals(): | |
| files['file'].close() | |
| def _validate_response(self, response: Dict[str, Any]) -> Dict[str, Any]: | |
| """Validate the API response structure. | |
| Args: | |
| response: API response dictionary | |
| Returns: | |
| Validation results | |
| """ | |
| validation = { | |
| "valid": True, | |
| "errors": [], | |
| "warnings": [] | |
| } | |
| # Required fields | |
| required_fields = ['request_id', 'media_type', 'processing_time', 'detections'] | |
| for field in required_fields: | |
| if field not in response: | |
| validation["errors"].append(f"Missing required field: {field}") | |
| validation["valid"] = False | |
| # Validate detections | |
| if 'detections' in response: | |
| if not isinstance(response['detections'], list): | |
| validation["errors"].append("detections must be a list") | |
| validation["valid"] = False | |
| else: | |
| for i, detection in enumerate(response['detections']): | |
| if not isinstance(detection, dict): | |
| validation["errors"].append(f"detection {i} must be a dictionary") | |
| validation["valid"] = False | |
| else: | |
| # Check for extracted_fields | |
| if 'extracted_fields' not in detection: | |
| validation["warnings"].append(f"detection {i} missing extracted_fields") | |
| if 'mrz_data' not in detection: | |
| validation["warnings"].append(f"detection {i} missing mrz_data") | |
| # Validate processing time | |
| if 'processing_time' in response: | |
| if not isinstance(response['processing_time'], (int, float)): | |
| validation["errors"].append("processing_time must be a number") | |
| validation["valid"] = False | |
| elif response['processing_time'] < 0: | |
| validation["warnings"].append("processing_time is negative") | |
| return validation | |
| def _validate_expected_fields( | |
| self, | |
| response: Dict[str, Any], | |
| expected_fields: Optional[List[str]] | |
| ) -> Dict[str, Any]: | |
| """Validate that expected fields are present in the response. | |
| Args: | |
| response: API response dictionary | |
| expected_fields: List of expected field names | |
| Returns: | |
| Field validation results | |
| """ | |
| if not expected_fields: | |
| return {"valid": True, "found_fields": [], "missing_fields": []} | |
| found_fields = [] | |
| missing_fields = [] | |
| # Check all detections for fields | |
| for i, detection in enumerate(response.get('detections', [])): | |
| extracted_fields = detection.get('extracted_fields', {}) | |
| for field_name in expected_fields: | |
| if field_name in extracted_fields and extracted_fields[field_name] is not None: | |
| found_fields.append(f"{field_name} (detection {i})") | |
| else: | |
| missing_fields.append(f"{field_name} (detection {i})") | |
| return { | |
| "valid": len(missing_fields) == 0, | |
| "found_fields": found_fields, | |
| "missing_fields": missing_fields | |
| } | |
| def test_multiple_images( | |
| self, | |
| image_paths: List[str], | |
| roi: Optional[Dict[str, float]] = None | |
| ) -> Dict[str, Any]: | |
| """Test multiple images and return aggregated results. | |
| Args: | |
| image_paths: List of image file paths | |
| roi: Optional ROI coordinates | |
| Returns: | |
| Aggregated test results | |
| """ | |
| logger.info(f"Testing {len(image_paths)} images") | |
| results = [] | |
| successful_tests = 0 | |
| total_processing_time = 0 | |
| for image_path in image_paths: | |
| if not os.path.exists(image_path): | |
| logger.warning(f"Image not found: {image_path}") | |
| results.append({ | |
| "image": image_path, | |
| "success": False, | |
| "error": "File not found" | |
| }) | |
| continue | |
| result = self.test_ocr_endpoint(image_path, roi) | |
| results.append({ | |
| "image": image_path, | |
| **result | |
| }) | |
| if result.get("success", False): | |
| successful_tests += 1 | |
| total_processing_time += result.get("request_time", 0) | |
| return { | |
| "total_images": len(image_paths), | |
| "successful_tests": successful_tests, | |
| "failed_tests": len(image_paths) - successful_tests, | |
| "success_rate": successful_tests / len(image_paths) if image_paths else 0, | |
| "average_processing_time": total_processing_time / successful_tests if successful_tests > 0 else 0, | |
| "results": results | |
| } | |
| def main(): | |
| """Main test function.""" | |
| parser = argparse.ArgumentParser(description="Test Dots.OCR API endpoint") | |
| parser.add_argument( | |
| "--url", | |
| default="http://localhost:7860", | |
| help="API base URL (default: http://localhost:7860)" | |
| ) | |
| parser.add_argument( | |
| "--timeout", | |
| type=int, | |
| default=30, | |
| help="Request timeout in seconds (default: 30)" | |
| ) | |
| parser.add_argument( | |
| "--roi", | |
| type=str, | |
| help="ROI coordinates as JSON string (e.g., '{\"x1\": 0.1, \"y1\": 0.1, \"x2\": 0.9, \"y2\": 0.9}')" | |
| ) | |
| parser.add_argument( | |
| "--expected-fields", | |
| nargs="+", | |
| help="Expected field names to validate (e.g., document_number surname given_names)" | |
| ) | |
| parser.add_argument( | |
| "--verbose", | |
| action="store_true", | |
| help="Enable verbose logging" | |
| ) | |
| args = parser.parse_args() | |
| if args.verbose: | |
| logging.getLogger().setLevel(logging.DEBUG) | |
| # Parse ROI if provided | |
| roi = None | |
| if args.roi: | |
| try: | |
| roi = json.loads(args.roi) | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Invalid ROI JSON: {e}") | |
| sys.exit(1) | |
| # Initialize tester | |
| tester = DotsOCRAPITester(args.url, args.timeout) | |
| # Health check | |
| logger.info("π Checking API health...") | |
| health = tester.health_check() | |
| if "error" in health: | |
| logger.error(f"β API health check failed: {health['error']}") | |
| sys.exit(1) | |
| logger.info(f"β API is healthy: {health}") | |
| # Test images | |
| test_images = [ | |
| "tom_id_card_front.jpg", | |
| "tom_id_card_back.jpg" | |
| ] | |
| # Check if test images exist | |
| existing_images = [] | |
| for image in test_images: | |
| image_path = Path(__file__).parent / image | |
| if image_path.exists(): | |
| existing_images.append(str(image_path)) | |
| else: | |
| logger.warning(f"Test image not found: {image_path}") | |
| if not existing_images: | |
| logger.error("β No test images found") | |
| sys.exit(1) | |
| # Expected fields for validation | |
| expected_fields = args.expected_fields or [ | |
| "document_number", | |
| "surname", | |
| "given_names", | |
| "nationality", | |
| "date_of_birth", | |
| "gender" | |
| ] | |
| # Run tests | |
| logger.info(f"π Starting API tests with {len(existing_images)} images...") | |
| if len(existing_images) == 1: | |
| # Single image test | |
| result = tester.test_ocr_endpoint(existing_images[0], roi, expected_fields) | |
| if result["success"]: | |
| logger.info("β Single image test passed") | |
| logger.info(f"β±οΈ Processing time: {result['request_time']:.2f}s") | |
| logger.info(f"π Detections: {len(result['response']['detections'])}") | |
| # Print field validation results | |
| field_validation = result.get("field_validation", {}) | |
| if field_validation.get("found_fields"): | |
| logger.info(f"β Found fields: {', '.join(field_validation['found_fields'])}") | |
| if field_validation.get("missing_fields"): | |
| logger.warning(f"β οΈ Missing fields: {', '.join(field_validation['missing_fields'])}") | |
| else: | |
| logger.error(f"β Single image test failed: {result.get('error', 'Unknown error')}") | |
| sys.exit(1) | |
| else: | |
| # Multiple images test | |
| results = tester.test_multiple_images(existing_images, roi) | |
| logger.info(f"π Test Results:") | |
| logger.info(f" Total images: {results['total_images']}") | |
| logger.info(f" Successful: {results['successful_tests']}") | |
| logger.info(f" Failed: {results['failed_tests']}") | |
| logger.info(f" Success rate: {results['success_rate']:.1%}") | |
| logger.info(f" Average processing time: {results['average_processing_time']:.2f}s") | |
| # Print detailed results | |
| for result in results["results"]: | |
| image_name = Path(result["image"]).name | |
| if result["success"]: | |
| logger.info(f" β {image_name}: {result['request_time']:.2f}s") | |
| else: | |
| logger.error(f" β {image_name}: {result.get('error', 'Unknown error')}") | |
| if results["failed_tests"] > 0: | |
| sys.exit(1) | |
| logger.info("π All tests completed successfully!") | |
| if __name__ == "__main__": | |
| main() | |