BirdNET-onnx / scripts /compare_onnx_tflite.py
justinchuby's picture
Simplify model (#2)
1f1b5fd verified
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""
Script to compare the results of an ONNX model with a TFLite model given the same input.
Optionally also compare with Tract runtime for ONNX.
Created by Copilot.
Usage:
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --input input.npy
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --rtol 1e-5 --atol 1e-5
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract --benchmark
"""
import argparse
import time
import numpy as np
import onnxruntime as ort
import tensorflow as tf
from typing import Dict, List, Tuple, Optional, Any
try:
import tract
TRACT_AVAILABLE = True
except ImportError:
TRACT_AVAILABLE = False
def load_onnx_model(onnx_path: str) -> ort.InferenceSession:
"""Load an ONNX model and return an inference session."""
print(f"Loading ONNX model from: {onnx_path}")
session = ort.InferenceSession(onnx_path)
return session
def load_tflite_model(tflite_path: str) -> tf.lite.Interpreter:
"""Load a TFLite model and return an interpreter."""
print(f"Loading TFLite model from: {tflite_path}")
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()
return interpreter
def load_tract_model(onnx_path: str) -> Optional[Any]:
"""Load an ONNX model using tract and return a runnable model."""
if not TRACT_AVAILABLE:
print("Tract is not available. Install with: pip install tract")
return None
print(f"Loading ONNX model with tract from: {onnx_path}")
model = tract.onnx().model_for_path(onnx_path).into_optimized().into_runnable()
return model
def get_onnx_model_info(session: ort.InferenceSession) -> Tuple[List, List]:
"""Get input and output information from ONNX model."""
inputs = session.get_inputs()
outputs = session.get_outputs()
print("\nONNX Model Information:")
print("Inputs:")
for inp in inputs:
print(f" - Name: {inp.name}, Shape: {inp.shape}, Type: {inp.type}")
print("Outputs:")
for out in outputs:
print(f" - Name: {out.name}, Shape: {out.shape}, Type: {out.type}")
return inputs, outputs
def get_tflite_model_info(interpreter: tf.lite.Interpreter) -> Tuple[List, List]:
"""Get input and output information from TFLite model."""
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("\nTFLite Model Information:")
print("Inputs:")
for inp in input_details:
print(f" - Name: {inp['name']}, Shape: {inp['shape']}, Type: {inp['dtype']}")
print("Outputs:")
for out in output_details:
print(f" - Name: {out['name']}, Shape: {out['shape']}, Type: {out['dtype']}")
return input_details, output_details
def generate_random_inputs(onnx_inputs: List, seed: int = 42) -> Dict[str, np.ndarray]:
"""Generate random inputs based on ONNX model input specs."""
np.random.seed(seed)
inputs = {}
print("\nGenerating random inputs:")
for inp in onnx_inputs:
# Handle dynamic dimensions
shape = []
for dim in inp.shape:
if isinstance(dim, str) or dim is None or dim < 0:
# Default to 1 for dynamic dimensions
shape.append(1)
else:
shape.append(dim)
# Generate random data based on type
if "float" in inp.type.lower():
data = np.random.randn(*shape).astype(np.float32)
elif "int64" in inp.type.lower():
data = np.random.randint(0, 100, size=shape).astype(np.int64)
elif "int32" in inp.type.lower():
data = np.random.randint(0, 100, size=shape).astype(np.int32)
else:
# Default to float32
data = np.random.randn(*shape).astype(np.float32)
inputs[inp.name] = data
print(f" - {inp.name}: shape={data.shape}, dtype={data.dtype}")
return inputs
def load_inputs_from_file(input_path: str) -> Dict[str, np.ndarray]:
"""Load inputs from a numpy file (.npy or .npz)."""
print(f"\nLoading inputs from: {input_path}")
if input_path.endswith(".npz"):
data = np.load(input_path)
inputs = {key: data[key] for key in data.files}
elif input_path.endswith(".npy"):
data = np.load(input_path)
# Assume single input
inputs = {"input": data}
else:
raise ValueError("Input file must be .npy or .npz format")
for name, value in inputs.items():
print(f" - {name}: shape={value.shape}, dtype={value.dtype}")
return inputs
def run_onnx_model(
session: ort.InferenceSession, inputs: Dict[str, np.ndarray]
) -> List[np.ndarray]:
"""Run inference on ONNX model."""
print("\nRunning ONNX model inference...")
outputs = session.run(None, inputs)
return outputs
def run_tflite_model(
interpreter: tf.lite.Interpreter, inputs: Dict[str, np.ndarray], input_details: List
) -> List[np.ndarray]:
"""Run inference on TFLite model."""
print("Running TFLite model inference...")
# Set input tensors
for i, detail in enumerate(input_details):
# Try to match by name or use order
input_data = None
if detail["name"] in inputs:
input_data = inputs[detail["name"]]
elif len(inputs) == 1:
# If only one input, use it
input_data = list(inputs.values())[0]
elif i < len(inputs):
# Use by order
input_data = list(inputs.values())[i]
else:
raise ValueError(f"Cannot match input for TFLite input {detail['name']}")
# Ensure correct dtype
if input_data.dtype != detail["dtype"]:
input_data = input_data.astype(detail["dtype"])
interpreter.set_tensor(detail["index"], input_data)
# Run inference
interpreter.invoke()
# Get output tensors
output_details = interpreter.get_output_details()
outputs = []
for detail in output_details:
outputs.append(interpreter.get_tensor(detail["index"]))
return outputs
def run_tract_model(model: Any, inputs: Dict[str, np.ndarray]) -> List[np.ndarray]:
"""Run inference on tract model."""
if model is None:
return []
print("Running tract model inference...")
# Convert inputs to list (tract expects a list of tensors)
input_list = list(inputs.values())
# Run inference
outputs = model.run(input_list)
# Convert outputs to numpy arrays
result = []
for output in outputs:
result.append(output.to_numpy())
return result
def benchmark_onnx_model(
session: ort.InferenceSession,
inputs: Dict[str, np.ndarray],
num_runs: int = 100,
warmup_runs: int = 10,
) -> Dict[str, float]:
"""Benchmark ONNX model inference speed."""
print(f"\nBenchmarking ONNX model ({warmup_runs} warmup + {num_runs} test runs)...")
# Warmup runs
for _ in range(warmup_runs):
session.run(None, inputs)
# Timed runs
times = []
for _ in range(num_runs):
start = time.perf_counter()
session.run(None, inputs)
end = time.perf_counter()
times.append((end - start) * 1000) # Convert to ms
return {
"mean": np.mean(times),
"median": np.median(times),
"std": np.std(times),
"min": np.min(times),
"max": np.max(times),
}
def benchmark_tflite_model(
interpreter: tf.lite.Interpreter,
inputs: Dict[str, np.ndarray],
input_details: List,
num_runs: int = 100,
warmup_runs: int = 10,
) -> Dict[str, float]:
"""Benchmark TFLite model inference speed."""
print(f"Benchmarking TFLite model ({warmup_runs} warmup + {num_runs} test runs)...")
# Prepare inputs
def set_inputs():
for i, detail in enumerate(input_details):
input_data = None
if detail["name"] in inputs:
input_data = inputs[detail["name"]]
elif len(inputs) == 1:
input_data = list(inputs.values())[0]
elif i < len(inputs):
input_data = list(inputs.values())[i]
else:
raise ValueError(
f"Cannot match input for TFLite input {detail['name']}"
)
if input_data.dtype != detail["dtype"]:
input_data = input_data.astype(detail["dtype"])
interpreter.set_tensor(detail["index"], input_data)
# Warmup runs
for _ in range(warmup_runs):
set_inputs()
interpreter.invoke()
# Timed runs
times = []
for _ in range(num_runs):
set_inputs()
start = time.perf_counter()
interpreter.invoke()
end = time.perf_counter()
times.append((end - start) * 1000) # Convert to ms
return {
"mean": np.mean(times),
"median": np.median(times),
"std": np.std(times),
"min": np.min(times),
"max": np.max(times),
}
def benchmark_tract_model(
model: Any,
inputs: Dict[str, np.ndarray],
num_runs: int = 100,
warmup_runs: int = 10,
) -> Optional[Dict[str, float]]:
"""Benchmark tract model inference speed."""
if model is None:
return None
print(f"Benchmarking tract model ({warmup_runs} warmup + {num_runs} test runs)...")
# Convert inputs to list
input_list = list(inputs.values())
# Warmup runs
for _ in range(warmup_runs):
model.run(input_list)
# Timed runs
times = []
for _ in range(num_runs):
start = time.perf_counter()
model.run(input_list)
end = time.perf_counter()
times.append((end - start) * 1000) # Convert to ms
return {
"mean": np.mean(times),
"median": np.median(times),
"std": np.std(times),
"min": np.min(times),
"max": np.max(times),
}
def print_benchmark_results(
onnx_stats: Dict[str, float],
tflite_stats: Dict[str, float],
tract_stats: Optional[Dict[str, float]] = None,
) -> None:
"""Print benchmark comparison results."""
print("\n" + "=" * 80)
print("BENCHMARK RESULTS")
print("=" * 80)
print("\nONNX Model:")
print(f" Mean: {onnx_stats['mean']:.3f} ms")
print(f" Median: {onnx_stats['median']:.3f} ms")
print(f" Std: {onnx_stats['std']:.3f} ms")
print(f" Min: {onnx_stats['min']:.3f} ms")
print(f" Max: {onnx_stats['max']:.3f} ms")
print("\nTFLite Model:")
print(f" Mean: {tflite_stats['mean']:.3f} ms")
print(f" Median: {tflite_stats['median']:.3f} ms")
print(f" Std: {tflite_stats['std']:.3f} ms")
print(f" Min: {tflite_stats['min']:.3f} ms")
print(f" Max: {tflite_stats['max']:.3f} ms")
if tract_stats:
print("\nTract Model:")
print(f" Mean: {tract_stats['mean']:.3f} ms")
print(f" Median: {tract_stats['median']:.3f} ms")
print(f" Std: {tract_stats['std']:.3f} ms")
print(f" Min: {tract_stats['min']:.3f} ms")
print(f" Max: {tract_stats['max']:.3f} ms")
print("\nComparison:")
speedup = tflite_stats["mean"] / onnx_stats["mean"]
if speedup > 1:
print(f" ONNX Runtime is {speedup:.2f}x faster than TFLite")
else:
print(f" TFLite is {1 / speedup:.2f}x faster than ONNX Runtime")
print(f" Difference: {abs(onnx_stats['mean'] - tflite_stats['mean']):.3f} ms")
if tract_stats:
speedup_tract = tflite_stats["mean"] / tract_stats["mean"]
if speedup_tract > 1:
print(f" Tract is {speedup_tract:.2f}x faster than TFLite")
else:
print(f" TFLite is {1 / speedup_tract:.2f}x faster than Tract")
print(f" Difference: {abs(tract_stats['mean'] - tflite_stats['mean']):.3f} ms")
speedup_ort = onnx_stats["mean"] / tract_stats["mean"]
if speedup_ort > 1:
print(f" Tract is {speedup_ort:.2f}x faster than ONNX Runtime")
else:
print(f" ONNX Runtime is {1 / speedup_ort:.2f}x faster than Tract")
print(f" Difference: {abs(tract_stats['mean'] - onnx_stats['mean']):.3f} ms")
print("=" * 80)
def compare_outputs(
onnx_outputs: List[np.ndarray],
tflite_outputs: List[np.ndarray],
tract_outputs: Optional[List[np.ndarray]] = None,
rtol: float = 1e-5,
atol: float = 1e-5,
) -> bool:
"""Compare outputs from ONNX, TFLite, and optionally Tract models."""
print("\n" + "=" * 80)
print("COMPARISON RESULTS")
print("=" * 80)
if len(onnx_outputs) != len(tflite_outputs):
print(
f"❌ Number of outputs differs: ONNX={len(onnx_outputs)}, TFLite={len(tflite_outputs)}"
)
return False
if tract_outputs and len(onnx_outputs) != len(tract_outputs):
print(
f"❌ Number of outputs differs: ONNX={len(onnx_outputs)}, Tract={len(tract_outputs)}"
)
return False
all_match = True
for i, (onnx_out, tflite_out) in enumerate(zip(onnx_outputs, tflite_outputs)):
tract_out = tract_outputs[i] if tract_outputs else None
print(f"\nOutput {i}:")
print(f" ONNX Runtime shape: {onnx_out.shape}, dtype: {onnx_out.dtype}")
print(f" TFLite shape: {tflite_out.shape}, dtype: {tflite_out.dtype}")
if tract_out is not None:
print(f" Tract shape: {tract_out.shape}, dtype: {tract_out.dtype}")
if onnx_out.shape != tflite_out.shape:
print(" ❌ Shape mismatch between ONNX and TFLite!")
all_match = False
continue
if tract_out is not None and onnx_out.shape != tract_out.shape:
print(" ❌ Shape mismatch between ONNX and Tract!")
all_match = False
continue
# Convert to same dtype for comparison
if onnx_out.dtype != tflite_out.dtype:
print(" ⚠️ Different dtypes, converting to float32 for comparison")
onnx_out = onnx_out.astype(np.float32)
tflite_out = tflite_out.astype(np.float32)
if tract_out is not None and onnx_out.dtype != tract_out.dtype:
tract_out = tract_out.astype(np.float32)
# Compute statistics - ONNX vs TFLite
print("\n ONNX Runtime vs TFLite:")
diff = np.abs(onnx_out - tflite_out)
max_diff = np.max(diff)
mean_diff = np.mean(diff)
is_close = np.allclose(onnx_out, tflite_out, rtol=rtol, atol=atol)
print(f" Max difference: {max_diff:.10f}")
print(f" Mean difference: {mean_diff:.10f}")
print(f" Relative tolerance: {rtol}")
print(f" Absolute tolerance: {atol}")
if is_close:
print(" ✅ Outputs match within tolerance")
else:
print(" ❌ Outputs do NOT match within tolerance")
all_match = False
# Show some sample values
print("\n Sample values (first 5 elements):")
flat_onnx = onnx_out.flatten()[:5]
flat_tflite = tflite_out.flatten()[:5]
for j, (o, t) in enumerate(zip(flat_onnx, flat_tflite)):
print(
f" [{j}] ONNX: {o:.10f}, TFLite: {t:.10f}, Diff: {abs(o - t):.10f}"
)
# Compute statistics - ONNX vs Tract
if tract_out is not None:
print("\n ONNX Runtime vs Tract:")
diff_tract = np.abs(onnx_out - tract_out)
max_diff_tract = np.max(diff_tract)
mean_diff_tract = np.mean(diff_tract)
is_close_tract = np.allclose(onnx_out, tract_out, rtol=rtol, atol=atol)
print(f" Max difference: {max_diff_tract:.10f}")
print(f" Mean difference: {mean_diff_tract:.10f}")
if is_close_tract:
print(" ✅ Outputs match within tolerance")
else:
print(" ❌ Outputs do NOT match within tolerance")
all_match = False
# Show some sample values
print("\n Sample values (first 5 elements):")
flat_onnx_tract = onnx_out.flatten()[:5]
flat_tract = tract_out.flatten()[:5]
for j, (o, tr) in enumerate(zip(flat_onnx_tract, flat_tract)):
print(
f" [{j}] ONNX: {o:.10f}, Tract: {tr:.10f}, Diff: {abs(o - tr):.10f}"
)
# Compute statistics - TFLite vs Tract
print("\n TFLite vs Tract:")
diff_tflite_tract = np.abs(tflite_out - tract_out)
max_diff_tflite_tract = np.max(diff_tflite_tract)
mean_diff_tflite_tract = np.mean(diff_tflite_tract)
is_close_tflite_tract = np.allclose(
tflite_out, tract_out, rtol=rtol, atol=atol
)
print(f" Max difference: {max_diff_tflite_tract:.10f}")
print(f" Mean difference: {mean_diff_tflite_tract:.10f}")
if is_close_tflite_tract:
print(" ✅ Outputs match within tolerance")
else:
print(" ❌ Outputs do NOT match within tolerance")
all_match = False
print("\n" + "=" * 80)
if all_match:
print("✅ ALL OUTPUTS MATCH!")
else:
print("❌ SOME OUTPUTS DO NOT MATCH")
print("=" * 80)
return all_match
def main():
parser = argparse.ArgumentParser(
description="Compare ONNX and TFLite model outputs",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Compare with random inputs
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite
# Compare with custom inputs from file
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --input input.npz
# Compare with custom tolerances
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --rtol 1e-3 --atol 1e-3
# Save outputs for inspection
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --save-outputs
# Benchmark execution speed
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark
# Benchmark with custom number of runs
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark --num-runs 200 --warmup-runs 20
# Compare with tract runtime as well
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract
# Benchmark all three runtimes
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract --benchmark
""",
)
parser.add_argument("--onnx", required=True, help="Path to ONNX model")
parser.add_argument("--tflite", required=True, help="Path to TFLite model")
parser.add_argument("--input", help="Path to input file (.npy or .npz)")
parser.add_argument(
"--rtol", type=float, default=1e-5, help="Relative tolerance (default: 1e-5)"
)
parser.add_argument(
"--atol", type=float, default=1e-5, help="Absolute tolerance (default: 1e-5)"
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for input generation (default: 42)",
)
parser.add_argument(
"--save-outputs", action="store_true", help="Save outputs to files"
)
parser.add_argument(
"--benchmark",
action="store_true",
help="Benchmark execution speed of both models",
)
parser.add_argument(
"--num-runs",
type=int,
default=100,
help="Number of benchmark runs (default: 100)",
)
parser.add_argument(
"--warmup-runs",
type=int,
default=10,
help="Number of warmup runs (default: 10)",
)
parser.add_argument(
"--use-tract", action="store_true", help="Also test with tract ONNX runtime"
)
args = parser.parse_args()
# Load models
onnx_session = load_onnx_model(args.onnx)
tflite_interpreter = load_tflite_model(args.tflite)
# Load tract model if requested
tract_model = None
if args.use_tract:
if not TRACT_AVAILABLE:
print(
"\n⚠️ Warning: Tract is not installed. Install with: pip install tract"
)
print("Continuing without tract comparison...\n")
else:
tract_model = load_tract_model(args.onnx)
# Get model info
onnx_inputs, onnx_outputs = get_onnx_model_info(onnx_session)
tflite_input_details, tflite_output_details = get_tflite_model_info(
tflite_interpreter
)
# Prepare inputs
if args.input:
inputs = load_inputs_from_file(args.input)
else:
inputs = generate_random_inputs(onnx_inputs, seed=args.seed)
# Run inference
onnx_results = run_onnx_model(onnx_session, inputs)
tflite_results = run_tflite_model(tflite_interpreter, inputs, tflite_input_details)
tract_results = None
if tract_model:
tract_results = run_tract_model(tract_model, inputs)
# Save outputs if requested
if args.save_outputs:
print("\nSaving outputs...")
np.savez("onnx_outputs.npz", *onnx_results)
np.savez("tflite_outputs.npz", *tflite_results)
print(" - onnx_outputs.npz")
print(" - tflite_outputs.npz")
if tract_results:
np.savez("tract_outputs.npz", *tract_results)
print(" - tract_outputs.npz")
# Compare results
match = compare_outputs(
onnx_results, tflite_results, tract_results, rtol=args.rtol, atol=args.atol
)
# Benchmark if requested
if args.benchmark:
onnx_stats = benchmark_onnx_model(
onnx_session, inputs, args.num_runs, args.warmup_runs
)
tflite_stats = benchmark_tflite_model(
tflite_interpreter,
inputs,
tflite_input_details,
args.num_runs,
args.warmup_runs,
)
tract_stats = None
if tract_model:
tract_stats = benchmark_tract_model(
tract_model, inputs, args.num_runs, args.warmup_runs
)
print_benchmark_results(onnx_stats, tflite_stats, tract_stats)
# Return exit code
return 0 if match else 1
if __name__ == "__main__":
exit(main())