from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch.nn.functional as F
import sys
sys.path.append('./RFCN-pytorch.1.0')
import _init_paths
import os
import numpy as np
import argparse
import pprint
import pdb
import time
import cv2
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import pickle
from roi_data_layer.roidb import combined_roidb
from roi_data_layer.roibatchLoader import roibatchLoader
from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from model.rpn.bbox_transform import clip_boxes
from model.roi_layers import nms
from model.rpn.bbox_transform import bbox_transform_inv
from model.utils.net_utils import save_net, load_net, vis_detections
from model.rfcn.resnet_atrous import resnet
import pdb
try:
xrange
except NameError:
xrange = range
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='test the accuracy of RFCN')
parser.add_argument("--image_folder_path", dest="file_path", default="./RFCN-pytorch.1.0/data/VOCdevkit2007/VOC2007/JPEGImages/",help='image of dataset')
parser.add_argument("--input",dest="input", default="./result/dumpOutput_device0/")
parser.add_argument("--output",dest="output", default="./output")
parser.add_argument("--net_input_width", default=1344)
parser.add_argument("--net_input_height", default=1344)
parser.add_argument('--dataset', dest='dataset',help='training dataset',default='pascal_voc', type=str)
parser.add_argument('--cfg', dest='cfg_file', help='optional config file',default='cfgs/res16.yml', type=str)
parser.add_argument('--net', dest='net',help='vgg16, res50, res101, res152',default='res101', type=str)
parser.add_argument('--set', dest='set_cfgs',help='set config keys', default=None,nargs=argparse.REMAINDER)
parser.add_argument('--load_dir', dest='load_dir',help='directory to load models', default="models", type=str)
parser.add_argument('--ls', dest='large_scale',help='whether use large imag scale',action='store_true')
parser.add_argument('--cag', dest='class_agnostic',help='whether perform class_agnostic bbox regression',action='store_true')
parser.add_argument('--parallel_type', dest='parallel_type',help='which part of model to parallel, 0: all, 1: model before roi pooling',default=0, type=int)
parser.add_argument('--bs', dest='batch_size',help='batch_size', default=1, type=int)
parser.add_argument('--vis', dest='vis', help='visualization mode',action='store_true')
args = parser.parse_args()
return args
lr = cfg.TRAIN.LEARNING_RATE
momentum = cfg.TRAIN.MOMENTUM
weight_decay = cfg.TRAIN.WEIGHT_DECAY
if __name__ == '__main__':
args = parse_args()
print('Called with args:')
print(args)
np.random.seed(cfg.RNG_SEED)
args.imdbval_name = "voc_2007_test"
args.set_cfgs = ['ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]']
args.cfg_file = "./RFCN-pytorch.1.0/cfgs/{}_ls.yml".format(args.net) if args.large_scale else "./RFCN-pytorch.1.0/cfgs/{}.yml".format(args.net)
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
cfg.TRAIN.USE_FLIPPED = False
imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdbval_name, False)
imdb.competition_mode(on=True)
im_data = torch.FloatTensor(1)
im_info = torch.FloatTensor(1)
num_boxes = torch.LongTensor(1)
gt_boxes = torch.FloatTensor(1)
with torch.no_grad():
im_data = Variable(im_data)
im_info = Variable(im_info)
num_boxes = Variable(num_boxes)
gt_boxes = Variable(gt_boxes)
start = time.time()
max_per_image = 100
vis = args.vis
if vis:
thresh = 0.05
else:
thresh = 0.0
num_images = len(imdb.image_index)
all_boxes = [[[] for _ in xrange(num_images)]
for _ in xrange(imdb.num_classes)]
output=args.output
if not os.path.exists(output):
os.makedirs(output)
dataset = roibatchLoader(roidb, ratio_list, ratio_index, args.batch_size, \
imdb.num_classes, training=False, normalize = False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
shuffle=False, num_workers=0,
pin_memory=True)
data_iter = iter(dataloader)
det_file = os.path.join(output, 'detections.pkl')
empty_array = np.transpose(np.array([[],[],[],[],[]]), (1,0))
dataset.resize_batch()
npu_result = args.input
with open("./RFCN-pytorch.1.0/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt") as f:
imglist = [x.strip() for x in f.readlines()]
num_images = len(imglist)
for i in range(num_images):
data = next(data_iter)
pad_value = 0
batch_shape = (3, 1344, 1344)
padding_size = [0, batch_shape[-1] - data[0].shape[-1],
0, batch_shape[-2] - data[0].shape[-2]]
im_info.resize_(data[1].size()).copy_(data[1])
gt_boxes.resize_(data[2].size()).copy_(data[2])
num_boxes.resize_(data[3].size()).copy_(data[3])
det_tic = time.time()
def read_data(data_path, input_shape=None):
if data_path.endswith('.bin'):
data = np.fromfile(data_path, dtype=np.float32)
data = data.reshape(input_shape)
elif data_path.endswith('.npy'):
data = np.load(data_path)
return data
rois = torch.from_numpy(
read_data(npu_result+'/'+'{}_0.bin'.format(imglist[i]), [1, 300, 5]))
cls_prob = torch.from_numpy(
read_data(npu_result+'/'+'{}_1.bin'.format(imglist[i]), [1, 300, 21]))
bbox_pred = torch.from_numpy(
read_data(npu_result+'/'+'{}_2.bin'.format(imglist[i]), [1, 300, 84]))
scores = cls_prob.data
boxes = rois.data[:, :, 1:5]
box_deltas = bbox_pred.data
box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS) \
+ torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS)
box_deltas = box_deltas.view(args.batch_size, -1, 4 * len(imdb.classes))
pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)
pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)
pred_boxes /= data[1][0][2]
scores = scores.squeeze()
pred_boxes = pred_boxes.squeeze()
det_toc = time.time()
detect_time = det_toc - det_tic
misc_tic = time.time()
if vis:
im = cv2.imread(imdb.image_path_at(i))
im2show = np.copy(im)
for j in xrange(1, imdb.num_classes):
inds = torch.nonzero(scores[:, j] > thresh).view(-1)
if inds.numel() > 0:
cls_scores = scores[:, j][inds]
_, order = torch.sort(cls_scores, 0, True)
if args.class_agnostic:
cls_boxes = pred_boxes[inds, :]
else:
cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]
cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
cls_dets = cls_dets[order]
keep = nms(cls_boxes[order, :], cls_scores[order], cfg.TEST.NMS)
cls_dets = cls_dets[keep.view(-1).long()]
if vis:
im2show = vis_detections(im2show, imdb.classes[j], cls_dets.cpu().numpy(), 0.3)
all_boxes[j][i] = cls_dets.cpu().numpy()
else:
all_boxes[j][i] = empty_array
if max_per_image > 0:
image_scores = np.hstack([all_boxes[j][i][:, -1]
for j in xrange(1, imdb.num_classes)])
if len(image_scores) > max_per_image:
image_thresh = np.sort(image_scores)[-max_per_image]
for j in xrange(1, imdb.num_classes):
keep = np.where(all_boxes[j][i][:, -1] >= image_thresh)[0]
all_boxes[j][i] = all_boxes[j][i][keep, :]
misc_toc = time.time()
nms_time = misc_toc - misc_tic
sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s \r' \
.format(i + 1, num_images, detect_time, nms_time))
sys.stdout.flush()
if vis:
cv2.imwrite('result.png', im2show)
pdb.set_trace()
with open(det_file, 'wb') as f:
pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)
print('Evaluating detections')
imdb.evaluate_detections(all_boxes, output)
end = time.time()
print("test time: %0.4fs" % (end - start))