import argparse
from magiconnx import OnnxGraph
import numpy as np
import torch
from utils import IMG_SIZE, load_model


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--ckpt-path", type=str, default="./model/best_model.pkl", help="Path to the model weights as pkl."
    )
    parser.add_argument(
        "--output_path", type=str, default="./model/best_model.onnx", help="Path to the model weights save as onnx."
    )
    args = parser.parse_args()
    return args


def build_dummy_inputs():
    batch_size = 1
    channel = 3
    w, h = IMG_SIZE
    dummy_input = torch.rand(batch_size, channel, h, w)
    return dummy_input


def transfer_onnx(pkl_path, onnx_path):
    model = load_model(pkl_path)
    model.eval()
    dummy_input = build_dummy_inputs()
    torch.onnx.export(model, dummy_input, onnx_path, input_names=['input'], output_names=["out_point_positions2D", "out_point_positions3D"], opset_version=11)


def fix_prelu(onnx_path):
    graph = OnnxGraph(onnx_path)
    prelu_nodes = graph.get_nodes(op_type='PRelu')
    for node in prelu_nodes:
        slope_para = graph[node.inputs[1]]
        fix_value = np.expand_dims(slope_para.value, axis=0)
        slope_para.value = fix_value

    graph.save(onnx_path)


if __name__ == "__main__":
    args = parse_args()
    transfer_onnx(args.ckpt_path, args.output_path)
    fix_prelu(args.output_path)