Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import base64 | |
| import io | |
| import yaml | |
| from pathlib import Path | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class YOLOLicensePlateDetector: | |
| """YOLO-based license plate detector matching the original API""" | |
| def __init__(self, detect_model_path, char_model_path, province_model_path, data_path, province_data_path, device): | |
| self.device = device | |
| # Load character mapping from data.yaml | |
| self.char_mapping = {} | |
| self.province_mapping = {} | |
| self._load_mappings(data_path) | |
| self._load_province_mappings(province_data_path) | |
| # Load YOLO models | |
| self.detect_model = None | |
| self.char_model = None | |
| self.province_model = None | |
| if detect_model_path and Path(detect_model_path).exists(): | |
| self.detect_model = YOLO(str(detect_model_path)) | |
| logger.info(f"License plate detection model loaded: {detect_model_path}") | |
| if char_model_path and Path(char_model_path).exists(): | |
| self.char_model = YOLO(str(char_model_path)) | |
| logger.info(f"Character reading model loaded: {char_model_path}") | |
| if province_model_path and Path(province_model_path).exists(): | |
| self.province_model = YOLO(str(province_model_path)) | |
| logger.info(f"Province detection model loaded: {province_model_path}") | |
| def _load_mappings(self, data_path): | |
| """Load character and province mappings from YAML""" | |
| try: | |
| if Path(data_path).exists(): | |
| with open(data_path, 'r', encoding='utf-8') as f: | |
| data = yaml.safe_load(f) | |
| # Load character mapping - keep keys as strings! | |
| self.char_mapping = data.get('char_mapping', {}) | |
| # Add digit mapping for class names "0"-"9" | |
| for i in range(10): | |
| class_name = str(i) | |
| if class_name not in self.char_mapping: | |
| self.char_mapping[class_name] = str(i) | |
| logger.info(f"Loaded {len(self.char_mapping)} character mappings") | |
| logger.info(f"Sample mappings: {dict(list(self.char_mapping.items())[:5])}") | |
| else: | |
| logger.warning(f"Data file not found: {data_path}") | |
| # Default mappings | |
| self.char_mapping = {str(i): str(i) for i in range(10)} # "0"-"9" | |
| except Exception as e: | |
| logger.error(f"Error loading mappings: {e}") | |
| self.char_mapping = {str(i): str(i) for i in range(10)} | |
| def _load_province_mappings(self, province_data_path): | |
| """Load province mappings from data_province.yaml (matching original API)""" | |
| try: | |
| if Path(province_data_path).exists(): | |
| with open(province_data_path, 'r', encoding='utf-8') as f: | |
| data = yaml.safe_load(f) | |
| # Load province mapping from char_mapping section (like original API) | |
| if 'char_mapping' in data: | |
| self.province_mapping = data['char_mapping'] | |
| logger.info(f"β Province mapping loaded from data_province.yaml") | |
| logger.info(f" Loaded {len(self.province_mapping)} province mappings") | |
| logger.info(f" Sample: {dict(list(self.province_mapping.items())[:3])}") | |
| elif 'names' in data: | |
| # Fallback: create mapping from names if no explicit mapping | |
| self.province_mapping = {str(i): name for i, name in enumerate(data['names'])} | |
| logger.info("β Province mapping created from names") | |
| logger.info(f" Created {len(self.province_mapping)} province mappings") | |
| else: | |
| self.province_mapping = {"0": "Unknown"} | |
| logger.warning("No province mapping found in data_province.yaml") | |
| else: | |
| logger.warning(f"Province data file not found: {province_data_path}") | |
| self.province_mapping = {"0": "Unknown"} | |
| except Exception as e: | |
| logger.error(f"Error loading province mappings: {e}") | |
| self.province_mapping = {"0": "Unknown"} | |
| def map_class_to_char(self, class_name): | |
| """Map YOLO class name to character (matching original API)""" | |
| return self.char_mapping.get(str(class_name), '?') | |
| def map_class_to_province(self, class_name): | |
| """Map YOLO class name to province (matching original API)""" | |
| return self.province_mapping.get(str(class_name), "Unknown") | |
| def detect_license_plate(self, vehicle_image): | |
| """Detect license plate in vehicle image using YOLO""" | |
| if self.detect_model is None: | |
| return None | |
| try: | |
| # Run license plate detection with confidence 0.3 (same as original API) | |
| results = self.detect_model(vehicle_image, conf=0.3) | |
| if not results or len(results) == 0: | |
| return None | |
| # Get the first (highest confidence) license plate detection | |
| for result in results: | |
| boxes = result.boxes | |
| if boxes is not None and len(boxes) > 0: | |
| # Get the highest confidence detection | |
| best_box = boxes[0] | |
| x1, y1, x2, y2 = best_box.xyxy[0].cpu().numpy().astype(int) | |
| confidence = best_box.conf[0].cpu().numpy() | |
| # Crop license plate region | |
| if isinstance(vehicle_image, Image.Image): | |
| vehicle_array = np.array(vehicle_image) | |
| else: | |
| vehicle_array = vehicle_image | |
| license_plate = vehicle_array[y1:y2, x1:x2] | |
| return { | |
| 'image': license_plate, | |
| 'bbox': [x1, y1, x2, y2], | |
| 'confidence': float(confidence) | |
| } | |
| return None | |
| except Exception as e: | |
| logger.error(f"License plate detection error: {e}") | |
| return None | |
| def read_characters(self, license_plate_image): | |
| """Read characters from license plate using YOLO (matching original API)""" | |
| if self.char_model is None: | |
| return [] | |
| try: | |
| # Ensure image is in correct format | |
| if isinstance(license_plate_image, Image.Image): | |
| img_array = np.array(license_plate_image) | |
| else: | |
| img_array = license_plate_image | |
| # Run character detection with confidence 0.3 (same as original API) | |
| results = self.char_model(img_array, conf=0.3) | |
| characters = [] | |
| for result in results: | |
| boxes = result.boxes | |
| if boxes is not None: | |
| for box in boxes: | |
| x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() | |
| confidence = box.conf[0].cpu().numpy() | |
| class_id = int(box.cls[0].cpu().numpy()) | |
| # Two-step mapping like original API: | |
| # 1. Get class name from model | |
| class_name = result.names[class_id] | |
| # 2. Map class name to character | |
| char = self.map_class_to_char(class_name) | |
| characters.append({ | |
| 'char': char, | |
| 'confidence': float(confidence), | |
| 'bbox': [float(x1), float(y1), float(x2), float(y2)], | |
| 'center_x': float((x1 + x2) / 2) | |
| }) | |
| # Sort characters by x-position (left to right) - same as original API | |
| characters.sort(key=lambda x: x['bbox'][0]) | |
| return characters | |
| except Exception as e: | |
| logger.error(f"Character reading error: {e}") | |
| return [] | |
| def detect_province(self, license_plate_image): | |
| """Detect province from license plate""" | |
| if self.province_model is None: | |
| return "Unknown" | |
| try: | |
| # Ensure image is in correct format | |
| if isinstance(license_plate_image, Image.Image): | |
| img_array = np.array(license_plate_image) | |
| else: | |
| img_array = license_plate_image | |
| # Run province detection with confidence 0.3 (same as original API) | |
| results = self.province_model(img_array, conf=0.3) | |
| for result in results: | |
| boxes = result.boxes | |
| if boxes is not None and len(boxes) > 0: | |
| # Get highest confidence detection | |
| best_box = boxes[0] | |
| class_id = int(best_box.cls[0].cpu().numpy()) | |
| confidence = best_box.conf[0].cpu().numpy() | |
| # Two-step mapping like original API: | |
| # 1. Get class name from model | |
| class_name = result.names[class_id] | |
| # 2. Map class name to province | |
| province = self.map_class_to_province(class_name) | |
| return province | |
| return "Unknown" | |
| except Exception as e: | |
| logger.error(f"Province detection error: {e}") | |
| return "Unknown" | |
| class LicensePlateDetector: | |
| """Main license plate detection system""" | |
| def __init__(self): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| # Model paths - try multiple locations | |
| base_paths = [Path("models"), Path("../models"), Path("./")] | |
| # Find YOLO models | |
| self.yolo_model_path = None | |
| self.segment_model_path = None | |
| for base_dir in base_paths: | |
| if (base_dir / "yolo11s.pt").exists(): | |
| self.yolo_model_path = base_dir / "yolo11s.pt" | |
| break | |
| elif (base_dir / "yolov9.pt").exists(): | |
| self.yolo_model_path = base_dir / "yolov9.pt" | |
| break | |
| for base_dir in base_paths: | |
| if (base_dir / "best_segment.pt").exists(): | |
| self.segment_model_path = base_dir / "best_segment.pt" | |
| break | |
| # Find license plate detection model (detect1.pt) | |
| self.detect_model_path = None | |
| detect_model_names = ["detect1.pt"] | |
| for base_dir in base_paths: | |
| for model_name in detect_model_names: | |
| if (base_dir / model_name).exists(): | |
| self.detect_model_path = base_dir / model_name | |
| break | |
| if self.detect_model_path: | |
| break | |
| # Find character reading model (read_char.pt) | |
| self.char_model_path = None | |
| char_model_names = ["read_char.pt"] | |
| for base_dir in base_paths: | |
| for model_name in char_model_names: | |
| if (base_dir / model_name).exists(): | |
| self.char_model_path = base_dir / model_name | |
| break | |
| if self.char_model_path: | |
| break | |
| # Find province recognition model | |
| self.province_model_path = None | |
| province_model_names = ["best_province.pt"] | |
| for base_dir in base_paths: | |
| for model_name in province_model_names: | |
| if (base_dir / model_name).exists(): | |
| self.province_model_path = base_dir / model_name | |
| break | |
| if self.province_model_path: | |
| break | |
| # Find data.yaml file (for character mapping) | |
| config_paths = [ | |
| Path("deploy_huggingface/config/data.yaml"), | |
| Path("config/data.yaml"), | |
| Path("../config/data.yaml"), | |
| Path("./data.yaml") | |
| ] | |
| self.data_path = None | |
| for config_path in config_paths: | |
| if config_path.exists(): | |
| self.data_path = config_path | |
| break | |
| if self.data_path is None: | |
| self.data_path = Path("deploy_huggingface/config/data.yaml") # Use default | |
| # Find data_province.yaml file (for province mapping) | |
| province_config_paths = [ | |
| Path("deploy_huggingface/config/data_province.yaml"), | |
| Path("config/data_province.yaml"), | |
| Path("../config/data_province.yaml"), | |
| Path("./data_province.yaml") | |
| ] | |
| self.province_data_path = None | |
| for config_path in province_config_paths: | |
| if config_path.exists(): | |
| self.province_data_path = config_path | |
| break | |
| if self.province_data_path is None: | |
| self.province_data_path = Path("deploy_huggingface/config/data_province.yaml") # Use default | |
| # Initialize models | |
| self.yolo_model = None | |
| self.license_plate_detector = None | |
| self._load_models() | |
| def _load_models(self): | |
| """Load all required models""" | |
| try: | |
| # YOLO vehicle detection model | |
| if self.yolo_model_path and self.yolo_model_path.exists(): | |
| self.yolo_model = YOLO(str(self.yolo_model_path)) | |
| logger.info("YOLO vehicle detection model loaded") | |
| else: | |
| logger.warning("YOLO vehicle detection model not found") | |
| # YOLO-based license plate detector | |
| self.license_plate_detector = YOLOLicensePlateDetector( | |
| detect_model_path=self.detect_model_path, | |
| char_model_path=self.char_model_path, | |
| province_model_path=self.province_model_path, | |
| data_path=self.data_path, | |
| province_data_path=self.province_data_path, | |
| device=self.device | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error loading models: {e}") | |
| print(f"Warning: Some models failed to load: {e}") | |
| def point_in_polygon(self, point, polygon): | |
| """Check if a point is inside a polygon""" | |
| x, y = point | |
| n = len(polygon) | |
| inside = False | |
| p1x, p1y = polygon[0] | |
| for i in range(1, n + 1): | |
| p2x, p2y = polygon[i % n] | |
| if y > min(p1y, p2y): | |
| if y <= max(p1y, p2y): | |
| if x <= max(p1x, p2x): | |
| if p1y != p2y: | |
| xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x | |
| if p1x == p2x or x <= xinters: | |
| inside = not inside | |
| p1x, p1y = p2x, p2y | |
| return inside | |
| def detect_objects_in_protection_area(self, image, protection_polygon): | |
| """Detect objects in the protection area""" | |
| results = [] | |
| if self.yolo_model is None: | |
| return results | |
| try: | |
| # Run YOLO detection | |
| detections = self.yolo_model(image, conf=0.25) | |
| for detection in detections: | |
| boxes = detection.boxes | |
| if boxes is not None: | |
| for box in boxes: | |
| # Get bounding box coordinates | |
| x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() | |
| center_x = (x1 + x2) / 2 | |
| center_y = (y1 + y2) / 2 | |
| # Check if center point is in protection area | |
| if self.point_in_polygon((center_x, center_y), protection_polygon): | |
| confidence = box.conf[0].cpu().numpy() | |
| class_id = int(box.cls[0].cpu().numpy()) | |
| class_name = detection.names[class_id] | |
| results.append({ | |
| 'bbox': [int(x1), int(y1), int(x2), int(y2)], | |
| 'confidence': float(confidence), | |
| 'class': class_name, | |
| 'center': [center_x, center_y] | |
| }) | |
| except Exception as e: | |
| logger.error(f"Object detection error: {e}") | |
| return results | |
| def detect_and_read_license_plate(self, vehicle_image): | |
| """Detect and read license plate from vehicle image using YOLO""" | |
| if self.license_plate_detector is None: | |
| return None, "Unknown", "Unknown" | |
| try: | |
| # Step 1: Detect license plate in vehicle image | |
| plate_detection = self.license_plate_detector.detect_license_plate(vehicle_image) | |
| if plate_detection is None: | |
| return None, "Unknown", "Unknown" | |
| plate_image = plate_detection['image'] | |
| # Step 2: Read characters from license plate | |
| characters = self.license_plate_detector.read_characters(plate_image) | |
| # Step 3: Assemble character text (exactly like original API) | |
| if characters: | |
| # Join characters directly (same as original API) | |
| char_text = ''.join([char['char'] for char in characters]) | |
| # Only show "Detected" if all characters are unknown | |
| if not char_text or char_text.replace('?', '') == '': | |
| char_text = "Detected" | |
| else: | |
| char_text = "Detected" # License plate detected but no characters read | |
| # Step 4: Detect province | |
| province = self.license_plate_detector.detect_province(plate_image) | |
| return plate_image, char_text, province | |
| except Exception as e: | |
| logger.error(f"License plate detection and reading error: {e}") | |
| return None, "Unknown", "Unknown" | |
| def process_image(self, image, protection_points): | |
| """Process the entire image for license plate detection""" | |
| results = { | |
| 'detected_objects': [], | |
| 'annotated_image': None, | |
| 'license_plates': [] | |
| } | |
| if len(protection_points) < 3: | |
| return results | |
| try: | |
| # Convert PIL to OpenCV format | |
| if isinstance(image, Image.Image): | |
| image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| else: | |
| image_cv = image | |
| # Detect objects in protection area | |
| detected_objects = self.detect_objects_in_protection_area(image_cv, protection_points) | |
| # Process each detected object (vehicle) | |
| for obj in detected_objects: | |
| # Crop vehicle image | |
| x1, y1, x2, y2 = obj['bbox'] | |
| vehicle_image = image_cv[y1:y2, x1:x2] | |
| # Detect and read license plate from vehicle | |
| plate_image, plate_text, province = self.detect_and_read_license_plate(vehicle_image) | |
| if plate_image is not None: | |
| obj['license_plate'] = { | |
| 'text': plate_text, | |
| 'province': province, | |
| 'image': plate_image | |
| } | |
| results['license_plates'].append({ | |
| 'text': plate_text, | |
| 'province': province, | |
| 'image': plate_image, | |
| 'bbox': obj['bbox'] | |
| }) | |
| results['detected_objects'].append(obj) | |
| # Create annotated image | |
| annotated_image = self.draw_annotations(image_cv, protection_points, results['detected_objects']) | |
| results['annotated_image'] = annotated_image | |
| except Exception as e: | |
| logger.error(f"Image processing error: {e}") | |
| return results | |
| def draw_annotations(self, image, protection_points, detected_objects): | |
| """Draw annotations on the image""" | |
| annotated = image.copy() | |
| # Draw protection zone | |
| if len(protection_points) >= 3: | |
| points = np.array(protection_points, np.int32) | |
| cv2.polylines(annotated, [points], True, (0, 255, 0), 3) | |
| # Fill with transparency | |
| overlay = annotated.copy() | |
| cv2.fillPoly(overlay, [points], (0, 255, 0)) | |
| cv2.addWeighted(overlay, 0.3, annotated, 0.7, 0, annotated) | |
| # Draw detected objects | |
| for obj in detected_objects: | |
| x1, y1, x2, y2 = obj['bbox'] | |
| # Draw bounding box | |
| cv2.rectangle(annotated, (x1, y1), (x2, y2), (255, 0, 0), 2) | |
| # Draw label | |
| label = f"{obj['class']}: {obj['confidence']:.2f}" | |
| if 'license_plate' in obj: | |
| label += f"\n{obj['license_plate']['text']}" | |
| label += f"\n{obj['license_plate']['province']}" | |
| cv2.putText(annotated, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2) | |
| return annotated | |
| class LicensePlateApp: | |
| """Gradio app for license plate detection""" | |
| def __init__(self): | |
| self.detector = LicensePlateDetector() | |
| self.protection_points = [] | |
| self.uploaded_image = None | |
| def clear_points(self): | |
| """Clear all protection zone points""" | |
| self.protection_points = [] | |
| return None, "Protection zone cleared. Upload an image and click to select new points." | |
| def add_point(self, image, evt: gr.SelectData): | |
| """Add a point to the protection zone when user clicks on image""" | |
| if image is None: | |
| return None, "Please upload an image first." | |
| x, y = evt.index[0], evt.index[1] | |
| self.protection_points.append([x, y]) | |
| # Draw the protection zone on the image | |
| img_with_zone = self.draw_protection_zone(image) | |
| status = f"Added point ({x}, {y}). Total points: {len(self.protection_points)}" | |
| if len(self.protection_points) >= 3: | |
| status += " (Ready to detect - you have enough points for a polygon)" | |
| return img_with_zone, status | |
| def draw_protection_zone(self, image): | |
| """Draw the protection zone on the image""" | |
| if len(self.protection_points) < 2: | |
| return image | |
| # Convert PIL to numpy array | |
| img_array = np.array(image) | |
| # Draw lines between consecutive points | |
| for i in range(len(self.protection_points)): | |
| start_point = tuple(self.protection_points[i]) | |
| end_point = tuple(self.protection_points[(i + 1) % len(self.protection_points)]) | |
| cv2.line(img_array, start_point, end_point, (0, 255, 0), 2) | |
| # Draw points | |
| for point in self.protection_points: | |
| cv2.circle(img_array, tuple(point), 5, (255, 0, 0), -1) | |
| # If we have 3+ points, draw a filled polygon with transparency | |
| if len(self.protection_points) >= 3: | |
| points = np.array(self.protection_points, np.int32) | |
| overlay = img_array.copy() | |
| cv2.fillPoly(overlay, [points], (0, 255, 0)) | |
| cv2.addWeighted(overlay, 0.3, img_array, 0.7, 0, img_array) | |
| return Image.fromarray(img_array) | |
| def detect_license_plates(self, image, confidence): | |
| """Process image for license plate detection""" | |
| if image is None: | |
| return None, [], "Please upload an image first." | |
| if len(self.protection_points) < 3: | |
| return None, [], "Please select at least 3 points to define a protection zone." | |
| try: | |
| # Process the image | |
| results = self.detector.process_image(image, self.protection_points) | |
| # Prepare results for display | |
| annotated_image = None | |
| if results['annotated_image'] is not None: | |
| annotated_image = Image.fromarray(cv2.cvtColor(results['annotated_image'], cv2.COLOR_BGR2RGB)) | |
| # Format license plates for gallery | |
| license_plates_gallery = [] | |
| summary_text = f""" | |
| π **Detection Results** | |
| π **Statistics:** | |
| - Objects detected in protection area: {len(results['detected_objects'])} | |
| - License plates found: {len(results['license_plates'])} | |
| π **Detected Objects:** | |
| """ | |
| for plate in results['license_plates']: | |
| if plate['image'] is not None: | |
| plate_pil = Image.fromarray(cv2.cvtColor(plate['image'], cv2.COLOR_BGR2RGB)) | |
| caption = f"License: {plate['text']}\nProvince: {plate['province']}" | |
| license_plates_gallery.append((plate_pil, caption)) | |
| summary_text += f""" | |
| - **Vehicle** (License Plate: {plate['text']}) | |
| - Province: {plate['province']} | |
| - Location: {plate['bbox']} | |
| """ | |
| if len(results['detected_objects']) == 0: | |
| summary_text += "\nNo objects detected in the protection zone." | |
| return annotated_image, license_plates_gallery, summary_text | |
| except Exception as e: | |
| error_msg = f"Error processing image: {str(e)}" | |
| logger.error(error_msg) | |
| return None, [], error_msg | |
| def create_gradio_interface(): | |
| """Create the Gradio interface""" | |
| app = LicensePlateApp() | |
| with gr.Blocks(title="π License Plate Detection System", theme=gr.themes.Soft()) as iface: | |
| gr.Markdown(""" | |
| # π License Plate Detection System | |
| AI-powered license plate detection and recognition for Thai vehicles | |
| ## How to use: | |
| 1. **Upload an image** with vehicles | |
| 2. **Click on the image** to select protection zone points (minimum 3 points) | |
| 3. **Adjust confidence** threshold if needed | |
| 4. **Click "Detect License Plates"** to run detection | |
| 5. **View results** including annotated image and detected license plates | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Input") | |
| # Image upload | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Upload Image", | |
| interactive=True | |
| ) | |
| # Confidence slider | |
| confidence_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.25, | |
| step=0.05, | |
| label="Confidence Threshold", | |
| info="Higher values = more strict detection" | |
| ) | |
| # Control buttons | |
| with gr.Row(): | |
| clear_btn = gr.Button("ποΈ Clear Protection Zone", variant="secondary") | |
| detect_btn = gr.Button("π Detect License Plates", variant="primary") | |
| # Status display | |
| status_text = gr.Textbox( | |
| label="Status", | |
| value="Upload an image and click to select protection zone points.", | |
| interactive=False, | |
| lines=3 | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π― Protection Zone Selection") | |
| gr.Markdown("Click on the image to add points for the protection zone (minimum 3 points)") | |
| # Image with protection zone | |
| zone_image = gr.Image( | |
| type="pil", | |
| label="Click to Select Protection Zone", | |
| interactive=False | |
| ) | |
| gr.Markdown("### π Results") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### πΌοΈ Annotated Detection") | |
| result_image = gr.Image( | |
| type="pil", | |
| label="Detection Results", | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### π Detection Summary") | |
| summary_text = gr.Markdown() | |
| gr.Markdown("#### π’ Detected License Plates") | |
| license_plates_gallery = gr.Gallery( | |
| label="License Plates Found", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=4, | |
| rows=2, | |
| object_fit="contain", | |
| height="auto" | |
| ) | |
| # Event handlers | |
| input_image.upload( | |
| fn=lambda img: (img, "Image uploaded. Click on the image to select protection zone points."), | |
| inputs=[input_image], | |
| outputs=[zone_image, status_text] | |
| ) | |
| zone_image.select( | |
| fn=app.add_point, | |
| inputs=[input_image], | |
| outputs=[zone_image, status_text] | |
| ) | |
| clear_btn.click( | |
| fn=app.clear_points, | |
| outputs=[zone_image, status_text] | |
| ) | |
| detect_btn.click( | |
| fn=app.detect_license_plates, | |
| inputs=[input_image, confidence_slider], | |
| outputs=[result_image, license_plates_gallery, summary_text] | |
| ) | |
| # Examples and instructions | |
| gr.Markdown("### π Instructions") | |
| gr.Markdown(""" | |
| **Step-by-step guide:** | |
| 1. **Upload Image**: Click "Upload Image" and select an image with vehicles | |
| 2. **Select Protection Zone**: | |
| - Click at least 3 points on the uploaded image to define a protection area | |
| - The area will be highlighted in green | |
| - You can click "Clear Protection Zone" to start over | |
| 3. **Adjust Settings**: Use the confidence slider to control detection sensitivity | |
| 4. **Run Detection**: Click "Detect License Plates" to process the image | |
| 5. **View Results**: | |
| - See the annotated image with detected objects | |
| - View individual license plate crops in the gallery | |
| - Read the detection summary | |
| **Tips:** | |
| - Select protection zones around areas where vehicles might pass | |
| - Higher confidence values will detect fewer but more certain objects | |
| - The protection zone should be a polygon (minimum 3 points) | |
| """) | |
| return iface | |
| if __name__ == "__main__": | |
| # Create and launch the interface | |
| iface = create_gradio_interface() | |
| iface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) |