from __future__ import print_function
import argparse
import os
import torch
import onnx, onnxruntime
import yaml
import numpy as np
from wenet.transformer.asr_model import init_asr_model
from wenet.transformer.decoder import TransformerDecoder, BiTransformerDecoder
from wenet.utils.checkpoint import load_checkpoint
def to_numpy(xx):
return xx.detach().numpy()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='export your script model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--output_onnx_file', required=True, help='output onnx file')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
model = init_asr_model(configs)
print(model)
load_checkpoint(model, args.checkpoint)
model.eval()
encoder = model.encoder
xs = torch.randn(1, 131, 80, requires_grad=False)
xs_lens = torch.tensor([131], dtype=torch.int32)
onnx_encoder_path = os.path.join(args.output_onnx_file, 'no_flash_encoder.onnx')
torch.onnx.export(encoder,
(xs, xs_lens),
onnx_encoder_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['xs_input', 'xs_input_lens'],
output_names=['xs_output', 'masks_output'],
dynamic_axes={'xs_input': [1], 'xs_input_lens': [0],
'xs_output': [1], 'masks_output': [2]},
verbose=True
)
onnx_model = onnx.load(onnx_encoder_path)
onnx.checker.check_model(onnx_model)
print("encoder onnx_model check pass!")
ort_session = onnxruntime.InferenceSession(onnx_encoder_path)
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(xs),
ort_session.get_inputs()[1].name: to_numpy(xs_lens),
}
ort_outs = ort_session.run(None, ort_inputs)
y1, y2 = encoder(xs, xs_lens)
print("Exported no flash encoder model has been tested with ONNXRuntime, and the result looks good!")
decoder = model.decoder
decoder.set_onnx_mode(True)
onnx_decoder_path = os.path.join(args.output_onnx_file, 'decoder.onnx')
memory = torch.randn(10, 131, 256)
memory_mask = torch.ones(10, 1, 131).bool()
ys_in_pad = torch.randint(0, 4232, (10, 50)).long()
ys_in_lens = torch.tensor([13, 13, 13, 13, 13, 13, 13, 13, 50, 13], dtype=torch.int32)
r_ys_in_pad = torch.randint(0, 4232, (10, 50)).long()
if isinstance(decoder, TransformerDecoder):
torch.onnx.export(decoder,
(memory, memory_mask, ys_in_pad, ys_in_lens),
onnx_decoder_path,
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['memory', 'memory_mask', 'ys_in_pad', 'ys_in_lens'],
output_names=['l_x', 'r_x'],
dynamic_axes={'memory': [1], 'memory_mask':[2], 'ys_in_pad':[1],
'ys_in_lens': [0]},
verbose=True
)
elif isinstance(decoder, BiTransformerDecoder):
print("BI mode")
torch.onnx.export(decoder,
(memory, memory_mask, ys_in_pad, ys_in_lens, r_ys_in_pad),
onnx_decoder_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['memory', 'memory_mask', 'ys_in_pad', 'ys_in_lens', 'r_ys_in_pad'],
output_names=['l_x', 'r_x', 'olens'],
dynamic_axes={'memory': [1], 'memory_mask':[2], 'ys_in_pad':[1],
'ys_in_lens': [0], 'r_ys_in_pad':[1]},
verbose=True
)