import argparse
from typing import List
import numpy as np
from auto_optimizer import OnnxGraph
from auto_optimizer.graph_refactor.interface.base_node import Initializer
def parse_args():
parser = argparse.ArgumentParser(description="fix albert onnx")
parser.add_argument("--input_file", type=str, required=True,
help="path to pth model")
parser.add_argument("--output_file", type=str, required=True,
help="path to save onnx model")
parser.add_argument("--model_size", type=str, default='base',
help="model_size of bert", choices=['base', 'large'])
args = parser.parse_args()
return args
def get_config(graph):
input_ph = graph.inputs[0]
bs, seq_len = input_ph.shape[0], input_ph.shape[1]
return bs, seq_len
def fix_attention_lnqkv(graph, qkv_start_node):
reshape_before_add = graph.add_node(
f"Reshape_before_{qkv_start_node.name}",
"Reshape"
)
reshape_init = graph.add_initializer(
f"{reshape_before_add.name}_value",
np.array([-1, HIDDEN_NUM], dtype="int64")
)
if graph.get_node(qkv_start_node.inputs[0], node_type=Initializer):
graph.insert_node(qkv_start_node.name, reshape_before_add, refer_index=1, mode="before")
else:
graph.insert_node(qkv_start_node.name, reshape_before_add, refer_index=0, mode="before")
reshape_before_add.inputs.append(reshape_init.name)
seen: List[List[int]] = []
next_nodes = graph.get_next_nodes(qkv_start_node.outputs[0])
matmul_nodes = [n for n in next_nodes if n.op_type == "MatMul"]
for idx in range(3):
matmul_node = matmul_nodes[idx]
add_node = graph.get_next_nodes(matmul_node.outputs[0])[0]
reshape_node = graph.get_next_nodes(add_node.outputs[0])[0]
transpose_node = graph.get_next_nodes(reshape_node.outputs[0])[0]
perm: List[int] = transpose_node.attrs.get('perm', [1])
if perm in seen:
seen.remove(perm)
query_perm = perm
else:
seen.append(perm)
key_perm = perm
key_transpose = transpose_node
key_transpose.attrs["perm"] = query_perm
new_perm = [query_perm.index(key_perm[i]) for i in range(len(key_perm))]
new_transpose = graph.add_node(
name=f"{key_transpose.name}_after",
op_type="Transpose",
attrs={"perm": new_perm}
)
graph.insert_node(key_transpose.name, new_transpose, mode="after")
def fix_attention_score(graph, softmax_node, bs, seq_len):
matmul_node = graph.get_next_nodes(softmax_node.outputs[0])[0]
transpose_node = graph.get_next_nodes(matmul_node.outputs[0])[0]
reshape_node = graph.get_next_nodes(transpose_node.outputs[0])[0]
reshape_init = graph[reshape_node.inputs[1]]
reshape_init.value = np.array([-1, HIDDEN_NUM], dtype="int64")
add_node = graph.get_prev_node(softmax_node.inputs[0])
prev_node = graph.get_prev_node(add_node.inputs[0])
if prev_node.op_type == "Div":
div_node = prev_node
refer_index = 0
else:
div_node = graph.get_prev_node(add_node.inputs[1])
refer_index = 1
div_init = graph.get_node(div_node.inputs[0], node_type=Initializer) or \
graph.get_node(div_node.inputs[1], node_type=Initializer)
mul_node = graph.add_node(
f"bert_Mul_before_{add_node.name}",
"Mul",
)
mul_init_value = np.array(1/div_init.value, dtype="float32")
mul_init = graph.add_initializer(
f"{mul_node.name}_value",
mul_init_value
)
graph.insert_node(add_node.name, mul_node, refer_index=refer_index, mode="before")
mul_node.inputs.append(mul_init.name)
graph.remove(div_node.name)
expand_node = graph.add_node(
f"Expand_before_{add_node.name}",
"Expand"
)
expand_init = graph.add_initializer(
f"{expand_node.name}_value",
np.array([bs, 1, seq_len, seq_len], dtype="int64")
)
graph.insert_node(add_node.name, expand_node, refer_index=~refer_index, mode="before")
expand_node.inputs.append(expand_init.name)
def main(graph):
bs, seq_len = get_config(graph)
add_nodes = graph.get_nodes("Add")
for add_node in add_nodes:
if len(graph.get_next_nodes(add_node.outputs[0])) == 4:
fix_attention_lnqkv(graph, add_node)
softmax_nodes = graph.get_nodes("Softmax")
for softmax_node in softmax_nodes:
fix_attention_score(graph, softmax_node, bs, seq_len)
last_add = graph.get_nodes(op_type="Add")[-2]
last_reshape = graph.add_node(
"last_reshape",
"Reshape"
)
reshape_init = graph.add_initializer(
f"{last_reshape.name}_value",
np.array([bs, seq_len, HIDDEN_NUM], dtype="int64")
)
if graph.get_node(last_add.inputs[0], node_type=Initializer):
graph.insert_node(last_add.name, last_reshape, refer_index=1, mode="before")
else:
graph.insert_node(last_add.name, last_reshape, refer_index=0, mode="before")
last_reshape.inputs.append(reshape_init.name)
if __name__=="__main__":
args = parse_args()
if args.model_size == "base":
HIDDEN_NUM=768
else:
HIDDEN_NUM=1024
onnx_graph = OnnxGraph.parse(args.input_file)
main(onnx_graph)
onnx_graph.infershape()
onnx_graph.save(args.output_file)