import os
import argparse
import numpy as np
import json
def process_pred(pred_file):
"""process predictions
Args:
pred_file (str): predictions
Returns:
int: predicted label
"""
data = np.loadtxt(pred_file)
assert len(data) == 1000
pred_label = data.argmax()
return pred_label
def pred_eval(label_file, pred_dir):
"""evaluate predictions
Args:
label_file (str): file name of label file
pred_dir (str): prediction folder
"""
with open(label_file, 'r') as f:
gt = json.load(f)
output_file_list = os.listdir(pred_dir)
result = []
for output_file in output_file_list:
output_name = '_'.join(output_file.split('_')[:3])
gt_label = gt[output_name]
pred_label = process_pred(os.path.join(pred_dir, output_file))
result.append(gt_label == pred_label)
print('Validation Results for', pred_dir)
print("Top 1 Accuracy: {:.1%}".format(sum(result) / len(result)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--label_file", default="")
parser.add_argument("--pred_dir", default="")
args = parser.parse_args()
pred_eval(args.label_file, args.pred_dir)