import argparse
from magiconnx import OnnxGraph
import numpy as np
import torch
from utils import IMG_SIZE, load_model
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt-path", type=str, default="./model/best_model.pkl", help="Path to the model weights as pkl."
)
parser.add_argument(
"--output_path", type=str, default="./model/best_model.onnx", help="Path to the model weights save as onnx."
)
args = parser.parse_args()
return args
def build_dummy_inputs():
batch_size = 1
channel = 3
w, h = IMG_SIZE
dummy_input = torch.rand(batch_size, channel, h, w)
return dummy_input
def transfer_onnx(pkl_path, onnx_path):
model = load_model(pkl_path)
model.eval()
dummy_input = build_dummy_inputs()
torch.onnx.export(model, dummy_input, onnx_path, input_names=['input'], output_names=["out_point_positions2D", "out_point_positions3D"], opset_version=11)
def fix_prelu(onnx_path):
graph = OnnxGraph(onnx_path)
prelu_nodes = graph.get_nodes(op_type='PRelu')
for node in prelu_nodes:
slope_para = graph[node.inputs[1]]
fix_value = np.expand_dims(slope_para.value, axis=0)
slope_para.value = fix_value
graph.save(onnx_path)
if __name__ == "__main__":
args = parse_args()
transfer_onnx(args.ckpt_path, args.output_path)
fix_prelu(args.output_path)