import os
import numpy as np
import argparse
import json
from tqdm import tqdm
def evaluate(args):
input_dir = args.input_dir
save_path = args.save_path
label_path = args.label_path
dtype = args.dtype
label_result = dict()
with open(label_path, 'r') as f:
for label_info in f.readlines():
image_name, label_id = label_info.split(' ')
label_result[os.path.splitext(image_name)[0]] = np.array(int(label_id))
predict_result = dict()
predict_files = os.listdir(input_dir)
predict_files = list(
filter(lambda x:os.path.splitext(x)[1] == ".bin", predict_files))
for predict_name in predict_files:
predict_path = os.path.join(input_dir, predict_name)
predict_data = np.argsort(-1 * np.fromfile(predict_path, dtype=dtype))
predict_result[os.path.splitext(predict_name)[0][:-2]] = {
"top1": predict_data[0], "top5": predict_data[:5]}
total_num = len(label_result)
if len(predict_result) != total_num:
raise ValueError(
"Num of predict results not equal to num of gt results: {} != {}".format(
len(predict_result), total_num
))
num_acc1 = 0
num_acc5 = 0
for file_name in tqdm(predict_result):
gt_label = label_result.get(file_name)
predict_acc1 = predict_result.get(file_name)["top1"]
predict_acc5 = predict_result.get(file_name)["top5"]
num_acc1 += np.sum(predict_acc1 == gt_label)
num_acc5 += np.sum(predict_acc5 == gt_label.repeat(5))
out_result = {
"Top1 Acc": "{:.2f}%".format(num_acc1 * 100 / total_num),
"Top5 Acc": "{:.2f}%".format(num_acc5 * 100 / total_num)
}
print(out_result)
with open(save_path, 'w') as f:
json.dump(out_result, f, ensure_ascii=False, indent=4)
def parse_arguments():
parser = argparse.ArgumentParser(description='Vision Transformer postprocess.')
parser.add_argument('-i', '--input_dir', type=str, required=True,
help='result dir for vision transformer model')
parser.add_argument('-l', '--label_path', type=str, required=True,
help='file path for val label')
parser.add_argument('-s', '--save_path', type=str, default='./result.json',
help='save path for evaluation result')
parser.add_argument('-d', '--dtype', type=str, default='float32',
help='dtype for predict result')
return parser.parse_args()
if __name__ == '__main__':
main_args = parse_arguments()
evaluate(main_args)