#!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 """ Script to compare the results of an ONNX model with a TFLite model given the same input. Optionally also compare with Tract runtime for ONNX. Created by Copilot. Usage: python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --input input.npy python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --rtol 1e-5 --atol 1e-5 python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract --benchmark """ import argparse import time import numpy as np import onnxruntime as ort import tensorflow as tf from typing import Dict, List, Tuple, Optional, Any try: import tract TRACT_AVAILABLE = True except ImportError: TRACT_AVAILABLE = False def load_onnx_model(onnx_path: str) -> ort.InferenceSession: """Load an ONNX model and return an inference session.""" print(f"Loading ONNX model from: {onnx_path}") session = ort.InferenceSession(onnx_path) return session def load_tflite_model(tflite_path: str) -> tf.lite.Interpreter: """Load a TFLite model and return an interpreter.""" print(f"Loading TFLite model from: {tflite_path}") interpreter = tf.lite.Interpreter(model_path=tflite_path) interpreter.allocate_tensors() return interpreter def load_tract_model(onnx_path: str) -> Optional[Any]: """Load an ONNX model using tract and return a runnable model.""" if not TRACT_AVAILABLE: print("Tract is not available. Install with: pip install tract") return None print(f"Loading ONNX model with tract from: {onnx_path}") model = tract.onnx().model_for_path(onnx_path).into_optimized().into_runnable() return model def get_onnx_model_info(session: ort.InferenceSession) -> Tuple[List, List]: """Get input and output information from ONNX model.""" inputs = session.get_inputs() outputs = session.get_outputs() print("\nONNX Model Information:") print("Inputs:") for inp in inputs: print(f" - Name: {inp.name}, Shape: {inp.shape}, Type: {inp.type}") print("Outputs:") for out in outputs: print(f" - Name: {out.name}, Shape: {out.shape}, Type: {out.type}") return inputs, outputs def get_tflite_model_info(interpreter: tf.lite.Interpreter) -> Tuple[List, List]: """Get input and output information from TFLite model.""" input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() print("\nTFLite Model Information:") print("Inputs:") for inp in input_details: print(f" - Name: {inp['name']}, Shape: {inp['shape']}, Type: {inp['dtype']}") print("Outputs:") for out in output_details: print(f" - Name: {out['name']}, Shape: {out['shape']}, Type: {out['dtype']}") return input_details, output_details def generate_random_inputs(onnx_inputs: List, seed: int = 42) -> Dict[str, np.ndarray]: """Generate random inputs based on ONNX model input specs.""" np.random.seed(seed) inputs = {} print("\nGenerating random inputs:") for inp in onnx_inputs: # Handle dynamic dimensions shape = [] for dim in inp.shape: if isinstance(dim, str) or dim is None or dim < 0: # Default to 1 for dynamic dimensions shape.append(1) else: shape.append(dim) # Generate random data based on type if "float" in inp.type.lower(): data = np.random.randn(*shape).astype(np.float32) elif "int64" in inp.type.lower(): data = np.random.randint(0, 100, size=shape).astype(np.int64) elif "int32" in inp.type.lower(): data = np.random.randint(0, 100, size=shape).astype(np.int32) else: # Default to float32 data = np.random.randn(*shape).astype(np.float32) inputs[inp.name] = data print(f" - {inp.name}: shape={data.shape}, dtype={data.dtype}") return inputs def load_inputs_from_file(input_path: str) -> Dict[str, np.ndarray]: """Load inputs from a numpy file (.npy or .npz).""" print(f"\nLoading inputs from: {input_path}") if input_path.endswith(".npz"): data = np.load(input_path) inputs = {key: data[key] for key in data.files} elif input_path.endswith(".npy"): data = np.load(input_path) # Assume single input inputs = {"input": data} else: raise ValueError("Input file must be .npy or .npz format") for name, value in inputs.items(): print(f" - {name}: shape={value.shape}, dtype={value.dtype}") return inputs def run_onnx_model( session: ort.InferenceSession, inputs: Dict[str, np.ndarray] ) -> List[np.ndarray]: """Run inference on ONNX model.""" print("\nRunning ONNX model inference...") outputs = session.run(None, inputs) return outputs def run_tflite_model( interpreter: tf.lite.Interpreter, inputs: Dict[str, np.ndarray], input_details: List ) -> List[np.ndarray]: """Run inference on TFLite model.""" print("Running TFLite model inference...") # Set input tensors for i, detail in enumerate(input_details): # Try to match by name or use order input_data = None if detail["name"] in inputs: input_data = inputs[detail["name"]] elif len(inputs) == 1: # If only one input, use it input_data = list(inputs.values())[0] elif i < len(inputs): # Use by order input_data = list(inputs.values())[i] else: raise ValueError(f"Cannot match input for TFLite input {detail['name']}") # Ensure correct dtype if input_data.dtype != detail["dtype"]: input_data = input_data.astype(detail["dtype"]) interpreter.set_tensor(detail["index"], input_data) # Run inference interpreter.invoke() # Get output tensors output_details = interpreter.get_output_details() outputs = [] for detail in output_details: outputs.append(interpreter.get_tensor(detail["index"])) return outputs def run_tract_model(model: Any, inputs: Dict[str, np.ndarray]) -> List[np.ndarray]: """Run inference on tract model.""" if model is None: return [] print("Running tract model inference...") # Convert inputs to list (tract expects a list of tensors) input_list = list(inputs.values()) # Run inference outputs = model.run(input_list) # Convert outputs to numpy arrays result = [] for output in outputs: result.append(output.to_numpy()) return result def benchmark_onnx_model( session: ort.InferenceSession, inputs: Dict[str, np.ndarray], num_runs: int = 100, warmup_runs: int = 10, ) -> Dict[str, float]: """Benchmark ONNX model inference speed.""" print(f"\nBenchmarking ONNX model ({warmup_runs} warmup + {num_runs} test runs)...") # Warmup runs for _ in range(warmup_runs): session.run(None, inputs) # Timed runs times = [] for _ in range(num_runs): start = time.perf_counter() session.run(None, inputs) end = time.perf_counter() times.append((end - start) * 1000) # Convert to ms return { "mean": np.mean(times), "median": np.median(times), "std": np.std(times), "min": np.min(times), "max": np.max(times), } def benchmark_tflite_model( interpreter: tf.lite.Interpreter, inputs: Dict[str, np.ndarray], input_details: List, num_runs: int = 100, warmup_runs: int = 10, ) -> Dict[str, float]: """Benchmark TFLite model inference speed.""" print(f"Benchmarking TFLite model ({warmup_runs} warmup + {num_runs} test runs)...") # Prepare inputs def set_inputs(): for i, detail in enumerate(input_details): input_data = None if detail["name"] in inputs: input_data = inputs[detail["name"]] elif len(inputs) == 1: input_data = list(inputs.values())[0] elif i < len(inputs): input_data = list(inputs.values())[i] else: raise ValueError( f"Cannot match input for TFLite input {detail['name']}" ) if input_data.dtype != detail["dtype"]: input_data = input_data.astype(detail["dtype"]) interpreter.set_tensor(detail["index"], input_data) # Warmup runs for _ in range(warmup_runs): set_inputs() interpreter.invoke() # Timed runs times = [] for _ in range(num_runs): set_inputs() start = time.perf_counter() interpreter.invoke() end = time.perf_counter() times.append((end - start) * 1000) # Convert to ms return { "mean": np.mean(times), "median": np.median(times), "std": np.std(times), "min": np.min(times), "max": np.max(times), } def benchmark_tract_model( model: Any, inputs: Dict[str, np.ndarray], num_runs: int = 100, warmup_runs: int = 10, ) -> Optional[Dict[str, float]]: """Benchmark tract model inference speed.""" if model is None: return None print(f"Benchmarking tract model ({warmup_runs} warmup + {num_runs} test runs)...") # Convert inputs to list input_list = list(inputs.values()) # Warmup runs for _ in range(warmup_runs): model.run(input_list) # Timed runs times = [] for _ in range(num_runs): start = time.perf_counter() model.run(input_list) end = time.perf_counter() times.append((end - start) * 1000) # Convert to ms return { "mean": np.mean(times), "median": np.median(times), "std": np.std(times), "min": np.min(times), "max": np.max(times), } def print_benchmark_results( onnx_stats: Dict[str, float], tflite_stats: Dict[str, float], tract_stats: Optional[Dict[str, float]] = None, ) -> None: """Print benchmark comparison results.""" print("\n" + "=" * 80) print("BENCHMARK RESULTS") print("=" * 80) print("\nONNX Model:") print(f" Mean: {onnx_stats['mean']:.3f} ms") print(f" Median: {onnx_stats['median']:.3f} ms") print(f" Std: {onnx_stats['std']:.3f} ms") print(f" Min: {onnx_stats['min']:.3f} ms") print(f" Max: {onnx_stats['max']:.3f} ms") print("\nTFLite Model:") print(f" Mean: {tflite_stats['mean']:.3f} ms") print(f" Median: {tflite_stats['median']:.3f} ms") print(f" Std: {tflite_stats['std']:.3f} ms") print(f" Min: {tflite_stats['min']:.3f} ms") print(f" Max: {tflite_stats['max']:.3f} ms") if tract_stats: print("\nTract Model:") print(f" Mean: {tract_stats['mean']:.3f} ms") print(f" Median: {tract_stats['median']:.3f} ms") print(f" Std: {tract_stats['std']:.3f} ms") print(f" Min: {tract_stats['min']:.3f} ms") print(f" Max: {tract_stats['max']:.3f} ms") print("\nComparison:") speedup = tflite_stats["mean"] / onnx_stats["mean"] if speedup > 1: print(f" ONNX Runtime is {speedup:.2f}x faster than TFLite") else: print(f" TFLite is {1 / speedup:.2f}x faster than ONNX Runtime") print(f" Difference: {abs(onnx_stats['mean'] - tflite_stats['mean']):.3f} ms") if tract_stats: speedup_tract = tflite_stats["mean"] / tract_stats["mean"] if speedup_tract > 1: print(f" Tract is {speedup_tract:.2f}x faster than TFLite") else: print(f" TFLite is {1 / speedup_tract:.2f}x faster than Tract") print(f" Difference: {abs(tract_stats['mean'] - tflite_stats['mean']):.3f} ms") speedup_ort = onnx_stats["mean"] / tract_stats["mean"] if speedup_ort > 1: print(f" Tract is {speedup_ort:.2f}x faster than ONNX Runtime") else: print(f" ONNX Runtime is {1 / speedup_ort:.2f}x faster than Tract") print(f" Difference: {abs(tract_stats['mean'] - onnx_stats['mean']):.3f} ms") print("=" * 80) def compare_outputs( onnx_outputs: List[np.ndarray], tflite_outputs: List[np.ndarray], tract_outputs: Optional[List[np.ndarray]] = None, rtol: float = 1e-5, atol: float = 1e-5, ) -> bool: """Compare outputs from ONNX, TFLite, and optionally Tract models.""" print("\n" + "=" * 80) print("COMPARISON RESULTS") print("=" * 80) if len(onnx_outputs) != len(tflite_outputs): print( f"❌ Number of outputs differs: ONNX={len(onnx_outputs)}, TFLite={len(tflite_outputs)}" ) return False if tract_outputs and len(onnx_outputs) != len(tract_outputs): print( f"❌ Number of outputs differs: ONNX={len(onnx_outputs)}, Tract={len(tract_outputs)}" ) return False all_match = True for i, (onnx_out, tflite_out) in enumerate(zip(onnx_outputs, tflite_outputs)): tract_out = tract_outputs[i] if tract_outputs else None print(f"\nOutput {i}:") print(f" ONNX Runtime shape: {onnx_out.shape}, dtype: {onnx_out.dtype}") print(f" TFLite shape: {tflite_out.shape}, dtype: {tflite_out.dtype}") if tract_out is not None: print(f" Tract shape: {tract_out.shape}, dtype: {tract_out.dtype}") if onnx_out.shape != tflite_out.shape: print(" ❌ Shape mismatch between ONNX and TFLite!") all_match = False continue if tract_out is not None and onnx_out.shape != tract_out.shape: print(" ❌ Shape mismatch between ONNX and Tract!") all_match = False continue # Convert to same dtype for comparison if onnx_out.dtype != tflite_out.dtype: print(" ⚠️ Different dtypes, converting to float32 for comparison") onnx_out = onnx_out.astype(np.float32) tflite_out = tflite_out.astype(np.float32) if tract_out is not None and onnx_out.dtype != tract_out.dtype: tract_out = tract_out.astype(np.float32) # Compute statistics - ONNX vs TFLite print("\n ONNX Runtime vs TFLite:") diff = np.abs(onnx_out - tflite_out) max_diff = np.max(diff) mean_diff = np.mean(diff) is_close = np.allclose(onnx_out, tflite_out, rtol=rtol, atol=atol) print(f" Max difference: {max_diff:.10f}") print(f" Mean difference: {mean_diff:.10f}") print(f" Relative tolerance: {rtol}") print(f" Absolute tolerance: {atol}") if is_close: print(" ✅ Outputs match within tolerance") else: print(" ❌ Outputs do NOT match within tolerance") all_match = False # Show some sample values print("\n Sample values (first 5 elements):") flat_onnx = onnx_out.flatten()[:5] flat_tflite = tflite_out.flatten()[:5] for j, (o, t) in enumerate(zip(flat_onnx, flat_tflite)): print( f" [{j}] ONNX: {o:.10f}, TFLite: {t:.10f}, Diff: {abs(o - t):.10f}" ) # Compute statistics - ONNX vs Tract if tract_out is not None: print("\n ONNX Runtime vs Tract:") diff_tract = np.abs(onnx_out - tract_out) max_diff_tract = np.max(diff_tract) mean_diff_tract = np.mean(diff_tract) is_close_tract = np.allclose(onnx_out, tract_out, rtol=rtol, atol=atol) print(f" Max difference: {max_diff_tract:.10f}") print(f" Mean difference: {mean_diff_tract:.10f}") if is_close_tract: print(" ✅ Outputs match within tolerance") else: print(" ❌ Outputs do NOT match within tolerance") all_match = False # Show some sample values print("\n Sample values (first 5 elements):") flat_onnx_tract = onnx_out.flatten()[:5] flat_tract = tract_out.flatten()[:5] for j, (o, tr) in enumerate(zip(flat_onnx_tract, flat_tract)): print( f" [{j}] ONNX: {o:.10f}, Tract: {tr:.10f}, Diff: {abs(o - tr):.10f}" ) # Compute statistics - TFLite vs Tract print("\n TFLite vs Tract:") diff_tflite_tract = np.abs(tflite_out - tract_out) max_diff_tflite_tract = np.max(diff_tflite_tract) mean_diff_tflite_tract = np.mean(diff_tflite_tract) is_close_tflite_tract = np.allclose( tflite_out, tract_out, rtol=rtol, atol=atol ) print(f" Max difference: {max_diff_tflite_tract:.10f}") print(f" Mean difference: {mean_diff_tflite_tract:.10f}") if is_close_tflite_tract: print(" ✅ Outputs match within tolerance") else: print(" ❌ Outputs do NOT match within tolerance") all_match = False print("\n" + "=" * 80) if all_match: print("✅ ALL OUTPUTS MATCH!") else: print("❌ SOME OUTPUTS DO NOT MATCH") print("=" * 80) return all_match def main(): parser = argparse.ArgumentParser( description="Compare ONNX and TFLite model outputs", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Compare with random inputs python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite # Compare with custom inputs from file python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --input input.npz # Compare with custom tolerances python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --rtol 1e-3 --atol 1e-3 # Save outputs for inspection python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --save-outputs # Benchmark execution speed python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark # Benchmark with custom number of runs python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark --num-runs 200 --warmup-runs 20 # Compare with tract runtime as well python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract # Benchmark all three runtimes python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract --benchmark """, ) parser.add_argument("--onnx", required=True, help="Path to ONNX model") parser.add_argument("--tflite", required=True, help="Path to TFLite model") parser.add_argument("--input", help="Path to input file (.npy or .npz)") parser.add_argument( "--rtol", type=float, default=1e-5, help="Relative tolerance (default: 1e-5)" ) parser.add_argument( "--atol", type=float, default=1e-5, help="Absolute tolerance (default: 1e-5)" ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for input generation (default: 42)", ) parser.add_argument( "--save-outputs", action="store_true", help="Save outputs to files" ) parser.add_argument( "--benchmark", action="store_true", help="Benchmark execution speed of both models", ) parser.add_argument( "--num-runs", type=int, default=100, help="Number of benchmark runs (default: 100)", ) parser.add_argument( "--warmup-runs", type=int, default=10, help="Number of warmup runs (default: 10)", ) parser.add_argument( "--use-tract", action="store_true", help="Also test with tract ONNX runtime" ) args = parser.parse_args() # Load models onnx_session = load_onnx_model(args.onnx) tflite_interpreter = load_tflite_model(args.tflite) # Load tract model if requested tract_model = None if args.use_tract: if not TRACT_AVAILABLE: print( "\n⚠️ Warning: Tract is not installed. Install with: pip install tract" ) print("Continuing without tract comparison...\n") else: tract_model = load_tract_model(args.onnx) # Get model info onnx_inputs, onnx_outputs = get_onnx_model_info(onnx_session) tflite_input_details, tflite_output_details = get_tflite_model_info( tflite_interpreter ) # Prepare inputs if args.input: inputs = load_inputs_from_file(args.input) else: inputs = generate_random_inputs(onnx_inputs, seed=args.seed) # Run inference onnx_results = run_onnx_model(onnx_session, inputs) tflite_results = run_tflite_model(tflite_interpreter, inputs, tflite_input_details) tract_results = None if tract_model: tract_results = run_tract_model(tract_model, inputs) # Save outputs if requested if args.save_outputs: print("\nSaving outputs...") np.savez("onnx_outputs.npz", *onnx_results) np.savez("tflite_outputs.npz", *tflite_results) print(" - onnx_outputs.npz") print(" - tflite_outputs.npz") if tract_results: np.savez("tract_outputs.npz", *tract_results) print(" - tract_outputs.npz") # Compare results match = compare_outputs( onnx_results, tflite_results, tract_results, rtol=args.rtol, atol=args.atol ) # Benchmark if requested if args.benchmark: onnx_stats = benchmark_onnx_model( onnx_session, inputs, args.num_runs, args.warmup_runs ) tflite_stats = benchmark_tflite_model( tflite_interpreter, inputs, tflite_input_details, args.num_runs, args.warmup_runs, ) tract_stats = None if tract_model: tract_stats = benchmark_tract_model( tract_model, inputs, args.num_runs, args.warmup_runs ) print_benchmark_results(onnx_stats, tflite_stats, tract_stats) # Return exit code return 0 if match else 1 if __name__ == "__main__": exit(main())