import sys
import numpy as np
from magiconnx import OnnxGraph


def fix_prelu(graph):
    prelu_nodes = graph.get_nodes(op_type='PRelu')
    for node in prelu_nodes:
        slope_para = graph[node.inputs[1]]
        fix_value = np.expand_dims(slope_para.value, axis=0)
        slope_para.value = fix_value
    return graph


if __name__ == '__main__':
    input_model = sys.argv[1]
    out_model = sys.argv[2]
    onnx_graph = OnnxGraph(input_model)
    onnx_graph = fix_prelu(onnx_graph)
    onnx_graph.save(out_model)