import argparse
import time
import numpy as np
import torch
import torchaudio
import torch_npu
from torch_npu.contrib import transfer_to_npu
from ais_bench.infer.interface import InferSession
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
class SenseVoiceOnnxModel():
def __init__(self):
super().__init__()
self.blank_id = 0
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
self.textnorm_dict = {'withitn': 14, "woitn": 15}
def infer_onnx(
self,
data_in,
om_sess,
tokenizer=None,
frontend=None,
**kwargs,
):
key = ["wav_file_tmp_name"]
use_itn = kwargs.get('use_itn', False)
audio_sample_list = load_audio_text_image_video(
data_in,
fs=frontend.fs,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
speech, speech_lengths = extract_fbank(
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
)
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
language = kwargs.get("language", "auto")
language = torch.LongTensor([self.lid_dict[language] if language in self.lid_dict else 0]).to(speech.device)
textnorm = kwargs.get("text_norm", None)
if textnorm is None:
textnorm = "withitn" if use_itn else "woitn"
textnorm = torch.LongTensor([self.textnorm_dict[textnorm]]).to(speech.device)
s = time.time()
feed = [speech.cpu().detach().numpy().astype(np.float32),
speech_lengths.cpu().detach().numpy().astype(np.int32),
language.cpu().detach().numpy().astype(np.int32),
textnorm.cpu().detach().numpy().astype(np.int32)]
ctc_logits, encoder_out_lens = om_sess.infer(feed, mode='dymshape', custom_sizes=10000000)
ctc_logits = torch.from_numpy(ctc_logits).npu()
encoder_out_lens = torch.from_numpy(encoder_out_lens).npu()
e = time.time()
cost_time = e - s
results = []
x = ctc_logits[0, : encoder_out_lens[0].item(), :]
yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1)
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()
text = tokenizer.decode(token_int)
result_i = {"key": key[0], "text": text}
results.append(result_i)
return results, cost_time
if __name__ == '__main__':
torch_npu.npu.set_compile_mode(jit_compile=False)
parser = argparse.ArgumentParser(description="Sensevoice infer")
parser.add_argument("--model_path", type=str, help="modelpath")
parser.add_argument('--om_path', type=str, help='om model')
parser.add_argument('--device', type=int, help='npu device num')
parser.add_argument('--input', type=str, help='input audio file')
parser.add_argument('--perform', type=bool, help='test performance')
parser.add_argument('--loop', default=10, type=int, help='loop time')
args = parser.parse_args()
_, kwargs = AutoModel.build_model(model=args.model_path, trust_remote_code=True)
m = SenseVoiceOnnxModel()
om_sess = InferSession(args.device, args.om_path)
with torch.no_grad():
res, _ = m.infer_onnx(
data_in=args.input,
om_sess=om_sess,
language="auto",
use_itn=False,
ban_emo_unk=False,
**kwargs,
)
text = rich_transcription_postprocess(res[0]['text'])
print('语音输出:')
print(text)
if args.perform:
t = 0
for _ in range(args.loop):
res, cost_time = m.infer_onnx(
data_in=args.input,
om_sess=om_sess,
language="auto",
use_itn=False,
ban_emo_unk=False,
**kwargs,
)
t += cost_time
print('单条数据推理耗时:')
print(str(t / args.loop))