import sys
import torch
import torch.onnx
sys.path.append(r"./pytorch-ssd")
from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd
def pytorch2onnx(ckpt_path, out_path):
num_classes = 21
net = create_mobilenetv1_ssd(num_classes, is_test=True)
print("begin to load model")
net.load(ckpt_path)
net.eval()
input_names = ["image"]
output_names = ['scores', 'boxes']
dynamic_axes = {
'image': {0: '-1'},
'scores': {0: '-1'},
'boxes': {0: '-1'}
}
dummy_input = torch.randn(16, 3, 300, 300)
print("begin to export")
torch.onnx.export(net, dummy_input, out_path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=11,
verbose=False)
print("end export")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser('Pytorch model convert to ONNX')
parser.add_argument('--ckpt', default=None,
help='input checkpoint file path')
parser.add_argument('--onnx', default='out.onnx',
help='output onnx file path')
args = parser.parse_args()
pytorch2onnx(args.ckpt, args.onnx)