|
|
|
|
|
"""BirdNET Real-Time Audio Classification Script |
|
|
|
|
|
This script captures audio from the microphone and uses the BirdNET ONNX model |
|
|
to predict bird species in real-time with continuous display updates. |
|
|
|
|
|
Created using Copilot. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import numpy as np |
|
|
import sounddevice as sd |
|
|
import onnxruntime as ort |
|
|
import argparse |
|
|
import os |
|
|
import time |
|
|
import threading |
|
|
from collections import deque |
|
|
from datetime import datetime |
|
|
import queue |
|
|
|
|
|
|
|
|
class RealTimeBirdDetector: |
|
|
"""Real-time bird detection using microphone input.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: str = "model.onnx", |
|
|
labels_path: str = "BirdNET_GLOBAL_6K_V2.4_Labels.txt", |
|
|
sample_rate: int = 48000, |
|
|
window_duration: float = 3.0, |
|
|
confidence_threshold: float = 0.1, |
|
|
top_k: int = 5, |
|
|
update_interval: float = 1.0, |
|
|
): |
|
|
""" |
|
|
Initialize the real-time bird detector. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the ONNX model file |
|
|
labels_path: Path to the species labels file |
|
|
sample_rate: Audio sample rate (48kHz for BirdNET) |
|
|
window_duration: Duration of each analysis window in seconds |
|
|
confidence_threshold: Minimum confidence for detections |
|
|
top_k: Number of top predictions to display |
|
|
update_interval: How often to update predictions (seconds) |
|
|
""" |
|
|
self.model_path = model_path |
|
|
self.labels_path = labels_path |
|
|
self.sample_rate = sample_rate |
|
|
self.window_duration = window_duration |
|
|
self.window_size = int(sample_rate * window_duration) |
|
|
self.confidence_threshold = confidence_threshold |
|
|
self.top_k = top_k |
|
|
self.update_interval = update_interval |
|
|
|
|
|
|
|
|
self.audio_buffer = deque(maxlen=self.window_size * 2) |
|
|
self.audio_queue = queue.Queue() |
|
|
|
|
|
|
|
|
self.current_detections = [] |
|
|
self.detection_history = deque(maxlen=100) |
|
|
self.running = False |
|
|
|
|
|
|
|
|
self._load_model() |
|
|
self._load_labels() |
|
|
|
|
|
def _load_model(self) -> None: |
|
|
"""Load the ONNX model.""" |
|
|
try: |
|
|
print(f"Loading ONNX model: {self.model_path}") |
|
|
self.session = ort.InferenceSession(self.model_path) |
|
|
|
|
|
|
|
|
input_info = self.session.get_inputs()[0] |
|
|
output_info = self.session.get_outputs()[0] |
|
|
print(f"Model input: {input_info.name}, shape: {input_info.shape}") |
|
|
print(f"Model output: {output_info.name}, shape: {output_info.shape}") |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error loading ONNX model {self.model_path}: {str(e)}") |
|
|
|
|
|
def _load_labels(self) -> None: |
|
|
"""Load species labels from file.""" |
|
|
try: |
|
|
print(f"Loading labels from: {self.labels_path}") |
|
|
self.labels = [] |
|
|
with open(self.labels_path, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
|
|
|
if "_" in line: |
|
|
common_name = line.split("_", 1)[1] |
|
|
self.labels.append(common_name) |
|
|
else: |
|
|
self.labels.append(line) |
|
|
print(f"Loaded {len(self.labels)} species labels") |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Error loading labels file {self.labels_path}: {str(e)}" |
|
|
) |
|
|
|
|
|
def _audio_callback( |
|
|
self, indata: np.ndarray, frames: int, time_info, status |
|
|
) -> None: |
|
|
"""Callback function for audio input.""" |
|
|
if status: |
|
|
print(f"Audio status: {status}") |
|
|
|
|
|
|
|
|
if len(indata.shape) > 1: |
|
|
audio_data = np.mean(indata, axis=1) |
|
|
else: |
|
|
audio_data = indata.flatten() |
|
|
|
|
|
|
|
|
self.audio_queue.put(audio_data.copy()) |
|
|
|
|
|
def _process_audio_buffer(self) -> None: |
|
|
"""Process audio data from the queue.""" |
|
|
while self.running: |
|
|
try: |
|
|
|
|
|
audio_chunk = self.audio_queue.get(timeout=0.1) |
|
|
|
|
|
|
|
|
self.audio_buffer.extend(audio_chunk) |
|
|
|
|
|
|
|
|
if len(self.audio_buffer) >= self.window_size: |
|
|
|
|
|
window_data = np.array(list(self.audio_buffer)[-self.window_size :]) |
|
|
|
|
|
|
|
|
self._analyze_audio_window(window_data) |
|
|
|
|
|
except queue.Empty: |
|
|
continue |
|
|
except Exception as e: |
|
|
print(f"Error processing audio: {e}") |
|
|
|
|
|
def _analyze_audio_window(self, audio_data: np.ndarray) -> None: |
|
|
"""Analyze a single audio window.""" |
|
|
try: |
|
|
|
|
|
audio_data = audio_data.astype(np.float32) |
|
|
|
|
|
|
|
|
input_data = np.expand_dims(audio_data, axis=0) |
|
|
|
|
|
|
|
|
input_name = self.session.get_inputs()[0].name |
|
|
|
|
|
|
|
|
outputs = self.session.run(None, {input_name: input_data}) |
|
|
predictions = outputs[0] |
|
|
|
|
|
|
|
|
predictions = np.array(predictions) |
|
|
if len(predictions.shape) > 1: |
|
|
scores = predictions[0] |
|
|
else: |
|
|
scores = predictions |
|
|
|
|
|
|
|
|
above_threshold = np.where(scores > self.confidence_threshold)[0] |
|
|
|
|
|
|
|
|
detections = [] |
|
|
for idx in above_threshold: |
|
|
confidence = float(scores[idx]) |
|
|
species_name = ( |
|
|
self.labels[idx] if idx < len(self.labels) else f"Class {idx}" |
|
|
) |
|
|
detections.append( |
|
|
{ |
|
|
"species": species_name, |
|
|
"confidence": confidence, |
|
|
"timestamp": datetime.now(), |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
detections.sort(key=lambda x: x["confidence"], reverse=True) |
|
|
|
|
|
|
|
|
self.current_detections = detections[: self.top_k] |
|
|
|
|
|
|
|
|
if detections: |
|
|
self.detection_history.extend(detections[: self.top_k]) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during inference: {e}") |
|
|
|
|
|
def _display_results(self) -> None: |
|
|
"""Continuously display detection results.""" |
|
|
while self.running: |
|
|
try: |
|
|
|
|
|
os.system("clear" if os.name == "posix" else "cls") |
|
|
|
|
|
|
|
|
print("🎤 BirdNET Real-Time Detection") |
|
|
print("=" * 50) |
|
|
print(f"Listening... (Confidence > {self.confidence_threshold:.2f})") |
|
|
print(f"Time: {datetime.now().strftime('%H:%M:%S')}") |
|
|
print() |
|
|
|
|
|
|
|
|
if self.current_detections: |
|
|
print( |
|
|
f"🐦 Current Detections (Top {len(self.current_detections)}):" |
|
|
) |
|
|
print("-" * 40) |
|
|
for i, detection in enumerate(self.current_detections, 1): |
|
|
confidence = detection["confidence"] |
|
|
species = detection["species"] |
|
|
|
|
|
bar_length = int(confidence * 20) |
|
|
bar = "█" * bar_length + "░" * (20 - bar_length) |
|
|
print(f"{i:2d}. {species}") |
|
|
print(f" {bar} {confidence:.4f}") |
|
|
else: |
|
|
print("🔍 No detections above threshold...") |
|
|
|
|
|
print() |
|
|
|
|
|
|
|
|
if self.detection_history: |
|
|
print("📊 Recent Activity (Last 10):") |
|
|
print("-" * 40) |
|
|
recent = list(self.detection_history)[-10:] |
|
|
for detection in reversed(recent): |
|
|
timestamp = detection["timestamp"].strftime("%H:%M:%S") |
|
|
species = detection["species"] |
|
|
confidence = detection["confidence"] |
|
|
print(f"{timestamp} - {species} ({confidence:.3f})") |
|
|
|
|
|
print() |
|
|
print("Press Ctrl+C to stop") |
|
|
|
|
|
|
|
|
time.sleep(self.update_interval) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"Display error: {e}") |
|
|
|
|
|
def start_detection(self) -> None: |
|
|
"""Start real-time detection.""" |
|
|
try: |
|
|
print("Starting real-time bird detection...") |
|
|
print(f"Sample rate: {self.sample_rate} Hz") |
|
|
print(f"Window size: {self.window_duration} seconds") |
|
|
print(f"Confidence threshold: {self.confidence_threshold}") |
|
|
print("Press Ctrl+C to stop\n") |
|
|
|
|
|
self.running = True |
|
|
|
|
|
|
|
|
audio_thread = threading.Thread( |
|
|
target=self._process_audio_buffer, daemon=True |
|
|
) |
|
|
audio_thread.start() |
|
|
|
|
|
|
|
|
display_thread = threading.Thread(target=self._display_results, daemon=True) |
|
|
display_thread.start() |
|
|
|
|
|
|
|
|
with sd.InputStream( |
|
|
callback=self._audio_callback, |
|
|
channels=1, |
|
|
samplerate=self.sample_rate, |
|
|
blocksize=int(self.sample_rate * 0.1), |
|
|
dtype=np.float32, |
|
|
): |
|
|
print("🎤 Microphone active - listening for birds...") |
|
|
|
|
|
|
|
|
try: |
|
|
while self.running: |
|
|
time.sleep(0.1) |
|
|
except KeyboardInterrupt: |
|
|
pass |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during detection: {e}") |
|
|
finally: |
|
|
self.running = False |
|
|
print("\n🛑 Detection stopped.") |
|
|
|
|
|
def stop_detection(self) -> None: |
|
|
"""Stop detection.""" |
|
|
self.running = False |
|
|
|
|
|
|
|
|
def main() -> int: |
|
|
"""Main function for real-time detection.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="BirdNET Real-Time Audio Classification" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model", default="model.onnx", help="Path to the ONNX model file" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--labels", |
|
|
default="BirdNET_GLOBAL_6K_V2.4_Labels.txt", |
|
|
help="Path to the labels file", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--confidence", |
|
|
type=float, |
|
|
default=0.1, |
|
|
help="Minimum confidence threshold for detections (default: 0.1)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--top-k", |
|
|
type=int, |
|
|
default=5, |
|
|
help="Number of top predictions to show (default: 5)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--update-interval", |
|
|
type=float, |
|
|
default=1.0, |
|
|
help="Display update interval in seconds (default: 1.0)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--list-devices", action="store_true", help="List available audio input devices" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.list_devices: |
|
|
print("Available audio input devices:") |
|
|
print(sd.query_devices()) |
|
|
return 0 |
|
|
|
|
|
|
|
|
if not os.path.exists(args.model): |
|
|
print(f"Error: Model file '{args.model}' not found.") |
|
|
return 1 |
|
|
|
|
|
if not os.path.exists(args.labels): |
|
|
print(f"Error: Labels file '{args.labels}' not found.") |
|
|
return 1 |
|
|
|
|
|
try: |
|
|
|
|
|
detector = RealTimeBirdDetector( |
|
|
model_path=args.model, |
|
|
labels_path=args.labels, |
|
|
confidence_threshold=args.confidence, |
|
|
top_k=args.top_k, |
|
|
update_interval=args.update_interval, |
|
|
) |
|
|
|
|
|
|
|
|
detector.start_detection() |
|
|
|
|
|
return 0 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {str(e)}") |
|
|
return 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|