import argparse
import os
import json
from tqdm import tqdm
import numpy as np
def read_info_from_json(json_path):
'''
input: json file path
output: dict
'''
if os.path.exists(json_path) is False:
print(json_path, 'is not exist')
with open(json_path, 'r') as f:
load_data = json.load(f)
file_info = load_data['filesinfo']
return file_info
def cre_groundtruth_dict_fromtxt(gtfile_path):
"""
:param filename: file contains the imagename and label number
:return: dictionary key imagename, value is label number
"""
img_gt_dict = {}
with open(gtfile_path, 'r')as f:
for line in f.readlines():
temp = line.strip().split(" ")
imgName = temp[0].split(".")[0]
imgLab = temp[1]
img_gt_dict[imgName] = imgLab
return img_gt_dict
def postProcesss(result_path, gtfile_path):
file_info = read_info_from_json(result_path)
gt_dict = cre_groundtruth_dict_fromtxt(gtfile_path)
outputs = []
labels = []
for i in tqdm(file_info.items()):
res_path = i[1]['outfiles'][0]
in_file_name = i[1]['infiles'][0].split('/')[-1][0:-4]
label = gt_dict[in_file_name]
ndata = np.load(res_path)
res_out = np.mean(ndata, axis=0)
outputs.append(res_out)
labels.append(int(label))
return outputs, labels
def mean_class_accuracy(scores, labels):
pred = np.argmax(scores, axis=1)
cf_mat = confusion_matrix(pred, labels).astype(float)
cls_cnt = cf_mat.sum(axis=1)
cls_hit = np.diag(cf_mat)
mean_class_acc = np.mean(
[hit / cnt if cnt else 0.0 for cnt, hit in zip(cls_cnt, cls_hit)])
return mean_class_acc
def top_k_accuracy(scores, labels, topk=(1, )):
res = []
labels = np.array(labels)[:, np.newaxis]
for k in topk:
max_k_preds = np.argsort(scores, axis=1)[:, -k:][:, ::-1]
match_array = np.logical_or.reduce(max_k_preds == labels, axis=1)
topk_acc_score = match_array.sum() / match_array.shape[0]
res.append(topk_acc_score)
return res
def confusion_matrix(y_pred, y_real, normalize=None):
if normalize not in ['true', 'pred', 'all', None]:
raise ValueError("normalize must be one of {'true', 'pred', "
"'all', None}")
if isinstance(y_pred, list):
y_pred = np.array(y_pred)
if not isinstance(y_pred, np.ndarray):
raise TypeError(
f'y_pred must be list or np.ndarray, but got {type(y_pred)}')
if not y_pred.dtype == np.int64:
raise TypeError(
f'y_pred dtype must be np.int64, but got {y_pred.dtype}')
if isinstance(y_real, list):
y_real = np.array(y_real)
if not isinstance(y_real, np.ndarray):
raise TypeError(
f'y_real must be list or np.ndarray, but got {type(y_real)}')
if not y_real.dtype == np.int64:
raise TypeError(
f'y_real dtype must be np.int64, but got {y_real.dtype}')
label_set = np.unique(np.concatenate((y_pred, y_real)))
num_labels = len(label_set)
max_label = label_set[-1]
label_map = np.zeros(max_label + 1, dtype=np.int64)
for i, label in enumerate(label_set):
label_map[label] = i
y_pred_mapped = label_map[y_pred]
y_real_mapped = label_map[y_real]
confusion_mat = np.bincount(
num_labels * y_real_mapped + y_pred_mapped,
minlength=num_labels**2).reshape(num_labels, num_labels)
with np.errstate(all='ignore'):
if normalize == 'true':
confusion_mat = (
confusion_mat / confusion_mat.sum(axis=1, keepdims=True))
elif normalize == 'pred':
confusion_mat = (
confusion_mat / confusion_mat.sum(axis=0, keepdims=True))
elif normalize == 'all':
confusion_mat = (confusion_mat / confusion_mat.sum())
confusion_mat = np.nan_to_num(confusion_mat)
return confusion_mat
if __name__ == '__main__':
parser = argparse.ArgumentParser('postprocess of r2plus1d')
parser.add_argument('--result_path', type=str,
help='inference results directory path')
parser.add_argument('--gtfile_path', type=str,
help='path to label file')
args = parser.parse_args()
assert os.path.isfile(args.result_path), \
"inference results folder does not exist."
assert os.path.isfile(args.gtfile_path), \
"Groundtruth file does not exist."
outputs, labels = postProcesss(args.result_path, args.gtfile_path)
print('Evaluating top_k_accuracy ...')
top_acc = top_k_accuracy(outputs, labels, topk=(1, 5))
print(f'\ntop{1}_acc\t{top_acc[0]:.4f}')
print(f'\ntop{5}_acc\t{top_acc[1]:.4f}')
print('Evaluating mean_class_accuracy ...')
mean_acc = mean_class_accuracy(outputs, labels)
print(f'\nmean_acc\t{mean_acc:.4f}')