from facenet_pytorch import InceptionResnetV1
import torch
import argparse
def FaceNet_pth2onnx(opt):
model = InceptionResnetV1(pretrained=opt.pretrain)
model.eval()
input_names = ["image"]
output_names = ["class"]
output_file = opt.output_file
if opt.output_file == '.':
output_file = opt.output_file
dynamic_axes = {'image': {0: '-1'}, 'class': {0: '-1'}}
dummy_input = torch.randn(16, 3, 160, 160)
torch.onnx.export(model, dummy_input, output_file, input_names=input_names, dynamic_axes=dynamic_axes,
output_names=output_names, verbose=True, opset_version=10)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pretrain', type=str, default='vggface2', help='[casia-webface, vggface2]')
parser.add_argument('--model', type=str, help='model path')
parser.add_argument('--output_file', type=str, default='.', help='output path')
arg = parser.parse_args()
FaceNet_pth2onnx(arg)