05360171创建于 2022年3月18日历史提交
# Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
# from base_model import L2Norm, ResNet
from resnet import ResNet, resnet34

# from nhwc.conv import Conv2d_NHWC

class SSD300(nn.Module):
    """
        Build a SSD module to take 300x300 image input,
        and output 8732 per class bounding boxes

        label_num: number of classes (including background 0)
    """
    def __init__(self, args, label_num, use_nhwc=False, pad_input=False, bn_group=1, pretrained=True):

        super(SSD300, self).__init__()

        self.label_num = label_num
        self.use_nhwc = use_nhwc
        self.pad_input = pad_input
        self.bn_group = bn_group

        # Explicitly RN34 all the time
        out_channels = 256
        out_size = 38
        self.out_chan = [out_channels, 512, 512, 256, 256, 256]

        # self.model = ResNet(self.use_nhwc, self.pad_input, self.bn_group)

        rn_args = {
            'bn_group' : bn_group,
            'pad_input' : pad_input,
            'nhwc' : use_nhwc,
            'pretrained' : pretrained,
            'ssd_mods' : True,
        }

        self.model = resnet34(**rn_args)

        self._build_additional_features()

        padding_channels_to = 8
        self._build_multibox_heads(use_nhwc, padding_channels_to)

        # after l2norm, conv7, conv8_2, conv9_2, conv10_2, conv11_2
        # classifer 1, 2, 3, 4, 5 ,6

        # intitalize all weights
        with torch.no_grad():
            self._init_weights()

    def _build_multibox_heads(self, use_nhwc, padding_channels_to=8):
        self.num_defaults = [4, 6, 6, 6, 4, 4]
        self.mbox = []
        self.padding_amounts = []

        # if self.use_nhwc:
        #     conv_fn = Conv2d_NHWC
        # else:
        #     conv_fn = nn.Conv2d
        conv_fn = nn.Conv2d                                    
        # Multiple to pad channels to
        for nd, oc in zip(self.num_defaults, self.out_chan):
            # Horizontally fuse loc and conf convolutions
            my_num_channels = nd*(4+self.label_num)
            # if self.use_nhwc:
            #     # Want to manually pad to get HMMA kernels in NHWC case
            #     padding_amount = padding_channels_to - (my_num_channels % padding_channels_to)
            # else:
            #
            padding_amount = 0                                
            self.padding_amounts.append(padding_amount)
            self.mbox.append(conv_fn(oc, my_num_channels + padding_amount, kernel_size=3, padding=1))

        self.mbox = nn.ModuleList(self.mbox)


    """
    Output size from RN34 is always 38x38
    """
    def _build_additional_features(self):
        self.additional_blocks = []

        # if self.use_nhwc:
        #     conv_fn = Conv2d_NHWC
        # else:
        #     conv_fn = nn.Conv2d
        conv_fn = nn.Conv2d                                    
        def build_block(input_channels, inter_channels, out_channels, stride=1, pad=0):
            return nn.Sequential(
                conv_fn(input_channels, inter_channels, kernel_size=1),
                nn.ReLU(inplace=True),
                conv_fn(inter_channels, out_channels, kernel_size=3, stride=stride, padding=pad),
                nn.ReLU(inplace=True)
            )

        strides = [2, 2, 2, 1, 1]
        intermediates = [256, 256, 128, 128, 128]
        paddings = [1, 1, 1, 0, 0]

        for i, im, o, stride, pad in zip(self.out_chan[:-1], intermediates, self.out_chan[1:], strides, paddings):
            self.additional_blocks.append(build_block(i, im, o, stride=stride, pad=pad))

        self.additional_blocks = nn.ModuleList(self.additional_blocks)

    def _init_additional_weights(self):
        addn_blocks = [*self.additional_blocks]
        # Need to handle additional blocks differently in NHWC case due to xavier initialization
        for layer in addn_blocks:
            for param in layer.parameters():
                if param.dim() > 1:
                    # if self.use_nhwc:
                    #     # xavier_uniform relies on fan-in/-out, so need to use NCHW here to get
                    #     # correct values (K, R) instead of the correct (K, C)
                    #     nchw_param_data = param.data.permute(0, 3, 1, 2).contiguous()
                    #     nn.init.xavier_uniform_(nchw_param_data)
                    #     # Now permute correctly-initialized param back to NHWC
                    #     param.data.copy_(nchw_param_data.permute(0, 2, 3, 1).contiguous())
                    # else:
                    #     nn.init.xavier_uniform_(param)
                    nn.init.xavier_uniform_(param)                       
    def _init_multibox_weights(self):
        layers = [ *self.mbox ]

        for layer, default, padding in zip(layers, self.num_defaults, self.padding_amounts):
            for param in layer.parameters():
                if param.dim() > 1 and self.use_nhwc:
                    # Need to be careful - we're initialising [loc, conf, pad] with
                    # all 3 needing to be treated separately
                    conf_channels = default * self.label_num
                    loc_channels  = default * 4
                    pad_channels  = padding
                    # Split the parameter into separate parts along K dimension
                    conf, loc, pad = param.data.split([conf_channels, loc_channels, pad_channels], dim=0)

                    # Padding should be zero
                    pad_data = torch.zeros_like(pad.data)

                    def init_loc_conf(p):
                        p_data = p.data.permute(0, 3, 1, 2).contiguous()
                        nn.init.xavier_uniform_(p_data)
                        p_data = p_data.permute(0, 2, 3, 1).contiguous()
                        return p_data

                    # Location and confidence data
                    loc_data = init_loc_conf(loc)
                    conf_data = init_loc_conf(conf)

                    # Put the full weight together again along K and copy
                    param.data.copy_(torch.cat([conf_data, loc_data, pad_data], dim=0))
                elif param.dim() > 1:
                    nn.init.xavier_uniform_(param)

    def _init_weights(self):
        self._init_additional_weights()
        self._init_multibox_weights()

    # Shape the classifier to the view of bboxes
    def bbox_view(self, src, mbox):
        locs = []
        confs = []
        for s, m, num_defaults, pad in zip(src, mbox, self.num_defaults, self.padding_amounts):
            mm = m(s)
            conf_channels = num_defaults * self.label_num
            loc_channels  = num_defaults * 4

            # if self.use_nhwc:
            #     conf, loc, _ = mm.split([conf_channels, loc_channels, pad], dim=3)
            #     conf, loc = conf.contiguous(), loc.contiguous()
            #     # We now have unfused [N, H, W, C]
            #     # Layout is a little awkward here.
            #     # Take C = c * d, then we actually have:
            #     # [N, H, W, c*d]
            #     # flatten HW first:
            #     #   [N, H, W, c*d] -> [N, HW, c*d]
            #     locs.append(
            #         loc.view(s.size(0), -1, 4 * num_defaults).permute(0, 2, 1).contiguous().view(loc.size(0), 4, -1))
            #     confs.append(
            #         conf.view(s.size(0), -1, self.label_num * num_defaults).permute(0, 2, 1).contiguous().view(conf.size(0), self.label_num, -1))
            # else:
            #     conf, loc = mm.split([conf_channels, loc_channels], dim=1)
            #     conf, loc = conf.contiguous(), loc.contiguous()
            #     # flatten the anchors for this layer
            #     locs.append(loc.view(s.size(0), 4, -1))
            #     confs.append(conf.view(s.size(0), self.label_num, -1))        
            conf, loc = mm.split([conf_channels, loc_channels], dim=1)
            conf, loc = conf.contiguous(), loc.contiguous()
            # flatten the anchors for this layer
            locs.append(loc.view(s.size(0), 4, -1))
            confs.append(conf.view(s.size(0), self.label_num, -1))
        cat_dim = 2
        locs, confs = torch.cat(locs, cat_dim), torch.cat(confs, cat_dim)

        return locs, confs

    def forward(self, data):

        layers = self.model(data)

        # last result from network goes into additional blocks
        x = layers
        # If necessary, transpose back to NCHW
        additional_results = []
        for i, l in enumerate(self.additional_blocks):
            x = l(x)
            additional_results.append(x)

        # do we need the l2norm on the first result?
        src = [layers, *additional_results]
        # Feature Map 38x38x4, 19x19x6, 10x10x6, 5x5x6, 3x3x4, 1x1x4

        locs, confs = self.bbox_view(src, self.mbox)

        # For SSD 300, shall return nbatch x 8732 x {nlabels, nlocs} results
        return locs, confs