import sys
import os
import multiprocessing
import torch
import numpy as np
from tqdm import tqdm
sys.path.append('./AdvancedEAST-PyTorch')
from utils import eval_pre_rec_f1
def eval_func(i, out, gt_xy_list):
eval_p_r_f = eval_pre_rec_f1()
eval_p_r_f.add(out, gt_xy_list)
mPre, mRec, mF1_score = eval_p_r_f.val()
np.save('eval_temp/{}.npy'.format(str(i)), [mPre, mRec, mF1_score])
eval_p_r_f.reset()
def eval(data_dir, result_dir):
train_label_dir = os.path.join(data_dir, 'labels_3T736')
img_list = os.listdir(result_dir)
thread_pool = multiprocessing.Pool(multiprocessing.cpu_count())
i = 0
pbar = tqdm(total=len(img_list))
update = lambda *args: pbar.update()
for img_file in img_list:
gt_xy_list = [np.load(os.path.join(train_label_dir, img_file[:-6] + '.npy'))]
out = np.fromfile(os.path.join(result_dir, img_file), dtype=np.float32)
size = 736 // 4
out.shape = (1, 7, size, size)
out = torch.from_numpy(out)
thread_pool.apply_async(eval_func, args=(i, out, gt_xy_list), callback=update)
i += 1
thread_pool.close()
thread_pool.join()
mPre = mRec = mF1_score = 0
size = len(img_list)
for i in range(size):
arr = np.load('eval_temp/{}.npy'.format(str(i)))
mPre += arr[0]
mRec += arr[1]
mF1_score += arr[2]
mPre /= size
mRec /= size
mF1_score /= size
print('precision:{:.2f}% recall:{:.2f}% f1-score:{:.2f}%'.format(mPre, mRec, mF1_score))
if __name__ == '__main__':
data_dir = sys.argv[1]
result_dir = sys.argv[2]
if not os.path.exists('eval_temp'):
os.mkdir('eval_temp')
eval(data_dir, result_dir)