# ------------------------------------------------------------------------ # RF-DETR # Copyright (c) 2025 Roboflow. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) # Copyright (c) 2024 Baidu. All Rights Reserved. # ------------------------------------------------------------------------ """ OnnxOptimizer """ import os from collections import OrderedDict from copy import deepcopy import numpy as np import onnx import torch from onnx import shape_inference import onnx_graphsurgeon as gs from polygraphy.backend.onnx.loader import fold_constants from onnx_graphsurgeon.logger.logger import G_LOGGER from .symbolic import CustomOpSymbolicRegistry class OnnxOptimizer(): def __init__( self, input, severity=G_LOGGER.INFO ): if isinstance(input, str): onnx_graph = self.load_onnx(input) else: onnx_graph = input self.graph = gs.import_onnx(onnx_graph) self.severity = severity self.set_severity(severity) def set_severity(self, severity): G_LOGGER.severity = severity def load_onnx(self, onnx_path:str): """Load onnx from file """ assert os.path.isfile(onnx_path), f"not found onnx file: {onnx_path}" onnx_graph = onnx.load(onnx_path) G_LOGGER.info(f"load onnx file: {onnx_path}") return onnx_graph def save_onnx(self, onnx_path:str): onnx_graph = gs.export_onnx(self.graph) G_LOGGER.info(f"save onnx file: {onnx_path}") onnx.save(onnx_graph, onnx_path) def info(self, prefix=''): G_LOGGER.verbose(f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs") def cleanup(self, return_onnx=False): self.graph.cleanup().toposort() if return_onnx: return gs.export_onnx(self.graph) def select_outputs(self, keep, names=None): self.graph.outputs = [self.graph.outputs[o] for o in keep] if names: for i, name in enumerate(names): self.graph.outputs[i].name = name def find_node_input(self, node, name:str=None, value=None) -> int: for i, inp in enumerate(node.inputs): if isinstance(name, str) and inp.name == name: index = i elif inp == value: index = i assert index >= 0, f"not found {name}({value}) in node.inputs" return index def find_node_output(self, node, name:str=None, value=None) -> int: for i, inp in enumerate(node.outputs): if isinstance(name, str) and inp.name == name: index = i elif inp == value: index = i assert index >= 0, f"not found {name}({value}) in node.outputs" return index def common_opt(self, return_onnx=False): for fn in CustomOpSymbolicRegistry._OPTIMIZER: fn(self) self.cleanup() onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=False) if onnx_graph.ByteSize() > 2147483648: raise TypeError("ERROR: model size exceeds supported 2GB limit") else: onnx_graph = shape_inference.infer_shapes(onnx_graph) self.graph = gs.import_onnx(onnx_graph) self.cleanup() if return_onnx: return onnx_graph def resize_fix(self): ''' This function loops through the graph looking for Resize nodes that uses scales for resize (has 3 inputs). It substitutes found Resize with Resize that takes the size of the output tensor instead of scales. It adds Shape->Slice->Concat Shape->Slice----^ subgraph to the graph to extract the shape of the output tensor. This fix is required for the dynamic shape support. ''' mResizeNodes = 0 for node in self.graph.nodes: if node.op == "Resize" and len(node.inputs) == 3: name = node.name + "/" add_node = node.o().o().i(1) div_node = node.i() shape_hw_out = gs.Variable(name=name + "shape_hw_out", dtype=np.int64, shape=[4]) shape_hw = gs.Node(op="Shape", name=name+"shape_hw", inputs=[add_node.outputs[0]], outputs=[shape_hw_out]) const_zero = gs.Constant(name=name + "const_zero", values=np.array([0], dtype=np.int64)) const_two = gs.Constant(name=name + "const_two", values=np.array([2], dtype=np.int64)) const_four = gs.Constant(name=name + "const_four", values=np.array([4], dtype=np.int64)) slice_hw_out = gs.Variable(name=name + "slice_hw_out", dtype=np.int64, shape=[2]) slice_hw = gs.Node(op="Slice", name=name+"slice_hw", inputs=[shape_hw_out, const_two, const_four, const_zero], outputs=[slice_hw_out]) shape_bc_out = gs.Variable(name=name + "shape_bc_out", dtype=np.int64, shape=[2]) shape_bc = gs.Node(op="Shape", name=name+"shape_bc", inputs=[div_node.outputs[0]], outputs=[shape_bc_out]) slice_bc_out = gs.Variable(name=name + "slice_bc_out", dtype=np.int64, shape=[2]) slice_bc = gs.Node(op="Slice", name=name+"slice_bc", inputs=[shape_bc_out, const_zero, const_two, const_zero], outputs=[slice_bc_out]) concat_bchw_out = gs.Variable(name=name + "concat_bchw_out", dtype=np.int64, shape=[4]) concat_bchw = gs.Node(op="Concat", name=name+"concat_bchw", attrs={"axis": 0}, inputs=[slice_bc_out, slice_hw_out], outputs=[concat_bchw_out]) none_var = gs.Variable.empty() resize_bchw = gs.Node(op="Resize", name=name+"resize_bchw", attrs=node.attrs, inputs=[node.inputs[0], none_var, none_var, concat_bchw_out], outputs=[node.outputs[0]]) self.graph.nodes.extend([shape_hw, slice_hw, shape_bc, slice_bc, concat_bchw, resize_bchw]) node.inputs = [] node.outputs = [] mResizeNodes += 1 self.cleanup() return mResizeNodes def adjustAddNode(self): nAdjustAddNode = 0 for node in self.graph.nodes: # Change the bias const to the second input to allow Gemm+BiasAdd fusion in TRT. if node.op in ["Add"] and isinstance(node.inputs[0], gs.ir.tensor.Constant): tensor = node.inputs[1] bias = node.inputs[0] node.inputs = [tensor, bias] nAdjustAddNode += 1 self.cleanup() return nAdjustAddNode def decompose_instancenorms(self): nRemoveInstanceNorm = 0 for node in self.graph.nodes: if node.op == "InstanceNormalization": name = node.name + "/" input_tensor = node.inputs[0] output_tensor = node.outputs[0] mean_out = gs.Variable(name=name + "mean_out") mean_node = gs.Node(op="ReduceMean", name=name + "mean_node", attrs={"axes": [-1]}, inputs=[input_tensor], outputs=[mean_out]) sub_out = gs.Variable(name=name + "sub_out") sub_node = gs.Node(op="Sub", name=name + "sub_node", attrs={}, inputs=[input_tensor, mean_out], outputs=[sub_out]) pow_out = gs.Variable(name=name + "pow_out") pow_const = gs.Constant(name=name + "pow_const", values=np.array([2.0], dtype=np.float32)) pow_node = gs.Node(op="Pow", name=name + "pow_node", attrs={}, inputs=[sub_out, pow_const], outputs=[pow_out]) mean2_out = gs.Variable(name=name + "mean2_out") mean2_node = gs.Node(op="ReduceMean", name=name + "mean2_node", attrs={"axes": [-1]}, inputs=[pow_out], outputs=[mean2_out]) epsilon_out = gs.Variable(name=name + "epsilon_out") epsilon_const = gs.Constant(name=name + "epsilon_const", values=np.array([node.attrs["epsilon"]], dtype=np.float32)) epsilon_node = gs.Node(op="Add", name=name + "epsilon_node", attrs={}, inputs=[mean2_out, epsilon_const], outputs=[epsilon_out]) sqrt_out = gs.Variable(name=name + "sqrt_out") sqrt_node = gs.Node(op="Sqrt", name=name + "sqrt_node", attrs={}, inputs=[epsilon_out], outputs=[sqrt_out]) div_out = gs.Variable(name=name + "div_out") div_node = gs.Node(op="Div", name=name + "div_node", attrs={}, inputs=[sub_out, sqrt_out], outputs=[div_out]) constantScale = gs.Constant("InstanceNormScaleV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[1].inputs[0].attrs["value"].values.reshape(1, 32, 1))) constantBias = gs.Constant("InstanceBiasV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[2].inputs[0].attrs["value"].values.reshape(1, 32, 1))) mul_out = gs.Variable(name=name + "mul_out") mul_node = gs.Node(op="Mul", name=name + "mul_node", attrs={}, inputs=[div_out, constantScale], outputs=[mul_out]) add_node = gs.Node(op="Add", name=name + "add_node", attrs={}, inputs=[mul_out, constantBias], outputs=[output_tensor]) self.graph.nodes.extend([mean_node, sub_node, pow_node, mean2_node, epsilon_node, sqrt_node, div_node, mul_node, add_node]) node.inputs = [] node.outputs = [] nRemoveInstanceNorm += 1 self.cleanup() return nRemoveInstanceNorm def insert_groupnorm_plugin(self): nGroupNormPlugin = 0 for node in self.graph.nodes: if node.op == "Reshape" and node.outputs != [] and \ node.o().op == "ReduceMean" and node.o(1).op == "Sub" and node.o().o() == node.o(1) and \ node.o().o().o().o().o().o().o().o().o().o().o().op == "Mul" and \ node.o().o().o().o().o().o().o().o().o().o().o().o().op == "Add" and \ len(node.o().o().o().o().o().o().o().o().inputs[1].values.shape) == 3: # "node.outputs != []" is added for VAE inputTensor = node.inputs[0] gammaNode = node.o().o().o().o().o().o().o().o().o().o().o() index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True) gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32) constantGamma = gs.Constant("groupNormGamma-" + str(nGroupNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!! betaNode = gammaNode.o() index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True) beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32) constantBeta = gs.Constant("groupNormBeta-" + str(nGroupNormPlugin), np.ascontiguousarray(beta.reshape(-1))) epsilon = node.o().o().o().o().o().inputs[1].values.tolist()[0] if betaNode.o().op == "Sigmoid": # need Swish bSwish = True lastNode = betaNode.o().o() # Mul node of Swish else: bSwish = False lastNode = betaNode # Cast node after Group Norm if lastNode.o().op == "Cast": lastNode = lastNode.o() inputList = [inputTensor, constantGamma, constantBeta] groupNormV = gs.Variable("GroupNormV-" + str(nGroupNormPlugin), np.dtype(np.float16), inputTensor.shape) groupNormN = gs.Node("GroupNorm", "GroupNormN-" + str(nGroupNormPlugin), inputs=inputList, outputs=[groupNormV], attrs=OrderedDict([('epsilon', epsilon), ('bSwish', int(bSwish))])) self.graph.nodes.append(groupNormN) for subNode in self.graph.nodes: if lastNode.outputs[0] in subNode.inputs: index = subNode.inputs.index(lastNode.outputs[0]) subNode.inputs[index] = groupNormV node.inputs = [] lastNode.outputs = [] nGroupNormPlugin += 1 self.cleanup() return nGroupNormPlugin def insert_layernorm_plugin(self): nLayerNormPlugin = 0 for node in self.graph.nodes: if node.op == 'ReduceMean' and \ node.o().op == 'Sub' and node.o().inputs[0] == node.inputs[0] and \ node.o().o(0).op =='Pow' and node.o().o(1).op =='Div' and \ node.o().o(0).o().op == 'ReduceMean' and \ node.o().o(0).o().o().op == 'Add' and \ node.o().o(0).o().o().o().op == 'Sqrt' and \ node.o().o(0).o().o().o().o().op == 'Div' and node.o().o(0).o().o().o().o() == node.o().o(1) and \ node.o().o(0).o().o().o().o().o().op == 'Mul' and \ node.o().o(0).o().o().o().o().o().o().op == 'Add' and \ len(node.o().o(0).o().o().o().o().o().inputs[1].values.shape) == 1: if node.i().op == "Add": inputTensor = node.inputs[0] # CLIP else: inputTensor = node.i().inputs[0] # UNet and VAE gammaNode = node.o().o().o().o().o().o().o() index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True) gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32) constantGamma = gs.Constant("LayerNormGamma-" + str(nLayerNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!! betaNode = gammaNode.o() index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True) beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32) constantBeta = gs.Constant("LayerNormBeta-" + str(nLayerNormPlugin), np.ascontiguousarray(beta.reshape(-1))) inputList = [inputTensor, constantGamma, constantBeta] layerNormV = gs.Variable("LayerNormV-" + str(nLayerNormPlugin), np.dtype(np.float32), inputTensor.shape) layerNormN = gs.Node("LayerNorm", "LayerNormN-" + str(nLayerNormPlugin), inputs=inputList, attrs=OrderedDict([('epsilon', 1.e-5)]), outputs=[layerNormV]) self.graph.nodes.append(layerNormN) nLayerNormPlugin += 1 if betaNode.outputs[0] in self.graph.outputs: index = self.graph.outputs.index(betaNode.outputs[0]) self.graph.outputs[index] = layerNormV else: if betaNode.o().op == "Cast": lastNode = betaNode.o() else: lastNode = betaNode for subNode in self.graph.nodes: if lastNode.outputs[0] in subNode.inputs: index = subNode.inputs.index(lastNode.outputs[0]) subNode.inputs[index] = layerNormV lastNode.outputs = [] self.cleanup() return nLayerNormPlugin def fuse_kv(self, node_k, node_v, fused_kv_idx, heads, num_dynamic=0): # Get weights of K weights_k = node_k.inputs[1].values # Get weights of V weights_v = node_v.inputs[1].values # Input number of channels to K and V C = weights_k.shape[0] # Number of heads H = heads # Dimension per head D = weights_k.shape[1] // H # Concat and interleave weights such that the output of fused KV GEMM has [b, s_kv, h, 2, d] shape weights_kv = np.dstack([weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 2 * H * D) # K and V have the same input input_tensor = node_k.inputs[0] # K and V must have the same output which we feed into fmha plugin output_tensor_k = node_k.outputs[0] # Create tensor constant_weights_kv = gs.Constant("Weights_KV_{}".format(fused_kv_idx), np.ascontiguousarray(weights_kv)) # Create fused KV node fused_kv_node = gs.Node(op="MatMul", name="MatMul_KV_{}".format(fused_kv_idx), inputs=[input_tensor, constant_weights_kv], outputs=[output_tensor_k]) self.graph.nodes.append(fused_kv_node) # Connect the output of fused node to the inputs of the nodes after K and V node_v.o(num_dynamic).inputs[0] = output_tensor_k node_k.o(num_dynamic).inputs[0] = output_tensor_k for i in range(0,num_dynamic): node_v.o().inputs.clear() node_k.o().inputs.clear() # Clear inputs and outputs of K and V to ge these nodes cleared node_k.outputs.clear() node_v.outputs.clear() node_k.inputs.clear() node_v.inputs.clear() self.cleanup() return fused_kv_node def insert_fmhca(self, node_q, node_kv, final_tranpose, mhca_idx, heads, num_dynamic=0): # Get inputs and outputs for the fMHCA plugin # We take an output of reshape that follows the Q GEMM output_q = node_q.o(num_dynamic).o().inputs[0] output_kv = node_kv.o().inputs[0] output_final_tranpose = final_tranpose.outputs[0] # Clear the inputs of the nodes that follow the Q and KV GEMM # to delete these subgraphs (it will be substituted by fMHCA plugin) node_kv.outputs[0].outputs[0].inputs.clear() node_kv.outputs[0].outputs[0].inputs.clear() node_q.o(num_dynamic).o().inputs.clear() for i in range(0,num_dynamic): node_q.o(i).o().o(1).inputs.clear() weights_kv = node_kv.inputs[1].values dims_per_head = weights_kv.shape[1] // (heads * 2) # Reshape dims shape = gs.Constant("Shape_KV_{}".format(mhca_idx), np.ascontiguousarray(np.array([0, 0, heads, 2, dims_per_head], dtype=np.int64))) # Reshape output tensor output_reshape = gs.Variable("ReshapeKV_{}".format(mhca_idx), np.dtype(np.float16), None) # Create fMHA plugin reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mhca_idx), inputs=[output_kv, shape], outputs=[output_reshape]) # Insert node self.graph.nodes.append(reshape) # Create fMHCA plugin fmhca = gs.Node(op="fMHCA", name="fMHCA_{}".format(mhca_idx), inputs=[output_q, output_reshape], outputs=[output_final_tranpose]) # Insert node self.graph.nodes.append(fmhca) # Connect input of fMHCA to output of Q GEMM node_q.o(num_dynamic).outputs[0] = output_q if num_dynamic > 0: reshape2_input1_out = gs.Variable("Reshape2_fmhca{}_out".format(mhca_idx), np.dtype(np.int64), None) reshape2_input1_shape = gs.Node("Shape", "Reshape2_fmhca{}_shape".format(mhca_idx), inputs=[node_q.inputs[0]], outputs=[reshape2_input1_out]) self.graph.nodes.append(reshape2_input1_shape) final_tranpose.o().inputs[1] = reshape2_input1_out # Clear outputs of transpose to get this subgraph cleared final_tranpose.outputs.clear() self.cleanup() def fuse_qkv(self, node_q, node_k, node_v, fused_qkv_idx, heads, num_dynamic=0): # Get weights of Q weights_q = node_q.inputs[1].values # Get weights of K weights_k = node_k.inputs[1].values # Get weights of V weights_v = node_v.inputs[1].values # Input number of channels to Q, K and V C = weights_k.shape[0] # Number of heads H = heads # Hidden dimension per head D = weights_k.shape[1] // H # Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape weights_qkv = np.dstack([weights_q.reshape(C, H, D), weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 3 * H * D) input_tensor = node_k.inputs[0] # K and V have the same input # Q, K and V must have the same output which we feed into fmha plugin output_tensor_k = node_k.outputs[0] # Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape constant_weights_qkv = gs.Constant("Weights_QKV_{}".format(fused_qkv_idx), np.ascontiguousarray(weights_qkv)) # Created a fused node fused_qkv_node = gs.Node(op="MatMul", name="MatMul_QKV_{}".format(fused_qkv_idx), inputs=[input_tensor, constant_weights_qkv], outputs=[output_tensor_k]) self.graph.nodes.append(fused_qkv_node) # Connect the output of the fused node to the inputs of the nodes after Q, K and V node_q.o(num_dynamic).inputs[0] = output_tensor_k node_k.o(num_dynamic).inputs[0] = output_tensor_k node_v.o(num_dynamic).inputs[0] = output_tensor_k for i in range(0,num_dynamic): node_q.o().inputs.clear() node_k.o().inputs.clear() node_v.o().inputs.clear() # Clear inputs and outputs of Q, K and V to ge these nodes cleared node_q.outputs.clear() node_k.outputs.clear() node_v.outputs.clear() node_q.inputs.clear() node_k.inputs.clear() node_v.inputs.clear() self.cleanup() return fused_qkv_node def insert_fmha(self, node_qkv, final_tranpose, mha_idx, heads, num_dynamic=0): # Get inputs and outputs for the fMHA plugin output_qkv = node_qkv.o().inputs[0] output_final_tranpose = final_tranpose.outputs[0] # Clear the inputs of the nodes that follow the QKV GEMM # to delete these subgraphs (it will be substituted by fMHA plugin) node_qkv.outputs[0].outputs[2].inputs.clear() node_qkv.outputs[0].outputs[1].inputs.clear() node_qkv.outputs[0].outputs[0].inputs.clear() weights_qkv = node_qkv.inputs[1].values dims_per_head = weights_qkv.shape[1] // (heads * 3) # Reshape dims shape = gs.Constant("Shape_QKV_{}".format(mha_idx), np.ascontiguousarray(np.array([0, 0, heads, 3, dims_per_head], dtype=np.int64))) # Reshape output tensor output_shape = gs.Variable("ReshapeQKV_{}".format(mha_idx), np.dtype(np.float16), None) # Create fMHA plugin reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mha_idx), inputs=[output_qkv, shape], outputs=[output_shape]) # Insert node self.graph.nodes.append(reshape) # Create fMHA plugin fmha = gs.Node(op="fMHA_V2", name="fMHA_{}".format(mha_idx), inputs=[output_shape], outputs=[output_final_tranpose]) # Insert node self.graph.nodes.append(fmha) if num_dynamic > 0: reshape2_input1_out = gs.Variable("Reshape2_{}_out".format(mha_idx), np.dtype(np.int64), None) reshape2_input1_shape = gs.Node("Shape", "Reshape2_{}_shape".format(mha_idx), inputs=[node_qkv.inputs[0]], outputs=[reshape2_input1_out]) self.graph.nodes.append(reshape2_input1_shape) final_tranpose.o().inputs[1] = reshape2_input1_out # Clear outputs of transpose to get this subgraph cleared final_tranpose.outputs.clear() self.cleanup() def mha_mhca_detected(self, node, mha): # Go from V GEMM down to the S*V MatMul and all way up to K GEMM # If we are looking for MHCA inputs of two matmuls (K and V) must be equal. # If we are looking for MHA inputs (K and V) must be not equal. if node.op == "MatMul" and len(node.outputs) == 1 and \ ((mha and len(node.inputs[0].inputs) > 0 and node.i().op == "Add") or \ (not mha and len(node.inputs[0].inputs) == 0)): if node.o().op == 'Shape': if node.o(1).op == 'Shape': num_dynamic_kv = 3 if node.o(2).op == 'Shape' else 2 else: num_dynamic_kv = 1 # For Cross-Attention, if batch axis is dynamic (in QKV), assume H*W (in Q) is dynamic as well num_dynamic_q = num_dynamic_kv if mha else num_dynamic_kv + 1 else: num_dynamic_kv = 0 num_dynamic_q = 0 o = node.o(num_dynamic_kv) if o.op == "Reshape" and \ o.o().op == "Transpose" and \ o.o().o().op == "Reshape" and \ o.o().o().o().op == "MatMul" and \ o.o().o().o().i(0).op == "Softmax" and \ o.o().o().o().i(1).op == "Reshape" and \ o.o().o().o().i(0).i().op == "Mul" and \ o.o().o().o().i(0).i().i().op == "MatMul" and \ o.o().o().o().i(0).i().i().i(0).op == "Reshape" and \ o.o().o().o().i(0).i().i().i(1).op == "Transpose" and \ o.o().o().o().i(0).i().i().i(1).i().op == "Reshape" and \ o.o().o().o().i(0).i().i().i(1).i().i().op == "Transpose" and \ o.o().o().o().i(0).i().i().i(1).i().i().i().op == "Reshape" and \ o.o().o().o().i(0).i().i().i(1).i().i().i().i().op == "MatMul" and \ node.name != o.o().o().o().i(0).i().i().i(1).i().i().i().i().name: # "len(node.outputs) == 1" to make sure we are not in the already fused node node_q = o.o().o().o().i(0).i().i().i(0).i().i().i() node_k = o.o().o().o().i(0).i().i().i(1).i().i().i().i() node_v = node final_tranpose = o.o().o().o().o(num_dynamic_q).o() # Sanity check to make sure that the graph looks like expected if node_q.op == "MatMul" and final_tranpose.op == "Transpose": return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose return False, 0, 0, None, None, None, None def fuse_kv_insert_fmhca(self, heads, mhca_index, sm): nodes = self.graph.nodes # Iterate over graph and search for MHCA pattern for idx, _ in enumerate(nodes): # fMHCA can't be at the 2 last layers of the network. It is a guard from OOB if idx + 1 > len(nodes) or idx + 2 > len(nodes): continue # Get anchor nodes for fusion and fMHCA plugin insertion if the MHCA is detected detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \ self.mha_mhca_detected(nodes[idx], mha=False) if detected: assert num_dynamic_q == 0 or num_dynamic_q == num_dynamic_kv + 1 # Skip the FMHCA plugin for SM75 except for when the dim per head is 40. if sm == 75 and node_q.inputs[1].shape[1] // heads == 160: continue # Fuse K and V GEMMS node_kv = self.fuse_kv(node_k, node_v, mhca_index, heads, num_dynamic_kv) # Insert fMHCA plugin self.insert_fmhca(node_q, node_kv, final_tranpose, mhca_index, heads, num_dynamic_q) return True return False def fuse_qkv_insert_fmha(self, heads, mha_index): nodes = self.graph.nodes # Iterate over graph and search for MHA pattern for idx, _ in enumerate(nodes): # fMHA can't be at the 2 last layers of the network. It is a guard from OOB if idx + 1 > len(nodes) or idx + 2 > len(nodes): continue # Get anchor nodes for fusion and fMHA plugin insertion if the MHA is detected detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \ self.mha_mhca_detected(nodes[idx], mha=True) if detected: assert num_dynamic_q == num_dynamic_kv # Fuse Q, K and V GEMMS node_qkv = self.fuse_qkv(node_q, node_k, node_v, mha_index, heads, num_dynamic_kv) # Insert fMHA plugin self.insert_fmha(node_qkv, final_tranpose, mha_index, heads, num_dynamic_kv) return True return False def insert_fmhca_plugin(self, num_heads, sm): mhca_index = 0 while self.fuse_kv_insert_fmhca(num_heads, mhca_index, sm): mhca_index += 1 return mhca_index def insert_fmha_plugin(self, num_heads): mha_index = 0 while self.fuse_qkv_insert_fmha(num_heads, mha_index): mha_index += 1 return mha_index