Justin Chu commited on
Commit
884d307
·
1 Parent(s): 7149468

Update optimization script

Browse files
Files changed (1) hide show
  1. scripts/optimize.py +37 -1
scripts/optimize.py CHANGED
@@ -1,15 +1,51 @@
1
  import onnxscript
2
  import onnx_ir as ir
3
  import onnx_ir.passes.common
 
4
 
5
 
6
  class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
7
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  model = ir.load("model.onnx")
10
  onnxscript.optimizer.optimize(
11
  model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
12
  )
 
13
 
14
  onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
15
  model.ir_version = 10
 
1
  import onnxscript
2
  import onnx_ir as ir
3
  import onnx_ir.passes.common
4
+ import numpy as np
5
 
6
 
7
  class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
8
+ def pattern(self, op, x, dft_length):
9
+ x = op.Reshape(x, _allow_other_inputs=True)
10
+ dft = op.DFT(x, dft_length, _outputs=["dft_output"])
11
+ real_part = op.Slice(dft, [0], [1], [-1])
12
+ return op.Squeeze(real_part, [-1])
13
+
14
+ def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value):
15
+ # Get the DFT node attributes
16
+ dft_node = dft_output.producer()
17
+ assert dft_node is not None
18
+
19
+ dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item()
20
+
21
+ # Create one-sided DFT matrix (only real part, DC to Nyquist)
22
+ # The real part of DFT is: Re(DFT[k]) = sum(x[n] * cos(2*pi*k*n/N))
23
+ # For one-sided DFT, we only need frequencies from 0 to Nyquist (dft_size//2 + 1)
24
+ num_freqs = dft_size // 2 + 1
25
+
26
+ # Vectorized creation of DFT matrix
27
+ k = np.arange(num_freqs, dtype=np.float32)[
28
+ :, np.newaxis
29
+ ] # Shape: (num_freqs, 1)
30
+ n = np.arange(dft_size, dtype=np.float32)[np.newaxis, :] # Shape: (1, dft_size)
31
+ dft_matrix = np.cos(
32
+ 2 * np.pi * k * n / dft_size
33
+ ) # Shape: (num_freqs, dft_size)
34
+
35
+ # Create constant node for the DFT matrix
36
+ dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix")
37
+
38
+ # DFT axis is already at the end, direct matrix multiplication
39
+ result = op.MatMul(x, dft_matrix)
40
+
41
+ return result
42
+
43
 
44
  model = ir.load("model.onnx")
45
  onnxscript.optimizer.optimize(
46
  model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
47
  )
48
+ onnxscript.rewriter.rewrite(model, [ReplaceDftWithMatMulRule().rule()])
49
 
50
  onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
51
  model.ir_version = 10