dnth's picture
Update app.py
3cb6271 verified
import colorsys
import os
import gradio as gr
import numpy as np
import onnxruntime as ort
import pandas as pd
from PIL import Image, ImageDraw
# Use absolute paths instead of relative paths
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = os.path.join(BASE_DIR, "models/deim-blood-cell-detection_small.onnx")
CLASS_NAMES_PATH = os.path.join(BASE_DIR, "models/classes.txt")
def resize_with_aspect_ratio(image, size, interpolation=Image.BILINEAR):
"""Resizes an image while maintaining aspect ratio and pads it."""
original_width, original_height = image.size
ratio = min(size / original_width, size / original_height)
new_width = int(original_width * ratio)
new_height = int(original_height * ratio)
image = image.resize((new_width, new_height), interpolation)
# Create a new image with the desired size and paste the resized image onto it
new_image = Image.new("RGB", (size, size))
new_image.paste(image, ((size - new_width) // 2, (size - new_height) // 2))
return new_image, ratio, (size - new_width) // 2, (size - new_height) // 2
def generate_colors(num_classes):
"""Generate a list of distinct colors for different classes."""
# Generate evenly spaced hues
hsv_tuples = [(x / num_classes, 0.8, 0.9) for x in range(num_classes)]
# Convert to RGB
colors = []
for hsv in hsv_tuples:
rgb = colorsys.hsv_to_rgb(*hsv)
# Convert to 0-255 range and to tuple
colors.append(tuple(int(255 * x) for x in rgb))
return colors
def draw(images, labels, boxes, scores, ratios, paddings, thrh=0.4, class_names=None):
result_images = []
# Generate colors for classes
num_classes = (
len(class_names) if class_names else 91
) # Use length of class_names if available, otherwise default to COCO's 91 classes
colors = generate_colors(num_classes)
for i, im in enumerate(images):
draw = ImageDraw.Draw(im)
scr = scores[i]
lab = labels[i][scr > thrh]
box = boxes[i][scr > thrh]
scr = scr[scr > thrh]
ratio = ratios[i]
pad_w, pad_h = paddings[i]
for lbl, bb in zip(lab, box):
# Get color for this class
class_idx = int(lbl)
color = colors[class_idx % len(colors)]
# Convert RGB to hex for PIL
hex_color = "#{:02x}{:02x}{:02x}".format(*color)
# Adjust bounding boxes according to the resizing and padding
bb = [
(bb[0] - pad_w) / ratio,
(bb[1] - pad_h) / ratio,
(bb[2] - pad_w) / ratio,
(bb[3] - pad_h) / ratio,
]
# Draw rectangle with class-specific color
draw.rectangle(bb, outline=hex_color, width=3)
# Use class name if available, otherwise use class index
if class_names and class_idx < len(class_names):
label_text = f"{class_names[class_idx]} {scr[lab == lbl][0]:.2f}"
else:
label_text = f"Class {class_idx} {scr[lab == lbl][0]:.2f}"
# Draw text background
text_size = draw.textbbox((0, 0), label_text, font=None)
text_width = text_size[2] - text_size[0]
text_height = text_size[3] - text_size[1]
# Draw text background rectangle
draw.rectangle(
[bb[0], bb[1] - text_height - 4, bb[0] + text_width + 4, bb[1]],
fill=hex_color,
)
# Draw text in white or black depending on color brightness
brightness = (color[0] * 299 + color[1] * 587 + color[2] * 114) / 1000
text_color = "black" if brightness > 128 else "white"
# Draw text
draw.text(
(bb[0] + 2, bb[1] - text_height - 2), text=label_text, fill=text_color
)
result_images.append(im)
return result_images
def load_model(model_path):
"""
Load an ONNX model for inference.
Args:
model_path: Path to the ONNX model file
Returns:
tuple: (session, error_message)
"""
providers = ["CPUExecutionProvider"]
try:
# Print the model path to debug
print(f"Loading model from: {model_path}")
if not os.path.exists(model_path):
return None, f"Model file not found at: {model_path}"
sess = ort.InferenceSession(model_path, providers=providers)
print(f"Using device: {ort.get_device()}")
return sess, None
except Exception as e:
return None, f"Error creating inference session: {e}"
def load_class_names(class_names_path):
"""
Load class names from a text file.
Args:
class_names_path: Path to a text file with class names (one per line)
Returns:
list: Class names or None if loading failed
"""
if not class_names_path or not os.path.exists(class_names_path):
return None
try:
with open(class_names_path, "r") as f:
class_names = [line.strip() for line in f.readlines()]
print(f"Loaded {len(class_names)} class names")
return class_names
except Exception as e:
print(f"Error loading class names: {e}")
return None
def prepare_image(image):
"""
Prepare image for inference by converting to PIL and resizing.
Args:
image: Input image (PIL or numpy array)
Returns:
tuple: (resized_image, original_image, ratio, padding)
"""
# Convert to PIL image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image).convert("RGB")
# Resize image while preserving aspect ratio
resized_image, ratio, pad_w, pad_h = resize_with_aspect_ratio(image, 640)
return resized_image, image, ratio, (pad_w, pad_h)
def run_inference(session, image):
"""
Run inference on the prepared image.
Args:
session: ONNX runtime session
image: Prepared PIL image
Returns:
tuple: (labels, boxes, scores)
"""
# Get original image dimensions
orig_height, orig_width = image.size[1], image.size[0]
# Convert to int64 as expected by the model
orig_size = np.array([[orig_height, orig_width]], dtype=np.int64)
# Convert PIL image to numpy array and normalize to 0-1 range
im_data = np.array(image, dtype=np.float32) / 255.0
# Transpose from HWC to CHW format
im_data = im_data.transpose(2, 0, 1)
# Add batch dimension
im_data = np.expand_dims(im_data, axis=0)
output = session.run(
output_names=None,
input_feed={"images": im_data, "orig_target_sizes": orig_size},
)
return output # labels, boxes, scores
def count_objects(labels, scores, confidence_threshold, class_names):
"""
Count detected objects by class.
Args:
labels: Detection labels
scores: Detection confidence scores
confidence_threshold: Minimum confidence threshold
class_names: List of class names
Returns:
dict: Counts of objects by class
"""
object_counts = {}
for i, score_batch in enumerate(scores):
for j, score in enumerate(score_batch):
if score >= confidence_threshold:
label = labels[i][j]
class_name = (
class_names[label]
if class_names and label < len(class_names)
else f"Class {label}"
)
object_counts[class_name] = object_counts.get(class_name, 0) + 1
return object_counts
def create_status_message(object_counts):
"""
Create a status message with object counts.
Args:
object_counts: Dictionary of object counts by class
Returns:
str: Formatted status message
"""
status_message = "Detection completed successfully\n\nObjects detected:"
if object_counts:
for class_name, count in object_counts.items():
status_message += f"\n- {class_name}: {count}"
else:
status_message += "\n- No objects detected above confidence threshold"
return status_message
def create_bar_data(object_counts):
"""
Create data for the bar plot visualization.
Args:
object_counts: Dictionary of object counts by class
Returns:
DataFrame: Data for bar plot
"""
if object_counts:
# Sort by count in descending order
sorted_counts = sorted(object_counts.items(), key=lambda x: x[1], reverse=True)
class_names_list = [item[0] for item in sorted_counts]
counts_list = [item[1] for item in sorted_counts]
# Create a pandas DataFrame for the bar plot
return pd.DataFrame({"Class": class_names_list, "Count": counts_list})
else:
return pd.DataFrame({"Class": ["No objects detected"], "Count": [0]})
def predict(image, model_path, class_names_path, confidence_threshold):
"""
Main prediction function that orchestrates the detection pipeline.
Args:
image: Input image
model_path: Path to ONNX model
class_names_path: Path to class names file
confidence_threshold: Detection confidence threshold
Returns:
tuple: (result_image, status_message, bar_data)
"""
# Load model
session, error = load_model(model_path)
if error:
return None, error, None
# Load class names
class_names = load_class_names(class_names_path)
try:
# Prepare image
resized_image, original_image, ratio, padding = prepare_image(image)
# Run inference
labels, boxes, scores = run_inference(session, resized_image)
# Draw detections on the original image
result_images = draw(
[original_image],
labels,
boxes,
scores,
[ratio],
[padding],
thrh=confidence_threshold,
class_names=class_names,
)
# Count objects by class
object_counts = count_objects(labels, scores, confidence_threshold, class_names)
# Create status message
status_message = create_status_message(object_counts)
# Create bar plot data
bar_data = create_bar_data(object_counts)
return result_images[0], status_message, bar_data
except Exception as e:
return None, f"Error during inference: {e}", None
def build_interface(model_path, class_names_path, example_images=None):
"""
Build the Gradio interface components.
Args:
model_path: Path to the ONNX model
class_names_path: Path to the class names file
example_images: List of example image paths
Returns:
gr.Blocks: The Gradio demo interface
"""
with gr.Blocks(title="Blood Cell Detection") as demo:
gr.Markdown("# Blood Cell Detection")
gr.Markdown("Upload an image to detect blood cells. The model can detect 3 types of blood cells: red blood cells, white blood cells and platelets.")
gr.Markdown("Model is trained using DEIM-D-FINE model S.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
confidence = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.4,
step=0.05,
label="Confidence Threshold",
)
submit_btn = gr.Button("Count Cells!", variant="primary")
with gr.Column():
output_image = gr.Image(type="pil", label="Detection Result")
with gr.Row(equal_height=True):
output_message = gr.Textbox(label="Status")
count_plot = gr.BarPlot(
x="Class",
y="Count",
title="Object Counts",
tooltip=["Class", "Count"],
height=300,
orientation="h",
label_title="Object Counts",
)
# Add examples component if example images are provided
if example_images:
gr.Examples(
examples=example_images,
inputs=input_image,
)
# Set up the click event inside the Blocks context
submit_btn.click(
fn=predict,
inputs=[
input_image,
gr.State(model_path),
gr.State(class_names_path),
confidence,
],
outputs=[output_image, output_message, count_plot],
)
with gr.Row():
with gr.Column():
gr.HTML("<div style='text-align: center; margin: 0 auto;'>Created by <a href='https://dicksonneoh.com' target='_blank'>Dickson Neoh</a>.</div>")
return demo
def launch_demo():
"""
Launch the Gradio demo with hardcoded model and class names paths.
"""
# Create examples directory if it doesn't exist
examples_dir = os.path.join(BASE_DIR, "examples")
if not os.path.exists(examples_dir):
os.makedirs(examples_dir)
print(f"Created examples directory at {examples_dir}")
# Get list of example images
example_images = []
if os.path.exists(examples_dir):
example_images = [
os.path.join(examples_dir, f)
for f in os.listdir(examples_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
]
print(f"Found {len(example_images)} example images")
demo = build_interface(MODEL_PATH, CLASS_NAMES_PATH, example_images)
# Launch the demo without the examples parameter
demo.launch(share=False) # Set share=True if you want to create a shareable link
if __name__ == "__main__":
launch_demo()