BirdNET-onnx / scripts /optimize.py
justinchuby's picture
Simplify model (#2)
1f1b5fd verified
import onnxscript
import onnx_ir as ir
import onnx_ir.passes.common
import numpy as np
class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
def pattern(self, op, x, dft_length):
x = op.Reshape(x, _allow_other_inputs=True)
dft = op.DFT(x, dft_length, _outputs=["dft_output"])
real_part = op.Slice(dft, [0], [1], [-1])
return op.Squeeze(real_part, [-1])
def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value):
# Get the DFT node attributes
dft_node = dft_output.producer()
assert dft_node is not None
dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item()
# Create one-sided DFT matrix (only real part, DC to Nyquist)
# The real part of DFT is: Re(DFT[k]) = sum(x[n] * cos(2*pi*k*n/N))
# For one-sided DFT, we only need frequencies from 0 to Nyquist (dft_size//2 + 1)
num_freqs = dft_size // 2 + 1
# Vectorized creation of DFT matrix
n = np.arange(dft_size, dtype=np.float32)[:, np.newaxis] # Shape: (dft_size, 1)
k = np.arange(num_freqs, dtype=np.float32)[
np.newaxis, :
] # Shape: (1, num_freqs)
dft_matrix = np.cos(
2 * np.pi * k * n / dft_size
) # Shape: (dft_size, num_freqs)
# Create constant node for the DFT matrix
dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix")
# DFT axis is already at the end, direct matrix multiplication
result = op.MatMul(x, dft_matrix)
return result
class ReplaceSplit(onnxscript.rewriter.RewriteRuleClassBase):
def pattern(self, op, x):
return op.Split(
x, _allow_other_inputs=True, _outputs=["split_out_1", "split_out_2"]
)
def rewrite(self, op, x: ir.Value, **kwargs):
zero = op.initializer(ir.tensor(np.array([0], dtype=np.int64)), "zero")
batch_size = op.Gather(x, zero)
sample_size = op.initializer(
ir.tensor(np.array([144000], dtype=np.int32)), "sample_size"
)
return batch_size, sample_size
class RemoveCast(onnxscript.rewriter.RewriteRuleClassBase):
def pattern(self, op, x):
return op.Cast(x)
def rewrite(self, op, x: ir.Value, **kwargs):
return op.Identity(x)
model = ir.load("model.onnx")
# Set dynamic axes
model.graph.inputs[0].shape = ir.Shape(["batch", 144000])
model.graph.outputs[0].shape = ir.Shape(["batch", 6522])
onnxscript.rewriter.rewrite(
model,
[ReplaceDftWithMatMulRule().rule(), ReplaceSplit().rule(), RemoveCast().rule()],
)
# Change all int32 initializers to int64
initializers = list(model.graph.initializers.values())
for initializer in initializers:
if initializer.dtype == ir.DataType.INT32:
int32_array = initializer.const_value.numpy()
int64_array = int32_array.astype(np.int64)
new_initializer = ir.val(initializer.name, const_value=ir.tensor(int64_array))
model.graph.initializers.pop(initializer.name)
model.graph.initializers.add(new_initializer)
initializer.replace_all_uses_with(new_initializer)
onnxscript.optimizer.optimize(
model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
# Remove Slice-Reshape
def remove_slice_reshape(model: ir.Model):
mul_node = model.graph.node("model/MEL_SPEC1/Mul")
first_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_1")
first_shape = ir.val(
"first_shape", const_value=ir.tensor([-1, 72000, 2], dtype=ir.DataType.INT64)
)
model.graph.initializers.add(first_shape)
second_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_1")
second_shape = ir.val(
"second_shape", const_value=ir.tensor([-1, 18000, 8], dtype=ir.DataType.INT64)
)
model.graph.initializers.add(second_shape)
third_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_4")
third_shape = ir.val(
"third_shape", const_value=ir.tensor([-1, 511, 2048], dtype=ir.DataType.INT64)
)
model.graph.initializers.add(third_shape)
fourth_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_4")
fourth_shape = ir.val(
"fourth_shape", const_value=ir.tensor([-1, 511, 1024], dtype=ir.DataType.INT64)
)
model.graph.initializers.add(fourth_shape)
# Replace with Mul-Reshape-Gather
first_reshape.replace_input_with(0, mul_node.outputs[0])
first_reshape.replace_input_with(1, first_shape)
second_reshape.replace_input_with(0, mul_node.outputs[0])
second_reshape.replace_input_with(1, second_shape)
third_reshape.replace_input_with(1, third_shape)
fourth_reshape.replace_input_with(1, fourth_shape)
remove_slice_reshape(model)
# Run DCE again
onnxscript.optimizer.optimize(
model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
model.graph.inputs[0].name = "input"
model.graph.outputs[0].name = "output"
model.ir_version = 10
model.producer_name = "onnx-ir"
model.graph.name = "BirdNET-v2.4"
ir.save(model, "birdnet.onnx")