import os
import sys
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.onnx
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
def conv_dw(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)
self.model = nn.Sequential(
conv_bn(3, 32, 2),
conv_dw(32, 64, 1),
conv_dw(64, 128, 2),
conv_dw(128, 128, 1),
conv_dw(128, 256, 2),
conv_dw(256, 256, 1),
conv_dw(256, 512, 2),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 1024, 2),
conv_dw(1024, 1024, 1),
nn.AvgPool2d(7),
)
self.fc = nn.Linear(1024, 1000)
def forward(self, x):
x = self.model(x)
x = x.view(-1, 1024)
x = self.fc(x)
return x
def proc_nodes_module(checkpoint, AttrName):
new_state_dict = OrderedDict()
for k, v in checkpoint[AttrName].items():
if k[0:7] == "module.":
name = k[7:]
else:
name = k[0:]
new_state_dict[name] = v
return new_state_dict
def convert_model_to_onnx(model_state, output_file):
model = Net()
if model_state:
model.load_state_dict(model_state)
model.eval()
input_names = ["image"]
output_names = ["class"]
dynamic_axes = {'image': {0: '-1'}, 'class': {0: '-1'}}
dummy_input = torch.randn(32, 3, 224, 224)
torch.onnx.export(model, dummy_input, output_file, input_names=input_names, dynamic_axes=dynamic_axes,
output_names=output_names, opset_version=11, verbose=True)
if __name__ == '__main__':
checkpoint_file = sys.argv[1]
output_file = sys.argv[2]
if os.path.isfile(checkpoint_file):
checkpoint = torch.load(checkpoint_file, map_location='cpu')
print("{} successfully loaded.".format(checkpoint_file))
model_state = proc_nodes_module(checkpoint, 'state_dict')
else:
print("Failed to load checkpoint from {}! Output model with initial state.".format(checkpoint_file))
model_state = OrderedDict()
convert_model_to_onnx(model_state, output_file)