import gradio as gr import torch import torch.nn as nn from torchvision import models import torchvision.transforms as transforms from PIL import Image import cv2 import numpy as np from ultralytics import YOLO import base64 import io import yaml from pathlib import Path import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class YOLOLicensePlateDetector: """YOLO-based license plate detector matching the original API""" def __init__(self, detect_model_path, char_model_path, province_model_path, data_path, province_data_path, device): self.device = device # Load character mapping from data.yaml self.char_mapping = {} self.province_mapping = {} self._load_mappings(data_path) self._load_province_mappings(province_data_path) # Load YOLO models self.detect_model = None self.char_model = None self.province_model = None if detect_model_path and Path(detect_model_path).exists(): self.detect_model = YOLO(str(detect_model_path)) logger.info(f"License plate detection model loaded: {detect_model_path}") if char_model_path and Path(char_model_path).exists(): self.char_model = YOLO(str(char_model_path)) logger.info(f"Character reading model loaded: {char_model_path}") if province_model_path and Path(province_model_path).exists(): self.province_model = YOLO(str(province_model_path)) logger.info(f"Province detection model loaded: {province_model_path}") def _load_mappings(self, data_path): """Load character and province mappings from YAML""" try: if Path(data_path).exists(): with open(data_path, 'r', encoding='utf-8') as f: data = yaml.safe_load(f) # Load character mapping - keep keys as strings! self.char_mapping = data.get('char_mapping', {}) # Add digit mapping for class names "0"-"9" for i in range(10): class_name = str(i) if class_name not in self.char_mapping: self.char_mapping[class_name] = str(i) logger.info(f"Loaded {len(self.char_mapping)} character mappings") logger.info(f"Sample mappings: {dict(list(self.char_mapping.items())[:5])}") else: logger.warning(f"Data file not found: {data_path}") # Default mappings self.char_mapping = {str(i): str(i) for i in range(10)} # "0"-"9" except Exception as e: logger.error(f"Error loading mappings: {e}") self.char_mapping = {str(i): str(i) for i in range(10)} def _load_province_mappings(self, province_data_path): """Load province mappings from data_province.yaml (matching original API)""" try: if Path(province_data_path).exists(): with open(province_data_path, 'r', encoding='utf-8') as f: data = yaml.safe_load(f) # Load province mapping from char_mapping section (like original API) if 'char_mapping' in data: self.province_mapping = data['char_mapping'] logger.info(f"✅ Province mapping loaded from data_province.yaml") logger.info(f" Loaded {len(self.province_mapping)} province mappings") logger.info(f" Sample: {dict(list(self.province_mapping.items())[:3])}") elif 'names' in data: # Fallback: create mapping from names if no explicit mapping self.province_mapping = {str(i): name for i, name in enumerate(data['names'])} logger.info("✅ Province mapping created from names") logger.info(f" Created {len(self.province_mapping)} province mappings") else: self.province_mapping = {"0": "Unknown"} logger.warning("No province mapping found in data_province.yaml") else: logger.warning(f"Province data file not found: {province_data_path}") self.province_mapping = {"0": "Unknown"} except Exception as e: logger.error(f"Error loading province mappings: {e}") self.province_mapping = {"0": "Unknown"} def map_class_to_char(self, class_name): """Map YOLO class name to character (matching original API)""" return self.char_mapping.get(str(class_name), '?') def map_class_to_province(self, class_name): """Map YOLO class name to province (matching original API)""" return self.province_mapping.get(str(class_name), "Unknown") def detect_license_plate(self, vehicle_image): """Detect license plate in vehicle image using YOLO""" if self.detect_model is None: return None try: # Run license plate detection with confidence 0.3 (same as original API) results = self.detect_model(vehicle_image, conf=0.3) if not results or len(results) == 0: return None # Get the first (highest confidence) license plate detection for result in results: boxes = result.boxes if boxes is not None and len(boxes) > 0: # Get the highest confidence detection best_box = boxes[0] x1, y1, x2, y2 = best_box.xyxy[0].cpu().numpy().astype(int) confidence = best_box.conf[0].cpu().numpy() # Crop license plate region if isinstance(vehicle_image, Image.Image): vehicle_array = np.array(vehicle_image) else: vehicle_array = vehicle_image license_plate = vehicle_array[y1:y2, x1:x2] return { 'image': license_plate, 'bbox': [x1, y1, x2, y2], 'confidence': float(confidence) } return None except Exception as e: logger.error(f"License plate detection error: {e}") return None def read_characters(self, license_plate_image): """Read characters from license plate using YOLO (matching original API)""" if self.char_model is None: return [] try: # Ensure image is in correct format if isinstance(license_plate_image, Image.Image): img_array = np.array(license_plate_image) else: img_array = license_plate_image # Run character detection with confidence 0.3 (same as original API) results = self.char_model(img_array, conf=0.3) characters = [] for result in results: boxes = result.boxes if boxes is not None: for box in boxes: x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() confidence = box.conf[0].cpu().numpy() class_id = int(box.cls[0].cpu().numpy()) # Two-step mapping like original API: # 1. Get class name from model class_name = result.names[class_id] # 2. Map class name to character char = self.map_class_to_char(class_name) characters.append({ 'char': char, 'confidence': float(confidence), 'bbox': [float(x1), float(y1), float(x2), float(y2)], 'center_x': float((x1 + x2) / 2) }) # Sort characters by x-position (left to right) - same as original API characters.sort(key=lambda x: x['bbox'][0]) return characters except Exception as e: logger.error(f"Character reading error: {e}") return [] def detect_province(self, license_plate_image): """Detect province from license plate""" if self.province_model is None: return "Unknown" try: # Ensure image is in correct format if isinstance(license_plate_image, Image.Image): img_array = np.array(license_plate_image) else: img_array = license_plate_image # Run province detection with confidence 0.3 (same as original API) results = self.province_model(img_array, conf=0.3) for result in results: boxes = result.boxes if boxes is not None and len(boxes) > 0: # Get highest confidence detection best_box = boxes[0] class_id = int(best_box.cls[0].cpu().numpy()) confidence = best_box.conf[0].cpu().numpy() # Two-step mapping like original API: # 1. Get class name from model class_name = result.names[class_id] # 2. Map class name to province province = self.map_class_to_province(class_name) return province return "Unknown" except Exception as e: logger.error(f"Province detection error: {e}") return "Unknown" class LicensePlateDetector: """Main license plate detection system""" def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Model paths - try multiple locations base_paths = [Path("models"), Path("../models"), Path("./")] # Find YOLO models self.yolo_model_path = None self.segment_model_path = None for base_dir in base_paths: if (base_dir / "yolo11s.pt").exists(): self.yolo_model_path = base_dir / "yolo11s.pt" break elif (base_dir / "yolov9.pt").exists(): self.yolo_model_path = base_dir / "yolov9.pt" break for base_dir in base_paths: if (base_dir / "best_segment.pt").exists(): self.segment_model_path = base_dir / "best_segment.pt" break # Find license plate detection model (detect1.pt) self.detect_model_path = None detect_model_names = ["detect1.pt"] for base_dir in base_paths: for model_name in detect_model_names: if (base_dir / model_name).exists(): self.detect_model_path = base_dir / model_name break if self.detect_model_path: break # Find character reading model (read_char.pt) self.char_model_path = None char_model_names = ["read_char.pt"] for base_dir in base_paths: for model_name in char_model_names: if (base_dir / model_name).exists(): self.char_model_path = base_dir / model_name break if self.char_model_path: break # Find province recognition model self.province_model_path = None province_model_names = ["best_province.pt"] for base_dir in base_paths: for model_name in province_model_names: if (base_dir / model_name).exists(): self.province_model_path = base_dir / model_name break if self.province_model_path: break # Find data.yaml file (for character mapping) config_paths = [ Path("deploy_huggingface/config/data.yaml"), Path("config/data.yaml"), Path("../config/data.yaml"), Path("./data.yaml") ] self.data_path = None for config_path in config_paths: if config_path.exists(): self.data_path = config_path break if self.data_path is None: self.data_path = Path("deploy_huggingface/config/data.yaml") # Use default # Find data_province.yaml file (for province mapping) province_config_paths = [ Path("deploy_huggingface/config/data_province.yaml"), Path("config/data_province.yaml"), Path("../config/data_province.yaml"), Path("./data_province.yaml") ] self.province_data_path = None for config_path in province_config_paths: if config_path.exists(): self.province_data_path = config_path break if self.province_data_path is None: self.province_data_path = Path("deploy_huggingface/config/data_province.yaml") # Use default # Initialize models self.yolo_model = None self.license_plate_detector = None self._load_models() def _load_models(self): """Load all required models""" try: # YOLO vehicle detection model if self.yolo_model_path and self.yolo_model_path.exists(): self.yolo_model = YOLO(str(self.yolo_model_path)) logger.info("YOLO vehicle detection model loaded") else: logger.warning("YOLO vehicle detection model not found") # YOLO-based license plate detector self.license_plate_detector = YOLOLicensePlateDetector( detect_model_path=self.detect_model_path, char_model_path=self.char_model_path, province_model_path=self.province_model_path, data_path=self.data_path, province_data_path=self.province_data_path, device=self.device ) except Exception as e: logger.error(f"Error loading models: {e}") print(f"Warning: Some models failed to load: {e}") def point_in_polygon(self, point, polygon): """Check if a point is inside a polygon""" x, y = point n = len(polygon) inside = False p1x, p1y = polygon[0] for i in range(1, n + 1): p2x, p2y = polygon[i % n] if y > min(p1y, p2y): if y <= max(p1y, p2y): if x <= max(p1x, p2x): if p1y != p2y: xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x if p1x == p2x or x <= xinters: inside = not inside p1x, p1y = p2x, p2y return inside def detect_objects_in_protection_area(self, image, protection_polygon): """Detect objects in the protection area""" results = [] if self.yolo_model is None: return results try: # Run YOLO detection detections = self.yolo_model(image, conf=0.25) for detection in detections: boxes = detection.boxes if boxes is not None: for box in boxes: # Get bounding box coordinates x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() center_x = (x1 + x2) / 2 center_y = (y1 + y2) / 2 # Check if center point is in protection area if self.point_in_polygon((center_x, center_y), protection_polygon): confidence = box.conf[0].cpu().numpy() class_id = int(box.cls[0].cpu().numpy()) class_name = detection.names[class_id] results.append({ 'bbox': [int(x1), int(y1), int(x2), int(y2)], 'confidence': float(confidence), 'class': class_name, 'center': [center_x, center_y] }) except Exception as e: logger.error(f"Object detection error: {e}") return results def detect_and_read_license_plate(self, vehicle_image): """Detect and read license plate from vehicle image using YOLO""" if self.license_plate_detector is None: return None, "Unknown", "Unknown" try: # Step 1: Detect license plate in vehicle image plate_detection = self.license_plate_detector.detect_license_plate(vehicle_image) if plate_detection is None: return None, "Unknown", "Unknown" plate_image = plate_detection['image'] # Step 2: Read characters from license plate characters = self.license_plate_detector.read_characters(plate_image) # Step 3: Assemble character text (exactly like original API) if characters: # Join characters directly (same as original API) char_text = ''.join([char['char'] for char in characters]) # Only show "Detected" if all characters are unknown if not char_text or char_text.replace('?', '') == '': char_text = "Detected" else: char_text = "Detected" # License plate detected but no characters read # Step 4: Detect province province = self.license_plate_detector.detect_province(plate_image) return plate_image, char_text, province except Exception as e: logger.error(f"License plate detection and reading error: {e}") return None, "Unknown", "Unknown" def process_image(self, image, protection_points): """Process the entire image for license plate detection""" results = { 'detected_objects': [], 'annotated_image': None, 'license_plates': [] } if len(protection_points) < 3: return results try: # Convert PIL to OpenCV format if isinstance(image, Image.Image): image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) else: image_cv = image # Detect objects in protection area detected_objects = self.detect_objects_in_protection_area(image_cv, protection_points) # Process each detected object (vehicle) for obj in detected_objects: # Crop vehicle image x1, y1, x2, y2 = obj['bbox'] vehicle_image = image_cv[y1:y2, x1:x2] # Detect and read license plate from vehicle plate_image, plate_text, province = self.detect_and_read_license_plate(vehicle_image) if plate_image is not None: obj['license_plate'] = { 'text': plate_text, 'province': province, 'image': plate_image } results['license_plates'].append({ 'text': plate_text, 'province': province, 'image': plate_image, 'bbox': obj['bbox'] }) results['detected_objects'].append(obj) # Create annotated image annotated_image = self.draw_annotations(image_cv, protection_points, results['detected_objects']) results['annotated_image'] = annotated_image except Exception as e: logger.error(f"Image processing error: {e}") return results def draw_annotations(self, image, protection_points, detected_objects): """Draw annotations on the image""" annotated = image.copy() # Draw protection zone if len(protection_points) >= 3: points = np.array(protection_points, np.int32) cv2.polylines(annotated, [points], True, (0, 255, 0), 3) # Fill with transparency overlay = annotated.copy() cv2.fillPoly(overlay, [points], (0, 255, 0)) cv2.addWeighted(overlay, 0.3, annotated, 0.7, 0, annotated) # Draw detected objects for obj in detected_objects: x1, y1, x2, y2 = obj['bbox'] # Draw bounding box cv2.rectangle(annotated, (x1, y1), (x2, y2), (255, 0, 0), 2) # Draw label label = f"{obj['class']}: {obj['confidence']:.2f}" if 'license_plate' in obj: label += f"\n{obj['license_plate']['text']}" label += f"\n{obj['license_plate']['province']}" cv2.putText(annotated, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2) return annotated class LicensePlateApp: """Gradio app for license plate detection""" def __init__(self): self.detector = LicensePlateDetector() self.protection_points = [] self.uploaded_image = None def clear_points(self): """Clear all protection zone points""" self.protection_points = [] return None, "Protection zone cleared. Upload an image and click to select new points." def add_point(self, image, evt: gr.SelectData): """Add a point to the protection zone when user clicks on image""" if image is None: return None, "Please upload an image first." x, y = evt.index[0], evt.index[1] self.protection_points.append([x, y]) # Draw the protection zone on the image img_with_zone = self.draw_protection_zone(image) status = f"Added point ({x}, {y}). Total points: {len(self.protection_points)}" if len(self.protection_points) >= 3: status += " (Ready to detect - you have enough points for a polygon)" return img_with_zone, status def draw_protection_zone(self, image): """Draw the protection zone on the image""" if len(self.protection_points) < 2: return image # Convert PIL to numpy array img_array = np.array(image) # Draw lines between consecutive points for i in range(len(self.protection_points)): start_point = tuple(self.protection_points[i]) end_point = tuple(self.protection_points[(i + 1) % len(self.protection_points)]) cv2.line(img_array, start_point, end_point, (0, 255, 0), 2) # Draw points for point in self.protection_points: cv2.circle(img_array, tuple(point), 5, (255, 0, 0), -1) # If we have 3+ points, draw a filled polygon with transparency if len(self.protection_points) >= 3: points = np.array(self.protection_points, np.int32) overlay = img_array.copy() cv2.fillPoly(overlay, [points], (0, 255, 0)) cv2.addWeighted(overlay, 0.3, img_array, 0.7, 0, img_array) return Image.fromarray(img_array) def detect_license_plates(self, image, confidence): """Process image for license plate detection""" if image is None: return None, [], "Please upload an image first." if len(self.protection_points) < 3: return None, [], "Please select at least 3 points to define a protection zone." try: # Process the image results = self.detector.process_image(image, self.protection_points) # Prepare results for display annotated_image = None if results['annotated_image'] is not None: annotated_image = Image.fromarray(cv2.cvtColor(results['annotated_image'], cv2.COLOR_BGR2RGB)) # Format license plates for gallery license_plates_gallery = [] summary_text = f""" 🔍 **Detection Results** 📊 **Statistics:** - Objects detected in protection area: {len(results['detected_objects'])} - License plates found: {len(results['license_plates'])} 🚗 **Detected Objects:** """ for plate in results['license_plates']: if plate['image'] is not None: plate_pil = Image.fromarray(cv2.cvtColor(plate['image'], cv2.COLOR_BGR2RGB)) caption = f"License: {plate['text']}\nProvince: {plate['province']}" license_plates_gallery.append((plate_pil, caption)) summary_text += f""" - **Vehicle** (License Plate: {plate['text']}) - Province: {plate['province']} - Location: {plate['bbox']} """ if len(results['detected_objects']) == 0: summary_text += "\nNo objects detected in the protection zone." return annotated_image, license_plates_gallery, summary_text except Exception as e: error_msg = f"Error processing image: {str(e)}" logger.error(error_msg) return None, [], error_msg def create_gradio_interface(): """Create the Gradio interface""" app = LicensePlateApp() with gr.Blocks(title="🚗 License Plate Detection System", theme=gr.themes.Soft()) as iface: gr.Markdown(""" # 🚗 License Plate Detection System AI-powered license plate detection and recognition for Thai vehicles ## How to use: 1. **Upload an image** with vehicles 2. **Click on the image** to select protection zone points (minimum 3 points) 3. **Adjust confidence** threshold if needed 4. **Click "Detect License Plates"** to run detection 5. **View results** including annotated image and detected license plates """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📤 Input") # Image upload input_image = gr.Image( type="pil", label="Upload Image", interactive=True ) # Confidence slider confidence_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold", info="Higher values = more strict detection" ) # Control buttons with gr.Row(): clear_btn = gr.Button("🗑️ Clear Protection Zone", variant="secondary") detect_btn = gr.Button("🔍 Detect License Plates", variant="primary") # Status display status_text = gr.Textbox( label="Status", value="Upload an image and click to select protection zone points.", interactive=False, lines=3 ) with gr.Column(scale=2): gr.Markdown("### 🎯 Protection Zone Selection") gr.Markdown("Click on the image to add points for the protection zone (minimum 3 points)") # Image with protection zone zone_image = gr.Image( type="pil", label="Click to Select Protection Zone", interactive=False ) gr.Markdown("### 📊 Results") with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### 🖼️ Annotated Detection") result_image = gr.Image( type="pil", label="Detection Results", interactive=False ) with gr.Column(scale=1): gr.Markdown("#### 📋 Detection Summary") summary_text = gr.Markdown() gr.Markdown("#### 🔢 Detected License Plates") license_plates_gallery = gr.Gallery( label="License Plates Found", show_label=True, elem_id="gallery", columns=4, rows=2, object_fit="contain", height="auto" ) # Event handlers input_image.upload( fn=lambda img: (img, "Image uploaded. Click on the image to select protection zone points."), inputs=[input_image], outputs=[zone_image, status_text] ) zone_image.select( fn=app.add_point, inputs=[input_image], outputs=[zone_image, status_text] ) clear_btn.click( fn=app.clear_points, outputs=[zone_image, status_text] ) detect_btn.click( fn=app.detect_license_plates, inputs=[input_image, confidence_slider], outputs=[result_image, license_plates_gallery, summary_text] ) # Examples and instructions gr.Markdown("### 📖 Instructions") gr.Markdown(""" **Step-by-step guide:** 1. **Upload Image**: Click "Upload Image" and select an image with vehicles 2. **Select Protection Zone**: - Click at least 3 points on the uploaded image to define a protection area - The area will be highlighted in green - You can click "Clear Protection Zone" to start over 3. **Adjust Settings**: Use the confidence slider to control detection sensitivity 4. **Run Detection**: Click "Detect License Plates" to process the image 5. **View Results**: - See the annotated image with detected objects - View individual license plate crops in the gallery - Read the detection summary **Tips:** - Select protection zones around areas where vehicles might pass - Higher confidence values will detect fewer but more certain objects - The protection zone should be a polygon (minimum 3 points) """) return iface if __name__ == "__main__": # Create and launch the interface iface = create_gradio_interface() iface.launch( server_name="0.0.0.0", server_port=7860, share=True )