from fairseq import hub_utils
import torch
import argparse
import os
OPERATOR_EXPORT_TYPE = torch._C._onnx.OperatorExportTypes.ONNX
def pth2onnx(checkpoint_path, checkpoint_file, data_name_or_path, batch_size, pad_length, output_file):
"""convert pth to onnx
Args:
checkpoint_path (str): dir of pth
checkpoint_file (str): name of pth, locate in dir of pth
data_name_or_path (str): dir of data
batch_size (int): batch size
pad_length (int): pad length of sentence
output_file (str): name of generated onnx file
"""
model = hub_utils.from_pretrained(
checkpoint_path,
checkpoint_file=checkpoint_file,
data_name_or_path=data_name_or_path,
bpe="gpt2",
load_checkpoint_heads=True,
)["models"][0]
model.eval()
org_dummy_input = torch.ones(batch_size, pad_length).long()
input_names = ["src_tokens"]
output_names = ["output"]
torch.onnx.export(model, org_dummy_input, output_file, input_names=input_names,
operator_export_type=OPERATOR_EXPORT_TYPE, output_names=output_names, opset_version=11, verbose=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_path',
default="checkpoint/", type=str, help='dir of pth')
parser.add_argument('--checkpoint_file',
default="checkpoint.pt", type=str, help='pth name, locate in dir of pth')
parser.add_argument('--onnx_path', default="onnx_file/",
type=str, help='dir of onnx')
parser.add_argument('--data_name_or_path',
default="SST-2", type=str, help='dir of data')
parser.add_argument('--pad_length', default=128, type=int,
help='fix the pad length of one sentence')
parser.add_argument('--batch_size', default=1, type=int, help='batch size')
args = parser.parse_args()
ONNX_OUTPUT_PATH = os.path.join(
args.onnx_path, "roberta_base_seq{}_bs{}.onnx".format(args.pad_length, args.batch_size))
pth2onnx(args.checkpoint_path, args.checkpoint_file, args.data_name_or_path,
args.batch_size, args.pad_length, ONNX_OUTPUT_PATH)