import os
import sys
import glob
import random
from pathlib import Path
from argparse import ArgumentParser
import numpy as np
import torch
from mmcv.onnx import register_extra_symbolics
from mmocr.apis import init_detector
sys.path.append('mmocr/tools')
from deployment.pytorch2onnx import _convert_batchnorm
def create_input_data(prep_dir):
prep_dir = Path(prep_dir)
file_name = random.choice(os.listdir(prep_dir/'texts'))
dummy_input = []
for input_name in ['relations', 'texts', 'mask']:
file_path = prep_dir/input_name/file_name
data = torch.from_numpy(np.load(file_path))
dummy_input.append(data)
return tuple(dummy_input)
def pytorch2onnx(config_file, checkpoint_file, prep_dir,
output_file, opset_version=12):
device = torch.device(type='cpu')
model = init_detector(config_file, checkpoint_file, device=device)
if hasattr(model, 'module'):
model = model.module
model.to(torch.device('cpu')).eval()
_convert_batchnorm(model)
dummy_input = create_input_data(prep_dir)
dynamic_axes = {
'relations': {0: 'num_texts', 1: 'num_texts'},
'texts': {0: 'num_texts', 1: 'num_chars'},
'mask': {0: 'num_texts', 1: 'num_chars'},
'nodes': {0: 'num_texts', 1: 'num_texts'},
'edges': {0: 'num_edges'},
}
model.forward = model.forward_onnx
model.bbox_head.forward = model.bbox_head.forward_onnx
register_extra_symbolics(opset_version)
with torch.no_grad():
torch.onnx.export(
model,
dummy_input,
output_file,
input_names=['relations', 'texts', 'mask'],
output_names=['nodes', 'edges'],
export_params=True,
keep_initializers_as_inputs=False,
verbose=False,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
)
print(f'Successfully exported ONNX model: {output_file}')
def main():
parser = ArgumentParser('Convert MMOCR models from pytorch to ONNX')
parser.add_argument('--config', type=str, help='config file.')
parser.add_argument('--checkpoint', type=str, help='checkpint file.')
parser.add_argument('--prep-dir', type=str, help='path to preprocessed data')
parser.add_argument('--onnx', type=str, help='path to save onnx model.')
parser.add_argument('--opset-version', type=int, default=12,
help='ONNX opset version.')
args = parser.parse_args()
pytorch2onnx(args.config, args.checkpoint, args.prep_dir,
args.onnx, opset_version=args.opset_version)
if __name__ == '__main__':
main()