import sys
import torch
from model import Backbone
def pth2onnx(input_file, output_file, batch_size):
"""
input_file: pth权重路径
output_file: onnx权重路径
batch_size: batch size
"""
device = torch.device('cpu')
model = Backbone(num_layers=100, drop_ratio=0.6, mode='ir_se').to(device)
ckpt = torch.load(input_file, map_location='cpu')
if 'model' in ckpt:
model.load_state_dict(ckpt['model'])
else:
model.load_state_dict(ckpt)
model.eval()
input_names = ["image"]
output_names = ["features"]
dynamic_axes = {'image': {0: f'{batch_size}'}, 'features': {0: f'{batch_size}'}}
dummy_input = torch.randn(batch_size, 3, 112, 112)
torch.onnx.export(model,
dummy_input,
output_file,
input_names=input_names,
dynamic_axes=dynamic_axes,
output_names=output_names,
opset_version=11,
verbose=True)
print("*************Convert to ONNX model file SUCCESS!*************")
if __name__ == '__main__':
pth2onnx(sys.argv[1], sys.argv[2], int(sys.argv[3]))