import os
import stat
import sys
import json
import time
import argparse
import numpy as np
np.set_printoptions(threshold=sys.maxsize)
def cre_groundtruth_dict_fromtxt(val_label_path):
"""
读取标签文件信息
:输入:标签文件地址
:输出: dict结构,key:图片名称,value:图片分类
"""
img_label_dict = {}
with open(val_label_path, 'r')as f:
for line in f.readlines():
temp = line.strip().split(" ")
imgName = temp[0].split(".")[0]
imgLab = temp[1]
img_label_dict[imgName] = imgLab
return img_label_dict
def load_statistical_predict_result(filepath, tfile_type):
"""
function:
the prediction result file data extraction
input:
result file:filepath
output:
n_label:numble of label
data_vec: the probabilitie of prediction in the 1000
:return: probabilities, numble of label
"""
f = ""
data = ""
temp = ""
if tfile_type == 'bin':
f = open(filepath, 'rb')
data = np.fromfile(f,dtype=np.int64)
temp = str(data[0]).strip().split(" ")
elif tfile_type == 'txt':
f = open(filepath, 'r')
data = f.readline()
temp = data.strip().split(" ")
n_label = len(temp)
data_vec = np.zeros((n_label), dtype=np.float32)
for ind, prob in enumerate(temp):
data_vec[ind] = np.float32(prob)
return data_vec, n_label
def create_visualization_statistical_result(prediction_file_path,
dest_path, dest_name,
dict_img_label):
"""
:param prediction_file_path: 推理结果路径
:param dest_path: 后处理结果保存的json文件路径
:param dest_name: 结果文件的名字
:param dict_img_label: 真实标签结果,dict形式,key为图片名称,value是标签
:return: NA
"""
flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
modes = stat.S_IWUSR | stat.S_IRUSR
writer = os.fdopen(os.open(os.path.join(result_json_path, json_file_name), flags, modes), 'w')
table_dict = {}
table_dict["title"] = "Overall statistical evaluation"
table_dict["value"] = []
count = 0
n_labels = ""
count_hit = 0
for tfile_name in os.listdir(prediction_file_path):
temp = tfile_name.split('.')[0]
tfile_type = tfile_name.split('.')[1]
index = temp.rfind('_')
img_name = temp[:index]
if (temp[index+1:] == '0'):
count += 1
filepath = os.path.join(prediction_file_path, tfile_name)
ret = load_statistical_predict_result(filepath, tfile_type)
prediction = np.array(ret[0]).astype(np.int16)[0] - 1
n_labels = ret[1]
gt = dict_img_label[img_name]
if (n_labels == 1000):
realLabel = int(gt)
elif (n_labels == 1001):
realLabel = int(gt) + 1
else:
realLabel = int(gt)
else:
continue
if (str(realLabel) == str(prediction)):
count_hit += 1
if 'value' not in table_dict.keys():
print("the item value does not exist!")
else:
table_dict["value"].extend(
[{"key": "Number of images", "value": str(count)},
{"key": "Number of classes", "value": str(n_labels)}])
if count == 0:
accuracy = 0
else:
accuracy = count_hit / count
table_dict["value"].append({"key": " accuracy",
"value": str(round(accuracy * 100, 2)) + '%'})
json.dump(table_dict, writer)
writer.close()
if __name__ == '__main__':
start = time.time()
try:
infer_result_path = sys.argv[1]
annotation_file_path = sys.argv[2]
result_json_path = sys.argv[3]
json_file_name = sys.argv[4]
except IndexError:
print("Stopped!")
exit(1)
if not (os.path.exists(infer_result_path)):
print("infer result path does not exist.")
if not (os.path.exists(annotation_file_path)):
print("Ground truth file does not exist.")
if not (os.path.exists(result_json_path)):
print("Result folder doesn't exist.")
image_label_dict = cre_groundtruth_dict_fromtxt(annotation_file_path)
create_visualization_statistical_result(infer_result_path,
result_json_path, json_file_name,
image_label_dict)
elapsed = (time.time() - start)
print("Time used:", elapsed)