|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Script to compare the results of an ONNX model with a TFLite model given the same input. |
|
|
Optionally also compare with Tract runtime for ONNX. |
|
|
Created by Copilot. |
|
|
|
|
|
Usage: |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --input input.npy |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --rtol 1e-5 --atol 1e-5 |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract --benchmark |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import time |
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
import tensorflow as tf |
|
|
from typing import Dict, List, Tuple, Optional, Any |
|
|
|
|
|
try: |
|
|
import tract |
|
|
|
|
|
TRACT_AVAILABLE = True |
|
|
except ImportError: |
|
|
TRACT_AVAILABLE = False |
|
|
|
|
|
|
|
|
def load_onnx_model(onnx_path: str) -> ort.InferenceSession: |
|
|
"""Load an ONNX model and return an inference session.""" |
|
|
print(f"Loading ONNX model from: {onnx_path}") |
|
|
session = ort.InferenceSession(onnx_path) |
|
|
return session |
|
|
|
|
|
|
|
|
def load_tflite_model(tflite_path: str) -> tf.lite.Interpreter: |
|
|
"""Load a TFLite model and return an interpreter.""" |
|
|
print(f"Loading TFLite model from: {tflite_path}") |
|
|
interpreter = tf.lite.Interpreter(model_path=tflite_path) |
|
|
interpreter.allocate_tensors() |
|
|
return interpreter |
|
|
|
|
|
|
|
|
def load_tract_model(onnx_path: str) -> Optional[Any]: |
|
|
"""Load an ONNX model using tract and return a runnable model.""" |
|
|
if not TRACT_AVAILABLE: |
|
|
print("Tract is not available. Install with: pip install tract") |
|
|
return None |
|
|
print(f"Loading ONNX model with tract from: {onnx_path}") |
|
|
model = tract.onnx().model_for_path(onnx_path).into_optimized().into_runnable() |
|
|
return model |
|
|
|
|
|
|
|
|
def get_onnx_model_info(session: ort.InferenceSession) -> Tuple[List, List]: |
|
|
"""Get input and output information from ONNX model.""" |
|
|
inputs = session.get_inputs() |
|
|
outputs = session.get_outputs() |
|
|
|
|
|
print("\nONNX Model Information:") |
|
|
print("Inputs:") |
|
|
for inp in inputs: |
|
|
print(f" - Name: {inp.name}, Shape: {inp.shape}, Type: {inp.type}") |
|
|
print("Outputs:") |
|
|
for out in outputs: |
|
|
print(f" - Name: {out.name}, Shape: {out.shape}, Type: {out.type}") |
|
|
|
|
|
return inputs, outputs |
|
|
|
|
|
|
|
|
def get_tflite_model_info(interpreter: tf.lite.Interpreter) -> Tuple[List, List]: |
|
|
"""Get input and output information from TFLite model.""" |
|
|
input_details = interpreter.get_input_details() |
|
|
output_details = interpreter.get_output_details() |
|
|
|
|
|
print("\nTFLite Model Information:") |
|
|
print("Inputs:") |
|
|
for inp in input_details: |
|
|
print(f" - Name: {inp['name']}, Shape: {inp['shape']}, Type: {inp['dtype']}") |
|
|
print("Outputs:") |
|
|
for out in output_details: |
|
|
print(f" - Name: {out['name']}, Shape: {out['shape']}, Type: {out['dtype']}") |
|
|
|
|
|
return input_details, output_details |
|
|
|
|
|
|
|
|
def generate_random_inputs(onnx_inputs: List, seed: int = 42) -> Dict[str, np.ndarray]: |
|
|
"""Generate random inputs based on ONNX model input specs.""" |
|
|
np.random.seed(seed) |
|
|
inputs = {} |
|
|
|
|
|
print("\nGenerating random inputs:") |
|
|
for inp in onnx_inputs: |
|
|
|
|
|
shape = [] |
|
|
for dim in inp.shape: |
|
|
if isinstance(dim, str) or dim is None or dim < 0: |
|
|
|
|
|
shape.append(1) |
|
|
else: |
|
|
shape.append(dim) |
|
|
|
|
|
|
|
|
if "float" in inp.type.lower(): |
|
|
data = np.random.randn(*shape).astype(np.float32) |
|
|
elif "int64" in inp.type.lower(): |
|
|
data = np.random.randint(0, 100, size=shape).astype(np.int64) |
|
|
elif "int32" in inp.type.lower(): |
|
|
data = np.random.randint(0, 100, size=shape).astype(np.int32) |
|
|
else: |
|
|
|
|
|
data = np.random.randn(*shape).astype(np.float32) |
|
|
|
|
|
inputs[inp.name] = data |
|
|
print(f" - {inp.name}: shape={data.shape}, dtype={data.dtype}") |
|
|
|
|
|
return inputs |
|
|
|
|
|
|
|
|
def load_inputs_from_file(input_path: str) -> Dict[str, np.ndarray]: |
|
|
"""Load inputs from a numpy file (.npy or .npz).""" |
|
|
print(f"\nLoading inputs from: {input_path}") |
|
|
|
|
|
if input_path.endswith(".npz"): |
|
|
data = np.load(input_path) |
|
|
inputs = {key: data[key] for key in data.files} |
|
|
elif input_path.endswith(".npy"): |
|
|
data = np.load(input_path) |
|
|
|
|
|
inputs = {"input": data} |
|
|
else: |
|
|
raise ValueError("Input file must be .npy or .npz format") |
|
|
|
|
|
for name, value in inputs.items(): |
|
|
print(f" - {name}: shape={value.shape}, dtype={value.dtype}") |
|
|
|
|
|
return inputs |
|
|
|
|
|
|
|
|
def run_onnx_model( |
|
|
session: ort.InferenceSession, inputs: Dict[str, np.ndarray] |
|
|
) -> List[np.ndarray]: |
|
|
"""Run inference on ONNX model.""" |
|
|
print("\nRunning ONNX model inference...") |
|
|
outputs = session.run(None, inputs) |
|
|
return outputs |
|
|
|
|
|
|
|
|
def run_tflite_model( |
|
|
interpreter: tf.lite.Interpreter, inputs: Dict[str, np.ndarray], input_details: List |
|
|
) -> List[np.ndarray]: |
|
|
"""Run inference on TFLite model.""" |
|
|
print("Running TFLite model inference...") |
|
|
|
|
|
|
|
|
for i, detail in enumerate(input_details): |
|
|
|
|
|
input_data = None |
|
|
if detail["name"] in inputs: |
|
|
input_data = inputs[detail["name"]] |
|
|
elif len(inputs) == 1: |
|
|
|
|
|
input_data = list(inputs.values())[0] |
|
|
elif i < len(inputs): |
|
|
|
|
|
input_data = list(inputs.values())[i] |
|
|
else: |
|
|
raise ValueError(f"Cannot match input for TFLite input {detail['name']}") |
|
|
|
|
|
|
|
|
if input_data.dtype != detail["dtype"]: |
|
|
input_data = input_data.astype(detail["dtype"]) |
|
|
|
|
|
interpreter.set_tensor(detail["index"], input_data) |
|
|
|
|
|
|
|
|
interpreter.invoke() |
|
|
|
|
|
|
|
|
output_details = interpreter.get_output_details() |
|
|
outputs = [] |
|
|
for detail in output_details: |
|
|
outputs.append(interpreter.get_tensor(detail["index"])) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
def run_tract_model(model: Any, inputs: Dict[str, np.ndarray]) -> List[np.ndarray]: |
|
|
"""Run inference on tract model.""" |
|
|
if model is None: |
|
|
return [] |
|
|
print("Running tract model inference...") |
|
|
|
|
|
|
|
|
input_list = list(inputs.values()) |
|
|
|
|
|
|
|
|
outputs = model.run(input_list) |
|
|
|
|
|
|
|
|
result = [] |
|
|
for output in outputs: |
|
|
result.append(output.to_numpy()) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def benchmark_onnx_model( |
|
|
session: ort.InferenceSession, |
|
|
inputs: Dict[str, np.ndarray], |
|
|
num_runs: int = 100, |
|
|
warmup_runs: int = 10, |
|
|
) -> Dict[str, float]: |
|
|
"""Benchmark ONNX model inference speed.""" |
|
|
print(f"\nBenchmarking ONNX model ({warmup_runs} warmup + {num_runs} test runs)...") |
|
|
|
|
|
|
|
|
for _ in range(warmup_runs): |
|
|
session.run(None, inputs) |
|
|
|
|
|
|
|
|
times = [] |
|
|
for _ in range(num_runs): |
|
|
start = time.perf_counter() |
|
|
session.run(None, inputs) |
|
|
end = time.perf_counter() |
|
|
times.append((end - start) * 1000) |
|
|
|
|
|
return { |
|
|
"mean": np.mean(times), |
|
|
"median": np.median(times), |
|
|
"std": np.std(times), |
|
|
"min": np.min(times), |
|
|
"max": np.max(times), |
|
|
} |
|
|
|
|
|
|
|
|
def benchmark_tflite_model( |
|
|
interpreter: tf.lite.Interpreter, |
|
|
inputs: Dict[str, np.ndarray], |
|
|
input_details: List, |
|
|
num_runs: int = 100, |
|
|
warmup_runs: int = 10, |
|
|
) -> Dict[str, float]: |
|
|
"""Benchmark TFLite model inference speed.""" |
|
|
print(f"Benchmarking TFLite model ({warmup_runs} warmup + {num_runs} test runs)...") |
|
|
|
|
|
|
|
|
def set_inputs(): |
|
|
for i, detail in enumerate(input_details): |
|
|
input_data = None |
|
|
if detail["name"] in inputs: |
|
|
input_data = inputs[detail["name"]] |
|
|
elif len(inputs) == 1: |
|
|
input_data = list(inputs.values())[0] |
|
|
elif i < len(inputs): |
|
|
input_data = list(inputs.values())[i] |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Cannot match input for TFLite input {detail['name']}" |
|
|
) |
|
|
|
|
|
if input_data.dtype != detail["dtype"]: |
|
|
input_data = input_data.astype(detail["dtype"]) |
|
|
|
|
|
interpreter.set_tensor(detail["index"], input_data) |
|
|
|
|
|
|
|
|
for _ in range(warmup_runs): |
|
|
set_inputs() |
|
|
interpreter.invoke() |
|
|
|
|
|
|
|
|
times = [] |
|
|
for _ in range(num_runs): |
|
|
set_inputs() |
|
|
start = time.perf_counter() |
|
|
interpreter.invoke() |
|
|
end = time.perf_counter() |
|
|
times.append((end - start) * 1000) |
|
|
|
|
|
return { |
|
|
"mean": np.mean(times), |
|
|
"median": np.median(times), |
|
|
"std": np.std(times), |
|
|
"min": np.min(times), |
|
|
"max": np.max(times), |
|
|
} |
|
|
|
|
|
|
|
|
def benchmark_tract_model( |
|
|
model: Any, |
|
|
inputs: Dict[str, np.ndarray], |
|
|
num_runs: int = 100, |
|
|
warmup_runs: int = 10, |
|
|
) -> Optional[Dict[str, float]]: |
|
|
"""Benchmark tract model inference speed.""" |
|
|
if model is None: |
|
|
return None |
|
|
print(f"Benchmarking tract model ({warmup_runs} warmup + {num_runs} test runs)...") |
|
|
|
|
|
|
|
|
input_list = list(inputs.values()) |
|
|
|
|
|
|
|
|
for _ in range(warmup_runs): |
|
|
model.run(input_list) |
|
|
|
|
|
|
|
|
times = [] |
|
|
for _ in range(num_runs): |
|
|
start = time.perf_counter() |
|
|
model.run(input_list) |
|
|
end = time.perf_counter() |
|
|
times.append((end - start) * 1000) |
|
|
|
|
|
return { |
|
|
"mean": np.mean(times), |
|
|
"median": np.median(times), |
|
|
"std": np.std(times), |
|
|
"min": np.min(times), |
|
|
"max": np.max(times), |
|
|
} |
|
|
|
|
|
|
|
|
def print_benchmark_results( |
|
|
onnx_stats: Dict[str, float], |
|
|
tflite_stats: Dict[str, float], |
|
|
tract_stats: Optional[Dict[str, float]] = None, |
|
|
) -> None: |
|
|
"""Print benchmark comparison results.""" |
|
|
print("\n" + "=" * 80) |
|
|
print("BENCHMARK RESULTS") |
|
|
print("=" * 80) |
|
|
|
|
|
print("\nONNX Model:") |
|
|
print(f" Mean: {onnx_stats['mean']:.3f} ms") |
|
|
print(f" Median: {onnx_stats['median']:.3f} ms") |
|
|
print(f" Std: {onnx_stats['std']:.3f} ms") |
|
|
print(f" Min: {onnx_stats['min']:.3f} ms") |
|
|
print(f" Max: {onnx_stats['max']:.3f} ms") |
|
|
|
|
|
print("\nTFLite Model:") |
|
|
print(f" Mean: {tflite_stats['mean']:.3f} ms") |
|
|
print(f" Median: {tflite_stats['median']:.3f} ms") |
|
|
print(f" Std: {tflite_stats['std']:.3f} ms") |
|
|
print(f" Min: {tflite_stats['min']:.3f} ms") |
|
|
print(f" Max: {tflite_stats['max']:.3f} ms") |
|
|
|
|
|
if tract_stats: |
|
|
print("\nTract Model:") |
|
|
print(f" Mean: {tract_stats['mean']:.3f} ms") |
|
|
print(f" Median: {tract_stats['median']:.3f} ms") |
|
|
print(f" Std: {tract_stats['std']:.3f} ms") |
|
|
print(f" Min: {tract_stats['min']:.3f} ms") |
|
|
print(f" Max: {tract_stats['max']:.3f} ms") |
|
|
|
|
|
print("\nComparison:") |
|
|
speedup = tflite_stats["mean"] / onnx_stats["mean"] |
|
|
if speedup > 1: |
|
|
print(f" ONNX Runtime is {speedup:.2f}x faster than TFLite") |
|
|
else: |
|
|
print(f" TFLite is {1 / speedup:.2f}x faster than ONNX Runtime") |
|
|
print(f" Difference: {abs(onnx_stats['mean'] - tflite_stats['mean']):.3f} ms") |
|
|
|
|
|
if tract_stats: |
|
|
speedup_tract = tflite_stats["mean"] / tract_stats["mean"] |
|
|
if speedup_tract > 1: |
|
|
print(f" Tract is {speedup_tract:.2f}x faster than TFLite") |
|
|
else: |
|
|
print(f" TFLite is {1 / speedup_tract:.2f}x faster than Tract") |
|
|
print(f" Difference: {abs(tract_stats['mean'] - tflite_stats['mean']):.3f} ms") |
|
|
|
|
|
speedup_ort = onnx_stats["mean"] / tract_stats["mean"] |
|
|
if speedup_ort > 1: |
|
|
print(f" Tract is {speedup_ort:.2f}x faster than ONNX Runtime") |
|
|
else: |
|
|
print(f" ONNX Runtime is {1 / speedup_ort:.2f}x faster than Tract") |
|
|
print(f" Difference: {abs(tract_stats['mean'] - onnx_stats['mean']):.3f} ms") |
|
|
|
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
def compare_outputs( |
|
|
onnx_outputs: List[np.ndarray], |
|
|
tflite_outputs: List[np.ndarray], |
|
|
tract_outputs: Optional[List[np.ndarray]] = None, |
|
|
rtol: float = 1e-5, |
|
|
atol: float = 1e-5, |
|
|
) -> bool: |
|
|
"""Compare outputs from ONNX, TFLite, and optionally Tract models.""" |
|
|
print("\n" + "=" * 80) |
|
|
print("COMPARISON RESULTS") |
|
|
print("=" * 80) |
|
|
|
|
|
if len(onnx_outputs) != len(tflite_outputs): |
|
|
print( |
|
|
f"❌ Number of outputs differs: ONNX={len(onnx_outputs)}, TFLite={len(tflite_outputs)}" |
|
|
) |
|
|
return False |
|
|
|
|
|
if tract_outputs and len(onnx_outputs) != len(tract_outputs): |
|
|
print( |
|
|
f"❌ Number of outputs differs: ONNX={len(onnx_outputs)}, Tract={len(tract_outputs)}" |
|
|
) |
|
|
return False |
|
|
|
|
|
all_match = True |
|
|
for i, (onnx_out, tflite_out) in enumerate(zip(onnx_outputs, tflite_outputs)): |
|
|
tract_out = tract_outputs[i] if tract_outputs else None |
|
|
|
|
|
print(f"\nOutput {i}:") |
|
|
print(f" ONNX Runtime shape: {onnx_out.shape}, dtype: {onnx_out.dtype}") |
|
|
print(f" TFLite shape: {tflite_out.shape}, dtype: {tflite_out.dtype}") |
|
|
if tract_out is not None: |
|
|
print(f" Tract shape: {tract_out.shape}, dtype: {tract_out.dtype}") |
|
|
|
|
|
if onnx_out.shape != tflite_out.shape: |
|
|
print(" ❌ Shape mismatch between ONNX and TFLite!") |
|
|
all_match = False |
|
|
continue |
|
|
|
|
|
if tract_out is not None and onnx_out.shape != tract_out.shape: |
|
|
print(" ❌ Shape mismatch between ONNX and Tract!") |
|
|
all_match = False |
|
|
continue |
|
|
|
|
|
|
|
|
if onnx_out.dtype != tflite_out.dtype: |
|
|
print(" ⚠️ Different dtypes, converting to float32 for comparison") |
|
|
onnx_out = onnx_out.astype(np.float32) |
|
|
tflite_out = tflite_out.astype(np.float32) |
|
|
|
|
|
if tract_out is not None and onnx_out.dtype != tract_out.dtype: |
|
|
tract_out = tract_out.astype(np.float32) |
|
|
|
|
|
|
|
|
print("\n ONNX Runtime vs TFLite:") |
|
|
diff = np.abs(onnx_out - tflite_out) |
|
|
max_diff = np.max(diff) |
|
|
mean_diff = np.mean(diff) |
|
|
is_close = np.allclose(onnx_out, tflite_out, rtol=rtol, atol=atol) |
|
|
|
|
|
print(f" Max difference: {max_diff:.10f}") |
|
|
print(f" Mean difference: {mean_diff:.10f}") |
|
|
print(f" Relative tolerance: {rtol}") |
|
|
print(f" Absolute tolerance: {atol}") |
|
|
|
|
|
if is_close: |
|
|
print(" ✅ Outputs match within tolerance") |
|
|
else: |
|
|
print(" ❌ Outputs do NOT match within tolerance") |
|
|
all_match = False |
|
|
|
|
|
|
|
|
print("\n Sample values (first 5 elements):") |
|
|
flat_onnx = onnx_out.flatten()[:5] |
|
|
flat_tflite = tflite_out.flatten()[:5] |
|
|
for j, (o, t) in enumerate(zip(flat_onnx, flat_tflite)): |
|
|
print( |
|
|
f" [{j}] ONNX: {o:.10f}, TFLite: {t:.10f}, Diff: {abs(o - t):.10f}" |
|
|
) |
|
|
|
|
|
|
|
|
if tract_out is not None: |
|
|
print("\n ONNX Runtime vs Tract:") |
|
|
diff_tract = np.abs(onnx_out - tract_out) |
|
|
max_diff_tract = np.max(diff_tract) |
|
|
mean_diff_tract = np.mean(diff_tract) |
|
|
is_close_tract = np.allclose(onnx_out, tract_out, rtol=rtol, atol=atol) |
|
|
|
|
|
print(f" Max difference: {max_diff_tract:.10f}") |
|
|
print(f" Mean difference: {mean_diff_tract:.10f}") |
|
|
|
|
|
if is_close_tract: |
|
|
print(" ✅ Outputs match within tolerance") |
|
|
else: |
|
|
print(" ❌ Outputs do NOT match within tolerance") |
|
|
all_match = False |
|
|
|
|
|
|
|
|
print("\n Sample values (first 5 elements):") |
|
|
flat_onnx_tract = onnx_out.flatten()[:5] |
|
|
flat_tract = tract_out.flatten()[:5] |
|
|
for j, (o, tr) in enumerate(zip(flat_onnx_tract, flat_tract)): |
|
|
print( |
|
|
f" [{j}] ONNX: {o:.10f}, Tract: {tr:.10f}, Diff: {abs(o - tr):.10f}" |
|
|
) |
|
|
|
|
|
|
|
|
print("\n TFLite vs Tract:") |
|
|
diff_tflite_tract = np.abs(tflite_out - tract_out) |
|
|
max_diff_tflite_tract = np.max(diff_tflite_tract) |
|
|
mean_diff_tflite_tract = np.mean(diff_tflite_tract) |
|
|
is_close_tflite_tract = np.allclose( |
|
|
tflite_out, tract_out, rtol=rtol, atol=atol |
|
|
) |
|
|
|
|
|
print(f" Max difference: {max_diff_tflite_tract:.10f}") |
|
|
print(f" Mean difference: {mean_diff_tflite_tract:.10f}") |
|
|
|
|
|
if is_close_tflite_tract: |
|
|
print(" ✅ Outputs match within tolerance") |
|
|
else: |
|
|
print(" ❌ Outputs do NOT match within tolerance") |
|
|
all_match = False |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
if all_match: |
|
|
print("✅ ALL OUTPUTS MATCH!") |
|
|
else: |
|
|
print("❌ SOME OUTPUTS DO NOT MATCH") |
|
|
print("=" * 80) |
|
|
|
|
|
return all_match |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Compare ONNX and TFLite model outputs", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
# Compare with random inputs |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite |
|
|
|
|
|
# Compare with custom inputs from file |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --input input.npz |
|
|
|
|
|
# Compare with custom tolerances |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --rtol 1e-3 --atol 1e-3 |
|
|
|
|
|
# Save outputs for inspection |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --save-outputs |
|
|
|
|
|
# Benchmark execution speed |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark |
|
|
|
|
|
# Benchmark with custom number of runs |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark --num-runs 200 --warmup-runs 20 |
|
|
|
|
|
# Compare with tract runtime as well |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract |
|
|
|
|
|
# Benchmark all three runtimes |
|
|
python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract --benchmark |
|
|
""", |
|
|
) |
|
|
|
|
|
parser.add_argument("--onnx", required=True, help="Path to ONNX model") |
|
|
parser.add_argument("--tflite", required=True, help="Path to TFLite model") |
|
|
parser.add_argument("--input", help="Path to input file (.npy or .npz)") |
|
|
parser.add_argument( |
|
|
"--rtol", type=float, default=1e-5, help="Relative tolerance (default: 1e-5)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--atol", type=float, default=1e-5, help="Absolute tolerance (default: 1e-5)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
type=int, |
|
|
default=42, |
|
|
help="Random seed for input generation (default: 42)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--save-outputs", action="store_true", help="Save outputs to files" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--benchmark", |
|
|
action="store_true", |
|
|
help="Benchmark execution speed of both models", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num-runs", |
|
|
type=int, |
|
|
default=100, |
|
|
help="Number of benchmark runs (default: 100)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--warmup-runs", |
|
|
type=int, |
|
|
default=10, |
|
|
help="Number of warmup runs (default: 10)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--use-tract", action="store_true", help="Also test with tract ONNX runtime" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
onnx_session = load_onnx_model(args.onnx) |
|
|
tflite_interpreter = load_tflite_model(args.tflite) |
|
|
|
|
|
|
|
|
tract_model = None |
|
|
if args.use_tract: |
|
|
if not TRACT_AVAILABLE: |
|
|
print( |
|
|
"\n⚠️ Warning: Tract is not installed. Install with: pip install tract" |
|
|
) |
|
|
print("Continuing without tract comparison...\n") |
|
|
else: |
|
|
tract_model = load_tract_model(args.onnx) |
|
|
|
|
|
|
|
|
onnx_inputs, onnx_outputs = get_onnx_model_info(onnx_session) |
|
|
tflite_input_details, tflite_output_details = get_tflite_model_info( |
|
|
tflite_interpreter |
|
|
) |
|
|
|
|
|
|
|
|
if args.input: |
|
|
inputs = load_inputs_from_file(args.input) |
|
|
else: |
|
|
inputs = generate_random_inputs(onnx_inputs, seed=args.seed) |
|
|
|
|
|
|
|
|
onnx_results = run_onnx_model(onnx_session, inputs) |
|
|
tflite_results = run_tflite_model(tflite_interpreter, inputs, tflite_input_details) |
|
|
tract_results = None |
|
|
if tract_model: |
|
|
tract_results = run_tract_model(tract_model, inputs) |
|
|
|
|
|
|
|
|
if args.save_outputs: |
|
|
print("\nSaving outputs...") |
|
|
np.savez("onnx_outputs.npz", *onnx_results) |
|
|
np.savez("tflite_outputs.npz", *tflite_results) |
|
|
print(" - onnx_outputs.npz") |
|
|
print(" - tflite_outputs.npz") |
|
|
if tract_results: |
|
|
np.savez("tract_outputs.npz", *tract_results) |
|
|
print(" - tract_outputs.npz") |
|
|
|
|
|
|
|
|
match = compare_outputs( |
|
|
onnx_results, tflite_results, tract_results, rtol=args.rtol, atol=args.atol |
|
|
) |
|
|
|
|
|
|
|
|
if args.benchmark: |
|
|
onnx_stats = benchmark_onnx_model( |
|
|
onnx_session, inputs, args.num_runs, args.warmup_runs |
|
|
) |
|
|
tflite_stats = benchmark_tflite_model( |
|
|
tflite_interpreter, |
|
|
inputs, |
|
|
tflite_input_details, |
|
|
args.num_runs, |
|
|
args.warmup_runs, |
|
|
) |
|
|
tract_stats = None |
|
|
if tract_model: |
|
|
tract_stats = benchmark_tract_model( |
|
|
tract_model, inputs, args.num_runs, args.warmup_runs |
|
|
) |
|
|
print_benchmark_results(onnx_stats, tflite_stats, tract_stats) |
|
|
|
|
|
|
|
|
return 0 if match else 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|