import sys
from typing import Optional, List, Union
import numpy as np
from auto_optimizer import OnnxGraph, OnnxNode
def pattern_select(
graph: OnnxGraph,
candidate_nodes: Union[str, List[str]],
preorders: Optional[List[str]] = None,
successors: Optional[List[str]] = None
) -> List[OnnxNode]:
ret = []
preorders = preorders or []
successors = successors or []
if isinstance(candidate_nodes, str):
candidate_nodes = graph.get_nodes(candidate_nodes)
for node in candidate_nodes:
pattern_check = False
current_node = node
for p in preorders[::-1]:
for input_name in current_node.inputs:
current_node = graph.get_prev_node(input_name)
if current_node and current_node.op_type == p:
pattern_check = True
break
if not pattern_check:
break
current_node = node
for s in successors:
pattern_check = False
for output_name in current_node.outputs:
next_nodes = graph.get_next_nodes(output_name)
for next_node in next_nodes:
if next_node.op_type == s:
current_node = next_node
pattern_check = True
break
if pattern_check:
break
if not pattern_check:
break
if pattern_check:
ret.append(node)
return ret
def insert_reshape_node(graph, anchor_node, dst_shape, mode='after'):
inserted_reshape_node = graph.add_node(
f"Reshape_{mode}_{anchor_node.name}",
"Reshape",
)
inserted_reshape_init = graph.add_initializer(
f"Reshape_init_{mode}_{anchor_node.name}",
np.array(dst_shape, dtype="int64")
)
graph.insert_node(anchor_node.name, inserted_reshape_node, mode=mode)
inserted_reshape_node.inputs.append(inserted_reshape_init.name)
def fix_cpu(graph):
for cast_node in graph.get_nodes(op_type="Cast"):
next_node = graph.get_next_nodes(cast_node.outputs[0])[0]
if next_node.op_type == "Add":
cast_node['to'] = 6
inserted_add_init = graph.add_initializer(
f"{next_node.name}_init",
np.array(1, dtype='int32')
)
next_node.inputs[1] = inserted_add_init.name
def merge_axis(graph, seq, bs):
target_add = pattern_select(graph, 'Add', preorders=['Gather'])[0]
insert_reshape_node(graph, target_add, [-1, 768])
target_sub = pattern_select(graph, 'Sub', successors=['Mul'])[0]
insert_reshape_node(graph, target_sub, [bs*seq, 1])
target_gather = pattern_select(graph, 'Gather', successors=['Gemm'])[0]
insert_reshape_node(graph, target_gather, [-1, seq, 768], mode='before')
def opt_attention(graph, seq, bs):
transpose_nodes = graph.get_nodes(op_type="Transpose")
transpose_nodes = sorted(transpose_nodes,
key=lambda node : int(node.name.split("_")[1]))
graph.remove(transpose_nodes[0].name)
graph.remove(transpose_nodes[-1].name)
for softmax_node in graph.get_nodes(op_type="Softmax"):
softmax_node['axis'] = -1
reshape_node0 = graph.get_prev_node(softmax_node.inputs[0])
where_node = graph.get_prev_node(reshape_node0.inputs[0])
reshape_node1 = graph.get_prev_node(where_node.inputs[-1])
matmul_node1 = graph.get_prev_node(reshape_node1.inputs[0])
transpose_node1_1 = graph.get_prev_node(matmul_node1.inputs[0])
reshape_node2_1 = graph.get_prev_node(transpose_node1_1.inputs[0])
transpose_node1_2 = graph.get_prev_node(matmul_node1.inputs[1])
reshape_node2_2 = graph.get_prev_node(transpose_node1_2.inputs[0])
dst_shape_name = reshape_node1.inputs[1]
graph[dst_shape_name].value = np.array([bs, seq, 12, 64], dtype="int64")
transpose_node1_1['perm'] = [0, 2, 1, 3]
transpose_node1_2['perm'] = [0, 2, 1, 3]
reshape_node2_1.inputs[1] = dst_shape_name
reshape_node2_2.inputs[1] = dst_shape_name
graph.remove(reshape_node0.name)
graph.remove(reshape_node1.name)
inserted_transpose_node = graph.add_node(
f"Transpose_after_{transpose_node1_2.name}",
"Transpose",
attrs={
"perm": [0, 1, 3, 2]
}
)
graph.insert_node(transpose_node1_2.name, inserted_transpose_node)
unsqueeze_node = graph.get_prev_node(where_node.inputs[0])
where_ori_input0 = where_node.inputs[0]
inserted_add_node = graph.add_node(
where_node.name.replace("Where", "Add"),
"Add"
)
graph.insert_node(matmul_node1.name, inserted_add_node)
inserted_add_node.inputs.append(where_ori_input0)
graph.remove(where_node.name)
softmax_node.inputs[0] = inserted_add_node.outputs[0]
inserted_cast_node = graph.add_node(
f"Cast_after_{unsqueeze_node.name}",
"Cast",
attrs={
'to': 1
}
)
graph.insert_node(unsqueeze_node.name, inserted_cast_node)
inserted_mul_node = graph.add_node(
f"Mul_after_{unsqueeze_node.name}",
"Mul"
)
graph.insert_node(inserted_cast_node.name, inserted_mul_node)
mul_init = graph.add_initializer(
f"Mul_init_after_{unsqueeze_node.name}",
np.array(-65504).astype("float32")
)
inserted_mul_node.inputs.append(mul_init.name)
inserted_expand_node = graph.add_node(
f"Expand_after_{unsqueeze_node.name}",
"Expand"
)
graph.insert_node(inserted_mul_node.name, inserted_expand_node)
expand_init = graph.add_initializer(
f"Expand_init_after_{unsqueeze_node.name}",
np.array([bs, 1, seq, seq]).astype("int64")
)
inserted_expand_node.inputs.append(expand_init.name)
mul_node = graph.get_prev_node(reshape_node2_1.inputs[0])
reshape_node2_1.inputs[0] = mul_node.inputs[0]
mul_node.inputs[0] = matmul_node1.outputs[0]
inserted_add_node.inputs[0] = mul_node.outputs[0]
mul_node.name = "bert_" + mul_node.name
matmul_node2 = graph.get_next_nodes(softmax_node.outputs[0])[0]
transpose_node2 = graph.get_prev_node(matmul_node2.inputs[1])
reshape_node3 = graph.get_prev_node(transpose_node2.inputs[0])
transpose_node2['perm'] = [0, 2, 1, 3]
reshape_node3.inputs[1] = dst_shape_name
transpose_node3 = graph.get_next_nodes(matmul_node2.outputs[0])[0]
reshape_node4 = graph.get_next_nodes(transpose_node3.outputs[0])[0]
transpose_node3 ['perm'] = [0, 2, 1, 3]
graph[reshape_node4.inputs[1]].value = np.array([-1, 768], dtype='int64')
if __name__ == '__main__':
input_path = sys.argv[1]
save_path = sys.argv[2]
bs = int(sys.argv[3])
seq = int(sys.argv[4])
onnx_graph = OnnxGraph.parse(input_path)
fix_cpu(onnx_graph)
merge_axis(onnx_graph, seq, bs)
opt_attention(onnx_graph, seq, bs)
onnx_graph.update_map()
onnx_graph.remove_unused_nodes()
onnx_graph.infershape()
onnx_graph.save(save_path)