import argparse
from enum import Enum
import onnx
import torch
from model.model import parsingNet
class ModelType(Enum):
TUSIMPLE = 0
CULANE = 1
class ModelConfig():
def __init__(self, model_type):
if model_type == ModelType.TUSIMPLE:
self.init_tusimple_config()
else:
self.init_culane_config()
def init_tusimple_config(self):
self.img_w = 1280
self.img_h = 720
self.griding_num = 100
self.cls_num_per_lane = 56
def init_culane_config(self):
self.img_w = 1640
self.img_h = 590
self.griding_num = 200
self.cls_num_per_lane = 18
def convert_model(model_path, onnx_file_path, model_type=ModelType.TUSIMPLE):
cfg = ModelConfig(model_type)
net = parsingNet(pretrained=False, backbone='18',
cls_dim=(cfg.griding_num + 1, cfg.cls_num_per_lane, 4),
use_aux=False)
state_dict = torch.load(model_path, map_location='cpu')['model']
compatible_state_dict = {}
for k, v in state_dict.items():
if 'module.' in k:
compatible_state_dict[k[7:]] = v
else:
compatible_state_dict[k] = v
net.load_state_dict(compatible_state_dict, strict=False)
img = torch.zeros(1, 3, 288, 800).to('cpu')
input_name =['input']
output_name =['output']
dynamic_axes = {'input':{0:'-1'}, 'output':{0:'-1'}}
torch.onnx.export(net, img, onnx_file_path,
input_names=input_name,
dynamic_axes=dynamic_axes,
output_names=output_name,
verbose=False)
model = onnx.load(onnx_file_path)
onnx.checker.check_model(model)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'convert original image to bin file.')
parser.add_argument('--model-path', type=str, required=True,
help='path to weights file.')
parser.add_argument('--onnx-path', type=str, required=True,
help='path to save onnx file.')
parser.add_argument('--model-type', type=int, default=0, choices=[0, 1],
help='choice a dataset. {0: Tisimple, 1: Culane}.')
args = parser.parse_args()
convert_model(args.model_path, args.onnx_path, ModelType(args.model_type))
print('ONNX generated.')