import sys
import numpy as np
from magiconnx import OnnxGraph
def fix_prelu(graph):
prelu_nodes = graph.get_nodes(op_type='PRelu')
for node in prelu_nodes:
slope_para = graph[node.inputs[1]]
fix_value = np.expand_dims(slope_para.value, axis=0)
slope_para.value = fix_value
return graph
if __name__ == '__main__':
input_model = sys.argv[1]
out_model = sys.argv[2]
onnx_graph = OnnxGraph(input_model)
onnx_graph = fix_prelu(onnx_graph)
onnx_graph.save(out_model)