|
|
|
|
|
"""BirdNET Audio Classification Script |
|
|
|
|
|
This script loads a WAV file and uses the BirdNET ONNX model to predict bird species. |
|
|
The model expects audio input of shape [batch_size, 144000] (3 seconds at 48kHz). |
|
|
|
|
|
Created using Copilot. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import numpy as np |
|
|
import librosa |
|
|
import onnxruntime as ort |
|
|
import argparse |
|
|
import os |
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
def load_audio( |
|
|
file_path: str, target_sr: int = 48000, duration: float = 3.0 |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Load and preprocess audio file for BirdNET model. |
|
|
|
|
|
Args: |
|
|
file_path (str): Path to the audio file |
|
|
target_sr (int): Target sample rate (48kHz for BirdNET) |
|
|
duration (float): Duration in seconds (3.0 for BirdNET) |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Preprocessed audio array of shape [144000] |
|
|
""" |
|
|
try: |
|
|
|
|
|
audio, sr = librosa.load(file_path, sr=target_sr, duration=duration) |
|
|
|
|
|
|
|
|
target_length = int(target_sr * duration) |
|
|
|
|
|
if len(audio) < target_length: |
|
|
|
|
|
audio = np.pad(audio, (0, target_length - len(audio))) |
|
|
elif len(audio) > target_length: |
|
|
|
|
|
audio = audio[:target_length] |
|
|
|
|
|
return audio.astype(np.float32) |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error loading audio file {file_path}: {str(e)}") |
|
|
|
|
|
|
|
|
def load_labels(labels_path: str) -> list[str]: |
|
|
""" |
|
|
Load BirdNET species labels from the labels file. |
|
|
|
|
|
Args: |
|
|
labels_path (str): Path to the labels file |
|
|
|
|
|
Returns: |
|
|
list[str]: List of species names |
|
|
""" |
|
|
try: |
|
|
labels = [] |
|
|
with open(labels_path, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
|
|
|
|
|
|
if "_" in line: |
|
|
common_name = line.split("_", 1)[1] |
|
|
labels.append(common_name) |
|
|
else: |
|
|
labels.append(line) |
|
|
return labels |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error loading labels file {labels_path}: {str(e)}") |
|
|
|
|
|
|
|
|
def load_audio_full(file_path: str, target_sr: int = 48000) -> np.ndarray: |
|
|
""" |
|
|
Load full audio file for moving window analysis. |
|
|
|
|
|
Args: |
|
|
file_path (str): Path to the audio file |
|
|
target_sr (int): Target sample rate (48kHz for BirdNET) |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Full audio array |
|
|
""" |
|
|
try: |
|
|
|
|
|
audio, sr = librosa.load(file_path, sr=target_sr) |
|
|
return audio.astype(np.float32) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error loading audio file {file_path}: {str(e)}") |
|
|
|
|
|
|
|
|
def create_audio_windows( |
|
|
audio: np.ndarray, window_size: int = 144000, overlap: float = 0.5 |
|
|
) -> tuple[np.ndarray, list[float]]: |
|
|
""" |
|
|
Create overlapping windows from audio for analysis. |
|
|
|
|
|
Args: |
|
|
audio (np.ndarray): Full audio array |
|
|
window_size (int): Size of each window (144000 for 3 seconds at 48kHz) |
|
|
overlap (float): Overlap ratio (0.5 = 50% overlap) |
|
|
|
|
|
Returns: |
|
|
tuple[np.ndarray, list[float]]: (windows array, timestamps) |
|
|
""" |
|
|
step_size = int(window_size * (1 - overlap)) |
|
|
windows = [] |
|
|
timestamps = [] |
|
|
|
|
|
for start in range(0, len(audio) - window_size + 1, step_size): |
|
|
end = start + window_size |
|
|
window = audio[start:end] |
|
|
|
|
|
|
|
|
if len(window) == window_size: |
|
|
windows.append(window) |
|
|
|
|
|
timestamps.append(start / 48000.0) |
|
|
|
|
|
return np.array(windows), timestamps |
|
|
|
|
|
|
|
|
def load_onnx_model(model_path: str) -> ort.InferenceSession: |
|
|
""" |
|
|
Load ONNX model for inference. |
|
|
|
|
|
Args: |
|
|
model_path (str): Path to the ONNX model file |
|
|
|
|
|
Returns: |
|
|
ort.InferenceSession: Loaded ONNX model session |
|
|
""" |
|
|
try: |
|
|
|
|
|
session = ort.InferenceSession(model_path) |
|
|
return session |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error loading ONNX model {model_path}: {str(e)}") |
|
|
|
|
|
|
|
|
def predict_audio(session: ort.InferenceSession, audio_data: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Run inference on audio data using the ONNX model. |
|
|
|
|
|
Args: |
|
|
session (ort.InferenceSession): ONNX model session |
|
|
audio_data (np.ndarray): Audio data of shape [144000] or [batch, 144000] |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Model predictions |
|
|
""" |
|
|
try: |
|
|
|
|
|
if len(audio_data.shape) == 1: |
|
|
input_data = np.expand_dims(audio_data, axis=0) |
|
|
else: |
|
|
input_data = audio_data |
|
|
|
|
|
|
|
|
input_name = session.get_inputs()[0].name |
|
|
|
|
|
|
|
|
outputs = session.run(None, {input_name: input_data}) |
|
|
|
|
|
return outputs[0] |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error during model inference: {str(e)}") |
|
|
|
|
|
|
|
|
def predict_audio_batch( |
|
|
session: ort.InferenceSession, |
|
|
windows_batch: np.ndarray, |
|
|
batch_size: int = 128, |
|
|
show_progress: bool = True, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Run inference on batches of audio windows for better performance. |
|
|
|
|
|
Args: |
|
|
session (ort.InferenceSession): ONNX model session |
|
|
windows_batch (np.ndarray): Array of windows, shape [num_windows, 144000] |
|
|
batch_size (int): Number of windows to process in each batch |
|
|
show_progress (bool): Whether to show progress updates |
|
|
|
|
|
Returns: |
|
|
np.ndarray: All predictions concatenated, shape [num_windows, num_classes] |
|
|
""" |
|
|
try: |
|
|
all_predictions = [] |
|
|
num_windows = len(windows_batch) |
|
|
|
|
|
|
|
|
input_name = session.get_inputs()[0].name |
|
|
|
|
|
|
|
|
batch_num = 0 |
|
|
for start_idx in range(0, num_windows, batch_size): |
|
|
end_idx = min(start_idx + batch_size, num_windows) |
|
|
current_batch = windows_batch[start_idx:end_idx] |
|
|
batch_num += 1 |
|
|
|
|
|
if show_progress and (batch_num % 5 == 0 or batch_num == 1): |
|
|
progress = (end_idx / num_windows) * 100 |
|
|
print( |
|
|
f" Batch {batch_num}: processing windows {start_idx + 1}-{end_idx} ({progress:.1f}%)" |
|
|
) |
|
|
|
|
|
|
|
|
outputs = session.run(None, {input_name: current_batch}) |
|
|
batch_predictions = outputs[0] |
|
|
|
|
|
all_predictions.append(batch_predictions) |
|
|
|
|
|
|
|
|
return np.concatenate(all_predictions, axis=0) |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error during batch model inference: {str(e)}") |
|
|
|
|
|
|
|
|
def analyze_detections( |
|
|
all_predictions: np.ndarray, |
|
|
timestamps: list[float], |
|
|
labels: list[str], |
|
|
confidence_threshold: float = 0.1, |
|
|
) -> dict[str, list[dict[str, float | int]]]: |
|
|
""" |
|
|
Analyze predictions across all windows and summarize detections. |
|
|
|
|
|
Args: |
|
|
all_predictions (np.ndarray): Predictions from all windows, shape [num_windows, num_classes] |
|
|
timestamps (list[float]): Timestamps for each window |
|
|
labels (list[str]): Species labels |
|
|
confidence_threshold (float): Minimum confidence for detection |
|
|
|
|
|
Returns: |
|
|
dict[str, list[dict[str, float | int]]]: Summary of detections with timestamps |
|
|
""" |
|
|
detections = defaultdict(list) |
|
|
|
|
|
|
|
|
for i, (predictions, timestamp) in enumerate(zip(all_predictions, timestamps)): |
|
|
|
|
|
scores = predictions |
|
|
|
|
|
|
|
|
above_threshold = np.where(scores > confidence_threshold)[0] |
|
|
|
|
|
for idx in above_threshold: |
|
|
confidence = float(scores[idx]) |
|
|
species_name = labels[idx] if idx < len(labels) else f"Class {idx}" |
|
|
|
|
|
detections[species_name].append( |
|
|
{"timestamp": timestamp, "confidence": confidence, "window": i} |
|
|
) |
|
|
|
|
|
return dict(detections) |
|
|
|
|
|
|
|
|
def main() -> int: |
|
|
parser = argparse.ArgumentParser( |
|
|
description="BirdNET Audio Classification with Moving Window" |
|
|
) |
|
|
parser.add_argument("audio_file", help="Path to the WAV audio file") |
|
|
parser.add_argument( |
|
|
"--model", default="model.onnx", help="Path to the ONNX model file" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--labels", |
|
|
default="BirdNET_GLOBAL_6K_V2.4_Labels.txt", |
|
|
help="Path to the labels file", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--top-k", |
|
|
type=int, |
|
|
default=5, |
|
|
help="Number of top predictions to show per window", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--overlap", type=float, default=0.5, help="Window overlap ratio (0.0-1.0)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--confidence", |
|
|
type=float, |
|
|
default=0.1, |
|
|
help="Minimum confidence threshold for detections", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch-size", |
|
|
type=int, |
|
|
default=128, |
|
|
help="Batch size for inference (default: 128)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--single-window", |
|
|
action="store_true", |
|
|
help="Analyze only first 3 seconds (single window)", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not os.path.exists(args.audio_file): |
|
|
print(f"Error: Audio file '{args.audio_file}' not found.") |
|
|
return 1 |
|
|
|
|
|
if not os.path.exists(args.model): |
|
|
print(f"Error: Model file '{args.model}' not found.") |
|
|
return 1 |
|
|
|
|
|
if not os.path.exists(args.labels): |
|
|
print(f"Error: Labels file '{args.labels}' not found.") |
|
|
return 1 |
|
|
|
|
|
try: |
|
|
|
|
|
print(f"Loading labels from: {args.labels}") |
|
|
labels = load_labels(args.labels) |
|
|
print(f"Loaded {len(labels)} species labels") |
|
|
|
|
|
|
|
|
print(f"Loading ONNX model: {args.model}") |
|
|
session = load_onnx_model(args.model) |
|
|
|
|
|
|
|
|
input_info = session.get_inputs()[0] |
|
|
output_info = session.get_outputs()[0] |
|
|
print(f"Model input: {input_info.name}, shape: {input_info.shape}") |
|
|
print(f"Model output: {output_info.name}, shape: {output_info.shape}") |
|
|
|
|
|
if args.single_window: |
|
|
|
|
|
print(f"Loading first 3 seconds of audio file: {args.audio_file}") |
|
|
audio_data = load_audio(args.audio_file) |
|
|
print(f"Audio loaded successfully. Shape: {audio_data.shape}") |
|
|
|
|
|
print("Running inference on single window...") |
|
|
predictions = predict_audio(session, audio_data) |
|
|
|
|
|
|
|
|
predictions = np.array(predictions) |
|
|
if len(predictions.shape) > 1: |
|
|
scores = predictions[0] |
|
|
else: |
|
|
scores = predictions |
|
|
|
|
|
|
|
|
top_indices = np.argsort(scores)[-args.top_k :][::-1] |
|
|
|
|
|
print(f"\nTop {args.top_k} predictions for first 3 seconds:") |
|
|
for i, idx in enumerate(top_indices): |
|
|
confidence = float(scores[idx]) |
|
|
species_name = labels[idx] if idx < len(labels) else f"Class {idx}" |
|
|
print(f"{i + 1:2d}. {species_name}: {confidence:.6f}") |
|
|
|
|
|
else: |
|
|
|
|
|
print(f"Loading full audio file: {args.audio_file}") |
|
|
full_audio = load_audio_full(args.audio_file) |
|
|
audio_duration = len(full_audio) / 48000.0 |
|
|
print(f"Audio loaded successfully. Duration: {audio_duration:.2f} seconds") |
|
|
|
|
|
|
|
|
print(f"Creating windows with {args.overlap * 100:.0f}% overlap...") |
|
|
windows, timestamps = create_audio_windows(full_audio, overlap=args.overlap) |
|
|
print(f"Created {len(windows)} windows of 3 seconds each") |
|
|
|
|
|
|
|
|
print( |
|
|
f"Running batch inference on {len(windows)} windows (batch size: {args.batch_size})..." |
|
|
) |
|
|
num_batches = (len(windows) + args.batch_size - 1) // args.batch_size |
|
|
print(f"Processing {num_batches} batches...") |
|
|
|
|
|
|
|
|
all_predictions = predict_audio_batch(session, windows, args.batch_size) |
|
|
print(f"Completed batch inference on {len(windows)} windows") |
|
|
|
|
|
|
|
|
print( |
|
|
f"Analyzing detections with confidence threshold {args.confidence}..." |
|
|
) |
|
|
detections = analyze_detections( |
|
|
all_predictions, timestamps, labels, args.confidence |
|
|
) |
|
|
|
|
|
|
|
|
sorted_species = sorted( |
|
|
detections.items(), |
|
|
key=lambda x: max(det["confidence"] for det in x[1]), |
|
|
reverse=True, |
|
|
) |
|
|
|
|
|
print("\n=== DETECTION SUMMARY ===") |
|
|
print(f"Audio duration: {audio_duration:.2f} seconds") |
|
|
print(f"Windows analyzed: {len(windows)}") |
|
|
print( |
|
|
f"Species detected (>{args.confidence:.2f} confidence): {len(sorted_species)}" |
|
|
) |
|
|
|
|
|
if sorted_species: |
|
|
print("\nTop detections:") |
|
|
for species, detections_list in sorted_species[: args.top_k]: |
|
|
max_conf = max(det["confidence"] for det in detections_list) |
|
|
num_detections = len(detections_list) |
|
|
first_detection = min(det["timestamp"] for det in detections_list) |
|
|
last_detection = max(det["timestamp"] for det in detections_list) |
|
|
|
|
|
print(f"\n{species}") |
|
|
print(f" Max confidence: {max_conf:.6f}") |
|
|
print(f" Detections: {num_detections}") |
|
|
print( |
|
|
f" Time range: {first_detection:.1f}s - {last_detection:.1f}s" |
|
|
) |
|
|
|
|
|
|
|
|
strong_detections = sorted( |
|
|
detections_list, key=lambda x: x["confidence"], reverse=True |
|
|
)[:3] |
|
|
for det in strong_detections: |
|
|
print(f" {det['timestamp']:6.1f}s: {det['confidence']:.6f}") |
|
|
else: |
|
|
print( |
|
|
f"No detections found above confidence threshold {args.confidence}" |
|
|
) |
|
|
|
|
|
return 0 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {str(e)}") |
|
|
return 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|