import argparse
import os
import numpy as np
import json
import tqdm
from tqdm import *
def get_args_parser():
parser = argparse.ArgumentParser('XCiT pre-post process scipt', add_help=False)
parser.add_argument('--result_path', default='result', help='Om result path')
parser.add_argument('--target_file', default='target.json', help='label file respond')
parser.add_argument('--save_file', required=True, default='./result.json')
return parser
def postprocess(result_path, target_file, save_file):
re_files = os.listdir(result_path)
labels = json.load(open(target_file, 'rb'))
top1_cnt = 0.0
top5_cnt = 0.0
pbar = tqdm(total = 50000)
for file in re_files:
result = np.loadtxt(os.path.join(result_path, file))
img_name = file.split('.')[0].split('_')
img_name = '{}_{}_{}_1'.format(img_name[0], img_name[1], img_name[2])
ans = labels[img_name]
if ans == result.argmax():
top1_cnt = top1_cnt + 1
top5_cnt = top5_cnt + 1
else:
for p in range(5):
if ans == result.argmax():
top5_cnt = top5_cnt + 1
break
result[result.argmax()] = 0
pbar.update(1)
ans = {}
ans['Accuracy@1'] = top1_cnt / len(re_files)
ans['Accuracy@5'] = top5_cnt / len(re_files)
print(ans)
writer = open(save_file, 'w')
json.dump(ans, writer)
writer.close()
if __name__ == '__main__':
parser = get_args_parser()
args = parser.parse_args()
result_path = args.result_path
target_file = args.target_file
postprocess(result_path, target_file, args.save_file)