import argparse
import time
import torch
import torch_npu
from torchair_auto_model import TorchairAutoModel
def parse_args():
parser = argparse.ArgumentParser(description="Paraformer Inference")
parser.add_argument(
"--model_path",
type=str,
default="./speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="Path of pretrained paraformer model"
)
parser.add_argument(
"--model_vad",
type=str,
default=None,
help="Path of pretrained vad model"
)
parser.add_argument(
"--model_punc",
type=str,
default=None,
help="Path of pretrained vad model"
)
parser.add_argument(
"--batch_size",
type=int,
default=1
)
parser.add_argument(
'--data',
type=str,
default="speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav",
help='Path of wav file'
)
parser.add_argument(
'--hotwords',
type=str,
default=None,
help='Hotword.'
)
parser.add_argument("--threshold", type=float, default=0.98, help="threshold for cif predictor")
parser.add_argument('--warmup', type=int, default=3, help="Warm up times")
parser.add_argument('--device', type=int, default=0, help='Npu Device Id')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
device = 'npu:{}'.format(args.device)
model = TorchairAutoModel(model=args.model_path,
vad_model=args.model_vad,
punc_model=args.model_punc,
batch_size=args.batch_size,
cif_threshold=args.threshold,
device=device)
if model.kwargs['model'] == "BiCifParaformer":
if args.hotwords is not None:
print(f"Using paraformer long context version, hotwords will not be used")
else:
if args.hotwords is None:
print(f"Using paraformer hotword version, but hotword is not specified, please check your input args")
exit()
with torch.inference_mode():
print("Begin warming up.")
for _ in range(args.warmup):
_ = model.generate(input_data=args.data, hotword=args.hotwords)
print("Finish warming up")
print("Begin inference")
results, time_stats = model.generate(input_data=args.data, hotword=args.hotwords)
for res in results:
print(f"transcription result: {res['text']}")