|
|
import onnxscript |
|
|
import onnx_ir as ir |
|
|
import onnx_ir.passes.common |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase): |
|
|
def pattern(self, op, x, dft_length): |
|
|
x = op.Reshape(x, _allow_other_inputs=True) |
|
|
dft = op.DFT(x, dft_length, _outputs=["dft_output"]) |
|
|
real_part = op.Slice(dft, [0], [1], [-1]) |
|
|
return op.Squeeze(real_part, [-1]) |
|
|
|
|
|
def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value): |
|
|
|
|
|
dft_node = dft_output.producer() |
|
|
assert dft_node is not None |
|
|
|
|
|
dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_freqs = dft_size // 2 + 1 |
|
|
|
|
|
|
|
|
n = np.arange(dft_size, dtype=np.float32)[:, np.newaxis] |
|
|
k = np.arange(num_freqs, dtype=np.float32)[ |
|
|
np.newaxis, : |
|
|
] |
|
|
dft_matrix = np.cos( |
|
|
2 * np.pi * k * n / dft_size |
|
|
) |
|
|
|
|
|
|
|
|
dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix") |
|
|
|
|
|
|
|
|
result = op.MatMul(x, dft_matrix) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
class ReplaceSplit(onnxscript.rewriter.RewriteRuleClassBase): |
|
|
def pattern(self, op, x): |
|
|
return op.Split( |
|
|
x, _allow_other_inputs=True, _outputs=["split_out_1", "split_out_2"] |
|
|
) |
|
|
|
|
|
def rewrite(self, op, x: ir.Value, **kwargs): |
|
|
zero = op.initializer(ir.tensor(np.array([0], dtype=np.int64)), "zero") |
|
|
batch_size = op.Gather(x, zero) |
|
|
sample_size = op.initializer( |
|
|
ir.tensor(np.array([144000], dtype=np.int32)), "sample_size" |
|
|
) |
|
|
return batch_size, sample_size |
|
|
|
|
|
|
|
|
class RemoveCast(onnxscript.rewriter.RewriteRuleClassBase): |
|
|
def pattern(self, op, x): |
|
|
return op.Cast(x) |
|
|
|
|
|
def rewrite(self, op, x: ir.Value, **kwargs): |
|
|
return op.Identity(x) |
|
|
|
|
|
|
|
|
model = ir.load("model.onnx") |
|
|
|
|
|
|
|
|
model.graph.inputs[0].shape = ir.Shape(["batch", 144000]) |
|
|
model.graph.outputs[0].shape = ir.Shape(["batch", 6522]) |
|
|
|
|
|
onnxscript.rewriter.rewrite( |
|
|
model, |
|
|
[ReplaceDftWithMatMulRule().rule(), ReplaceSplit().rule(), RemoveCast().rule()], |
|
|
) |
|
|
|
|
|
|
|
|
initializers = list(model.graph.initializers.values()) |
|
|
for initializer in initializers: |
|
|
if initializer.dtype == ir.DataType.INT32: |
|
|
int32_array = initializer.const_value.numpy() |
|
|
int64_array = int32_array.astype(np.int64) |
|
|
new_initializer = ir.val(initializer.name, const_value=ir.tensor(int64_array)) |
|
|
model.graph.initializers.pop(initializer.name) |
|
|
model.graph.initializers.add(new_initializer) |
|
|
initializer.replace_all_uses_with(new_initializer) |
|
|
|
|
|
onnxscript.optimizer.optimize( |
|
|
model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def remove_slice_reshape(model: ir.Model): |
|
|
mul_node = model.graph.node("model/MEL_SPEC1/Mul") |
|
|
first_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_1") |
|
|
first_shape = ir.val( |
|
|
"first_shape", const_value=ir.tensor([-1, 72000, 2], dtype=ir.DataType.INT64) |
|
|
) |
|
|
model.graph.initializers.add(first_shape) |
|
|
second_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_1") |
|
|
second_shape = ir.val( |
|
|
"second_shape", const_value=ir.tensor([-1, 18000, 8], dtype=ir.DataType.INT64) |
|
|
) |
|
|
model.graph.initializers.add(second_shape) |
|
|
|
|
|
third_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_4") |
|
|
third_shape = ir.val( |
|
|
"third_shape", const_value=ir.tensor([-1, 511, 2048], dtype=ir.DataType.INT64) |
|
|
) |
|
|
model.graph.initializers.add(third_shape) |
|
|
fourth_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_4") |
|
|
fourth_shape = ir.val( |
|
|
"fourth_shape", const_value=ir.tensor([-1, 511, 1024], dtype=ir.DataType.INT64) |
|
|
) |
|
|
model.graph.initializers.add(fourth_shape) |
|
|
|
|
|
|
|
|
first_reshape.replace_input_with(0, mul_node.outputs[0]) |
|
|
first_reshape.replace_input_with(1, first_shape) |
|
|
second_reshape.replace_input_with(0, mul_node.outputs[0]) |
|
|
second_reshape.replace_input_with(1, second_shape) |
|
|
third_reshape.replace_input_with(1, third_shape) |
|
|
fourth_reshape.replace_input_with(1, fourth_shape) |
|
|
|
|
|
|
|
|
remove_slice_reshape(model) |
|
|
|
|
|
onnxscript.optimizer.optimize( |
|
|
model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 |
|
|
) |
|
|
|
|
|
|
|
|
onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model) |
|
|
model.graph.inputs[0].name = "input" |
|
|
model.graph.outputs[0].name = "output" |
|
|
model.ir_version = 10 |
|
|
model.producer_name = "onnx-ir" |
|
|
model.graph.name = "BirdNET-v2.4" |
|
|
|
|
|
ir.save(model, "birdnet.onnx") |
|
|
|