import argparse
import os
from pprint import pprint
from pathlib import Path
import numpy as np
import torch
from mmocr.models.builder import build_convertor
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
def parse_args():
parser = argparse.ArgumentParser(
description='postprocess.')
parser.add_argument('--result-dir', type=str, required=True,
help='output directory of inferencing.')
parser.add_argument('--gt-path', type=str, required=True,
help='path to groundtruth file.')
args = parser.parse_args()
return args
def get_pred_texts(result_dir):
label_convertor = build_convertor({
'type': 'AttnConvertor', 'dict_type': 'DICT90',
'with_unknown': True, 'max_seq_len': 25
})
result_files = [
res_path.__str__() for res_path in Path(result_dir).iterdir()
]
result_files.sort()
pred_texts = []
file_stems = []
for res_path in result_files:
stem = Path(res_path).name.replace('_0.bin', '')
if stem.startswith('padding') or stem.startswith('sumary'):
continue
result = np.fromfile(res_path, np.float32).reshape(1, 25, 92)
result = torch.from_numpy(result)
label_indexes, label_scores = label_convertor.tensor2idx(result)
label_strings = label_convertor.idx2str(label_indexes)
pred_texts.extend(label_strings)
file_stems.append(stem)
return pred_texts, file_stems
def get_gt_texts(gt_path, file_stems):
gt_dict = {}
for line in open(gt_path, 'r', encoding='utf-8'):
img_path, text = line.strip().split(' ', 1)
gt_dict[Path(img_path).stem] = text
gt_texts = [gt_dict[stem] for stem in file_stems if stem in gt_dict]
assert len(gt_texts) == len(file_stems)
return gt_texts
def evaluate(result_dir, gt_path, out_path=None):
pred_texts, file_stems = get_pred_texts(result_dir)
gt_texts = get_gt_texts(gt_path, file_stems)
eval_results = eval_ocr_metric(pred_texts, gt_texts, metric='acc')
pprint(eval_results)
if __name__ == '__main__':
args = parse_args()
evaluate(args.result_dir, args.gt_path)