BirdNET-onnx / predict_audio.py
justinchuby's picture
Upload folder using huggingface_hub
7b7cd7f verified
raw
history blame
15.1 kB
#!/usr/bin/env python3
"""BirdNET Audio Classification Script
This script loads a WAV file and uses the BirdNET ONNX model to predict bird species.
The model expects audio input of shape [batch_size, 144000] (3 seconds at 48kHz).
Created using Copilot.
"""
from __future__ import annotations
import numpy as np
import librosa
import onnxruntime as ort
import argparse
import os
from collections import defaultdict
def load_audio(
file_path: str, target_sr: int = 48000, duration: float = 3.0
) -> np.ndarray:
"""
Load and preprocess audio file for BirdNET model.
Args:
file_path (str): Path to the audio file
target_sr (int): Target sample rate (48kHz for BirdNET)
duration (float): Duration in seconds (3.0 for BirdNET)
Returns:
np.ndarray: Preprocessed audio array of shape [144000]
"""
try:
# Load audio file
audio, sr = librosa.load(file_path, sr=target_sr, duration=duration)
# Ensure we have exactly 144000 samples (3 seconds at 48kHz)
target_length = int(target_sr * duration)
if len(audio) < target_length:
# Pad with zeros if too short
audio = np.pad(audio, (0, target_length - len(audio)))
elif len(audio) > target_length:
# Truncate if too long
audio = audio[:target_length]
return audio.astype(np.float32)
except Exception as e:
raise RuntimeError(f"Error loading audio file {file_path}: {str(e)}")
def load_labels(labels_path: str) -> list[str]:
"""
Load BirdNET species labels from the labels file.
Args:
labels_path (str): Path to the labels file
Returns:
list[str]: List of species names
"""
try:
labels = []
with open(labels_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
# Format: "Scientific_name_Common Name"
# Extract the common name part after the underscore
if "_" in line:
common_name = line.split("_", 1)[1]
labels.append(common_name)
else:
labels.append(line)
return labels
except Exception as e:
raise RuntimeError(f"Error loading labels file {labels_path}: {str(e)}")
def load_audio_full(file_path: str, target_sr: int = 48000) -> np.ndarray:
"""
Load full audio file for moving window analysis.
Args:
file_path (str): Path to the audio file
target_sr (int): Target sample rate (48kHz for BirdNET)
Returns:
np.ndarray: Full audio array
"""
try:
# Load entire audio file
audio, sr = librosa.load(file_path, sr=target_sr)
return audio.astype(np.float32)
except Exception as e:
raise RuntimeError(f"Error loading audio file {file_path}: {str(e)}")
def create_audio_windows(
audio: np.ndarray, window_size: int = 144000, overlap: float = 0.5
) -> tuple[np.ndarray, list[float]]:
"""
Create overlapping windows from audio for analysis.
Args:
audio (np.ndarray): Full audio array
window_size (int): Size of each window (144000 for 3 seconds at 48kHz)
overlap (float): Overlap ratio (0.5 = 50% overlap)
Returns:
tuple[np.ndarray, list[float]]: (windows array, timestamps)
"""
step_size = int(window_size * (1 - overlap))
windows = []
timestamps = []
for start in range(0, len(audio) - window_size + 1, step_size):
end = start + window_size
window = audio[start:end]
# Ensure window is exactly the right size
if len(window) == window_size:
windows.append(window)
# Calculate timestamp in seconds
timestamps.append(start / 48000.0)
return np.array(windows), timestamps
def load_onnx_model(model_path: str) -> ort.InferenceSession:
"""
Load ONNX model for inference.
Args:
model_path (str): Path to the ONNX model file
Returns:
ort.InferenceSession: Loaded ONNX model session
"""
try:
# Create inference session
session = ort.InferenceSession(model_path)
return session
except Exception as e:
raise RuntimeError(f"Error loading ONNX model {model_path}: {str(e)}")
def predict_audio(session: ort.InferenceSession, audio_data: np.ndarray) -> np.ndarray:
"""
Run inference on audio data using the ONNX model.
Args:
session (ort.InferenceSession): ONNX model session
audio_data (np.ndarray): Audio data of shape [144000] or [batch, 144000]
Returns:
np.ndarray: Model predictions
"""
try:
# Ensure we have batch dimension
if len(audio_data.shape) == 1:
input_data = np.expand_dims(audio_data, axis=0)
else:
input_data = audio_data
# Get input name from the model
input_name = session.get_inputs()[0].name
# Run inference
outputs = session.run(None, {input_name: input_data})
return outputs[0]
except Exception as e:
raise RuntimeError(f"Error during model inference: {str(e)}")
def predict_audio_batch(
session: ort.InferenceSession,
windows_batch: np.ndarray,
batch_size: int = 128,
show_progress: bool = True,
) -> np.ndarray:
"""
Run inference on batches of audio windows for better performance.
Args:
session (ort.InferenceSession): ONNX model session
windows_batch (np.ndarray): Array of windows, shape [num_windows, 144000]
batch_size (int): Number of windows to process in each batch
show_progress (bool): Whether to show progress updates
Returns:
np.ndarray: All predictions concatenated, shape [num_windows, num_classes]
"""
try:
all_predictions = []
num_windows = len(windows_batch)
# Get input name from the model
input_name = session.get_inputs()[0].name
# Process in batches
batch_num = 0
for start_idx in range(0, num_windows, batch_size):
end_idx = min(start_idx + batch_size, num_windows)
current_batch = windows_batch[start_idx:end_idx]
batch_num += 1
if show_progress and (batch_num % 5 == 0 or batch_num == 1):
progress = (end_idx / num_windows) * 100
print(
f" Batch {batch_num}: processing windows {start_idx + 1}-{end_idx} ({progress:.1f}%)"
)
# Run inference on current batch
outputs = session.run(None, {input_name: current_batch})
batch_predictions = outputs[0]
all_predictions.append(batch_predictions)
# Concatenate all batch results
return np.concatenate(all_predictions, axis=0)
except Exception as e:
raise RuntimeError(f"Error during batch model inference: {str(e)}")
def analyze_detections(
all_predictions: np.ndarray,
timestamps: list[float],
labels: list[str],
confidence_threshold: float = 0.1,
) -> dict[str, list[dict[str, float | int]]]:
"""
Analyze predictions across all windows and summarize detections.
Args:
all_predictions (np.ndarray): Predictions from all windows, shape [num_windows, num_classes]
timestamps (list[float]): Timestamps for each window
labels (list[str]): Species labels
confidence_threshold (float): Minimum confidence for detection
Returns:
dict[str, list[dict[str, float | int]]]: Summary of detections with timestamps
"""
detections = defaultdict(list)
# all_predictions is now shape [num_windows, num_classes] from batch processing
for i, (predictions, timestamp) in enumerate(zip(all_predictions, timestamps)):
# predictions is now a 1D array of scores for this window
scores = predictions
# Find all detections above threshold
above_threshold = np.where(scores > confidence_threshold)[0]
for idx in above_threshold:
confidence = float(scores[idx])
species_name = labels[idx] if idx < len(labels) else f"Class {idx}"
detections[species_name].append(
{"timestamp": timestamp, "confidence": confidence, "window": i}
)
return dict(detections)
def main() -> int:
parser = argparse.ArgumentParser(
description="BirdNET Audio Classification with Moving Window"
)
parser.add_argument("audio_file", help="Path to the WAV audio file")
parser.add_argument(
"--model", default="model.onnx", help="Path to the ONNX model file"
)
parser.add_argument(
"--labels",
default="BirdNET_GLOBAL_6K_V2.4_Labels.txt",
help="Path to the labels file",
)
parser.add_argument(
"--top-k",
type=int,
default=5,
help="Number of top predictions to show per window",
)
parser.add_argument(
"--overlap", type=float, default=0.5, help="Window overlap ratio (0.0-1.0)"
)
parser.add_argument(
"--confidence",
type=float,
default=0.1,
help="Minimum confidence threshold for detections",
)
parser.add_argument(
"--batch-size",
type=int,
default=128,
help="Batch size for inference (default: 128)",
)
parser.add_argument(
"--single-window",
action="store_true",
help="Analyze only first 3 seconds (single window)",
)
args = parser.parse_args()
# Check if files exist
if not os.path.exists(args.audio_file):
print(f"Error: Audio file '{args.audio_file}' not found.")
return 1
if not os.path.exists(args.model):
print(f"Error: Model file '{args.model}' not found.")
return 1
if not os.path.exists(args.labels):
print(f"Error: Labels file '{args.labels}' not found.")
return 1
try:
# Load labels
print(f"Loading labels from: {args.labels}")
labels = load_labels(args.labels)
print(f"Loaded {len(labels)} species labels")
# Load ONNX model
print(f"Loading ONNX model: {args.model}")
session = load_onnx_model(args.model)
# Print model info
input_info = session.get_inputs()[0]
output_info = session.get_outputs()[0]
print(f"Model input: {input_info.name}, shape: {input_info.shape}")
print(f"Model output: {output_info.name}, shape: {output_info.shape}")
if args.single_window:
# Single window analysis (original behavior)
print(f"Loading first 3 seconds of audio file: {args.audio_file}")
audio_data = load_audio(args.audio_file)
print(f"Audio loaded successfully. Shape: {audio_data.shape}")
print("Running inference on single window...")
predictions = predict_audio(session, audio_data)
# Get scores
predictions = np.array(predictions)
if len(predictions.shape) > 1:
scores = predictions[0]
else:
scores = predictions
# Get top-k predictions
top_indices = np.argsort(scores)[-args.top_k :][::-1]
print(f"\nTop {args.top_k} predictions for first 3 seconds:")
for i, idx in enumerate(top_indices):
confidence = float(scores[idx])
species_name = labels[idx] if idx < len(labels) else f"Class {idx}"
print(f"{i + 1:2d}. {species_name}: {confidence:.6f}")
else:
# Moving window analysis
print(f"Loading full audio file: {args.audio_file}")
full_audio = load_audio_full(args.audio_file)
audio_duration = len(full_audio) / 48000.0
print(f"Audio loaded successfully. Duration: {audio_duration:.2f} seconds")
# Create windows
print(f"Creating windows with {args.overlap * 100:.0f}% overlap...")
windows, timestamps = create_audio_windows(full_audio, overlap=args.overlap)
print(f"Created {len(windows)} windows of 3 seconds each")
# Run batch inference on all windows
print(
f"Running batch inference on {len(windows)} windows (batch size: {args.batch_size})..."
)
num_batches = (len(windows) + args.batch_size - 1) // args.batch_size
print(f"Processing {num_batches} batches...")
# Use batch prediction for better performance
all_predictions = predict_audio_batch(session, windows, args.batch_size)
print(f"Completed batch inference on {len(windows)} windows")
# Analyze detections across all windows
print(
f"Analyzing detections with confidence threshold {args.confidence}..."
)
detections = analyze_detections(
all_predictions, timestamps, labels, args.confidence
)
# Sort species by maximum confidence
sorted_species = sorted(
detections.items(),
key=lambda x: max(det["confidence"] for det in x[1]),
reverse=True,
)
print("\n=== DETECTION SUMMARY ===")
print(f"Audio duration: {audio_duration:.2f} seconds")
print(f"Windows analyzed: {len(windows)}")
print(
f"Species detected (>{args.confidence:.2f} confidence): {len(sorted_species)}"
)
if sorted_species:
print("\nTop detections:")
for species, detections_list in sorted_species[: args.top_k]:
max_conf = max(det["confidence"] for det in detections_list)
num_detections = len(detections_list)
first_detection = min(det["timestamp"] for det in detections_list)
last_detection = max(det["timestamp"] for det in detections_list)
print(f"\n{species}")
print(f" Max confidence: {max_conf:.6f}")
print(f" Detections: {num_detections}")
print(
f" Time range: {first_detection:.1f}s - {last_detection:.1f}s"
)
# Show strongest detections for this species
strong_detections = sorted(
detections_list, key=lambda x: x["confidence"], reverse=True
)[:3]
for det in strong_detections:
print(f" {det['timestamp']:6.1f}s: {det['confidence']:.6f}")
else:
print(
f"No detections found above confidence threshold {args.confidence}"
)
return 0
except Exception as e:
print(f"Error: {str(e)}")
return 1
if __name__ == "__main__":
exit(main())