Sompote's picture
Upload 15 files
2039756 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
import cv2
import numpy as np
from ultralytics import YOLO
import base64
import io
import yaml
from pathlib import Path
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class YOLOLicensePlateDetector:
"""YOLO-based license plate detector matching the original API"""
def __init__(self, detect_model_path, char_model_path, province_model_path, data_path, province_data_path, device):
self.device = device
# Load character mapping from data.yaml
self.char_mapping = {}
self.province_mapping = {}
self._load_mappings(data_path)
self._load_province_mappings(province_data_path)
# Load YOLO models
self.detect_model = None
self.char_model = None
self.province_model = None
if detect_model_path and Path(detect_model_path).exists():
self.detect_model = YOLO(str(detect_model_path))
logger.info(f"License plate detection model loaded: {detect_model_path}")
if char_model_path and Path(char_model_path).exists():
self.char_model = YOLO(str(char_model_path))
logger.info(f"Character reading model loaded: {char_model_path}")
if province_model_path and Path(province_model_path).exists():
self.province_model = YOLO(str(province_model_path))
logger.info(f"Province detection model loaded: {province_model_path}")
def _load_mappings(self, data_path):
"""Load character and province mappings from YAML"""
try:
if Path(data_path).exists():
with open(data_path, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
# Load character mapping - keep keys as strings!
self.char_mapping = data.get('char_mapping', {})
# Add digit mapping for class names "0"-"9"
for i in range(10):
class_name = str(i)
if class_name not in self.char_mapping:
self.char_mapping[class_name] = str(i)
logger.info(f"Loaded {len(self.char_mapping)} character mappings")
logger.info(f"Sample mappings: {dict(list(self.char_mapping.items())[:5])}")
else:
logger.warning(f"Data file not found: {data_path}")
# Default mappings
self.char_mapping = {str(i): str(i) for i in range(10)} # "0"-"9"
except Exception as e:
logger.error(f"Error loading mappings: {e}")
self.char_mapping = {str(i): str(i) for i in range(10)}
def _load_province_mappings(self, province_data_path):
"""Load province mappings from data_province.yaml (matching original API)"""
try:
if Path(province_data_path).exists():
with open(province_data_path, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
# Load province mapping from char_mapping section (like original API)
if 'char_mapping' in data:
self.province_mapping = data['char_mapping']
logger.info(f"βœ… Province mapping loaded from data_province.yaml")
logger.info(f" Loaded {len(self.province_mapping)} province mappings")
logger.info(f" Sample: {dict(list(self.province_mapping.items())[:3])}")
elif 'names' in data:
# Fallback: create mapping from names if no explicit mapping
self.province_mapping = {str(i): name for i, name in enumerate(data['names'])}
logger.info("βœ… Province mapping created from names")
logger.info(f" Created {len(self.province_mapping)} province mappings")
else:
self.province_mapping = {"0": "Unknown"}
logger.warning("No province mapping found in data_province.yaml")
else:
logger.warning(f"Province data file not found: {province_data_path}")
self.province_mapping = {"0": "Unknown"}
except Exception as e:
logger.error(f"Error loading province mappings: {e}")
self.province_mapping = {"0": "Unknown"}
def map_class_to_char(self, class_name):
"""Map YOLO class name to character (matching original API)"""
return self.char_mapping.get(str(class_name), '?')
def map_class_to_province(self, class_name):
"""Map YOLO class name to province (matching original API)"""
return self.province_mapping.get(str(class_name), "Unknown")
def detect_license_plate(self, vehicle_image):
"""Detect license plate in vehicle image using YOLO"""
if self.detect_model is None:
return None
try:
# Run license plate detection with confidence 0.3 (same as original API)
results = self.detect_model(vehicle_image, conf=0.3)
if not results or len(results) == 0:
return None
# Get the first (highest confidence) license plate detection
for result in results:
boxes = result.boxes
if boxes is not None and len(boxes) > 0:
# Get the highest confidence detection
best_box = boxes[0]
x1, y1, x2, y2 = best_box.xyxy[0].cpu().numpy().astype(int)
confidence = best_box.conf[0].cpu().numpy()
# Crop license plate region
if isinstance(vehicle_image, Image.Image):
vehicle_array = np.array(vehicle_image)
else:
vehicle_array = vehicle_image
license_plate = vehicle_array[y1:y2, x1:x2]
return {
'image': license_plate,
'bbox': [x1, y1, x2, y2],
'confidence': float(confidence)
}
return None
except Exception as e:
logger.error(f"License plate detection error: {e}")
return None
def read_characters(self, license_plate_image):
"""Read characters from license plate using YOLO (matching original API)"""
if self.char_model is None:
return []
try:
# Ensure image is in correct format
if isinstance(license_plate_image, Image.Image):
img_array = np.array(license_plate_image)
else:
img_array = license_plate_image
# Run character detection with confidence 0.3 (same as original API)
results = self.char_model(img_array, conf=0.3)
characters = []
for result in results:
boxes = result.boxes
if boxes is not None:
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
confidence = box.conf[0].cpu().numpy()
class_id = int(box.cls[0].cpu().numpy())
# Two-step mapping like original API:
# 1. Get class name from model
class_name = result.names[class_id]
# 2. Map class name to character
char = self.map_class_to_char(class_name)
characters.append({
'char': char,
'confidence': float(confidence),
'bbox': [float(x1), float(y1), float(x2), float(y2)],
'center_x': float((x1 + x2) / 2)
})
# Sort characters by x-position (left to right) - same as original API
characters.sort(key=lambda x: x['bbox'][0])
return characters
except Exception as e:
logger.error(f"Character reading error: {e}")
return []
def detect_province(self, license_plate_image):
"""Detect province from license plate"""
if self.province_model is None:
return "Unknown"
try:
# Ensure image is in correct format
if isinstance(license_plate_image, Image.Image):
img_array = np.array(license_plate_image)
else:
img_array = license_plate_image
# Run province detection with confidence 0.3 (same as original API)
results = self.province_model(img_array, conf=0.3)
for result in results:
boxes = result.boxes
if boxes is not None and len(boxes) > 0:
# Get highest confidence detection
best_box = boxes[0]
class_id = int(best_box.cls[0].cpu().numpy())
confidence = best_box.conf[0].cpu().numpy()
# Two-step mapping like original API:
# 1. Get class name from model
class_name = result.names[class_id]
# 2. Map class name to province
province = self.map_class_to_province(class_name)
return province
return "Unknown"
except Exception as e:
logger.error(f"Province detection error: {e}")
return "Unknown"
class LicensePlateDetector:
"""Main license plate detection system"""
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Model paths - try multiple locations
base_paths = [Path("models"), Path("../models"), Path("./")]
# Find YOLO models
self.yolo_model_path = None
self.segment_model_path = None
for base_dir in base_paths:
if (base_dir / "yolo11s.pt").exists():
self.yolo_model_path = base_dir / "yolo11s.pt"
break
elif (base_dir / "yolov9.pt").exists():
self.yolo_model_path = base_dir / "yolov9.pt"
break
for base_dir in base_paths:
if (base_dir / "best_segment.pt").exists():
self.segment_model_path = base_dir / "best_segment.pt"
break
# Find license plate detection model (detect1.pt)
self.detect_model_path = None
detect_model_names = ["detect1.pt"]
for base_dir in base_paths:
for model_name in detect_model_names:
if (base_dir / model_name).exists():
self.detect_model_path = base_dir / model_name
break
if self.detect_model_path:
break
# Find character reading model (read_char.pt)
self.char_model_path = None
char_model_names = ["read_char.pt"]
for base_dir in base_paths:
for model_name in char_model_names:
if (base_dir / model_name).exists():
self.char_model_path = base_dir / model_name
break
if self.char_model_path:
break
# Find province recognition model
self.province_model_path = None
province_model_names = ["best_province.pt"]
for base_dir in base_paths:
for model_name in province_model_names:
if (base_dir / model_name).exists():
self.province_model_path = base_dir / model_name
break
if self.province_model_path:
break
# Find data.yaml file (for character mapping)
config_paths = [
Path("deploy_huggingface/config/data.yaml"),
Path("config/data.yaml"),
Path("../config/data.yaml"),
Path("./data.yaml")
]
self.data_path = None
for config_path in config_paths:
if config_path.exists():
self.data_path = config_path
break
if self.data_path is None:
self.data_path = Path("deploy_huggingface/config/data.yaml") # Use default
# Find data_province.yaml file (for province mapping)
province_config_paths = [
Path("deploy_huggingface/config/data_province.yaml"),
Path("config/data_province.yaml"),
Path("../config/data_province.yaml"),
Path("./data_province.yaml")
]
self.province_data_path = None
for config_path in province_config_paths:
if config_path.exists():
self.province_data_path = config_path
break
if self.province_data_path is None:
self.province_data_path = Path("deploy_huggingface/config/data_province.yaml") # Use default
# Initialize models
self.yolo_model = None
self.license_plate_detector = None
self._load_models()
def _load_models(self):
"""Load all required models"""
try:
# YOLO vehicle detection model
if self.yolo_model_path and self.yolo_model_path.exists():
self.yolo_model = YOLO(str(self.yolo_model_path))
logger.info("YOLO vehicle detection model loaded")
else:
logger.warning("YOLO vehicle detection model not found")
# YOLO-based license plate detector
self.license_plate_detector = YOLOLicensePlateDetector(
detect_model_path=self.detect_model_path,
char_model_path=self.char_model_path,
province_model_path=self.province_model_path,
data_path=self.data_path,
province_data_path=self.province_data_path,
device=self.device
)
except Exception as e:
logger.error(f"Error loading models: {e}")
print(f"Warning: Some models failed to load: {e}")
def point_in_polygon(self, point, polygon):
"""Check if a point is inside a polygon"""
x, y = point
n = len(polygon)
inside = False
p1x, p1y = polygon[0]
for i in range(1, n + 1):
p2x, p2y = polygon[i % n]
if y > min(p1y, p2y):
if y <= max(p1y, p2y):
if x <= max(p1x, p2x):
if p1y != p2y:
xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
if p1x == p2x or x <= xinters:
inside = not inside
p1x, p1y = p2x, p2y
return inside
def detect_objects_in_protection_area(self, image, protection_polygon):
"""Detect objects in the protection area"""
results = []
if self.yolo_model is None:
return results
try:
# Run YOLO detection
detections = self.yolo_model(image, conf=0.25)
for detection in detections:
boxes = detection.boxes
if boxes is not None:
for box in boxes:
# Get bounding box coordinates
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2
# Check if center point is in protection area
if self.point_in_polygon((center_x, center_y), protection_polygon):
confidence = box.conf[0].cpu().numpy()
class_id = int(box.cls[0].cpu().numpy())
class_name = detection.names[class_id]
results.append({
'bbox': [int(x1), int(y1), int(x2), int(y2)],
'confidence': float(confidence),
'class': class_name,
'center': [center_x, center_y]
})
except Exception as e:
logger.error(f"Object detection error: {e}")
return results
def detect_and_read_license_plate(self, vehicle_image):
"""Detect and read license plate from vehicle image using YOLO"""
if self.license_plate_detector is None:
return None, "Unknown", "Unknown"
try:
# Step 1: Detect license plate in vehicle image
plate_detection = self.license_plate_detector.detect_license_plate(vehicle_image)
if plate_detection is None:
return None, "Unknown", "Unknown"
plate_image = plate_detection['image']
# Step 2: Read characters from license plate
characters = self.license_plate_detector.read_characters(plate_image)
# Step 3: Assemble character text (exactly like original API)
if characters:
# Join characters directly (same as original API)
char_text = ''.join([char['char'] for char in characters])
# Only show "Detected" if all characters are unknown
if not char_text or char_text.replace('?', '') == '':
char_text = "Detected"
else:
char_text = "Detected" # License plate detected but no characters read
# Step 4: Detect province
province = self.license_plate_detector.detect_province(plate_image)
return plate_image, char_text, province
except Exception as e:
logger.error(f"License plate detection and reading error: {e}")
return None, "Unknown", "Unknown"
def process_image(self, image, protection_points):
"""Process the entire image for license plate detection"""
results = {
'detected_objects': [],
'annotated_image': None,
'license_plates': []
}
if len(protection_points) < 3:
return results
try:
# Convert PIL to OpenCV format
if isinstance(image, Image.Image):
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
else:
image_cv = image
# Detect objects in protection area
detected_objects = self.detect_objects_in_protection_area(image_cv, protection_points)
# Process each detected object (vehicle)
for obj in detected_objects:
# Crop vehicle image
x1, y1, x2, y2 = obj['bbox']
vehicle_image = image_cv[y1:y2, x1:x2]
# Detect and read license plate from vehicle
plate_image, plate_text, province = self.detect_and_read_license_plate(vehicle_image)
if plate_image is not None:
obj['license_plate'] = {
'text': plate_text,
'province': province,
'image': plate_image
}
results['license_plates'].append({
'text': plate_text,
'province': province,
'image': plate_image,
'bbox': obj['bbox']
})
results['detected_objects'].append(obj)
# Create annotated image
annotated_image = self.draw_annotations(image_cv, protection_points, results['detected_objects'])
results['annotated_image'] = annotated_image
except Exception as e:
logger.error(f"Image processing error: {e}")
return results
def draw_annotations(self, image, protection_points, detected_objects):
"""Draw annotations on the image"""
annotated = image.copy()
# Draw protection zone
if len(protection_points) >= 3:
points = np.array(protection_points, np.int32)
cv2.polylines(annotated, [points], True, (0, 255, 0), 3)
# Fill with transparency
overlay = annotated.copy()
cv2.fillPoly(overlay, [points], (0, 255, 0))
cv2.addWeighted(overlay, 0.3, annotated, 0.7, 0, annotated)
# Draw detected objects
for obj in detected_objects:
x1, y1, x2, y2 = obj['bbox']
# Draw bounding box
cv2.rectangle(annotated, (x1, y1), (x2, y2), (255, 0, 0), 2)
# Draw label
label = f"{obj['class']}: {obj['confidence']:.2f}"
if 'license_plate' in obj:
label += f"\n{obj['license_plate']['text']}"
label += f"\n{obj['license_plate']['province']}"
cv2.putText(annotated, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
return annotated
class LicensePlateApp:
"""Gradio app for license plate detection"""
def __init__(self):
self.detector = LicensePlateDetector()
self.protection_points = []
self.uploaded_image = None
def clear_points(self):
"""Clear all protection zone points"""
self.protection_points = []
return None, "Protection zone cleared. Upload an image and click to select new points."
def add_point(self, image, evt: gr.SelectData):
"""Add a point to the protection zone when user clicks on image"""
if image is None:
return None, "Please upload an image first."
x, y = evt.index[0], evt.index[1]
self.protection_points.append([x, y])
# Draw the protection zone on the image
img_with_zone = self.draw_protection_zone(image)
status = f"Added point ({x}, {y}). Total points: {len(self.protection_points)}"
if len(self.protection_points) >= 3:
status += " (Ready to detect - you have enough points for a polygon)"
return img_with_zone, status
def draw_protection_zone(self, image):
"""Draw the protection zone on the image"""
if len(self.protection_points) < 2:
return image
# Convert PIL to numpy array
img_array = np.array(image)
# Draw lines between consecutive points
for i in range(len(self.protection_points)):
start_point = tuple(self.protection_points[i])
end_point = tuple(self.protection_points[(i + 1) % len(self.protection_points)])
cv2.line(img_array, start_point, end_point, (0, 255, 0), 2)
# Draw points
for point in self.protection_points:
cv2.circle(img_array, tuple(point), 5, (255, 0, 0), -1)
# If we have 3+ points, draw a filled polygon with transparency
if len(self.protection_points) >= 3:
points = np.array(self.protection_points, np.int32)
overlay = img_array.copy()
cv2.fillPoly(overlay, [points], (0, 255, 0))
cv2.addWeighted(overlay, 0.3, img_array, 0.7, 0, img_array)
return Image.fromarray(img_array)
def detect_license_plates(self, image, confidence):
"""Process image for license plate detection"""
if image is None:
return None, [], "Please upload an image first."
if len(self.protection_points) < 3:
return None, [], "Please select at least 3 points to define a protection zone."
try:
# Process the image
results = self.detector.process_image(image, self.protection_points)
# Prepare results for display
annotated_image = None
if results['annotated_image'] is not None:
annotated_image = Image.fromarray(cv2.cvtColor(results['annotated_image'], cv2.COLOR_BGR2RGB))
# Format license plates for gallery
license_plates_gallery = []
summary_text = f"""
πŸ” **Detection Results**
πŸ“Š **Statistics:**
- Objects detected in protection area: {len(results['detected_objects'])}
- License plates found: {len(results['license_plates'])}
πŸš— **Detected Objects:**
"""
for plate in results['license_plates']:
if plate['image'] is not None:
plate_pil = Image.fromarray(cv2.cvtColor(plate['image'], cv2.COLOR_BGR2RGB))
caption = f"License: {plate['text']}\nProvince: {plate['province']}"
license_plates_gallery.append((plate_pil, caption))
summary_text += f"""
- **Vehicle** (License Plate: {plate['text']})
- Province: {plate['province']}
- Location: {plate['bbox']}
"""
if len(results['detected_objects']) == 0:
summary_text += "\nNo objects detected in the protection zone."
return annotated_image, license_plates_gallery, summary_text
except Exception as e:
error_msg = f"Error processing image: {str(e)}"
logger.error(error_msg)
return None, [], error_msg
def create_gradio_interface():
"""Create the Gradio interface"""
app = LicensePlateApp()
with gr.Blocks(title="πŸš— License Plate Detection System", theme=gr.themes.Soft()) as iface:
gr.Markdown("""
# πŸš— License Plate Detection System
AI-powered license plate detection and recognition for Thai vehicles
## How to use:
1. **Upload an image** with vehicles
2. **Click on the image** to select protection zone points (minimum 3 points)
3. **Adjust confidence** threshold if needed
4. **Click "Detect License Plates"** to run detection
5. **View results** including annotated image and detected license plates
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Input")
# Image upload
input_image = gr.Image(
type="pil",
label="Upload Image",
interactive=True
)
# Confidence slider
confidence_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.25,
step=0.05,
label="Confidence Threshold",
info="Higher values = more strict detection"
)
# Control buttons
with gr.Row():
clear_btn = gr.Button("πŸ—‘οΈ Clear Protection Zone", variant="secondary")
detect_btn = gr.Button("πŸ” Detect License Plates", variant="primary")
# Status display
status_text = gr.Textbox(
label="Status",
value="Upload an image and click to select protection zone points.",
interactive=False,
lines=3
)
with gr.Column(scale=2):
gr.Markdown("### 🎯 Protection Zone Selection")
gr.Markdown("Click on the image to add points for the protection zone (minimum 3 points)")
# Image with protection zone
zone_image = gr.Image(
type="pil",
label="Click to Select Protection Zone",
interactive=False
)
gr.Markdown("### πŸ“Š Results")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### πŸ–ΌοΈ Annotated Detection")
result_image = gr.Image(
type="pil",
label="Detection Results",
interactive=False
)
with gr.Column(scale=1):
gr.Markdown("#### πŸ“‹ Detection Summary")
summary_text = gr.Markdown()
gr.Markdown("#### πŸ”’ Detected License Plates")
license_plates_gallery = gr.Gallery(
label="License Plates Found",
show_label=True,
elem_id="gallery",
columns=4,
rows=2,
object_fit="contain",
height="auto"
)
# Event handlers
input_image.upload(
fn=lambda img: (img, "Image uploaded. Click on the image to select protection zone points."),
inputs=[input_image],
outputs=[zone_image, status_text]
)
zone_image.select(
fn=app.add_point,
inputs=[input_image],
outputs=[zone_image, status_text]
)
clear_btn.click(
fn=app.clear_points,
outputs=[zone_image, status_text]
)
detect_btn.click(
fn=app.detect_license_plates,
inputs=[input_image, confidence_slider],
outputs=[result_image, license_plates_gallery, summary_text]
)
# Examples and instructions
gr.Markdown("### πŸ“– Instructions")
gr.Markdown("""
**Step-by-step guide:**
1. **Upload Image**: Click "Upload Image" and select an image with vehicles
2. **Select Protection Zone**:
- Click at least 3 points on the uploaded image to define a protection area
- The area will be highlighted in green
- You can click "Clear Protection Zone" to start over
3. **Adjust Settings**: Use the confidence slider to control detection sensitivity
4. **Run Detection**: Click "Detect License Plates" to process the image
5. **View Results**:
- See the annotated image with detected objects
- View individual license plate crops in the gallery
- Read the detection summary
**Tips:**
- Select protection zones around areas where vehicles might pass
- Higher confidence values will detect fewer but more certain objects
- The protection zone should be a polygon (minimum 3 points)
""")
return iface
if __name__ == "__main__":
# Create and launch the interface
iface = create_gradio_interface()
iface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)