import argparse
import json
import torch
from apex import amp
"""
follow instructions here if you want to import kaldi_io
1. Install kaldi
2. cd tools
3. make clean; make KALDI=/path/to/kaldi
4. cd test
5. source path.sh
"""
import kaldi_io
from sp_transformer import Transformer
from utils import add_results_to_json, process_dict, pad_list
from data import build_LFR_features
def parse_args():
parser = argparse.ArgumentParser(description='Online Inferenece Demo')
parser.add_argument('--model_path',
default='./test/output/final.pth.tar',
type=str)
parser.add_argument('--sp_dict',
default='./test/data/lang_1char/train_chars.txt',
type=str)
parser.add_argument('--recog_json', default='./test/dump/test/deltafalse/data.json', type=str)
parser.add_argument('--device', default='npu', help='device type')
parser.add_argument('--result_label', default='./data.json')
args = parser.parse_args()
return args
def recognize():
args = parse_args()
model, LFR_m, LFR_n = Transformer.load_model(args.model_path)
model.eval()
model.npu()
char_list, sos_id, eos_id = process_dict(args.sp_dict)
assert model.decoder.sos_id == sos_id and model.decoder.eos_id == eos_id
model= amp.initialize(model, opt_level="O2", loss_scale=128.0, combine_grad=True)
with open(args.recog_json, 'rb') as f:
js = json.load(f)['utts']
new_js = {}
with torch.no_grad():
for idx, name in enumerate(js.keys(), 1):
print('(%d/%d) decoding %s' %
(idx, len(js.keys()), name), flush=True)
input = kaldi_io.read_mat(js[name]['input'][0]['feat'])
input = build_LFR_features(input, LFR_m, LFR_n)
input = torch.from_numpy(input).float()
input_length = torch.tensor([input.size(0)], dtype=torch.int)
input = input.unsqueeze(0)
input = pad_list(input, 0, max_len = 512)
input = input.npu()
input_length = input_length.npu()
nbest_hyps = model(input, input_length)
index = torch.where(nbest_hyps[0]['yseq'][0] == 2)[0]
nbest_hyps[0]['yseq'] = nbest_hyps[0]['yseq'][:, :index + 1][0].cpu().numpy().tolist()
new_js[name] = add_results_to_json(js[name], nbest_hyps, char_list)
with open(args.result_label, 'wb') as f:
f.write(json.dumps({'utts': new_js}, indent=4,
sort_keys=True).encode('utf_8'))
if __name__ == "__main__":
recognize()