Justin Chu
commited on
Commit
·
884d307
1
Parent(s):
7149468
Update optimization script
Browse files- scripts/optimize.py +37 -1
scripts/optimize.py
CHANGED
|
@@ -1,15 +1,51 @@
|
|
| 1 |
import onnxscript
|
| 2 |
import onnx_ir as ir
|
| 3 |
import onnx_ir.passes.common
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
model = ir.load("model.onnx")
|
| 10 |
onnxscript.optimizer.optimize(
|
| 11 |
model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
|
| 12 |
)
|
|
|
|
| 13 |
|
| 14 |
onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
|
| 15 |
model.ir_version = 10
|
|
|
|
| 1 |
import onnxscript
|
| 2 |
import onnx_ir as ir
|
| 3 |
import onnx_ir.passes.common
|
| 4 |
+
import numpy as np
|
| 5 |
|
| 6 |
|
| 7 |
class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
|
| 8 |
+
def pattern(self, op, x, dft_length):
|
| 9 |
+
x = op.Reshape(x, _allow_other_inputs=True)
|
| 10 |
+
dft = op.DFT(x, dft_length, _outputs=["dft_output"])
|
| 11 |
+
real_part = op.Slice(dft, [0], [1], [-1])
|
| 12 |
+
return op.Squeeze(real_part, [-1])
|
| 13 |
+
|
| 14 |
+
def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value):
|
| 15 |
+
# Get the DFT node attributes
|
| 16 |
+
dft_node = dft_output.producer()
|
| 17 |
+
assert dft_node is not None
|
| 18 |
+
|
| 19 |
+
dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item()
|
| 20 |
+
|
| 21 |
+
# Create one-sided DFT matrix (only real part, DC to Nyquist)
|
| 22 |
+
# The real part of DFT is: Re(DFT[k]) = sum(x[n] * cos(2*pi*k*n/N))
|
| 23 |
+
# For one-sided DFT, we only need frequencies from 0 to Nyquist (dft_size//2 + 1)
|
| 24 |
+
num_freqs = dft_size // 2 + 1
|
| 25 |
+
|
| 26 |
+
# Vectorized creation of DFT matrix
|
| 27 |
+
k = np.arange(num_freqs, dtype=np.float32)[
|
| 28 |
+
:, np.newaxis
|
| 29 |
+
] # Shape: (num_freqs, 1)
|
| 30 |
+
n = np.arange(dft_size, dtype=np.float32)[np.newaxis, :] # Shape: (1, dft_size)
|
| 31 |
+
dft_matrix = np.cos(
|
| 32 |
+
2 * np.pi * k * n / dft_size
|
| 33 |
+
) # Shape: (num_freqs, dft_size)
|
| 34 |
+
|
| 35 |
+
# Create constant node for the DFT matrix
|
| 36 |
+
dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix")
|
| 37 |
+
|
| 38 |
+
# DFT axis is already at the end, direct matrix multiplication
|
| 39 |
+
result = op.MatMul(x, dft_matrix)
|
| 40 |
+
|
| 41 |
+
return result
|
| 42 |
+
|
| 43 |
|
| 44 |
model = ir.load("model.onnx")
|
| 45 |
onnxscript.optimizer.optimize(
|
| 46 |
model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
|
| 47 |
)
|
| 48 |
+
onnxscript.rewriter.rewrite(model, [ReplaceDftWithMatMulRule().rule()])
|
| 49 |
|
| 50 |
onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
|
| 51 |
model.ir_version = 10
|