Spaces:
Sleeping
Sleeping
| """HF YOLO-E Detection Endpoint | |
| This FastAPI application provides a Hugging Face Space endpoint for YOLO-E | |
| document detection with European document classification, ML-based orientation | |
| detection, and video processing capabilities. | |
| """ | |
| import logging | |
| import time | |
| import uuid | |
| import json | |
| import os | |
| from typing import List, Optional, Dict, Any, Tuple | |
| from contextlib import asynccontextmanager | |
| import cv2 | |
| import numpy as np | |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| from enum import Enum | |
| import torch | |
| from ultralytics import YOLOE | |
| from PIL import Image | |
| import io | |
| import base64 | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global model instances | |
| yolo_model = None | |
| orientation_classifier = None | |
| class_mapping = {} | |
| # Selected inference device string (e.g., 'cuda:0', 'mps', or 'cpu') | |
| yolo_device: str = "cpu" | |
| # Load class mapping from config | |
| def load_class_mapping(): | |
| """Load class mapping from labels.json configuration.""" | |
| global class_mapping | |
| try: | |
| # Try to load from config directory | |
| config_path = os.path.join(os.path.dirname(__file__), "config", "labels.json") | |
| if os.path.exists(config_path): | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| class_mapping = config.get("classes", {}) | |
| else: | |
| # Fallback to default mapping | |
| class_mapping = { | |
| "0": "id_front", | |
| "1": "id_back", | |
| "2": "driver_license", | |
| "3": "passport", | |
| "4": "mrz" | |
| } | |
| logger.info(f"Loaded class mapping: {class_mapping}") | |
| except Exception as e: | |
| logger.warning(f"Failed to load class mapping: {e}") | |
| class_mapping = { | |
| "0": "id_front", | |
| "1": "id_back", | |
| "2": "driver_license", | |
| "3": "passport", | |
| "4": "mrz" | |
| } | |
| # Document type mapping for European documents | |
| DOCUMENT_TYPE_MAPPING = { | |
| "id_front": "identity_card", | |
| "id_back": "identity_card", | |
| "driver_license": "driver_license", | |
| "passport": "passport", | |
| "mrz": "identity_card" # MRZ typically indicates ID card back | |
| } | |
| class DocumentType(str, Enum): | |
| """Detected document types for European documents.""" | |
| IDENTITY_CARD = "identity_card" | |
| PASSPORT = "passport" | |
| DRIVER_LICENSE = "driver_license" | |
| RESIDENCE_PERMIT = "residence_permit" | |
| UNKNOWN = "unknown" | |
| class Orientation(str, Enum): | |
| """Document orientation classification.""" | |
| FRONT = "front" | |
| BACK = "back" | |
| UNKNOWN = "unknown" | |
| class BoundingBox(BaseModel): | |
| """Normalized bounding box coordinates.""" | |
| x1: float = Field(..., ge=0.0, le=1.0, description="Top-left x coordinate") | |
| y1: float = Field(..., ge=0.0, le=1.0, description="Top-left y coordinate") | |
| x2: float = Field(..., ge=0.0, le=1.0, description="Bottom-right x coordinate") | |
| y2: float = Field(..., ge=0.0, le=1.0, description="Bottom-right y coordinate") | |
| class QualityMetrics(BaseModel): | |
| """Quality assessment metrics.""" | |
| sharpness: float = Field(..., ge=0.0, le=1.0, description="Image sharpness score") | |
| glare_score: float = Field(..., ge=0.0, le=1.0, description="Glare detection score") | |
| coverage: float = Field(..., ge=0.0, le=1.0, description="Document coverage percentage") | |
| brightness: Optional[float] = Field(None, ge=0.0, le=1.0, description="Overall brightness") | |
| contrast: Optional[float] = Field(None, ge=0.0, le=1.0, description="Image contrast") | |
| class TrackingInfo(BaseModel): | |
| """Tracking information for video processing.""" | |
| track_id: Optional[str] = Field(None, description="Unique track identifier") | |
| tracking_confidence: Optional[float] = Field(None, description="Tracking confidence") | |
| track_age: Optional[int] = Field(None, description="Track age in frames") | |
| is_tracked: bool = Field(False, description="Whether object is being tracked") | |
| tracker_type: Optional[str] = Field(None, description="Tracker type used") | |
| class DetectionMetadata(BaseModel): | |
| """Additional detection metadata.""" | |
| class_name: str = Field(..., description="Detected class name") | |
| original_coordinates: List[float] = Field(..., description="Original pixel coordinates") | |
| mask_used: bool = Field(False, description="Whether segmentation mask was used") | |
| class DocumentDetection(BaseModel): | |
| """Single document detection result.""" | |
| document_type: DocumentType = Field(..., description="Type of detected document") | |
| orientation: Orientation = Field(..., description="Document orientation (front/back)") | |
| confidence: float = Field(..., ge=0.0, le=1.0, description="Detection confidence") | |
| bounding_box: BoundingBox = Field(..., description="Normalized bounding box") | |
| quality: QualityMetrics = Field(..., description="Quality assessment metrics") | |
| tracking: TrackingInfo = Field(..., description="Tracking information") | |
| crop_data: Optional[str] = Field(None, description="Base64 encoded crop data") | |
| metadata: DetectionMetadata = Field(..., description="Additional metadata") | |
| class DetectionResponse(BaseModel): | |
| """Detection API response.""" | |
| request_id: str = Field(..., description="Unique request identifier") | |
| media_type: str = Field(..., description="Media type processed") | |
| processing_time: float = Field(..., description="Processing time in seconds") | |
| detections: List[DocumentDetection] = Field(..., description="List of detections") | |
| frame_count: Optional[int] = Field(None, description="Number of frames processed (video only)") | |
| class QualityAssessor: | |
| """Enhanced quality assessment for document images.""" | |
| def calculate_sharpness(image: np.ndarray) -> float: | |
| """Calculate image sharpness using Laplacian variance.""" | |
| try: | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var() | |
| # Normalize to 0-1 range (empirically determined) | |
| return min(laplacian_var / 1000.0, 1.0) | |
| except Exception: | |
| return 0.5 | |
| def calculate_glare_score(image: np.ndarray) -> float: | |
| """Calculate glare score using brightness thresholding.""" | |
| try: | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| # Apply Gaussian blur to reduce noise | |
| blurred = cv2.GaussianBlur(gray, (5, 5), 0) | |
| # Find bright pixels (above 90th percentile) | |
| threshold_value = np.percentile(blurred, 90) | |
| bright_pixels = blurred > threshold_value | |
| # Calculate percentage of bright pixels | |
| bright_ratio = np.sum(bright_pixels) / bright_pixels.size | |
| return min(bright_ratio, 1.0) | |
| except Exception: | |
| return 0.5 | |
| def calculate_coverage(image: np.ndarray, bbox: BoundingBox) -> float: | |
| """Calculate document coverage within bounding box.""" | |
| try: | |
| h, w = image.shape[:2] | |
| x1 = int(bbox.x1 * w) | |
| y1 = int(bbox.y1 * h) | |
| x2 = int(bbox.x2 * w) | |
| y2 = int(bbox.y2 * h) | |
| # Calculate area ratio | |
| bbox_area = (x2 - x1) * (y2 - y1) | |
| total_area = w * h | |
| return min(bbox_area / total_area, 1.0) | |
| except Exception: | |
| return 0.5 | |
| def calculate_brightness(image: np.ndarray) -> float: | |
| """Calculate overall image brightness.""" | |
| try: | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| mean_brightness = np.mean(gray) / 255.0 | |
| return float(mean_brightness) | |
| except Exception: | |
| return 0.5 | |
| def calculate_contrast(image: np.ndarray) -> float: | |
| """Calculate image contrast using standard deviation.""" | |
| try: | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| std_dev = np.std(gray) | |
| # Normalize to 0-1 scale (typical std dev range: 0-128) | |
| contrast = min(std_dev / 64.0, 1.0) | |
| return float(contrast) | |
| except Exception: | |
| return 0.5 | |
| def assess_quality(image: np.ndarray, bbox: BoundingBox) -> QualityMetrics: | |
| """Assess all quality metrics for a document image.""" | |
| return QualityMetrics( | |
| sharpness=QualityAssessor.calculate_sharpness(image), | |
| glare_score=QualityAssessor.calculate_glare_score(image), | |
| coverage=QualityAssessor.calculate_coverage(image, bbox), | |
| brightness=QualityAssessor.calculate_brightness(image), | |
| contrast=QualityAssessor.calculate_contrast(image) | |
| ) | |
| class OrientationClassifier: | |
| """ML-based orientation classification for European documents.""" | |
| def __init__(self, yolo_model: Optional[YOLOE] = None): | |
| """Initialize the orientation classifier.""" | |
| self.yolo_model = yolo_model | |
| def classify_orientation(self, image: np.ndarray, class_name: str) -> Orientation: | |
| """Classify document orientation using multiple methods. | |
| Args: | |
| image: Document image as numpy array | |
| class_name: Detected class name from YOLO-E | |
| Returns: | |
| Document orientation classification | |
| """ | |
| try: | |
| # Method 1: Class-based classification (most reliable) | |
| class_orientation = self._classify_by_class(class_name) | |
| if class_orientation != Orientation.UNKNOWN: | |
| return class_orientation | |
| # Method 2: Portrait-based classification | |
| if self.yolo_model is not None: | |
| portrait_orientation = self._classify_by_portrait(image) | |
| if portrait_orientation != Orientation.UNKNOWN: | |
| return portrait_orientation | |
| # Method 3: Heuristic-based classification | |
| heuristic_orientation = self._classify_by_heuristics(image) | |
| return heuristic_orientation | |
| except Exception as e: | |
| logger.warning(f"Orientation classification failed: {e}") | |
| return Orientation.UNKNOWN | |
| def _classify_by_class(self, class_name: str) -> Orientation: | |
| """Classify orientation based on detected class.""" | |
| if class_name in ["id_front", "passport"]: | |
| return Orientation.FRONT | |
| elif class_name in ["id_back", "mrz"]: | |
| return Orientation.BACK | |
| elif class_name == "driver_license": | |
| # Driver licenses can be front or back, need additional analysis | |
| return Orientation.UNKNOWN | |
| else: | |
| return Orientation.UNKNOWN | |
| def _classify_by_portrait(self, image: np.ndarray) -> Orientation: | |
| """Classify orientation based on portrait/face detection.""" | |
| if self.yolo_model is None: | |
| return Orientation.UNKNOWN | |
| try: | |
| # Detect faces/portraits using YOLO-E | |
| results = self.yolo_model(image, verbose=False) | |
| if not results or len(results) == 0: | |
| return Orientation.UNKNOWN | |
| # Process detection results for faces | |
| face_detections = [] | |
| for result in results: | |
| if hasattr(result, 'boxes') and result.boxes is not None: | |
| boxes = result.boxes | |
| for conf, xyxy in zip(boxes.conf, boxes.xyxy): | |
| if conf >= 0.5: # Confidence threshold for face detection | |
| face_detections.append(float(conf)) | |
| if face_detections: | |
| # Strong face detection suggests front of document | |
| max_confidence = max(face_detections) | |
| if max_confidence > 0.7: | |
| return Orientation.FRONT | |
| elif max_confidence > 0.5: | |
| return Orientation.FRONT | |
| return Orientation.UNKNOWN | |
| except Exception as e: | |
| logger.warning(f"Portrait-based classification failed: {e}") | |
| return Orientation.UNKNOWN | |
| def _classify_by_heuristics(self, image: np.ndarray) -> Orientation: | |
| """Classify orientation using image analysis heuristics.""" | |
| try: | |
| # Convert to grayscale | |
| if len(image.shape) == 3: | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| else: | |
| gray = image | |
| height, width = gray.shape | |
| # Heuristic 1: Text density analysis | |
| text_density = self._analyze_text_density(gray) | |
| # Heuristic 2: Symmetry analysis | |
| symmetry_score = self._analyze_symmetry(gray) | |
| # Heuristic 3: Edge analysis | |
| edge_score = self._analyze_edges(gray) | |
| # Combine heuristics with weights | |
| combined_score = ( | |
| text_density * 0.4 + | |
| symmetry_score * 0.3 + | |
| edge_score * 0.3 | |
| ) | |
| # Threshold-based classification | |
| if combined_score > 0.6: | |
| return Orientation.BACK | |
| elif combined_score < 0.4: | |
| return Orientation.FRONT | |
| else: | |
| return Orientation.UNKNOWN | |
| except Exception as e: | |
| logger.warning(f"Heuristic classification failed: {e}") | |
| return Orientation.UNKNOWN | |
| def _analyze_text_density(self, gray_image: np.ndarray) -> float: | |
| """Analyze text density in the image.""" | |
| try: | |
| # Apply adaptive thresholding to find text regions | |
| thresh = cv2.adaptiveThreshold( | |
| gray_image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2 | |
| ) | |
| # Remove small noise | |
| kernel = np.ones((3, 3), np.uint8) | |
| cleaned = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel) | |
| # Calculate text density | |
| text_pixels = np.sum(cleaned > 0) | |
| total_pixels = cleaned.size | |
| density = text_pixels / total_pixels | |
| # Normalize to 0-1 range | |
| normalized_density = min(density * 5.0, 1.0) | |
| return float(normalized_density) | |
| except Exception: | |
| return 0.5 | |
| def _analyze_symmetry(self, gray_image: np.ndarray) -> float: | |
| """Analyze image symmetry.""" | |
| try: | |
| height, width = gray_image.shape | |
| # Split image into left and right halves | |
| mid = width // 2 | |
| left_half = gray_image[:, :mid] | |
| right_half = cv2.flip(gray_image[:, -mid:], 1) | |
| # Ensure same size for comparison | |
| min_width = min(left_half.shape[1], right_half.shape[1]) | |
| left_half = left_half[:, :min_width] | |
| right_half = right_half[:, :min_width] | |
| # Calculate correlation coefficient | |
| correlation = np.corrcoef(left_half.flatten(), right_half.flatten())[0, 1] | |
| # Convert to symmetry score | |
| symmetry = (correlation + 1.0) / 2.0 | |
| return float(symmetry) | |
| except Exception: | |
| return 0.5 | |
| def _analyze_edges(self, gray_image: np.ndarray) -> float: | |
| """Analyze edge patterns for orientation clues.""" | |
| try: | |
| # Detect edges | |
| edges = cv2.Canny(gray_image, 50, 150) | |
| # Divide image into regions | |
| height, width = edges.shape | |
| regions = { | |
| 'top_left': edges[:height//2, :width//2], | |
| 'top_right': edges[:height//2, width//2:], | |
| 'bottom_left': edges[height//2:, :width//2], | |
| 'bottom_right': edges[height//2:, width//2:], | |
| 'center': edges[height//3:2*height//3, width//3:2*width//3] | |
| } | |
| # Calculate edge density in each region | |
| edge_densities = {} | |
| for region_name, region in regions.items(): | |
| edge_densities[region_name] = np.sum(region > 0) / region.size | |
| # Front documents often have more edges in center (portrait) | |
| # Back documents often have more edges in corners (text, MRZ) | |
| center_density = edge_densities['center'] | |
| corner_density = ( | |
| edge_densities['top_left'] + | |
| edge_densities['top_right'] + | |
| edge_densities['bottom_left'] + | |
| edge_densities['bottom_right'] | |
| ) / 4.0 | |
| # Higher corner density suggests back document | |
| if corner_density > center_density: | |
| return min(corner_density / center_density * 0.5, 1.0) | |
| else: | |
| return max(0.0, 1.0 - (center_density / max(corner_density, 0.01)) * 0.5) | |
| except Exception: | |
| return 0.5 | |
| class VideoProcessor: | |
| """Video processing utilities for frame extraction and quality-based selection.""" | |
| def __init__(self, sample_fps: float = 2.0): | |
| """Initialize video processor. | |
| Args: | |
| sample_fps: Frames per second to sample from video | |
| """ | |
| self.sample_fps = sample_fps | |
| def extract_frames(self, video_path: str) -> List[Tuple[np.ndarray, float]]: | |
| """Extract frames from video at specified sampling rate. | |
| Args: | |
| video_path: Path to video file | |
| Returns: | |
| List of (frame, timestamp) tuples | |
| """ | |
| frames = [] | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError(f"Could not open video file: {video_path}") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_interval = max(1, int(fps / self.sample_fps)) | |
| frame_count = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_count % frame_interval == 0: | |
| timestamp = frame_count / fps | |
| frames.append((frame.copy(), timestamp)) | |
| frame_count += 1 | |
| cap.release() | |
| logger.info(f"Extracted {len(frames)} frames from video") | |
| return frames | |
| def extract_frames_from_bytes(self, video_data: bytes) -> List[Tuple[np.ndarray, float]]: | |
| """Extract frames from video bytes. | |
| Args: | |
| video_data: Video file as bytes | |
| Returns: | |
| List of (frame, timestamp) tuples | |
| """ | |
| # Write video data to temporary file | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: | |
| tmp_file.write(video_data) | |
| tmp_path = tmp_file.name | |
| try: | |
| frames = self.extract_frames(tmp_path) | |
| logger.info(f"Extracted {len(frames)} frames from video bytes") | |
| except Exception as e: | |
| logger.error(f"Failed to extract frames from video: {e}") | |
| frames = [] | |
| finally: | |
| # Clean up temporary file | |
| try: | |
| os.unlink(tmp_path) | |
| except OSError: | |
| pass | |
| return frames | |
| class SimpleTracker: | |
| """Simple tracking system for video processing.""" | |
| def __init__(self): | |
| """Initialize the tracker.""" | |
| self.track_counter = 0 | |
| self.active_tracks = {} # track_id -> track_info | |
| self.track_threshold = 0.3 # IoU threshold for track association | |
| def update_tracks(self, detections: List[DocumentDetection], frame_idx: int) -> List[DocumentDetection]: | |
| """Update tracks for current frame detections. | |
| Args: | |
| detections: List of detections in current frame | |
| frame_idx: Current frame index | |
| Returns: | |
| List of detections with updated tracking info | |
| """ | |
| if not detections: | |
| return detections | |
| # Simple tracking: assign track IDs based on position similarity | |
| for detection in detections: | |
| track_id = self._assign_track_id(detection, frame_idx) | |
| detection.tracking = TrackingInfo( | |
| track_id=track_id, | |
| tracking_confidence=0.8, # Default confidence | |
| track_age=frame_idx - self.active_tracks.get(track_id, {}).get('first_seen', frame_idx), | |
| is_tracked=True, | |
| tracker_type="simple_position_based" | |
| ) | |
| return detections | |
| def _assign_track_id(self, detection: DocumentDetection, frame_idx: int) -> str: | |
| """Assign a track ID to a detection based on position similarity.""" | |
| bbox = detection.bounding_box | |
| # Check for existing tracks with similar position | |
| for track_id, track_info in self.active_tracks.items(): | |
| if self._calculate_iou(bbox, track_info['last_bbox']) > self.track_threshold: | |
| # Update existing track | |
| track_info['last_bbox'] = bbox | |
| track_info['last_seen'] = frame_idx | |
| return track_id | |
| # Create new track | |
| self.track_counter += 1 | |
| track_id = f"track_{self.track_counter:03d}" | |
| self.active_tracks[track_id] = { | |
| 'first_seen': frame_idx, | |
| 'last_seen': frame_idx, | |
| 'last_bbox': bbox | |
| } | |
| return track_id | |
| def _calculate_iou(self, bbox1: BoundingBox, bbox2: BoundingBox) -> float: | |
| """Calculate Intersection over Union (IoU) between two bounding boxes.""" | |
| # Calculate intersection | |
| x1 = max(bbox1.x1, bbox2.x1) | |
| y1 = max(bbox1.y1, bbox2.y1) | |
| x2 = min(bbox1.x2, bbox2.x2) | |
| y2 = min(bbox1.y2, bbox2.y2) | |
| if x2 <= x1 or y2 <= y1: | |
| return 0.0 | |
| intersection = (x2 - x1) * (y2 - y1) | |
| # Calculate union | |
| area1 = (bbox1.x2 - bbox1.x1) * (bbox1.y2 - bbox1.y1) | |
| area2 = (bbox2.x2 - bbox2.x1) * (bbox2.y2 - bbox2.y1) | |
| union = area1 + area2 - intersection | |
| return intersection / union if union > 0 else 0.0 | |
| class QualitySelector: | |
| """Quality-based frame selection for video processing.""" | |
| def __init__(self, quality_threshold: float = 0.7): | |
| """Initialize quality selector. | |
| Args: | |
| quality_threshold: Minimum quality score threshold | |
| """ | |
| self.quality_threshold = quality_threshold | |
| def select_best_detections( | |
| self, | |
| detections_by_frame: List[List[DocumentDetection]] | |
| ) -> List[DocumentDetection]: | |
| """Select the highest quality detection for each unique document. | |
| Args: | |
| detections_by_frame: List of detection lists, one per frame | |
| Returns: | |
| List of best quality detections | |
| """ | |
| if not detections_by_frame: | |
| return [] | |
| # Group detections by unique document identifier | |
| unique_detections = self._group_detections_by_document(detections_by_frame) | |
| # Select best quality detection for each group | |
| best_detections = [] | |
| for doc_id, detection_group in unique_detections.items(): | |
| best_detection = self._select_best_detection(detection_group) | |
| if best_detection: | |
| best_detections.append(best_detection) | |
| logger.debug(f"Selected best detection for {doc_id}") | |
| logger.info(f"Selected {len(best_detections)} best quality detections") | |
| return best_detections | |
| def _group_detections_by_document( | |
| self, | |
| detections_by_frame: List[List[DocumentDetection]] | |
| ) -> Dict[str, List[DocumentDetection]]: | |
| """Group detections by unique document identifier.""" | |
| document_groups = {} | |
| for frame_idx, frame_detections in enumerate(detections_by_frame): | |
| for detection in frame_detections: | |
| # Create unique document identifier based on type and position | |
| doc_id = self._create_document_id(detection) | |
| if doc_id not in document_groups: | |
| document_groups[doc_id] = [] | |
| document_groups[doc_id].append(detection) | |
| return document_groups | |
| def _create_document_id(self, detection: DocumentDetection) -> str: | |
| """Create a unique identifier for a document detection.""" | |
| # Use document type and position for grouping | |
| bbox = detection.bounding_box | |
| position_hash = f"{bbox.x1:.3f}_{bbox.y1:.3f}_{bbox.x2:.3f}_{bbox.y2:.3f}" | |
| return f"{detection.document_type.value}_{position_hash}" | |
| def _select_best_detection(self, detection_group: List[DocumentDetection]) -> Optional[DocumentDetection]: | |
| """Select the best quality detection from a group.""" | |
| if not detection_group: | |
| return None | |
| # Calculate composite quality score for each detection and sort | |
| detection_scores = [] | |
| for detection in detection_group: | |
| quality_score = self._calculate_composite_quality_score(detection) | |
| detection_scores.append((detection, quality_score)) | |
| # Sort by quality score (descending) | |
| detection_scores.sort(key=lambda x: x[1], reverse=True) | |
| return detection_scores[0][0] | |
| def _calculate_composite_quality_score(self, detection: DocumentDetection) -> float: | |
| """Calculate composite quality score for a detection.""" | |
| quality = detection.quality | |
| # Weighted combination of quality metrics | |
| weights = { | |
| 'sharpness': 0.3, | |
| 'glare_score': 0.2, # Inverted - lower glare is better | |
| 'coverage': 0.2, | |
| 'brightness': 0.15, | |
| 'contrast': 0.15 | |
| } | |
| score = 0.0 | |
| total_weight = 0.0 | |
| for metric, weight in weights.items(): | |
| if hasattr(quality, metric): | |
| value = getattr(quality, metric) | |
| if value is not None: | |
| # Invert glare score (lower is better) | |
| if metric == 'glare_score': | |
| value = 1.0 - value | |
| score += value * weight | |
| total_weight += weight | |
| if total_weight > 0: | |
| return score / total_weight | |
| return 0.5 # Default if no metrics available | |
| def normalize_bbox(bbox: List[float], img_width: int, img_height: int) -> BoundingBox: | |
| """Normalize bounding box coordinates to [0,1] range.""" | |
| x1, y1, x2, y2 = bbox | |
| return BoundingBox( | |
| x1=x1 / img_width, | |
| y1=y1 / img_height, | |
| x2=x2 / img_width, | |
| y2=y2 / img_height | |
| ) | |
| def classify_document_type(class_id: int) -> DocumentType: | |
| """Classify document type based on detected class ID.""" | |
| global class_mapping, DOCUMENT_TYPE_MAPPING | |
| # Get class name from mapping | |
| class_name = class_mapping.get(str(class_id), "unknown") | |
| # Map to document type | |
| doc_type = DOCUMENT_TYPE_MAPPING.get(class_name, "unknown") | |
| try: | |
| return DocumentType(doc_type) | |
| except ValueError: | |
| return DocumentType.UNKNOWN | |
| def get_class_name(class_id: int) -> str: | |
| """Get class name from class ID.""" | |
| global class_mapping | |
| return class_mapping.get(str(class_id), "unknown") | |
| async def lifespan(app: FastAPI): | |
| """Application lifespan manager for model loading.""" | |
| global yolo_model, orientation_classifier, yolo_device | |
| logger.info("Loading YOLO-E model and initializing components...") | |
| try: | |
| # Load class mapping | |
| load_class_mapping() | |
| # Select device (prefer CUDA on HF GPU instances; otherwise CPU) | |
| # Why: deployment targets Linux GPU; macOS MPS is not relevant here. | |
| yolo_device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Selected device: {yolo_device} (cuda_available={torch.cuda.is_available()})") | |
| # Log detailed device/runtime information for observability | |
| try: | |
| if yolo_device.startswith("cuda"): | |
| # Query active CUDA device details to confirm GPU runtime | |
| device_index = torch.cuda.current_device() | |
| device_name = torch.cuda.get_device_name(device_index) | |
| cc_major, cc_minor = torch.cuda.get_device_capability(device_index) | |
| logger.info( | |
| "CUDA device info: name=%s index=%s capability=%s.%s torch=%s cuda_runtime=%s", | |
| device_name, | |
| device_index, | |
| cc_major, | |
| cc_minor, | |
| torch.__version__, | |
| getattr(torch.version, "cuda", "unknown"), | |
| ) | |
| else: | |
| logger.info("CPU runtime active: torch=%s", torch.__version__) | |
| except Exception as device_log_err: | |
| # Avoid startup failure if device metadata is unavailable | |
| logger.warning(f"Device info logging failed: {device_log_err}") | |
| # Load YOLO-E model (yolo11 variant) | |
| yolo_model = YOLOE("yolo11n.pt") # Use nano for faster inference | |
| # Move model to device when API is available. Fallback to underlying .model. | |
| try: | |
| # Preferred: Ultralytics model interface | |
| _ = yolo_model.to(yolo_device) | |
| except Exception: | |
| try: | |
| # Fallback: underlying PyTorch module | |
| _ = yolo_model.model.to(yolo_device) # type: ignore[attr-defined] | |
| except Exception: | |
| # If neither works, we'll rely on per-call device selection below | |
| logger.warning("Could not move model to device at load time; will set device per call") | |
| logger.info("YOLO-E model loaded successfully") | |
| # Initialize orientation classifier with YOLO model | |
| orientation_classifier = OrientationClassifier(yolo_model) | |
| logger.info("Orientation classifier initialized") | |
| # Optional warm-up on GPU to trigger lazy CUDA init and JITs | |
| try: | |
| dummy = np.zeros((640, 640, 3), dtype=np.uint8) | |
| # Use a very low confidence and no verbose to minimize overhead | |
| _ = yolo_model(dummy, conf=0.01, verbose=False, device=yolo_device) | |
| if yolo_device.startswith("cuda"): | |
| torch.cuda.synchronize() | |
| logger.info("Warm-up inference completed") | |
| except Exception as warmup_err: | |
| logger.warning(f"Warm-up skipped due to: {warmup_err}") | |
| except Exception as e: | |
| logger.error(f"Failed to load models: {e}") | |
| raise | |
| yield | |
| logger.info("Shutting down YOLO-E endpoint...") | |
| app = FastAPI( | |
| title="KYB YOLO-E European Document Detection", | |
| description="Enhanced YOLO-E for European identity document detection with ML-based orientation classification and video processing", | |
| version="2.0.0", | |
| lifespan=lifespan | |
| ) | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return {"status": "healthy", "version": "2.0.0"} | |
| async def detect_documents( | |
| file: UploadFile = File(..., description="Image file to process"), | |
| min_confidence: float = Form(0.25, ge=0.0, le=1.0, description="Minimum confidence threshold"), | |
| return_crops: bool = Form(False, description="Whether to return cropped images") | |
| ): | |
| """Detect European identity documents in uploaded image.""" | |
| if yolo_model is None or orientation_classifier is None: | |
| raise HTTPException(status_code=503, detail="Models not loaded") | |
| start_time = time.time() | |
| request_id = str(uuid.uuid4()) | |
| try: | |
| # Read and validate image | |
| image_data = await file.read() | |
| image = Image.open(io.BytesIO(image_data)) | |
| image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| img_height, img_width = image_cv.shape[:2] | |
| # Run YOLO-E detection on the selected device | |
| results = yolo_model(image_cv, conf=min_confidence, device=yolo_device, verbose=False) | |
| detections = [] | |
| for result in results: | |
| if result.boxes is not None: | |
| for box in result.boxes: | |
| # Extract detection data | |
| conf = float(box.conf[0]) | |
| if conf < min_confidence: | |
| continue | |
| # Get class ID and name | |
| class_id = int(box.cls[0]) | |
| class_name = get_class_name(class_id) | |
| # Get bounding box coordinates | |
| x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() | |
| bbox = normalize_bbox([x1, y1, x2, y2], img_width, img_height) | |
| # Classify document type | |
| document_type = classify_document_type(class_id) | |
| # Determine orientation using ML-based classifier | |
| orientation = orientation_classifier.classify_orientation(image_cv, class_name) | |
| # Assess quality | |
| quality = QualityAssessor.assess_quality(image_cv, bbox) | |
| # Prepare crop data if requested | |
| crop_data = None | |
| if return_crops: | |
| crop_img = image_cv[int(y1):int(y2), int(x1):int(x2)] | |
| _, buffer = cv2.imencode('.jpg', crop_img) | |
| crop_data = base64.b64encode(buffer).decode('utf-8') | |
| # Create detection | |
| detection = DocumentDetection( | |
| document_type=document_type, | |
| orientation=orientation, | |
| confidence=conf, | |
| bounding_box=bbox, | |
| quality=quality, | |
| tracking=TrackingInfo( | |
| track_id=None, | |
| tracking_confidence=None, | |
| track_age=None, | |
| is_tracked=False, | |
| tracker_type=None | |
| ), | |
| crop_data=crop_data, | |
| metadata=DetectionMetadata( | |
| class_name=class_name, | |
| original_coordinates=[float(x1), float(y1), float(x2), float(y2)], | |
| mask_used=False | |
| ) | |
| ) | |
| detections.append(detection) | |
| processing_time = time.time() - start_time | |
| return DetectionResponse( | |
| request_id=request_id, | |
| media_type="image", | |
| processing_time=processing_time, | |
| detections=detections, | |
| frame_count=None | |
| ) | |
| except Exception as e: | |
| logger.error(f"Detection failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Detection failed: {str(e)}") | |
| async def detect_documents_video( | |
| file: UploadFile = File(..., description="Video file to process"), | |
| min_confidence: float = Form(0.25, ge=0.0, le=1.0, description="Minimum confidence threshold"), | |
| sample_fps: float = Form(2.0, ge=0.1, le=30.0, description="Video sampling rate in frames per second"), | |
| return_crops: bool = Form(False, description="Whether to return cropped images"), | |
| max_detections: int = Form(10, ge=1, le=100, description="Maximum number of detections to return") | |
| ): | |
| """Detect European identity documents in uploaded video with quality-based frame selection.""" | |
| if yolo_model is None or orientation_classifier is None: | |
| raise HTTPException(status_code=503, detail="Models not loaded") | |
| start_time = time.time() | |
| request_id = str(uuid.uuid4()) | |
| try: | |
| # Read video data | |
| video_data = await file.read() | |
| # Initialize video processor, quality selector, and tracker | |
| video_processor = VideoProcessor(sample_fps=sample_fps) | |
| quality_selector = QualitySelector() | |
| tracker = SimpleTracker() | |
| # Extract frames from video | |
| frames = video_processor.extract_frames_from_bytes(video_data) | |
| if not frames: | |
| logger.error("No frames extracted from video") | |
| raise HTTPException(status_code=400, detail="No frames extracted from video") | |
| logger.info(f"Processing {len(frames)} frames from video") | |
| # Process each frame | |
| detections_by_frame = [] | |
| for frame_idx, (frame, timestamp) in enumerate(frames): | |
| frame_detections = [] | |
| # Run YOLO-E detection on the selected device | |
| results = yolo_model(frame, conf=min_confidence, device=yolo_device, verbose=False) | |
| for result in results: | |
| if result.boxes is not None: | |
| for box in result.boxes: | |
| # Extract detection data | |
| conf = float(box.conf[0]) | |
| if conf < min_confidence: | |
| continue | |
| # Get class ID and name | |
| class_id = int(box.cls[0]) | |
| class_name = get_class_name(class_id) | |
| # Get bounding box coordinates | |
| x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() | |
| img_height, img_width = frame.shape[:2] | |
| bbox = normalize_bbox([x1, y1, x2, y2], img_width, img_height) | |
| # Classify document type | |
| document_type = classify_document_type(class_id) | |
| # Determine orientation using ML-based classifier | |
| orientation = orientation_classifier.classify_orientation(frame, class_name) | |
| # Assess quality | |
| quality = QualityAssessor.assess_quality(frame, bbox) | |
| # Prepare crop data if requested | |
| crop_data = None | |
| if return_crops: | |
| crop_img = frame[int(y1):int(y2), int(x1):int(x2)] | |
| _, buffer = cv2.imencode('.jpg', crop_img) | |
| crop_data = base64.b64encode(buffer).decode('utf-8') | |
| # Create detection | |
| detection = DocumentDetection( | |
| document_type=document_type, | |
| orientation=orientation, | |
| confidence=conf, | |
| bounding_box=bbox, | |
| quality=quality, | |
| tracking=TrackingInfo(), # Will be updated by tracker | |
| crop_data=crop_data, | |
| metadata=DetectionMetadata( | |
| class_name=class_name, | |
| original_coordinates=[float(x1), float(y1), float(x2), float(y2)], | |
| mask_used=False | |
| ) | |
| ) | |
| frame_detections.append(detection) | |
| # Update tracks for this frame | |
| frame_detections = tracker.update_tracks(frame_detections, frame_idx) | |
| detections_by_frame.append(frame_detections) | |
| # Select best quality detections | |
| best_detections = quality_selector.select_best_detections(detections_by_frame) | |
| # Limit to max_detections | |
| if len(best_detections) > max_detections: | |
| best_detections = best_detections[:max_detections] | |
| processing_time = time.time() - start_time | |
| return DetectionResponse( | |
| request_id=request_id, | |
| media_type="video", | |
| processing_time=processing_time, | |
| detections=best_detections, | |
| frame_count=len(frames) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Video detection failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Video detection failed: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |