05360171创建于 2022年3月18日历史提交
#-*- coding:utf-8 -*-

# Copyright 2021 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

# http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.

from __future__ import division

from __future__ import absolute_import

from __future__ import print_function



import torch



from ..bbox_utils import decode, nms

from torch.autograd import Function





class Detect(Function):

    """At test time, Detect is the final layer of SSD.  Decode location preds,

    apply non-maximum suppression to location predictions based on conf

    scores and threshold to a top_k number of output predictions for both

    confidence score and locations.

    """



    def __init__(self, cfg):

        self.num_classes = cfg.NUM_CLASSES

        self.top_k = cfg.TOP_K

        self.nms_thresh = cfg.NMS_THRESH

        self.conf_thresh = cfg.CONF_THRESH

        self.variance = cfg.VARIANCE



    def forward(self, loc_data, conf_data, prior_data):

        """

        Args:

            loc_data: (tensor) Loc preds from loc layers

                Shape: [batch,num_priors*4]

            conf_data: (tensor) Shape: Conf preds from conf layers

                Shape: [batch*num_priors,num_classes]

            prior_data: (tensor) Prior boxes and variances from priorbox layers

                Shape: [1,num_priors,4] 

        """

        #print('loc_data device:{}'.format(loc_data.device))

        #print('conf_data device:{}'.format(conf_data.device))

        #print('prior_data device:{}'.format(prior_data.device))

        num = loc_data.size(0)

        num_priors = prior_data.size(0)



        conf_preds = conf_data.view(

            num, num_priors, self.num_classes).transpose(2, 1)

        batch_priors = prior_data.view(-1, num_priors,

                                       4).expand(num, num_priors, 4)

        batch_priors = batch_priors.contiguous().view(-1, 4)



        decoded_boxes = decode(loc_data.view(-1, 4),

                               batch_priors, self.variance)

        decoded_boxes = decoded_boxes.view(num, num_priors, 4)



        output = torch.zeros(num, self.num_classes, self.top_k, 5)

        

        for i in range(num):

            boxes = decoded_boxes[i].clone()

            conf_scores = conf_preds[i].clone()

            for cl in range(1, self.num_classes):

                c_mask = conf_scores[cl].gt(self.conf_thresh)

                scores = conf_scores[cl][c_mask]

                #change code

                if scores.numel() == 0:

                    continue

                l_mask = c_mask.unsqueeze(1).expand_as(boxes)

                boxes_ = boxes[l_mask].view(-1, 4)

                ids, count = nms(boxes_, scores, self.nms_thresh, self.top_k)

                output[i, cl, :count] = torch.cat((scores[ids[:count]].unsqueeze(1),

                                                   boxes_[ids[:count]]), 1)



        return output