#!/usr/bin/env python3 """BirdNET Audio Classification Script This script loads a WAV file and uses the BirdNET ONNX model to predict bird species. The model expects audio input of shape [batch_size, 144000] (3 seconds at 48kHz). Created using Copilot. """ from __future__ import annotations import numpy as np import librosa import onnxruntime as ort import argparse import os from collections import defaultdict def load_audio( file_path: str, target_sr: int = 48000, duration: float = 3.0 ) -> np.ndarray: """ Load and preprocess audio file for BirdNET model. Args: file_path (str): Path to the audio file target_sr (int): Target sample rate (48kHz for BirdNET) duration (float): Duration in seconds (3.0 for BirdNET) Returns: np.ndarray: Preprocessed audio array of shape [144000] """ try: # Load audio file audio, sr = librosa.load(file_path, sr=target_sr, duration=duration) # Ensure we have exactly 144000 samples (3 seconds at 48kHz) target_length = int(target_sr * duration) if len(audio) < target_length: # Pad with zeros if too short audio = np.pad(audio, (0, target_length - len(audio))) elif len(audio) > target_length: # Truncate if too long audio = audio[:target_length] return audio.astype(np.float32) except Exception as e: raise RuntimeError(f"Error loading audio file {file_path}: {str(e)}") def load_labels(labels_path: str) -> list[str]: """ Load BirdNET species labels from the labels file. Args: labels_path (str): Path to the labels file Returns: list[str]: List of species names """ try: labels = [] with open(labels_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: # Format: "Scientific_name_Common Name" # Extract the common name part after the underscore if "_" in line: common_name = line.split("_", 1)[1] labels.append(common_name) else: labels.append(line) return labels except Exception as e: raise RuntimeError(f"Error loading labels file {labels_path}: {str(e)}") def load_audio_full(file_path: str, target_sr: int = 48000) -> np.ndarray: """ Load full audio file for moving window analysis. Args: file_path (str): Path to the audio file target_sr (int): Target sample rate (48kHz for BirdNET) Returns: np.ndarray: Full audio array """ try: # Load entire audio file audio, sr = librosa.load(file_path, sr=target_sr) return audio.astype(np.float32) except Exception as e: raise RuntimeError(f"Error loading audio file {file_path}: {str(e)}") def create_audio_windows( audio: np.ndarray, window_size: int = 144000, overlap: float = 0.5 ) -> tuple[np.ndarray, list[float]]: """ Create overlapping windows from audio for analysis. Args: audio (np.ndarray): Full audio array window_size (int): Size of each window (144000 for 3 seconds at 48kHz) overlap (float): Overlap ratio (0.5 = 50% overlap) Returns: tuple[np.ndarray, list[float]]: (windows array, timestamps) """ step_size = int(window_size * (1 - overlap)) windows = [] timestamps = [] for start in range(0, len(audio) - window_size + 1, step_size): end = start + window_size window = audio[start:end] # Ensure window is exactly the right size if len(window) == window_size: windows.append(window) # Calculate timestamp in seconds timestamps.append(start / 48000.0) return np.array(windows), timestamps def load_onnx_model(model_path: str) -> ort.InferenceSession: """ Load ONNX model for inference. Args: model_path (str): Path to the ONNX model file Returns: ort.InferenceSession: Loaded ONNX model session """ try: # Create inference session session = ort.InferenceSession(model_path) return session except Exception as e: raise RuntimeError(f"Error loading ONNX model {model_path}: {str(e)}") def predict_audio(session: ort.InferenceSession, audio_data: np.ndarray) -> np.ndarray: """ Run inference on audio data using the ONNX model. Args: session (ort.InferenceSession): ONNX model session audio_data (np.ndarray): Audio data of shape [144000] or [batch, 144000] Returns: np.ndarray: Model predictions """ try: # Ensure we have batch dimension if len(audio_data.shape) == 1: input_data = np.expand_dims(audio_data, axis=0) else: input_data = audio_data # Get input name from the model input_name = session.get_inputs()[0].name # Run inference outputs = session.run(None, {input_name: input_data}) return outputs[0] except Exception as e: raise RuntimeError(f"Error during model inference: {str(e)}") def predict_audio_batch( session: ort.InferenceSession, windows_batch: np.ndarray, batch_size: int = 128, show_progress: bool = True, ) -> np.ndarray: """ Run inference on batches of audio windows for better performance. Args: session (ort.InferenceSession): ONNX model session windows_batch (np.ndarray): Array of windows, shape [num_windows, 144000] batch_size (int): Number of windows to process in each batch show_progress (bool): Whether to show progress updates Returns: np.ndarray: All predictions concatenated, shape [num_windows, num_classes] """ try: all_predictions = [] num_windows = len(windows_batch) # Get input name from the model input_name = session.get_inputs()[0].name # Process in batches batch_num = 0 for start_idx in range(0, num_windows, batch_size): end_idx = min(start_idx + batch_size, num_windows) current_batch = windows_batch[start_idx:end_idx] batch_num += 1 if show_progress and (batch_num % 5 == 0 or batch_num == 1): progress = (end_idx / num_windows) * 100 print( f" Batch {batch_num}: processing windows {start_idx + 1}-{end_idx} ({progress:.1f}%)" ) # Run inference on current batch outputs = session.run(None, {input_name: current_batch}) batch_predictions = outputs[0] all_predictions.append(batch_predictions) # Concatenate all batch results return np.concatenate(all_predictions, axis=0) except Exception as e: raise RuntimeError(f"Error during batch model inference: {str(e)}") def analyze_detections( all_predictions: np.ndarray, timestamps: list[float], labels: list[str], confidence_threshold: float = 0.1, ) -> dict[str, list[dict[str, float | int]]]: """ Analyze predictions across all windows and summarize detections. Args: all_predictions (np.ndarray): Predictions from all windows, shape [num_windows, num_classes] timestamps (list[float]): Timestamps for each window labels (list[str]): Species labels confidence_threshold (float): Minimum confidence for detection Returns: dict[str, list[dict[str, float | int]]]: Summary of detections with timestamps """ detections = defaultdict(list) # all_predictions is now shape [num_windows, num_classes] from batch processing for i, (predictions, timestamp) in enumerate(zip(all_predictions, timestamps)): # predictions is now a 1D array of scores for this window scores = predictions # Find all detections above threshold above_threshold = np.where(scores > confidence_threshold)[0] for idx in above_threshold: confidence = float(scores[idx]) species_name = labels[idx] if idx < len(labels) else f"Class {idx}" detections[species_name].append( {"timestamp": timestamp, "confidence": confidence, "window": i} ) return dict(detections) def main() -> int: parser = argparse.ArgumentParser( description="BirdNET Audio Classification with Moving Window" ) parser.add_argument("audio_file", help="Path to the WAV audio file") parser.add_argument( "--model", default="model.onnx", help="Path to the ONNX model file" ) parser.add_argument( "--labels", default="BirdNET_GLOBAL_6K_V2.4_Labels.txt", help="Path to the labels file", ) parser.add_argument( "--top-k", type=int, default=5, help="Number of top predictions to show per window", ) parser.add_argument( "--overlap", type=float, default=0.5, help="Window overlap ratio (0.0-1.0)" ) parser.add_argument( "--confidence", type=float, default=0.1, help="Minimum confidence threshold for detections", ) parser.add_argument( "--batch-size", type=int, default=128, help="Batch size for inference (default: 128)", ) parser.add_argument( "--single-window", action="store_true", help="Analyze only first 3 seconds (single window)", ) args = parser.parse_args() # Check if files exist if not os.path.exists(args.audio_file): print(f"Error: Audio file '{args.audio_file}' not found.") return 1 if not os.path.exists(args.model): print(f"Error: Model file '{args.model}' not found.") return 1 if not os.path.exists(args.labels): print(f"Error: Labels file '{args.labels}' not found.") return 1 try: # Load labels print(f"Loading labels from: {args.labels}") labels = load_labels(args.labels) print(f"Loaded {len(labels)} species labels") # Load ONNX model print(f"Loading ONNX model: {args.model}") session = load_onnx_model(args.model) # Print model info input_info = session.get_inputs()[0] output_info = session.get_outputs()[0] print(f"Model input: {input_info.name}, shape: {input_info.shape}") print(f"Model output: {output_info.name}, shape: {output_info.shape}") if args.single_window: # Single window analysis (original behavior) print(f"Loading first 3 seconds of audio file: {args.audio_file}") audio_data = load_audio(args.audio_file) print(f"Audio loaded successfully. Shape: {audio_data.shape}") print("Running inference on single window...") predictions = predict_audio(session, audio_data) # Get scores predictions = np.array(predictions) if len(predictions.shape) > 1: scores = predictions[0] else: scores = predictions # Get top-k predictions top_indices = np.argsort(scores)[-args.top_k :][::-1] print(f"\nTop {args.top_k} predictions for first 3 seconds:") for i, idx in enumerate(top_indices): confidence = float(scores[idx]) species_name = labels[idx] if idx < len(labels) else f"Class {idx}" print(f"{i + 1:2d}. {species_name}: {confidence:.6f}") else: # Moving window analysis print(f"Loading full audio file: {args.audio_file}") full_audio = load_audio_full(args.audio_file) audio_duration = len(full_audio) / 48000.0 print(f"Audio loaded successfully. Duration: {audio_duration:.2f} seconds") # Create windows print(f"Creating windows with {args.overlap * 100:.0f}% overlap...") windows, timestamps = create_audio_windows(full_audio, overlap=args.overlap) print(f"Created {len(windows)} windows of 3 seconds each") # Run batch inference on all windows print( f"Running batch inference on {len(windows)} windows (batch size: {args.batch_size})..." ) num_batches = (len(windows) + args.batch_size - 1) // args.batch_size print(f"Processing {num_batches} batches...") # Use batch prediction for better performance all_predictions = predict_audio_batch(session, windows, args.batch_size) print(f"Completed batch inference on {len(windows)} windows") # Analyze detections across all windows print( f"Analyzing detections with confidence threshold {args.confidence}..." ) detections = analyze_detections( all_predictions, timestamps, labels, args.confidence ) # Sort species by maximum confidence sorted_species = sorted( detections.items(), key=lambda x: max(det["confidence"] for det in x[1]), reverse=True, ) print("\n=== DETECTION SUMMARY ===") print(f"Audio duration: {audio_duration:.2f} seconds") print(f"Windows analyzed: {len(windows)}") print( f"Species detected (>{args.confidence:.2f} confidence): {len(sorted_species)}" ) if sorted_species: print("\nTop detections:") for species, detections_list in sorted_species[: args.top_k]: max_conf = max(det["confidence"] for det in detections_list) num_detections = len(detections_list) first_detection = min(det["timestamp"] for det in detections_list) last_detection = max(det["timestamp"] for det in detections_list) print(f"\n{species}") print(f" Max confidence: {max_conf:.6f}") print(f" Detections: {num_detections}") print( f" Time range: {first_detection:.1f}s - {last_detection:.1f}s" ) # Show strongest detections for this species strong_detections = sorted( detections_list, key=lambda x: x["confidence"], reverse=True )[:3] for det in strong_detections: print(f" {det['timestamp']:6.1f}s: {det['confidence']:.6f}") else: print( f"No detections found above confidence threshold {args.confidence}" ) return 0 except Exception as e: print(f"Error: {str(e)}") return 1 if __name__ == "__main__": exit(main())