import os
import sys
import torch
import argparse
sys.path.append('./DeepLearningExamples/PyTorch/Translation/GNMT/')
from seq2seq.data.tokenizer import Tokenizer
from seq2seq.models.gnmt import GNMT
import seq2seq.data.config as config
def run_pth2onnx(args):
device = torch.device('cpu' if not args.use_cuda else 'cuda')
checkpoint = torch.load(args.model, device)
tokenizer = Tokenizer()
tokenizer.set_state(checkpoint['tokenizer'])
dtype = {
'fp32': torch.FloatTensor,
'fp16': torch.HalfTensor
}
model_config = checkpoint['model_config']
model_config['batch_first'] = True
model_config['vocab_size'] = tokenizer.vocab_size
model_config['max_seq_len'] = args.max_seq_len
model = GNMT(**model_config)
model.load_state_dict(checkpoint['state_dict'])
model.type(dtype[args.math])
model = model.to(device)
model.eval()
input_encoder = torch.ones(1, args.max_seq_len, dtype=torch.int32, device=device)
input_enc_len = torch.tensor([args.max_seq_len], dtype=torch.int32, device=device)
bos = [[config.BOS]]
input_decoder = torch.tensor(bos, dtype=torch.int32, device=device).view(-1, 1)
input = (input_encoder, input_enc_len, input_decoder)
torch.onnx.export(model, input, args.onnx_path,
input_names=['input_encoder', 'input_enc_len', 'input_decoder'],
output_names=['translation'],
export_params=True,
verbose=False,
opset_version=13)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='./gnmt.pth')
parser.add_argument('--onnx_dir', type=str, default='./')
parser.add_argument('--math', default='fp32', choices=['fp16', 'fp32'])
parser.add_argument("--max_seq_len", type=int, default=30)
parser.add_argument('--use_cuda', action='store_true')
args = parser.parse_args()
args.onnx_path = "{}/gnmt_msl{}.onnx".format(args.onnx_dir, args.max_seq_len)
if not os.path.exists(args.onnx_dir):
os.makedirs(args.onnx_dir)
run_pth2onnx(args)
if __name__ == '__main__':
main()