Yyuekaizfix lint
33aee03e创建于 2025年10月9日历史提交
# Copyright (c)  2023 by manyeyes
# Copyright (c)  2023  Xiaomi Corporation

"""
This file demonstrates how to use sherpa-onnx Python API to transcribe
file(s) with a non-streaming model.

(1) For paraformer

    ./python-api-examples/offline-decode-files.py  \
      --tokens=/path/to/tokens.txt \
      --paraformer=/path/to/paraformer.onnx \
      --num-threads=2 \
      --decoding-method=greedy_search \
      --debug=false \
      --sample-rate=16000 \
      --feature-dim=80 \
      /path/to/0.wav \
      /path/to/1.wav

(2) For transducer models from icefall

    ./python-api-examples/offline-decode-files.py  \
      --tokens=/path/to/tokens.txt \
      --encoder=/path/to/encoder.onnx \
      --decoder=/path/to/decoder.onnx \
      --joiner=/path/to/joiner.onnx \
      --num-threads=2 \
      --decoding-method=greedy_search \
      --debug=false \
      --sample-rate=16000 \
      --feature-dim=80 \
      /path/to/0.wav \
      /path/to/1.wav

(3) For CTC models from NeMo

python3 ./python-api-examples/offline-decode-files.py \
  --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
  --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
  --num-threads=2 \
  --decoding-method=greedy_search \
  --debug=false \
  ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
  ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
  ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav

(4) For Whisper models

python3 ./python-api-examples/offline-decode-files.py \
  --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
  --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
  --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
  --whisper-task=transcribe \
  --num-threads=1 \
  ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
  ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
  ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav

(5) For CTC models from WeNet

python3 ./python-api-examples/offline-decode-files.py \
  --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
  --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
  ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
  ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
  ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav

(6) For tdnn models of the yesno recipe from icefall

python3 ./python-api-examples/offline-decode-files.py \
  --sample-rate=8000 \
  --feature-dim=23 \
  --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
  --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
  ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
  ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
  ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav

Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download non-streaming pre-trained models
used in this file.
"""
import argparse
import time
import wave
from pathlib import Path
from typing import List, Tuple, Dict, Iterable, TextIO, Union

import numpy as np
import sherpa_onnx
import soundfile as sf
from datasets import load_dataset
import logging
from collections import defaultdict
import kaldialign
from zhon.hanzi import punctuation
import string
punctuation_all = punctuation + string.punctuation
Pathlike = Union[str, Path]


def remove_punctuation(text: str) -> str:
    for x in punctuation_all:
        if x == '\'':
            continue
        text = text.replace(x, '')
    return text


def store_transcripts(
    filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
) -> None:
    """Save predicted results and reference transcripts to a file.

    Args:
      filename:
        File to save the results to.
      texts:
        An iterable of tuples. The first element is the cur_id, the second is
        the reference transcript and the third element is the predicted result.
        If it is a multi-talker ASR system, the ref and hyp may also be lists of
        strings.
    Returns:
      Return None.
    """
    with open(filename, "w", encoding="utf8") as f:
        for cut_id, ref, hyp in texts:
            if char_level:
                ref = list("".join(ref))
                hyp = list("".join(hyp))
            print(f"{cut_id}:\tref={ref}", file=f)
            print(f"{cut_id}:\thyp={hyp}", file=f)


def write_error_stats(
    f: TextIO,
    test_set_name: str,
    results: List[Tuple[str, str]],
    enable_log: bool = True,
    compute_CER: bool = False,
    sclite_mode: bool = False,
) -> float:
    """Write statistics based on predicted results and reference transcripts.

    It will write the following to the given file:

        - WER
        - number of insertions, deletions, substitutions, corrects and total
          reference words. For example::

              Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
              reference words (2337 correct)

        - The difference between the reference transcript and predicted result.
          An instance is given below::

            THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES

          The above example shows that the reference word is `EDISON`,
          but it is predicted to `ADDISON` (a substitution error).

          Another example is::

            FOR THE FIRST DAY (SIR->*) I THINK

          The reference word `SIR` is missing in the predicted
          results (a deletion error).
      results:
        An iterable of tuples. The first element is the cut_id, the second is
        the reference transcript and the third element is the predicted result.
      enable_log:
        If True, also print detailed WER to the console.
        Otherwise, it is written only to the given file.
    Returns:
      Return None.
    """
    subs: Dict[Tuple[str, str], int] = defaultdict(int)
    ins: Dict[str, int] = defaultdict(int)
    dels: Dict[str, int] = defaultdict(int)

    # `words` stores counts per word, as follows:
    #   corr, ref_sub, hyp_sub, ins, dels
    words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
    num_corr = 0
    ERR = "*"

    if compute_CER:
        for i, res in enumerate(results):
            cut_id, ref, hyp = res
            ref = list("".join(ref))
            hyp = list("".join(hyp))
            results[i] = (cut_id, ref, hyp)

    for _cut_id, ref, hyp in results:
        ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
        for ref_word, hyp_word in ali:
            if ref_word == ERR:
                ins[hyp_word] += 1
                words[hyp_word][3] += 1
            elif hyp_word == ERR:
                dels[ref_word] += 1
                words[ref_word][4] += 1
            elif hyp_word != ref_word:
                subs[(ref_word, hyp_word)] += 1
                words[ref_word][1] += 1
                words[hyp_word][2] += 1
            else:
                words[ref_word][0] += 1
                num_corr += 1
    ref_len = sum([len(r) for _, r, _ in results])
    sub_errs = sum(subs.values())
    ins_errs = sum(ins.values())
    del_errs = sum(dels.values())
    tot_errs = sub_errs + ins_errs + del_errs
    tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)

    if enable_log:
        logging.info(
            f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
            f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
            f"{del_errs} del, {sub_errs} sub ]"
        )

    print(f"%WER = {tot_err_rate}", file=f)
    print(
        f"Errors: {ins_errs} insertions, {del_errs} deletions, "
        f"{sub_errs} substitutions, over {ref_len} reference "
        f"words ({num_corr} correct)",
        file=f,
    )
    print(
        "Search below for sections starting with PER-UTT DETAILS:, "
        "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
        file=f,
    )

    print("", file=f)
    print("PER-UTT DETAILS: corr or (ref->hyp)  ", file=f)
    for cut_id, ref, hyp in results:
        ali = kaldialign.align(ref, hyp, ERR)
        combine_successive_errors = True
        if combine_successive_errors:
            ali = [[[x], [y]] for x, y in ali]
            for i in range(len(ali) - 1):
                if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
                    ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
                    ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
                    ali[i] = [[], []]
            ali = [
                [
                    list(filter(lambda a: a != ERR, x)),
                    list(filter(lambda a: a != ERR, y)),
                ]
                for x, y in ali
            ]
            ali = list(filter(lambda x: x != [[], []], ali))
            ali = [
                [
                    ERR if x == [] else " ".join(x),
                    ERR if y == [] else " ".join(y),
                ]
                for x, y in ali
            ]

        print(
            f"{cut_id}:\t"
            + " ".join(
                (
                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                    for ref_word, hyp_word in ali
                )
            ),
            file=f,
        )

    print("", file=f)
    print("SUBSTITUTIONS: count ref -> hyp", file=f)

    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
        print(f"{count}   {ref} -> {hyp}", file=f)

    print("", file=f)
    print("DELETIONS: count ref", file=f)
    for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
        print(f"{count}   {ref}", file=f)

    print("", file=f)
    print("INSERTIONS: count hyp", file=f)
    for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
        print(f"{count}   {hyp}", file=f)

    print("", file=f)
    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
    for _, word, counts in sorted(
        [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
    ):
        (corr, ref_sub, hyp_sub, ins, dels) = counts
        tot_errs = ref_sub + hyp_sub + ins + dels
        ref_count = corr + ref_sub + dels
        hyp_count = corr + hyp_sub + ins

        print(f"{word}   {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
    return float(tot_err_rate)


def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument(
        "--tokens",
        type=str,
        help="Path to tokens.txt",
    )

    parser.add_argument(
        "--hotwords-file",
        type=str,
        default="",
        help="""
        The file containing hotwords, one words/phrases per line, like
        HELLO WORLD
        你好世界
        """,
    )

    parser.add_argument(
        "--hotwords-score",
        type=float,
        default=1.5,
        help="""
        The hotword score of each token for biasing word/phrase. Used only if
        --hotwords-file is given.
        """,
    )

    parser.add_argument(
        "--modeling-unit",
        type=str,
        default="",
        help="""
        The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
        Used only when hotwords-file is given.
        """,
    )

    parser.add_argument(
        "--bpe-vocab",
        type=str,
        default="",
        help="""
        The path to the bpe vocabulary, the bpe vocabulary is generated by
        sentencepiece, you can also export the bpe vocabulary through a bpe model
        by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
        and modeling-unit is bpe or cjkchar+bpe.
        """,
    )

    parser.add_argument(
        "--encoder",
        default="",
        type=str,
        help="Path to the encoder model",
    )

    parser.add_argument(
        "--decoder",
        default="",
        type=str,
        help="Path to the decoder model",
    )

    parser.add_argument(
        "--joiner",
        default="",
        type=str,
        help="Path to the joiner model",
    )

    parser.add_argument(
        "--paraformer",
        default="",
        type=str,
        help="Path to the model.onnx from Paraformer",
    )

    parser.add_argument(
        "--nemo-ctc",
        default="",
        type=str,
        help="Path to the model.onnx from NeMo CTC",
    )

    parser.add_argument(
        "--wenet-ctc",
        default="",
        type=str,
        help="Path to the model.onnx from WeNet CTC",
    )

    parser.add_argument(
        "--tdnn-model",
        default="",
        type=str,
        help="Path to the model.onnx for the tdnn model of the yesno recipe",
    )

    parser.add_argument(
        "--num-threads",
        type=int,
        default=1,
        help="Number of threads for neural network computation",
    )

    parser.add_argument(
        "--whisper-encoder",
        default="",
        type=str,
        help="Path to whisper encoder model",
    )

    parser.add_argument(
        "--whisper-decoder",
        default="",
        type=str,
        help="Path to whisper decoder model",
    )

    parser.add_argument(
        "--whisper-language",
        default="",
        type=str,
        help="""It specifies the spoken language in the input audio file.
        Example values: en, fr, de, zh, jp.
        Available languages for multilingual models can be found at
        https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
        If not specified, we infer the language from the input audio file.
        """,
    )

    parser.add_argument(
        "--whisper-task",
        default="transcribe",
        choices=["transcribe", "translate"],
        type=str,
        help="""For multilingual models, if you specify translate, the output
        will be in English.
        """,
    )

    parser.add_argument(
        "--whisper-tail-paddings",
        default=-1,
        type=int,
        help="""Number of tail padding frames.
        We have removed the 30-second constraint from whisper, so you need to
        choose the amount of tail padding frames by yourself.
        Use -1 to use a default value for tail padding.
        """,
    )

    parser.add_argument(
        "--blank-penalty",
        type=float,
        default=0.0,
        help="""
        The penalty applied on blank symbol during decoding.
        Note: It is a positive value that would be applied to logits like
        this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
        [batch_size, vocab] and blank id is 0).
        """,
    )

    parser.add_argument(
        "--decoding-method",
        type=str,
        default="greedy_search",
        help="Valid values are greedy_search and modified_beam_search",
    )
    parser.add_argument(
        "--debug",
        type=bool,
        default=False,
        help="True to show debug messages",
    )

    parser.add_argument(
        "--sample-rate",
        type=int,
        default=16000,
        help="""Sample rate of the feature extractor. Must match the one
        expected  by the model. Note: The input sound files can have a
        different sample rate from this argument.""",
    )

    parser.add_argument(
        "--feature-dim",
        type=int,
        default=80,
        help="Feature dimension. Must match the one expected by the model",
    )

    parser.add_argument(
        "sound_files",
        type=str,
        nargs="+",
        help="The input sound file(s) to decode. Each file must be of WAVE"
        "format with a single channel, and each sample has 16-bit, "
        "i.e., int16_t. "
        "The sample rate of the file can be arbitrary and does not need to "
        "be 16 kHz",
    )

    parser.add_argument(
        "--name",
        type=str,
        default="",
        help="The directory containing the input sound files to decode",
    )

    parser.add_argument(
        "--log-dir",
        type=str,
        default="",
        help="The directory containing the input sound files to decode",
    )

    parser.add_argument(
        "--label",
        type=str,
        default=None,
        help="wav_base_name label",
    )

    # Dataset related arguments for loading labels when label file is not provided
    parser.add_argument(
        "--dataset-name",
        type=str,
        default="yuekai/seed_tts_cosy2",
        help="Huggingface dataset name for loading labels",
    )

    parser.add_argument(
        "--split-name",
        type=str,
        default="wenetspeech4tts",
        help="Dataset split name for loading labels",
    )

    return parser.parse_args()


def assert_file_exists(filename: str):
    assert Path(filename).is_file(), (
        f"{filename} does not exist!\n"
        "Please refer to "
        "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
    )


def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
    """
    Args:
      wave_filename:
        Path to a wave file. It should be single channel and can be of type
        32-bit floating point PCM. Its sample rate does not need to be 24kHz.

    Returns:
      Return a tuple containing:
       - A 1-D array of dtype np.float32 containing the samples,
         which are normalized to the range [-1, 1].
       - Sample rate of the wave file.
    """

    samples, sample_rate = sf.read(wave_filename, dtype="float32")
    assert (
        samples.ndim == 1
    ), f"Expected single channel, but got {samples.ndim} channels."

    samples_float32 = samples.astype(np.float32)

    return samples_float32, sample_rate


def normalize_text_alimeeting(text: str) -> str:
    """
    Text normalization similar to M2MeT challenge baseline.
    See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
    """
    import re
    text = text.replace('\u00A0', '')  # test_hard
    text = text.replace(" ", "")
    text = text.replace("<sil>", "")
    text = text.replace("<%>", "")
    text = text.replace("<->", "")
    text = text.replace("<$>", "")
    text = text.replace("<#>", "")
    text = text.replace("<_>", "")
    text = text.replace("<space>", "")
    text = text.replace("`", "")
    text = text.replace("&", "")
    text = text.replace(",", "")
    if re.search("[a-zA-Z]", text):
        text = text.upper()
    text = text.replace("A", "A")
    text = text.replace("a", "A")
    text = text.replace("b", "B")
    text = text.replace("c", "C")
    text = text.replace("k", "K")
    text = text.replace("t", "T")
    text = text.replace(",", "")
    text = text.replace("丶", "")
    text = text.replace("。", "")
    text = text.replace("、", "")
    text = text.replace("?", "")
    text = remove_punctuation(text)
    return text


def main():
    args = get_args()
    assert_file_exists(args.tokens)
    assert args.num_threads > 0, args.num_threads

    assert len(args.nemo_ctc) == 0, args.nemo_ctc
    assert len(args.wenet_ctc) == 0, args.wenet_ctc
    assert len(args.whisper_encoder) == 0, args.whisper_encoder
    assert len(args.whisper_decoder) == 0, args.whisper_decoder
    assert len(args.tdnn_model) == 0, args.tdnn_model

    assert_file_exists(args.paraformer)

    recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
        paraformer=args.paraformer,
        tokens=args.tokens,
        num_threads=args.num_threads,
        sample_rate=args.sample_rate,
        feature_dim=args.feature_dim,
        decoding_method=args.decoding_method,
        debug=args.debug,
    )

    print("Started!")
    start_time = time.time()

    streams, results = [], []
    total_duration = 0

    for i, wave_filename in enumerate(args.sound_files):
        assert_file_exists(wave_filename)
        samples, sample_rate = read_wave(wave_filename)
        duration = len(samples) / sample_rate
        total_duration += duration
        s = recognizer.create_stream()
        s.accept_waveform(sample_rate, samples)

        streams.append(s)
        if i % 10 == 0:
            recognizer.decode_streams(streams)
            results += [s.result.text for s in streams]
            streams = []
            print(f"Processed {i} files")
        # process the last batch
    if streams:
        recognizer.decode_streams(streams)
        results += [s.result.text for s in streams]
    end_time = time.time()
    print("Done!")

    results_dict = {}
    for wave_filename, result in zip(args.sound_files, results):
        print(f"{wave_filename}\n{result}")
        print("-" * 10)
        wave_basename = Path(wave_filename).stem
        results_dict[wave_basename] = result

    elapsed_seconds = end_time - start_time
    rtf = elapsed_seconds / total_duration
    print(f"num_threads: {args.num_threads}")
    print(f"decoding_method: {args.decoding_method}")
    print(f"Wave duration: {total_duration:.3f} s")
    print(f"Elapsed time: {elapsed_seconds:.3f} s")
    print(
        f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
    )

    # Load labels either from file or from dataset
    labels_dict = {}

    if args.label:
        # Load labels from file (original functionality)
        print(f"Loading labels from file: {args.label}")
        with open(args.label, "r") as f:
            for line in f:
                # fields = line.strip().split(" ")
                # fields = [item for item in fields if item]
                # assert len(fields) == 4
                # prompt_text, prompt_audio, text, audio_path = fields

                fields = line.strip().split("|")
                fields = [item for item in fields if item]
                assert len(fields) == 4
                audio_path, prompt_text, prompt_audio, text = fields
                labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
    else:
        # Load labels from dataset (new functionality)
        print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
        if 'zero' in args.split_name:
            dataset_name = "yuekai/CV3-Eval"
        else:
            dataset_name = "yuekai/seed_tts_cosy2"
        dataset = load_dataset(
            dataset_name,
            split=args.split_name,
            trust_remote_code=True,
        )

        for item in dataset:
            audio_id = item["id"]
            labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])

        print(f"Loaded {len(labels_dict)} labels from dataset")

    # Perform evaluation if labels are available
    if labels_dict:

        final_results = []
        for key, value in results_dict.items():
            if key in labels_dict:
                final_results.append((key, labels_dict[key], value))
            else:
                print(f"Warning: No label found for {key}, skipping...")

        if final_results:
            store_transcripts(
                filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
            )
            with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
                write_error_stats(f, "test-set", final_results, enable_log=True)

            with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
                print(f.readline())  # WER
                print(f.readline())  # Detailed errors
        else:
            print("No matching labels found for evaluation")
    else:
        print("No labels available for evaluation")


if __name__ == "__main__":
    main()