import argparse
import re
import time
from itertools import groupby
import numpy as np
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
import jiwer
from datasets import load_dataset
from ais_bench.infer.interface import InferSession
from funasr import AutoModel
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.models.ctc.ctc import CTC
from funasr.models.sense_voice.utils.ctc_alignment import ctc_forced_align
class SenseVoiceOnnxModel():
def __init__(self, device_id, om_path, model, vad_model):
super().__init__()
self.device = f"npu:{device_id}"
_, self.kwargs = AutoModel.build_model(model=model, trust_remote_code=True)
self.vad_model, self.vad_kwargs = AutoModel.build_model(model=vad_model, trust_remote_code=True)
encoder_output_size = self.kwargs.get("encoder_conf", {}).get("output_size", 256)
self.blank_id = 0
self.vocab_size = self.kwargs.get("vocab_size", -1)
self.frontend = self.kwargs.get("frontend", None)
self.tokenizer = self.kwargs.get("tokenizer", None)
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
self.textnorm_dict = {'withitn': 14, "woitn": 15}
self.om_sess = InferSession(device_id, om_path)
self.ignore_id = -1
ctc_conf = {}
self.ctc = CTC(odim=self.vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
self.ctc.ctc_lo = self.ctc.ctc_lo.to(device=self.device)
def is_valid_word(self, word):
return word.isalpha() and word.isascii()
def post_process(self, timestamp):
timestamp_new = []
prev_word = None
for i, t in enumerate(timestamp):
word, start, end = t
start = int(start * 1000)
end = int(end * 1000)
if word == "▁":
continue
if i == 0:
timestamp_new.append([start, end, word])
elif word.startswith("▁"):
word = word[1:]
timestamp_new.append([start, end, word])
elif prev_word is not None and self.is_valid_word(prev_word) and self.is_valid_word(word):
prev_word += word
timestamp_new[-1][1] = end
timestamp_new[-1][2] += word
else:
timestamp_new.append([start, end, word])
prev_word = word
return timestamp_new
def sense_voice_infer(self, feed, vad_res_list, output_timestamp=False):
custom_sizes = (feed[0].shape[1] + 4) * 4 * self.vocab_size
ctc_logits, encoder_out_lens, encoder_out = self.om_sess.infer(feed, mode='dymshape', custom_sizes=custom_sizes)
ctc_logits = torch.from_numpy(ctc_logits).to(device=self.device)
encoder_out_lens = torch.from_numpy(encoder_out_lens).to(device=self.device)
encoder_out = torch.from_numpy(encoder_out).to(device=self.device)
x = ctc_logits[0, : encoder_out_lens[0].item(), :]
yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1)
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()
text = self.tokenizer.decode(token_int)
if not output_timestamp:
return {'text': text}
timestamp = []
tokens = self.tokenizer.text2tokens(text)[4:]
token_back_to_id = self.tokenizer.tokens2ids(tokens)
token_ids = []
for tok_ls in token_back_to_id:
if tok_ls:
token_ids.extend(tok_ls)
else:
token_ids.append(124)
if len(token_ids) == 0:
return {'text': text}
logits_speech = self.ctc.softmax(encoder_out)[0, 4: encoder_out_lens[0].item(), :]
pred = logits_speech.argmax(-1).cpu()
logits_speech[pred == self.blank_id, self.blank_id] = 0
align = ctc_forced_align(
logits_speech.unsqueeze(0).float().cpu(),
torch.Tensor(token_ids).unsqueeze(0).long(),
(encoder_out_lens[0] - 4).long().cpu(),
torch.tensor(len(token_ids)).unsqueeze(0).long(),
ignore_id=self.ignore_id,
)
pred = groupby(align[0, : encoder_out_lens[0]])
vad_offset = 30
_start = (vad_res_list[0] + vad_offset) / 60
token_id = 0
ts_max = (vad_res_list[1] + vad_offset) / 60
for pred_token, pred_frame in pred:
_end = _start + len(list(pred_frame))
if pred_token != 0:
ts_left = max((_start * 60 - vad_offset) / 1000, 0)
ts_right = min((_end * 60 - vad_offset) / 1000, (ts_max * 60 - vad_offset) / 1000)
timestamp.append([tokens[token_id], ts_left, ts_right])
token_id += 1
_start = _end
timestamp = self.post_process(timestamp)
return {'text': text, 'timestamp': timestamp}
def infer(self, data_in, output_timestamp):
start_time = time.time()
audio_sample_list = load_audio_text_image_video(
data_in,
fs=self.frontend.fs,
audio_fs=self.kwargs.get("fs", 16000),
data_type=self.kwargs.get("data_type", "sound"),
tokenizer=self.tokenizer
)
speech_list = []
speech_lengths_list = []
results, meta_data = self.vad_model.inference([data_in], key=['test'], **self.vad_kwargs)
vad_res_list = results[0]['value']
for start, end in vad_res_list:
bed_idx = int(start * 16)
end_idx = min(int(end * 16), len(audio_sample_list))
sub_audio_sample_list = audio_sample_list[bed_idx:end_idx]
speech, speech_lengths = extract_fbank(
sub_audio_sample_list, data_type=self.kwargs.get("data_type", "sound"), frontend=self.frontend
)
speech = speech.to(device=self.device)
speech_lengths = speech_lengths.to(device=self.device)
speech_list.append(speech)
speech_lengths_list.append(speech_lengths)
language = self.kwargs.get("language", "auto")
language = torch.LongTensor([self.lid_dict[language] if language in self.lid_dict else 0]).to(device=self.device)
use_itn = self.kwargs.get('use_itn', True)
textnorm = self.kwargs.get("text_norm", None)
if textnorm is None:
textnorm = "withitn" if use_itn else "woitn"
textnorm = torch.LongTensor([self.textnorm_dict.get(textnorm, 0)]).to(device=self.device)
results = {"key": data_in, "text": "", "timestamp": []}
for speech, speech_lengths, vad_res in zip(speech_list, speech_lengths_list, vad_res_list):
feed = [speech.cpu().detach().numpy().astype(np.float32),
speech_lengths.cpu().detach().numpy().astype(np.int32),
language.cpu().detach().numpy().astype(np.int32),
textnorm.cpu().detach().numpy().astype(np.int32)]
result_i = self.sense_voice_infer(
feed=feed,
vad_res_list=vad_res,
output_timestamp=output_timestamp)
results['text'] += result_i.get('text', '')
results['timestamp'] += result_i.get('timestamp', [])
e2e_time = time.time() - start_time
print(f'infer E2E time = {e2e_time * 1000:.2f} ms')
results['e2e_time'] = e2e_time
return results
def process_output(text):
text = re.sub(r'<\|.*?\|>', ' ', text)
text = re.sub(r'[^a-zA-Z ]', '', text)
text = text.upper()
return text
def benchmark(data_path, output_timestamp, model):
hypotheses = []
references = []
total_e2e_time = 0
total_audio_time = 0
dataset = load_dataset("parquet", data_files=data_path).get('train', [])
if not dataset or dataset[0].get('audio', {}).get('array', []) is None:
print("benchmark fail, dataset is invalid")
return
res = model.infer(dataset[0]['audio']['array'], False)
for data in dataset:
array = data.get('audio', {}).get('array', [])
if array is None:
print("benchmark fail, dataset is invalid")
return
sampling_rate = data.get('audio', {}).get('sampling_rate', 16000)
text = data.get('text', '')
res = model.infer(array, output_timestamp)
references.append(text)
hypotheses.append(process_output(res['text']))
total_e2e_time += res['e2e_time']
total_audio_time += len(array) / sampling_rate
print("Transcription Rate:", total_audio_time / total_e2e_time)
wer = jiwer.wer(references, hypotheses)
print("WER:", wer)
if __name__ == '__main__':
torch_npu.npu.set_compile_mode(jit_compile=False)
parser = argparse.ArgumentParser(description="Sensevoice infer")
parser.add_argument('--vad_path', type=str, help='vad path')
parser.add_argument('--model_path', type=str, help='model path')
parser.add_argument('--om_path', type=str, help='om model')
parser.add_argument('--device', type=int, help='device', default=0)
parser.add_argument('--input', type=str, help='dataset path')
parser.add_argument('--output_timestamp', action='store_true')
args = parser.parse_args()
model = SenseVoiceOnnxModel(args.device, args.om_path, args.model_path, args.vad_path)
with torch.no_grad():
benchmark(args.input, args.output_timestamp, model)