import torch
import os
import numpy as np
import json
import sys
import argparse
parser = argparse.ArgumentParser(description="postprocess")
parser.add_argument('--c', type=int, default=1,
help="convert result to json")
parser.add_argument('--i', type=str, default='om_res',
help="input om res dir path")
parser.add_argument('--o', type=str, default='om_res.json',
help="output om res json file")
parser.add_argument('--t', type=str, default='UCF101bin_batch_info.json',
help="target json file path")
class EvalMetric(object):
def __init__(self, name, **kwargs):
self.name = str(name)
self.reset()
def update(self, preds, labels):
raise NotImplementedError()
def reset(self):
self.num_inst = 0
self.sum_metric = 0.0
def get(self):
if self.num_inst == 0:
return (self.name, float('nan'))
else:
return (self.name, format(self.sum_metric / self.num_inst,'0.4f'))
class Accuracy(EvalMetric):
"""Computes accuracy classification score.
"""
def __init__(self, name='accuracy', topk=1):
super(Accuracy, self).__init__(name)
self.topk = topk
def update(self, preds, labels):
preds = [torch.tensor(preds)]
labels = [torch.tensor(labels)]
for pred, label in zip(preds, labels):
assert self.topk <= pred.shape[1], \
"topk({}) should no larger than the pred dim({})".format(self.topk, pred.shape[1])
_, pred_topk = pred.topk(self.topk, 1, True, True)
pred_topk = pred_topk.t()
correct = pred_topk.eq(label.view(1, -1).expand_as(pred_topk))
self.sum_metric += float(correct.contiguous().view(-1).float().sum(0, keepdim=True).numpy())
self.num_inst += label.shape[0]
def postProcess(result_path, class_num, output_path):
class_num = int(class_num)
datatmp = os.listdir(result_path)[0]
bin_list = os.listdir(os.path.join(result_path,datatmp))
outputs = []
labels = []
for bin_dir in bin_list:
bin_path = os.path.join(result_path,datatmp, bin_dir)
name = bin_dir.split('.')[0]
label = '_'.join(name.split('_')[2:4])
labels.append(label)
output = np.loadtxt(bin_path).reshape(-1, class_num)
outputs.append(output.tolist())
res = dict(zip(labels, outputs))
with open(output_path, 'w') as f:
json.dump(res, f)
return res
def eval(res, target_file, output):
with open(target_file, 'r') as f:
targets = json.load(f)
if targets is None:
print('targets can not load : error')
return
error_num = 0
error_list = []
acc1 = Accuracy(name='acc-1', topk=1)
acc5 = Accuracy(name='acc-5', topk=5)
for label, value in res.items():
if label in targets:
acc5.update(preds=value, labels=targets[label])
acc1.update(preds=value, labels=targets[label])
else:
error_num += 1
error_list.append(label)
if error_num > 0:
print('error_num', error_num)
print(error_list)
print(acc1.get())
print(acc5.get())
result = {'acc1':acc1.get(),'acc5':acc5.get()}
with open(output_path+'_result.txt','w') as f:
json.dump(result,f)
if __name__ == '__main__':
args = parser.parse_args()
input_path = args.i
target_file = args.t
output_path = args.o
print(args.c)
ss = output_path.split('.')
if len(ss)==1:
output_path = output_path+'.json'
elif ss[1] != 'json':
print('file should be json')
if args.c == 1:
ress = postProcess(input_path, 101, output_path)
with open(output_path, 'r') as f:
ress = json.load(f)
eval(ress, target_file, output_path)