import os
import glob
import json
import numpy as np
import argparse
import editdistance
from unittest import result
from pythainlp.util import normalize
def load_vocab_map(vocab_path):
'''
Load the token-to-index mapping dictionary and convert it to the index-to-token mapping.
Args:
vocab_path: filepath of vocabulary for decode
Returns:
index-to-token mapping dictionary
'''
with open(vocab_path, "r", encoding="utf-8") as f:
token_index = eval(f.read())
index_token = dict((v, k) for k, v in token_index.items())
return index_token
def remove_adjacent(item):
nums = list(item)
a = nums[:1]
for item in nums[1:]:
if item != a[-1]:
a.append(item)
return ''.join(a)
def find_files(path, pattern="*.flac"):
'''Find files recursively according to `pattern`'''
filenames = []
for filename in glob.iglob(f'{path}/**/*{pattern}', recursive=True):
filenames.append(filename)
return filenames
def post_process(model_out_dir, index_token_map, batch_i_filename_map):
filenames = find_files(model_out_dir, "*.bin")
results = {}
for file in filenames:
with open(file, 'rb') as f:
buf = f.read()
batch_data = np.frombuffer(buf, dtype=np.float32)
batch_data = batch_data.reshape(-1, 312, 32)
batch_i = os.path.basename(file).split('.')[0][0:-2]
for key, data in zip (batch_i_filename_map[batch_i], batch_data):
prediction = np.argmax(data, axis=-1)
_t1 = ''.join([index_token_map[i] for i in list(prediction)])
text = normalize(''.join([remove_adjacent(j) for j in _t1.split("<pad>")]))
text = text.replace('|', ' ')
results[key] = text
return results
def get_id_text_dict(file_path):
'''
Load text file and create a mapping between sound file name and text
'''
dict = {}
with open(file_path, mode='r', encoding='utf-8') as f:
for line in f:
line = line.strip().split(' ')
dict[line[0]] = ' '.join(line[1:])
return dict
def eval_accuracy(ground_truth_dict, infer_dict):
'''
The inference data is compared with the real data to get the error rate
Args:
ground_truth_dict: a dictionary mapping from audio file to its corresponding ground truth text.
infer_dict: a dictionary mapping from audio file to its corresponding infered text.
'''
totol_char = 0
edit_distance = 0
for key, text in infer_dict.items():
try:
ground_truth_text = ground_truth_dict[key]
except Exception:
print("skip error key: ", key)
continue
totol_char += len(ground_truth_text)
edit_distance += editdistance.distance(text, ground_truth_text)
return edit_distance/totol_char
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', help='om model output',
default="./data/bin_om_out_bs1")
parser.add_argument('--batch_size', help='batch size', default=1)
args = parser.parse_args()
base_dir = args.input
batch_size = int(args.batch_size)
batch_i_filename_map_path = "./data/batch_i_filename_map_bs{}.json".format(batch_size)
with open(batch_i_filename_map_path, 'r', encoding='utf-8') as f:
batch_i_filename_map = json.load(f)
vocab_path = "./wav2vec2_pytorch_model/vocab.json"
index_token_map = load_vocab_map(vocab_path)
infered_result_dict = post_process(base_dir, index_token_map, batch_i_filename_map)
ground_truth_texts_path = "./data/ground_truth_texts.txt"
ground_truth_dict = get_id_text_dict(ground_truth_texts_path)
wer = eval_accuracy(ground_truth_dict, infered_result_dict)
print(f"Err: {wer}")
print(f"Acc: {1 - wer}")