import os
import torch
import torch.onnx
import argparse
from transformer.modeling import TinyBertForSequenceClassification
def make_input(args):
"""make the input data to create a model"""
eval_batch_size = args.eval_batch_size
max_seq_length = args.max_seq_length
org_input_ids = torch.ones(eval_batch_size, max_seq_length).long()
org_token_type_ids = torch.ones(eval_batch_size, max_seq_length).long()
org_input_mask = torch.ones(eval_batch_size, max_seq_length).long()
return (org_input_ids, org_token_type_ids, org_input_mask)
def convert(args):
"""convert the files into data"""
model = TinyBertForSequenceClassification.from_pretrained(args.input_model, num_labels = 2)
model.eval()
org_input = make_input(args)
input_names = ['input_ids', 'segment_ids', 'input_mask']
output_names = ['output']
OPERATOR_EXPORT_TYPE = torch._C._onnx.OperatorExportTypes.ONNX
torch.onnx.export(model, org_input, args.output_file, export_params = True,
input_names=input_names, output_names=output_names,
operator_export_type=OPERATOR_EXPORT_TYPE,
opset_version=11, verbose=True)
def main():
"""change the pth files into onnx"""
parser = argparse.ArgumentParser()
parser.add_argument("--input_model",
default=None,
type=str,
required=True,
help="The model(e.g. SST-2 distilled model)dir.")
parser.add_argument("--output_file",
default=None,
type=str,
required=True,
help="The output file of onnx. File name or dir is available.")
parser.add_argument("--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--max_seq_length",
default=64,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--eval_batch_size",
default=1,
type=int,
help="Total batch size for eval.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Set this flag if you are using an uncased model.")
args = parser.parse_args()
os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
convert(args)
if __name__ == "__main__":
main()