05360171创建于 2022年3月18日历史提交
# Copyright 2020 Huawei Technologies Co., Ltd

# 

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



import os

import sys

from collections import OrderedDict



import torch

import torch.nn as nn

import torch.onnx





class Net(nn.Module):

    def __init__(self):

        super(Net, self).__init__()



        def conv_bn(inp, oup, stride):

            return nn.Sequential(

                nn.Conv2d(inp, oup, 3, stride, 1, bias=False),

                nn.BatchNorm2d(oup),

                nn.ReLU(inplace=True)

            )



        def conv_dw(inp, oup, stride):

            return nn.Sequential(

                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),

                nn.BatchNorm2d(inp),

                nn.ReLU(inplace=True),



                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),

                nn.BatchNorm2d(oup),

                nn.ReLU(inplace=True),

            )



        self.model = nn.Sequential(

            conv_bn(3, 32, 2),

            conv_dw(32, 64, 1),

            conv_dw(64, 128, 2),

            conv_dw(128, 128, 1),

            conv_dw(128, 256, 2),

            conv_dw(256, 256, 1),

            conv_dw(256, 512, 2),

            conv_dw(512, 512, 1),

            conv_dw(512, 512, 1),

            conv_dw(512, 512, 1),

            conv_dw(512, 512, 1),

            conv_dw(512, 512, 1),

            conv_dw(512, 1024, 2),

            conv_dw(1024, 1024, 1),

            nn.AvgPool2d(7),

        )

        self.fc = nn.Linear(1024, 1000)



    def forward(self, x):

        x = self.model(x)

        x = x.view(-1, 1024)

        x = self.fc(x)

        return x





def proc_nodes_module(checkpoint, AttrName):

    new_state_dict = OrderedDict()

    for k, v in checkpoint[AttrName].items():

        if k[0:7] == "module.":

            name = k[7:]

        else:

            name = k[0:]

        new_state_dict[name] = v

    return new_state_dict





def convert_model_to_onnx(model_state, output_file):

    model = Net()

    if model_state:

        model.load_state_dict(model_state)

    model.eval()

    input_names = ["image"]

    output_names = ["class"]

    dynamic_axes = {'image': {0: '-1'}, 'class': {0: '-1'}}

    dummy_input = torch.randn(32, 3, 224, 224)  # (batch_size, channels, width, height)

    torch.onnx.export(model, dummy_input, output_file, input_names=input_names, dynamic_axes=dynamic_axes,

                      output_names=output_names, opset_version=11, verbose=True)





if __name__ == '__main__':

    checkpoint_file = sys.argv[1]

    output_file = sys.argv[2]



    if os.path.isfile(checkpoint_file):

        checkpoint = torch.load(checkpoint_file, map_location='cpu')

        print("{} successfully loaded.".format(checkpoint_file))

        model_state = proc_nodes_module(checkpoint, 'state_dict')

    else:

        print("Failed to load checkpoint from {}! Output model with initial state.".format(checkpoint_file))

        model_state = OrderedDict()

    convert_model_to_onnx(model_state, output_file)