05360171创建于 2022年3月18日历史提交
#!/bin/bash

# Copyright 2021 Huawei Technologies Co., Ltd

#

# Licensed under the BSD 3-Clause License  (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

# https://opensource.org/licenses/BSD-3-Clause

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



"""Adapted from:

    @longcw faster_rcnn_pytorch: https://github.com/longcw/faster_rcnn_pytorch

    @rbgirshick py-faster-rcnn https://github.com/rbgirshick/py-faster-rcnn

    Licensed under The MIT License [see LICENSE for details]

"""



from __future__ import print_function

import torch

import torch.nn as nn

import torch.backends.cudnn as cudnn

from torch.autograd import Variable

from data import VOCAnnotationTransform, VOCDetection, BaseTransform

from data import VOC_CLASSES as labelmap

from data import voc_refinedet, detection_collate_test

import torch.utils.data as data



from models.refinedet import build_refinedet



import sys

import os

import time

import argparse

import numpy as np

import pickle

import cv2

from apex import amp



cfg = voc_refinedet['320']



if sys.version_info[0] == 2:

    import xml.etree.cElementTree as ET

else:

    import xml.etree.ElementTree as ET



def str2bool(v):

    return v.lower() in ("yes", "true", "t", "1")



class Timer(object):

    """A simple timer."""

    def __init__(self):

        self.total_time = 0.

        self.calls = 0

        self.start_time = 0.

        self.diff = 0.

        self.average_time = 0.



    def tic(self):

        self.start_time = time.time()



    def toc(self, average=True):

        self.diff = time.time() - self.start_time

        self.total_time += self.diff

        self.calls += 1

        self.average_time = self.total_time / self.calls

        if average:

            return self.average_time

        else:

            return self.diff





def parse_rec(filename):

    """ Parse a PASCAL VOC xml file """

    tree = ET.parse(filename)

    objects = []

    for obj in tree.findall('object'):

        obj_struct = {}

        obj_struct['name'] = obj.find('name').text

        obj_struct['pose'] = obj.find('pose').text

        obj_struct['truncated'] = int(obj.find('truncated').text)

        obj_struct['difficult'] = int(obj.find('difficult').text)

        bbox = obj.find('bndbox')

        obj_struct['bbox'] = [int(bbox.find('xmin').text) - 1,

                              int(bbox.find('ymin').text) - 1,

                              int(bbox.find('xmax').text) - 1,

                              int(bbox.find('ymax').text) - 1]

        objects.append(obj_struct)



    return objects





def get_output_dir(name, phase):

    """Return the directory where experimental artifacts are placed.

    If the directory does not exist, it is created.

    A canonical path is built using the name from an imdb and a network

    (if not None).

    """

    filedir = os.path.join(name, phase)

    if not os.path.exists(filedir):

        os.makedirs(filedir)

    return filedir





def get_voc_results_file_template(image_set, cls):

    filename = 'det_' + image_set + '_%s.txt' % (cls)

    filedir = os.path.join(devkit_path, 'results')

    if not os.path.exists(filedir):

        os.makedirs(filedir)

    path = os.path.join(filedir, filename)

    return path





def write_voc_results_file(all_boxes, dataset, set_type='test'):

    for cls_ind, cls in enumerate(labelmap):

        print('Writing {:s} VOC results file'.format(cls))

        filename = get_voc_results_file_template(set_type, cls)

        with open(filename, 'wt') as f:

            for im_ind, index in enumerate(dataset.ids):

                dets = all_boxes[cls_ind+1][im_ind]

                if dets == []:

                    continue

                # the VOCdevkit expects 1-based indices

                for k in range(dets.shape[0]):

                    f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.

                            format(index[1], dets[k, -1],

                                   dets[k, 0] + 1, dets[k, 1] + 1,

                                   dets[k, 2] + 1, dets[k, 3] + 1))





def do_python_eval(output_dir='output', use_07=True, set_type='test'):

    cachedir = os.path.join(devkit_path, 'annotations_cache')

    aps = []

    # The PASCAL VOC metric changed in 2010

    use_07_metric = use_07

    print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))

    if not os.path.isdir(output_dir):

        os.mkdir(output_dir)

    for i, cls in enumerate(labelmap):

        filename = get_voc_results_file_template(set_type, cls)

        rec, prec, ap = voc_eval(

           filename, annopath, imgsetpath.format(set_type), cls, cachedir,

           ovthresh=0.5, use_07_metric=use_07_metric)

        aps += [ap]

        print('AP for {} = {:.4f}'.format(cls, ap))

        with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:

            pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)

    mAp = np.mean(aps)

    print('Mean AP on ' + set_type + ' set = {:.4f}'.format(mAp))

    return mAp.item()





def voc_ap(rec, prec, use_07_metric=True):

    """ ap = voc_ap(rec, prec, [use_07_metric])

    Compute VOC AP given precision and recall.

    If use_07_metric is true, uses the

    VOC 07 11 point method (default:True).

    """

    if use_07_metric:

        # 11 point metric

        ap = 0.

        for t in np.arange(0., 1.1, 0.1):

            if np.sum(rec >= t) == 0:

                p = 0

            else:

                p = np.max(prec[rec >= t])

            ap = ap + p / 11.

    else:

        # correct AP calculation

        # first append sentinel values at the end

        mrec = np.concatenate(([0.], rec, [1.]))

        mpre = np.concatenate(([0.], prec, [0.]))



        # compute the precision envelope

        for i in range(mpre.size - 1, 0, -1):

            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])



        # to calculate area under PR curve, look for points

        # where X axis (recall) changes value

        i = np.where(mrec[1:] != mrec[:-1])[0]



        # and sum (\Delta recall) * prec

        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])

    return ap





def voc_eval(detpath,

             annopath,

             imagesetfile,

             classname,

             cachedir,

             ovthresh=0.5,

             use_07_metric=True):

    """rec, prec, ap = voc_eval(detpath,

                           annopath,

                           imagesetfile,

                           classname,

                           [ovthresh],

                           [use_07_metric])

Top level function that does the PASCAL VOC evaluation.

detpath: Path to detections

   detpath.format(classname) should produce the detection results file.

annopath: Path to annotations

   annopath.format(imagename) should be the xml annotations file.

imagesetfile: Text file containing the list of images, one image per line.

classname: Category name (duh)

cachedir: Directory for caching the annotations

[ovthresh]: Overlap threshold (default = 0.5)

[use_07_metric]: Whether to use VOC07's 11 point AP computation

   (default True)

"""

# assumes detections are in detpath.format(classname)

# assumes annotations are in annopath.format(imagename)

# assumes imagesetfile is a text file with each line an image name

# cachedir caches the annotations in a pickle file

# first load gt

    if not os.path.isdir(cachedir):

        os.mkdir(cachedir)

    cachefile = os.path.join(cachedir, 'annots.pkl')

    # read list of images

    with open(imagesetfile, 'r') as f:

        lines = f.readlines()

    imagenames = [x.strip() for x in lines]

    if not os.path.isfile(cachefile):

        # load annots

        recs = {}

        for i, imagename in enumerate(imagenames):

            recs[imagename] = parse_rec(annopath % (imagename))

            if i % 100 == 0:

                print('Reading annotation for {:d}/{:d}'.format(

                   i + 1, len(imagenames)))

        # save

        print('Saving cached annotations to {:s}'.format(cachefile))

        with open(cachefile, 'wb') as f:

            pickle.dump(recs, f)

    else:

        # load

        with open(cachefile, 'rb') as f:

            recs = pickle.load(f)



    # extract gt objects for this class

    class_recs = {}

    npos = 0

    for imagename in imagenames:

        R = [obj for obj in recs[imagename] if obj['name'] == classname]

        bbox = np.array([x['bbox'] for x in R])

        difficult = np.array([x['difficult'] for x in R]).astype(np.bool)

        det = [False] * len(R)

        npos = npos + sum(~difficult)

        class_recs[imagename] = {'bbox': bbox,

                                 'difficult': difficult,

                                 'det': det}



    # read dets

    detfile = detpath.format(classname)

    with open(detfile, 'r') as f:

        lines = f.readlines()

    if any(lines) == 1:



        splitlines = [x.strip().split(' ') for x in lines]

        image_ids = [x[0] for x in splitlines]

        confidence = np.array([float(x[1]) for x in splitlines])

        BB = np.array([[float(z) for z in x[2:]] for x in splitlines])



        # sort by confidence

        sorted_ind = np.argsort(-confidence)

        sorted_scores = np.sort(-confidence)

        BB = BB[sorted_ind, :]

        image_ids = [image_ids[x] for x in sorted_ind]



        # go down dets and mark TPs and FPs

        nd = len(image_ids)

        tp = np.zeros(nd)

        fp = np.zeros(nd)

        for d in range(nd):

            R = class_recs[image_ids[d]]

            bb = BB[d, :].astype(float)

            ovmax = -np.inf

            BBGT = R['bbox'].astype(float)

            if BBGT.size > 0:

                # compute overlaps

                # intersection

                ixmin = np.maximum(BBGT[:, 0], bb[0])

                iymin = np.maximum(BBGT[:, 1], bb[1])

                ixmax = np.minimum(BBGT[:, 2], bb[2])

                iymax = np.minimum(BBGT[:, 3], bb[3])

                iw = np.maximum(ixmax - ixmin, 0.)

                ih = np.maximum(iymax - iymin, 0.)

                inters = iw * ih

                uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) +

                       (BBGT[:, 2] - BBGT[:, 0]) *

                       (BBGT[:, 3] - BBGT[:, 1]) - inters)

                overlaps = inters / uni

                ovmax = np.max(overlaps)

                jmax = np.argmax(overlaps)



            if ovmax > ovthresh:

                if not R['difficult'][jmax]:

                    if not R['det'][jmax]:

                        tp[d] = 1.

                        R['det'][jmax] = 1

                    else:

                        fp[d] = 1.

            else:

                fp[d] = 1.



        # compute precision recall

        fp = np.cumsum(fp)

        tp = np.cumsum(tp)

        rec = tp / float(npos)

        # avoid divide by zero in case the first detection matches a difficult

        # ground truth

        prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)

        ap = voc_ap(rec, prec, use_07_metric)

    else:

        rec = -1.

        prec = -1.

        ap = -1.



    return rec, prec, ap





def test_net(save_folder, net, cuda, dataset, dataloador, transform, top_k,

             im_size=300, thresh=0.05, set_type='test'):

    num_images = len(dataset)

    # all detections are collected into:

    #    all_boxes[cls][image] = N x 5 array of detections in

    #    (x1, y1, x2, y2, score)

    all_boxes = [[[] for _ in range(num_images)]

                 for _ in range(len(labelmap)+1)]



    # timers

    _t = {'im_detect': Timer(), 'misc': Timer()}

    output_dir = get_output_dir('ssd300_120000', set_type)

    det_file = os.path.join(output_dir, 'detections.pkl')



    detection_list, h_list, w_list = [], [], []

    for i, item in enumerate(dataloador):

        # im, gt, h, w = dataset.pull_item(i)

        # print(im.shape,h,w)

        

        x, _, h, w  = item

        bs, _, _, _ = x.size()

        # print(x.size())

        # x = Variable(im.unsqueeze(0))

        if cfg['cuda']:

            x = x.cuda()

        elif  cfg['npu']:

            x = x.npu()

        _t['im_detect'].tic()

        detections = net(x).data

        detect_time = _t['im_detect'].toc(average=False)

        detection_list.append(detections.cpu())

        h_list.extend(h)

        w_list.extend(w)

        print('im_detect: {:d}/{:d} {:.3f}s'.format(i + 1,

                                                    len(dataloador), detect_time))

    # skip j = 0, because it's the background class

    strat_time = time.time()

    detections = torch.cat(detection_list, dim=0)

    for idx in range(detections.size(0)):

        h, w = h_list[idx], w_list[idx]

        for j in range(1, detections.size(1)):

            #dets = detections[0, j, :]

            dets = detections[idx, j, :]

            mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t()

            dets = torch.masked_select(dets, mask).view(-1, 5)

            if dets.size(0) == 0:

                continue

            boxes = dets[:, 1:]

            boxes[:, 0] *= w

            boxes[:, 2] *= w

            boxes[:, 1] *= h

            boxes[:, 3] *= h

            scores = dets[:, 0].cpu().numpy()

            cls_dets = np.hstack((boxes.cpu().numpy(),

                                scores[:, np.newaxis])).astype(np.float32,

                                                                copy=False)

            # all_boxes[j][i] = cls_dets

            all_boxes[j][idx] = cls_dets 

    end_time = time.time()

    print('spend time: %.3fs'%(end_time-strat_time))



    with open(det_file, 'wb') as f:

        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)



    print('Evaluating detections')

    mAp = evaluate_detections(all_boxes, output_dir, dataset, set_type=set_type)

    return mAp





def evaluate_detections(box_list, output_dir, dataset, set_type='test'):

    write_voc_results_file(box_list, dataset, set_type=set_type)

    mAp = do_python_eval(output_dir, set_type=set_type)

    return mAp









if __name__ == '__main__':

    pth_path, data_path = sys.argv[1:3]

    if not os.path.exists(cfg['save_folder']):

        os.makedirs(cfg['save_folder'])



    if torch.cuda.is_available():

        if cfg['cuda']:

            torch.set_default_tensor_type('torch.cuda.FloatTensor')

    if not cfg['cuda']:

        print("WARNING: It looks like you have a CUDA device, but aren't using \

              CUDA.  Run with --cuda for optimal eval speed.")

        torch.set_default_tensor_type('torch.FloatTensor')

    else:

        torch.set_default_tensor_type('torch.FloatTensor')



    annopath = os.path.join(data_path, 'VOC2007', 'Annotations', '%s.xml')

    imgpath = os.path.join(data_path, 'VOC2007', 'JPEGImages', '%s.jpg')

    imgsetpath = os.path.join(data_path, 'VOC2007', 'ImageSets',

                          'Main', '{:s}.txt')

    YEAR = '2007'

    devkit_path = data_path + 'VOC' + YEAR

    dataset_mean = (104, 117, 123)

    # load net

    num_classes = len(labelmap) + 1                      # +1 for background

    net = build_refinedet('test', int(cfg['input_size']), num_classes, batch_norm=True)            # initialize SSD



    



    # load data

    set_type = 'test'

    dataset = VOCDetection(root=data_path,

                           image_sets=[('2007', set_type)],

                           transform=BaseTransform(int(cfg['input_size']), dataset_mean),

                           target_transform=VOCAnnotationTransform(),

                           dataset_name='VOC07test')



    if cfg['cuda']:

        net = net.cuda()

        cudnn.benchmark = True

    elif cfg['npu']:

        net = net.npu()

        cudnn.benchmark = True

    

    if cfg['amp']:

        net = amp.initialize(net, opt_level='O1', loss_scale=128)





    net.eval()

    data_loader = data.DataLoader(dataset, 

                                  batch_size=128,

                                  num_workers=16,

                                  shuffle=False, 

                                  collate_fn=detection_collate_test,

                                  pin_memory=True)

    save_path = './RefineDet320_bn/RefineDet320_VOC_231.pth'

    save_path = pth_path

    net.load_state_dict(torch.load(save_path, map_location='cpu'))

    print('Finished loading model!   ' + save_path)



    # evaluation

    with torch.no_grad():

        mAp = test_net(cfg['save_folder'], net, cfg['cuda'], dataset, data_loader,

                BaseTransform(int(cfg['input_size']), dataset_mean), cfg['top_k'], int(cfg['input_size']),

                thresh=cfg['confidence_threshold'])