import argparse
import os
import json
import numpy as np
def read_info_from_json(json_path):
"""
此函数用于读取inference_tools生成的json文件
输入:json文件地址
输出:dict结构;为原始的json转换出来的结构
"""
if not os.path.exists(json_path):
print(json_path, 'is not exist')
file_info = None
with open(json_path, 'r') as f:
load_data = json.load(f)
file_info = load_data['filesinfo']
return file_info
def postProcesss(result_path):
file_info = read_info_from_json(result_path)
outputs = []
for i in file_info.items():
res_path = i[1]['outfiles'][0]
ndata = np.loadtxt(res_path)
outputs.append(ndata)
return outputs
def cre_groundtruth_dict_fromtxt(gtfile_path):
labels = []
with open(gtfile_path, 'r')as f:
for line in f:
temp = line.strip().split(" ")
labels.append(int(temp[1]))
return labels
def top_k_accuracy(scores, labels, topk=(1, )):
res = []
labels = np.array(labels)[:, np.newaxis]
image_nums = 50000
for k in topk:
max_k_preds = np.argsort(scores[:image_nums], axis=1)[:, -k:][:, ::-1]
match_array = np.logical_or.reduce(max_k_preds == labels, axis=1)
topk_acc_score = match_array.sum() / match_array.shape[0]
res.append(topk_acc_score)
return res
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='postprocess of vgg16')
parser.add_argument('--result_path')
parser.add_argument('--gtfile_path')
opt = parser.parse_args()
outputs = postProcesss(opt.result_path)
labels = cre_groundtruth_dict_fromtxt(opt.gtfile_path)
print('Evaluating top_k_accuracy ...')
top_acc = top_k_accuracy(outputs, labels, topk=(1, 5))
print(f'\ntop{1}_acc\t{top_acc[0]:.4f}')
print(f'\ntop{5}_acc\t{top_acc[1]:.4f}')