import sys
import numpy as np
from magiconnx import OnnxGraph
def create_slice_node(graph, idx):
name = "Slice_" + str(idx)
op_type = "Slice"
dst = graph.add_node(name, op_type)
return dst
def create_concat_node(graph, idx, axis):
name = 'Concat_' + str(idx)
op_type = "Concat"
attrs = {
'axis': axis,
}
dst = graph.add_node(name, op_type, attrs)
return dst
def add_add_tensor(graph):
"""
add算子中的两个输入分别为网络中的输入以及一个常量,
由于对网络中tensor的shape进行了修改,因此同时需要对常量的shape进行修改,使其shape一致。
"""
add_nodes = graph.get_nodes('Add')
if add_nodes is None or len(add_nodes) is 0:
print('There is no Add.')
for node in add_nodes:
input = graph[node.inputs[1]]
if input is not None and (input.op_type == "Constant" or input.op_type == "Initializer"):
value = input.value.copy()
tensor_shape = list(value.shape)
flag = False
for i in range(0, len(tensor_shape)):
if tensor_shape[i] == 197:
tensor_shape[i] = 208
flag = True
if flag:
tensor = np.zeros(tensor_shape, dtype=np.float32)
for i in range(0, value.shape[0]):
for j in range(0, value.shape[1]):
for k in range(0, value.shape[2]):
for l in range(0, value.shape[3]):
tensor[i][j][k][l] = value[i][j][k][l]
graph[node.inputs[1]].value = tensor
def refill_zero_gather(graph, batch_size):
"""
由于将tenosr的shape从197改为208,其中新增加的11维在经过多轮计算后值发生了变化会影响后续乘法等计算,
因此将其新增加11维中的数据重新改为0 从而避免影响
"""
gather_nodes = graph.get_nodes('Gather')
if gather_nodes is None or len(gather_nodes) is 0:
print('There is no Gather.')
idx = 6000
for node in gather_nodes:
if graph[node.inputs[1]].value == 1:
concat = create_concat_node(graph, idx+1, 2)
graph.insert_node(node.name, concat, index=0, mode='after')
tensor = np.float32(np.zeros((int(batch_size), 12, 11, 64)))
graph.add_initializer(str(idx+2), tensor)
concat.inputs.append(str(idx+2))
sli = create_slice_node(graph, idx+3)
graph.insert_node(node.name, sli, index=0, mode='after')
starts = np.array([0], dtype=np.int64)
ends = np.array([197], dtype=np.int64)
axes = np.array([2], dtype=np.int64)
steps = np.array([1], dtype=np.int64)
graph.add_initializer(str(idx+4), starts)
graph.add_initializer(str(idx+5), ends)
graph.add_initializer(str(idx+6), axes)
graph.add_initializer(str(idx+7), steps)
sli.inputs.append(str(idx+4))
sli.inputs.append(str(idx+5))
sli.inputs.append(str(idx+6))
sli.inputs.append(str(idx+7))
idx = idx + 8
def reshape_axis(graph: OnnxGraph) -> None:
"""
modify 'Reshape' operator shape
将concat后tensor的shape进行修改后,需要将网络中出现的reshape同样进行修改,
否则reshape算子会将tensor的shape重新改回原来的shape
"""
reshape_nodes = graph.get_nodes('Reshape')
if reshape_nodes is None or len(reshape_nodes) is 0:
print('There is no Reshape.')
for node in reshape_nodes:
input = graph[node.inputs[1]]
if input is not None and (input.op_type == "Constant" or input.op_type == "Initializer"):
newshape = input.value.copy()
for i in range(0, len(input.value)):
if input.value[i] == 197:
newshape[i] = 208
input.value = newshape
def add_concat_tensor(graph: OnnxGraph, batch_size: int) -> None:
"""
modify 'Concat' operator, add one tensor to input
对add算子进行修改,使其输出tensor的shape后面两维均为16的倍数,从而避免了后续数据格式改变。
"""
nodes = graph.get_nodes('Concat')
node_name = nodes[0].name
initializer_name = '1234'
tensor = np.zeros((int(batch_size), 11, 768), dtype=np.float32)
graph.add_initializer(initializer_name, tensor)
graph[node_name].inputs.append(initializer_name)
def del_layernorm_transdata(graph: OnnxGraph, batch_size: int) -> bool:
reshape_axis(graph)
add_concat_tensor(graph, int(batch_size))
add_add_tensor(graph)
refill_zero_gather(graph, int(batch_size))
return True
def improve_model(path: str, new_path: str, batch_size: int) -> None:
graph = OnnxGraph(path)
if graph is None:
print('onnx model not exist.')
return None
ret = del_layernorm_transdata(graph, batch_size)
if not ret:
print('delete layernorm transdata failed.')
return None
graph.save(new_path)
def dump_data(path, save_path, batch_size: int):
graph = OnnxGraph(path)
data = np.random.randn(batch_size,3,224,224)
data = data.astype(np.float32)
graph.dump([data,], save_path)
if __name__ == '__main__':
if len(sys.argv) is not 4:
print('only need 3 params, include: onnx source path | dest path | batch_size.')
else:
improve_model(sys.argv[1], sys.argv[2], sys.argv[3])