import onnxscript import onnx_ir as ir import onnx_ir.passes.common import numpy as np class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase): def pattern(self, op, x, dft_length): x = op.Reshape(x, _allow_other_inputs=True) dft = op.DFT(x, dft_length, _outputs=["dft_output"]) real_part = op.Slice(dft, [0], [1], [-1]) return op.Squeeze(real_part, [-1]) def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value): # Get the DFT node attributes dft_node = dft_output.producer() assert dft_node is not None dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item() # Create one-sided DFT matrix (only real part, DC to Nyquist) # The real part of DFT is: Re(DFT[k]) = sum(x[n] * cos(2*pi*k*n/N)) # For one-sided DFT, we only need frequencies from 0 to Nyquist (dft_size//2 + 1) num_freqs = dft_size // 2 + 1 # Vectorized creation of DFT matrix n = np.arange(dft_size, dtype=np.float32)[:, np.newaxis] # Shape: (dft_size, 1) k = np.arange(num_freqs, dtype=np.float32)[ np.newaxis, : ] # Shape: (1, num_freqs) dft_matrix = np.cos( 2 * np.pi * k * n / dft_size ) # Shape: (dft_size, num_freqs) # Create constant node for the DFT matrix dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix") # DFT axis is already at the end, direct matrix multiplication result = op.MatMul(x, dft_matrix) return result class ReplaceSplit(onnxscript.rewriter.RewriteRuleClassBase): def pattern(self, op, x): return op.Split( x, _allow_other_inputs=True, _outputs=["split_out_1", "split_out_2"] ) def rewrite(self, op, x: ir.Value, **kwargs): zero = op.initializer(ir.tensor(np.array([0], dtype=np.int64)), "zero") batch_size = op.Gather(x, zero) sample_size = op.initializer( ir.tensor(np.array([144000], dtype=np.int32)), "sample_size" ) return batch_size, sample_size class RemoveCast(onnxscript.rewriter.RewriteRuleClassBase): def pattern(self, op, x): return op.Cast(x) def rewrite(self, op, x: ir.Value, **kwargs): return op.Identity(x) model = ir.load("model.onnx") # Set dynamic axes model.graph.inputs[0].shape = ir.Shape(["batch", 144000]) model.graph.outputs[0].shape = ir.Shape(["batch", 6522]) onnxscript.rewriter.rewrite( model, [ReplaceDftWithMatMulRule().rule(), ReplaceSplit().rule(), RemoveCast().rule()], ) # Change all int32 initializers to int64 initializers = list(model.graph.initializers.values()) for initializer in initializers: if initializer.dtype == ir.DataType.INT32: int32_array = initializer.const_value.numpy() int64_array = int32_array.astype(np.int64) new_initializer = ir.val(initializer.name, const_value=ir.tensor(int64_array)) model.graph.initializers.pop(initializer.name) model.graph.initializers.add(new_initializer) initializer.replace_all_uses_with(new_initializer) onnxscript.optimizer.optimize( model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 ) # Remove Slice-Reshape def remove_slice_reshape(model: ir.Model): mul_node = model.graph.node("model/MEL_SPEC1/Mul") first_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_1") first_shape = ir.val( "first_shape", const_value=ir.tensor([-1, 72000, 2], dtype=ir.DataType.INT64) ) model.graph.initializers.add(first_shape) second_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_1") second_shape = ir.val( "second_shape", const_value=ir.tensor([-1, 18000, 8], dtype=ir.DataType.INT64) ) model.graph.initializers.add(second_shape) third_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_4") third_shape = ir.val( "third_shape", const_value=ir.tensor([-1, 511, 2048], dtype=ir.DataType.INT64) ) model.graph.initializers.add(third_shape) fourth_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_4") fourth_shape = ir.val( "fourth_shape", const_value=ir.tensor([-1, 511, 1024], dtype=ir.DataType.INT64) ) model.graph.initializers.add(fourth_shape) # Replace with Mul-Reshape-Gather first_reshape.replace_input_with(0, mul_node.outputs[0]) first_reshape.replace_input_with(1, first_shape) second_reshape.replace_input_with(0, mul_node.outputs[0]) second_reshape.replace_input_with(1, second_shape) third_reshape.replace_input_with(1, third_shape) fourth_reshape.replace_input_with(1, fourth_shape) remove_slice_reshape(model) # Run DCE again onnxscript.optimizer.optimize( model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 ) onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model) model.graph.inputs[0].name = "input" model.graph.outputs[0].name = "output" model.ir_version = 10 model.producer_name = "onnx-ir" model.graph.name = "BirdNET-v2.4" ir.save(model, "birdnet.onnx")