from argparse import ArgumentParser
from functools import partial
import torch
from torch import nn
from mmocr.apis import init_detector
from mmcv.onnx import register_extra_symbolics
def _convert_batchnorm(module):
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm2d(
module.num_features, module.eps, module.momentum,
module.affine, module.track_running_stats
)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def pytorch2onnx(model, output_file, opset_version=11):
device = torch.device(type='cpu')
model.to(device).eval()
_convert_batchnorm(model)
img_metas = [[{
'img_shape': (32, 100, 3),
'valid_ratio': 1.0,
'resize_shape': (32, 100, 3),
'img_norm_cfg': {
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225]
},
}]]
model.forward = partial(
model.forward,
img_metas=img_metas,
return_loss=False,
rescale=True)
register_extra_symbolics(opset_version)
dynamic_axes = {'input': {0: '-1'}, 'output': {0: '-1'}}
with torch.no_grad():
torch.onnx.export(
model,
torch.randn(1, 3, 32, 100),
output_file,
input_names=['input'],
output_names=['output'],
export_params=True,
keep_initializers_as_inputs=False,
verbose=False,
opset_version=opset_version,
dynamic_axes=dynamic_axes)
print(f'Successfully exported ONNX model: {output_file}')
def main():
parser = ArgumentParser(
description='Convert models from pytorch to ONNX')
parser.add_argument('--cfg-path', type=str, required=True,
help='path ot config file.')
parser.add_argument('--ckpt-path', type=str, required=True,
help='path to checkpoint.')
parser.add_argument('--onnx-path', type=str, required=True,
help='path to save onnx model.')
parser.add_argument('--opset-version', type=int, default=11,
help='ONNX opset version, default to 11.')
args = parser.parse_args()
device = torch.device(type='cpu')
model = init_detector(args.cfg_path, args.ckpt_path, device=device)
if hasattr(model, 'module'):
model = model.module
pytorch2onnx(model, args.onnx_path, args.opset_version)
if __name__ == '__main__':
main()