from auto_optimizer import OnnxGraph
import numpy as np
import argparse
class OnnxModel:
def __init__(self, path):
self.graph = OnnxGraph.parse(path)
def remove_useless_node(self):
reshape_nodes = self.graph.get_nodes("Reshape")
for node in reshape_nodes:
if node.inputs[0] == "masks":
unsqueeze_nodes = self.graph.get_next_nodes(node.outputs[0])
cast_node = self.graph.add_node("cast_mask", "Cast", attrs={"to": 1})
self.graph[node.name] = cast_node
self.graph.remove(unsqueeze_nodes[0].name)
softmax_node = self.graph.get_nodes("Softmax")
for node in softmax_node:
pre_node = self.graph.get_prev_node(node.inputs[0])
if pre_node.op_type == "Reshape":
add_node = self.graph.get_prev_node(pre_node.inputs[0])
reshape_node = self.graph.get_prev_node(add_node.inputs[0])
self.graph.remove(pre_node.name, {0: 0})
self.graph.remove(reshape_node.name, {0: 0})
const_nodes = self.graph.get_nodes("ConstantOfShape")
for node in const_nodes:
if not self.graph.get_prev_node(node.inputs[0]):
self.graph[node.inputs[0]].value = np.array([3], dtype=np.int64)
where_node = self.graph.get_next_nodes(node.outputs[0])[1]
shape_node = self.graph[where_node.inputs[2]]
v = shape_node.value
self.graph[where_node.inputs[2]].value = v[1:]
concat_nodes = self.graph.get_nodes("Concat")
for node in concat_nodes:
if len(node.inputs) == 4 and "Constant" in node.inputs[1] and self.graph[node.inputs[1]].value == [1]:
node.inputs = [node.inputs[0], node.inputs[2], node.inputs[3]]
def remove_overflow_node(self):
node_list = self.graph.get_nodes("ReduceMax")
for node in node_list:
output_edges = node.outputs
next_nodes = self.graph.get_next_nodes(output_edges[0])
self.graph.remove(node.name)
self.graph.remove(next_nodes[0].name)
def save_model(self, output):
self.graph.save(output)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help="onnx model path", default='rpn.onnx')
parser.add_argument('--output_path', type=str, help="onnx model path after modify", default="modify.onnx")
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
G = OnnxModel(args.model_path)
G.remove_overflow_node()
G.remove_useless_node()
G.save_model(args.output_path)