import sys
import onnx


def fix_onnx_node_names(model_path, output_path):
    model = onnx.load(model_path)
    previous_node = ""

    for node in model.graph.output:
        node.name = f"new_{node.name}"

    for node in model.graph.node:
        for i, input_name in enumerate(node.input):
            if input_name == previous_node:
                node.input[i] = f"new_{input_name}"
    
        for i, output_name in enumerate(node.output):
            if output_name == node.name:
                node.output[i] = f"new_{output_name}"
    
        previous_node = node.name
    
    onnx.save(model, output_path)

if __name__ == '__main__':
    input_model = sys.argv[1]
    out_model = sys.argv[2]
    fix_onnx_node_names(input_model, out_model)