8f0b28b8创建于 3月16日历史提交
# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)
#                2023  Nvidia              (authors: Yuekai Zhang)
#                2023  Recurrent.ai        (authors: Songtao Shi)
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script supports to load dataset from huggingface and sends it to the server
for decoding, in parallel.

Usage:
num_task=2

# For offline F5-TTS
python3 client_grpc.py \
    --server-addr localhost \
    --model-name f5_tts \
    --num-tasks $num_task \
    --huggingface-dataset yuekai/seed_tts \
    --split-name test_zh \
    --log-dir ./log_concurrent_tasks_${num_task}

# For offline Spark-TTS-0.5B
python3 client_grpc.py \
    --server-addr localhost \
    --model-name spark_tts \
    --num-tasks $num_task \
    --huggingface-dataset yuekai/seed_tts \
    --split-name wenetspeech4tts \
    --log-dir ./log_concurrent_tasks_${num_task}
"""

import argparse
import asyncio
import json
import queue
import uuid
import functools

import os
import time
import types
from pathlib import Path

import numpy as np
import soundfile as sf
import tritonclient
import tritonclient.grpc.aio as grpcclient_aio
import tritonclient.grpc as grpcclient_sync
from tritonclient.utils import np_to_triton_dtype, InferenceServerException


class UserData:
    def __init__(self):
        self._completed_requests = queue.Queue()
        self._first_chunk_time = None
        self._second_chunk_time = None
        self._start_time = None

    def record_start_time(self):
        self._start_time = time.time()

    def get_first_chunk_latency(self):
        if self._first_chunk_time and self._start_time:
            return self._first_chunk_time - self._start_time
        return None

    def get_second_chunk_latency(self):
        if self._first_chunk_time and self._second_chunk_time:
            return self._second_chunk_time - self._first_chunk_time
        return None


def callback(user_data, result, error):
    if not error:
        if user_data._first_chunk_time is None:
            user_data._first_chunk_time = time.time()
        elif user_data._second_chunk_time is None:
            user_data._second_chunk_time = time.time()

    if error:
        user_data._completed_requests.put(error)
    else:
        user_data._completed_requests.put(result)


def stream_callback(user_data_map, result, error):
    request_id = None
    if error:
        print(f"An error occurred in the stream callback: {error}")
    else:
        request_id = result.get_response().id

    if request_id:
        user_data = user_data_map.get(request_id)
        if user_data:
            callback(user_data, result, error)
        else:
            print(f"Warning: Could not find user_data for request_id {request_id}")


def write_triton_stats(stats, summary_file):
    with open(summary_file, "w") as summary_f:
        model_stats = stats["model_stats"]
        for model_state in model_stats:
            if "last_inference" not in model_state:
                continue
            summary_f.write(f"model name is {model_state['name']} \n")
            model_inference_stats = model_state["inference_stats"]
            total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
            total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
            total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
            total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
            summary_f.write(
                f"queue time {total_queue_time_s:<5.2f} s, "
                f"compute infer time {total_infer_time_s:<5.2f} s, "
                f"compute input time {total_input_time_s:<5.2f} s, "
                f"compute output time {total_output_time_s:<5.2f} s \n"
            )
            model_batch_stats = model_state["batch_stats"]
            for batch in model_batch_stats:
                batch_size = int(batch["batch_size"])
                compute_input = batch["compute_input"]
                compute_output = batch["compute_output"]
                compute_infer = batch["compute_infer"]
                batch_count = int(compute_infer["count"])
                if batch_count == 0:
                    continue
                assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
                compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
                compute_input_time_ms = int(compute_input["ns"]) / 1e6
                compute_output_time_ms = int(compute_output["ns"]) / 1e6
                summary_f.write(
                    f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, "
                    f"total_infer_time {compute_infer_time_ms:<9.2f} ms, "
                    f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}="
                    f"{compute_infer_time_ms / batch_count:.2f} ms, "
                    f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}="
                    f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
                )
                summary_f.write(
                    f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
                )
                summary_f.write(
                    f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"
                )


def subtract_stats(stats_after, stats_before):
    """Subtracts two Triton inference statistics objects."""
    stats_diff = json.loads(json.dumps(stats_after))

    model_stats_before_map = {
        s["name"]: {
            "version": s["version"],
            "last_inference": s.get("last_inference", 0),
            "inference_count": s.get("inference_count", 0),
            "execution_count": s.get("execution_count", 0),
            "inference_stats": s.get("inference_stats", {}),
            "batch_stats": s.get("batch_stats", []),
        }
        for s in stats_before["model_stats"]
    }

    for model_stat_after in stats_diff["model_stats"]:
        model_name = model_stat_after["name"]
        if model_name in model_stats_before_map:
            model_stat_before = model_stats_before_map[model_name]

            model_stat_after["inference_count"] = str(
                int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
            )
            model_stat_after["execution_count"] = str(
                int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
            )

            if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
                for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
                    if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
                        if "ns" in model_stat_after["inference_stats"][key]:
                            ns_after = int(model_stat_after["inference_stats"][key]["ns"])
                            ns_before = int(model_stat_before["inference_stats"][key]["ns"])
                            model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
                        if "count" in model_stat_after["inference_stats"][key]:
                            count_after = int(model_stat_after["inference_stats"][key]["count"])
                            count_before = int(model_stat_before["inference_stats"][key]["count"])
                            model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)

            if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
                batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
                for batch_stat_after in model_stat_after["batch_stats"]:
                    bs = batch_stat_after["batch_size"]
                    if bs in batch_stats_before_map:
                        batch_stat_before = batch_stats_before_map[bs]
                        for key in ["compute_input", "compute_infer", "compute_output"]:
                            if key in batch_stat_after and key in batch_stat_before:
                                count_after = int(batch_stat_after[key]["count"])
                                count_before = int(batch_stat_before[key]["count"])
                                batch_stat_after[key]["count"] = str(count_after - count_before)

                                ns_after = int(batch_stat_after[key]["ns"])
                                ns_before = int(batch_stat_before[key]["ns"])
                                batch_stat_after[key]["ns"] = str(ns_after - ns_before)
    return stats_diff


def get_args():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        "--server-addr",
        type=str,
        default="localhost",
        help="Address of the server",
    )

    parser.add_argument(
        "--server-port",
        type=int,
        default=8001,
        help="Grpc port of the triton server, default is 8001",
    )

    parser.add_argument(
        "--reference-audio",
        type=str,
        default=None,
        help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
    )

    parser.add_argument(
        "--reference-text",
        type=str,
        default="",
        help="",
    )

    parser.add_argument(
        "--target-text",
        type=str,
        default="",
        help="",
    )

    parser.add_argument(
        "--huggingface-dataset",
        type=str,
        default="yuekai/seed_tts",
        help="dataset name in huggingface dataset hub",
    )

    parser.add_argument(
        "--split-name",
        type=str,
        default="wenetspeech4tts",
        choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
        help="dataset split name, default is 'test'",
    )

    parser.add_argument(
        "--manifest-path",
        type=str,
        default=None,
        help="Path to the manifest dir which includes wav.scp trans.txt files.",
    )

    parser.add_argument(
        "--model-name",
        type=str,
        default="f5_tts",
        choices=[
            "f5_tts",
            "spark_tts",
            "cosyvoice3",
            "cosyvoice2",
            "cosyvoice2_dit"],
        help="triton model_repo module name to request",
    )

    parser.add_argument(
        "--num-tasks",
        type=int,
        default=1,
        help="Number of concurrent tasks for sending",
    )

    parser.add_argument(
        "--log-interval",
        type=int,
        default=5,
        help="Controls how frequently we print the log.",
    )

    parser.add_argument(
        "--compute-wer",
        action="store_true",
        default=False,
        help="""True to compute WER.
        """,
    )

    parser.add_argument(
        "--log-dir",
        type=str,
        required=False,
        default="./tmp",
        help="log directory",
    )

    parser.add_argument(
        "--mode",
        type=str,
        default="offline",
        choices=["offline", "streaming"],
        help="Select offline or streaming benchmark mode."
    )
    parser.add_argument(
        "--chunk-overlap-duration",
        type=float,
        default=0.1,
        help="Chunk overlap duration for streaming reconstruction (in seconds)."
    )

    parser.add_argument(
        "--use-spk2info-cache",
        type=str,
        default="False",
        help="Use spk2info cache for reference audio.",
    )

    return parser.parse_args()


def load_audio(wav_path, target_sample_rate=16000):
    assert target_sample_rate == 16000, "hard coding in server"
    if isinstance(wav_path, dict):
        waveform = wav_path["array"]
        sample_rate = wav_path["sampling_rate"]
    else:
        waveform, sample_rate = sf.read(wav_path)
    if sample_rate != target_sample_rate:
        from scipy.signal import resample

        num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
        waveform = resample(waveform, num_samples)
    return waveform, target_sample_rate


def prepare_request_input_output(
    protocol_client,
    waveform,
    reference_text,
    target_text,
    sample_rate=16000,
    padding_duration: int = None,
    use_spk2info_cache: bool = False
):
    """Prepares inputs for Triton inference (offline or streaming)."""
    assert len(waveform.shape) == 1, "waveform should be 1D"
    lengths = np.array([[len(waveform)]], dtype=np.int32)

    if padding_duration:
        duration = len(waveform) / sample_rate
        if reference_text:
            estimated_target_duration = duration / len(reference_text) * len(target_text)
        else:
            estimated_target_duration = duration

        required_total_samples = padding_duration * sample_rate * (
            (int(estimated_target_duration + duration) // padding_duration) + 1
        )
        samples = np.zeros((1, required_total_samples), dtype=np.float32)
        samples[0, : len(waveform)] = waveform
    else:
        samples = waveform.reshape(1, -1).astype(np.float32)

    inputs = [
        protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
        protocol_client.InferInput(
            "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
        ),
        protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
        protocol_client.InferInput("target_text", [1, 1], "BYTES"),
    ]
    inputs[0].set_data_from_numpy(samples)
    inputs[1].set_data_from_numpy(lengths)

    input_data_numpy = np.array([reference_text], dtype=object)
    input_data_numpy = input_data_numpy.reshape((1, 1))
    inputs[2].set_data_from_numpy(input_data_numpy)

    input_data_numpy = np.array([target_text], dtype=object)
    input_data_numpy = input_data_numpy.reshape((1, 1))
    inputs[3].set_data_from_numpy(input_data_numpy)

    outputs = [protocol_client.InferRequestedOutput("waveform")]
    if use_spk2info_cache:
        inputs = inputs[-1:]
    return inputs, outputs


def run_sync_streaming_inference(
    sync_triton_client: tritonclient.grpc.InferenceServerClient,
    model_name: str,
    inputs: list,
    outputs: list,
    request_id: str,
    user_data: UserData,
    chunk_overlap_duration: float,
    save_sample_rate: int,
    audio_save_path: str,
):
    """Helper function to run the blocking sync streaming call."""
    start_time_total = time.time()
    user_data.record_start_time()

    sync_triton_client.async_stream_infer(
        model_name,
        inputs,
        request_id=request_id,
        outputs=outputs,
        enable_empty_final_response=True,
    )

    audios = []
    while True:
        try:
            result = user_data._completed_requests.get(timeout=200)
            if isinstance(result, InferenceServerException):
                print(f"Received InferenceServerException: {result}")
                return None, None, None, None
            response = result.get_response()
            final = response.parameters["triton_final_response"].bool_param
            if final is True:
                break

            audio_chunk = result.as_numpy("waveform").reshape(-1)
            if audio_chunk.size > 0:
                audios.append(audio_chunk)
            else:
                print("Warning: received empty audio chunk.")

        except queue.Empty:
            print(f"Timeout waiting for response for request id {request_id}")
            return None, None, None, None

    end_time_total = time.time()
    total_request_latency = end_time_total - start_time_total
    first_chunk_latency = user_data.get_first_chunk_latency()
    second_chunk_latency = user_data.get_second_chunk_latency()

    if audios:
        if model_name == "spark_tts":
            cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
            fade_out = np.linspace(1, 0, cross_fade_samples)
            fade_in = np.linspace(0, 1, cross_fade_samples)
            reconstructed_audio = None

            if not audios:
                print("Warning: No audio chunks received.")
                reconstructed_audio = np.array([], dtype=np.float32)
            elif len(audios) == 1:
                reconstructed_audio = audios[0]
            else:
                reconstructed_audio = audios[0][:-cross_fade_samples]
                for i in range(1, len(audios)):
                    cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
                                           audios[i - 1][-cross_fade_samples:] * fade_out)
                    middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
                    reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
                reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])

            if reconstructed_audio is not None and reconstructed_audio.size > 0:
                actual_duration = len(reconstructed_audio) / save_sample_rate
                sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
            else:
                print("Warning: No audio chunks received or reconstructed.")
                actual_duration = 0
        else:
            reconstructed_audio = np.concatenate(audios)
            actual_duration = len(reconstructed_audio) / save_sample_rate
            sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")

    else:
        print("Warning: No audio chunks received.")
        actual_duration = 0

    return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration


async def send_streaming(
    manifest_item_list: list,
    name: str,
    server_url: str,
    protocol_client: types.ModuleType,
    log_interval: int,
    model_name: str,
    audio_save_dir: str = "./",
    save_sample_rate: int = 16000,
    chunk_overlap_duration: float = 0.1,
    padding_duration: int = None,
    use_spk2info_cache: bool = False,
):
    total_duration = 0.0
    latency_data = []
    task_id = int(name[5:])
    sync_triton_client = None
    user_data_map = {}

    try:
        print(f"{name}: Initializing sync client for streaming...")
        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)
        sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))

        print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
        for i, item in enumerate(manifest_item_list):
            if i % log_interval == 0:
                print(f"{name}: Processing item {i}/{len(manifest_item_list)}")

            try:
                waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
                reference_text, target_text = item["reference_text"], item["target_text"]

                inputs, outputs = prepare_request_input_output(
                    protocol_client,
                    waveform,
                    reference_text,
                    target_text,
                    sample_rate,
                    padding_duration=padding_duration,
                    use_spk2info_cache=use_spk2info_cache
                )

                request_id = str(uuid.uuid4())
                user_data = UserData()
                user_data_map[request_id] = user_data

                audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
                total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
                    run_sync_streaming_inference,
                    sync_triton_client,
                    model_name,
                    inputs,
                    outputs,
                    request_id,
                    user_data,
                    chunk_overlap_duration,
                    save_sample_rate,
                    audio_save_path
                )

                if total_request_latency is not None:
                    print(
                        f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
                        f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
                        f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
                    )
                    latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
                    total_duration += actual_duration
                else:
                    print(f"{name}: Item {i} failed.")

                del user_data_map[request_id]

            except FileNotFoundError:
                print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
            except Exception as e:
                print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
                import traceback
                traceback.print_exc()

    finally:
        if sync_triton_client:
            try:
                print(f"{name}: Closing stream and sync client...")
                sync_triton_client.stop_stream()
                sync_triton_client.close()
            except Exception as e:
                print(f"{name}: Error closing sync client: {e}")

    print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
    return total_duration, latency_data


async def send(
    manifest_item_list: list,
    name: str,
    triton_client: tritonclient.grpc.aio.InferenceServerClient,
    protocol_client: types.ModuleType,
    log_interval: int,
    model_name: str,
    padding_duration: int = None,
    audio_save_dir: str = "./",
    save_sample_rate: int = 16000,
    use_spk2info_cache: bool = False,
):
    total_duration = 0.0
    latency_data = []
    task_id = int(name[5:])

    for i, item in enumerate(manifest_item_list):
        if i % log_interval == 0:
            print(f"{name}: {i}/{len(manifest_item_list)}")
        waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
        reference_text, target_text = item["reference_text"], item["target_text"]

        inputs, outputs = prepare_request_input_output(
            protocol_client,
            waveform,
            reference_text,
            target_text,
            sample_rate,
            padding_duration=padding_duration,
            use_spk2info_cache=use_spk2info_cache
        )
        sequence_id = 100000000 + i + task_id * 10
        start = time.time()
        response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)

        audio = response.as_numpy("waveform").reshape(-1)
        actual_duration = len(audio) / save_sample_rate

        end = time.time() - start

        audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
        sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")

        latency_data.append((end, actual_duration))
        total_duration += actual_duration

    return total_duration, latency_data


def load_manifests(manifest_path):
    with open(manifest_path, "r") as f:
        manifest_list = []
        for line in f:
            assert len(line.strip().split("|")) == 4
            utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
            utt = Path(utt).stem
            if not os.path.isabs(prompt_wav):
                prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
            manifest_list.append(
                {
                    "audio_filepath": prompt_wav,
                    "reference_text": prompt_text,
                    "target_text": gt_text,
                    "target_audio_path": utt,
                }
            )
    return manifest_list


def split_data(data, k):
    n = len(data)
    if n < k:
        print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
        k = n

    quotient = n // k
    remainder = n % k

    result = []
    start = 0
    for i in range(k):
        if i < remainder:
            end = start + quotient + 1
        else:
            end = start + quotient

        result.append(data[start:end])
        start = end

    return result


async def main():
    args = get_args()
    url = f"{args.server_addr}:{args.server_port}"

    triton_client = None
    protocol_client = None
    if args.mode == "offline":
        print("Initializing gRPC client for offline mode...")
        triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
        protocol_client = grpcclient_aio
    elif args.mode == "streaming":
        print("Initializing gRPC client for streaming mode...")
        protocol_client = grpcclient_sync
    else:
        raise ValueError(f"Invalid mode: {args.mode}")

    if args.reference_audio:
        args.num_tasks = 1
        args.log_interval = 1
        manifest_item_list = [
            {
                "reference_text": args.reference_text,
                "target_text": args.target_text,
                "audio_filepath": args.reference_audio,
                "target_audio_path": "test",
            }
        ]
    elif args.huggingface_dataset:
        import datasets

        dataset = datasets.load_dataset(
            args.huggingface_dataset,
            split=args.split_name,
            trust_remote_code=True,
        )
        manifest_item_list = []
        for i in range(len(dataset)):
            manifest_item_list.append(
                {
                    "audio_filepath": dataset[i]["prompt_audio"],
                    "reference_text": dataset[i]["prompt_text"],
                    "target_audio_path": dataset[i]["id"],
                    "target_text": dataset[i]["target_text"],
                }
            )
    else:
        manifest_item_list = load_manifests(args.manifest_path)

    stats_client = None
    stats_before = None
    try:
        print("Initializing temporary async client for fetching stats...")
        stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
        print("Fetching inference statistics before running tasks...")
        stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
    except Exception as e:
        print(f"Could not retrieve statistics before running tasks: {e}")

    num_tasks = min(args.num_tasks, len(manifest_item_list))
    manifest_item_list = split_data(manifest_item_list, num_tasks)

    os.makedirs(args.log_dir, exist_ok=True)
    args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
    tasks = []
    start_time = time.time()
    for i in range(num_tasks):
        if args.mode == "offline":
            task = asyncio.create_task(
                send(
                    manifest_item_list[i],
                    name=f"task-{i}",
                    triton_client=triton_client,
                    protocol_client=protocol_client,
                    log_interval=args.log_interval,
                    model_name=args.model_name,
                    audio_save_dir=args.log_dir,
                    padding_duration=1,
                    save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
                    use_spk2info_cache=args.use_spk2info_cache,
                )
            )
        elif args.mode == "streaming":
            task = asyncio.create_task(
                send_streaming(
                    manifest_item_list[i],
                    name=f"task-{i}",
                    server_url=url,
                    protocol_client=protocol_client,
                    log_interval=args.log_interval,
                    model_name=args.model_name,
                    audio_save_dir=args.log_dir,
                    padding_duration=10,
                    save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
                    chunk_overlap_duration=args.chunk_overlap_duration,
                    use_spk2info_cache=args.use_spk2info_cache,
                )
            )
        tasks.append(task)

    ans_list = await asyncio.gather(*tasks)

    end_time = time.time()
    elapsed = end_time - start_time

    total_duration = 0.0
    latency_data = []
    for ans in ans_list:
        if ans:
            total_duration += ans[0]
            latency_data.extend(ans[1])
        else:
            print("Warning: A task returned None, possibly due to an error.")

    if total_duration == 0:
        print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
        rtf = float('inf')
    else:
        rtf = elapsed / total_duration

    s = f"Mode: {args.mode}\n"
    s += f"RTF: {rtf:.4f}\n"
    s += f"total_duration: {total_duration:.3f} seconds\n"
    s += f"({total_duration / 3600:.2f} hours)\n"
    s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"

    if latency_data:
        if args.mode == "offline":
            latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
            if latency_list:
                latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
                latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
                s += f"latency_variance: {latency_variance:.2f}\n"
                s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
                s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
                s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
                s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
                s += f"average_latency_ms: {latency_ms:.2f}\n"
            else:
                s += "No latency data collected for offline mode.\n"

        elif args.mode == "streaming":
            total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
            first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
            second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]

            s += "\n--- Total Request Latency ---\n"
            if total_latency_list:
                avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
                variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
                s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
                s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
                s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
                s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
                s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
                s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
            else:
                s += "No total request latency data collected.\n"

            s += "\n--- First Chunk Latency ---\n"
            if first_chunk_latency_list:
                avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
                variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
                s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
                s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
                s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
                s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
                s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
                s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
            else:
                s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"

            s += "\n--- Second Chunk Latency ---\n"
            if second_chunk_latency_list:
                avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
                variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
                s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
                s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
                s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
                s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
                s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
                s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
            else:
                s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
    else:
        s += "No latency data collected.\n"

    print(s)
    if args.manifest_path:
        name = Path(args.manifest_path).stem
    elif args.split_name:
        name = args.split_name
    elif args.reference_audio:
        name = Path(args.reference_audio).stem
    else:
        name = "results"
    with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
        f.write(s)

    try:
        if stats_client and stats_before:
            print("Fetching inference statistics after running tasks...")
            stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)

            print("Calculating statistics difference...")
            stats = subtract_stats(stats_after, stats_before)

            print("Fetching model config...")
            metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)

            write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")

            with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
                json.dump(metadata, f, indent=4)
        else:
            print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")

    except Exception as e:
        print(f"Could not retrieve statistics or config: {e}")
    finally:
        if stats_client:
            try:
                print("Closing temporary async stats client...")
                await stats_client.close()
            except Exception as e:
                print(f"Error closing async stats client: {e}")


if __name__ == "__main__":
    async def run_main():
        try:
            await main()
        except Exception as e:
            print(f"An error occurred in main: {e}")
            import traceback
            traceback.print_exc()

    asyncio.run(run_main())