""" Example Usage
CUDA_VISIBLE_DEVICES=0 \
python3 infer_cosyvoice3_token2wav.py \
--output-dir $output_dir \
--llm-model-name-or-path $huggingface_model_local_dir \
--token2wav-path $token2wav_model_dir \
--backend $backend \
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
--engine-dir $trt_engines_dir \
--split-name ${dataset} || exit 1
"""
import argparse
import json
import os
import time
import asyncio
import torch
import torchaudio
import s3tokenizer
import soundfile as sf
import requests
import httpx
from transformers import AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from functools import partial
from tqdm import tqdm
from token2wav_cosyvoice3 import CosyVoice3_Token2Wav
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
async def send_request_async(client, url, payload):
response = await client.post(url, json=payload, timeout=None)
response.raise_for_status()
response_json = response.json()
return response_json['choices'][0]['message']['content']
async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
async with httpx.AsyncClient() as client:
tasks = []
for chat in chats:
payload = {
"model": model_name,
"messages": chat,
"max_tokens": 2048,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": False,
}
tasks.append(send_request_async(client, api_base, payload))
return await asyncio.gather(*tasks)
def extract_speech_ids(speech_tokens_str):
"""Extract speech IDs from token strings like <|s_23456|>"""
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith('<|s_') and token_str.endswith('|>'):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
def convert_cosy3_tokens_to_speech_id_str(cosy3_tokens):
"""Convert CosyVoice3 tokens to speech IDs string like <|s_23456|>"""
if hasattr(cosy3_tokens, 'cpu'):
cosy3_tokens = cosy3_tokens.cpu().numpy().tolist()
speech_id_str = ""
for token in cosy3_tokens:
speech_id_str += f"<|s_{token}|>"
return speech_id_str
def get_args():
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice3")
parser.add_argument(
"--split-name", type=str, default="wenetspeech4tts",
help="huggingface dataset split name",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="dir to save result",
)
parser.add_argument(
"--batch-size", default=1, type=int,
help="batch size (per-device) for LLM inference",
)
parser.add_argument(
"--token2wav-batch-size", default=1, type=int,
help="batch size (per-device) for token2wav inference",
)
parser.add_argument(
"--num-workers", type=int, default=0, help="workers for dataloader",
)
parser.add_argument(
"--prefetch", type=int, default=None, help="prefetch for dataloader",
)
parser.add_argument(
"--llm-model-name-or-path", required=True, type=str,
help="CosyVoice3 HF LLM path (e.g. ./hf_cosyvoice3_llm)",
)
parser.add_argument(
"--token2wav-path", required=True, type=str,
help="CosyVoice3 model path (e.g. /workspace_yuekai/HF/Fun-CosyVoice3-0.5B-2512)",
)
parser.add_argument(
"--enable-trt", action="store_true",
help="Enable TensorRT for flow decoder estimator",
)
parser.add_argument(
"--streaming", action="store_true",
help="Enable streaming for flow decoder estimator",
)
parser.add_argument(
"--top-p", type=float, default=0.95, help="top p for sampling",
)
parser.add_argument(
"--temperature", type=float, default=0.8, help="temperature for sampling",
)
parser.add_argument(
"--top-k", type=int, default=15, help="top k for sampling",
)
parser.add_argument(
"--backend", type=str, default="hf",
choices=["hf", "trtllm", "vllm", "trtllm-serve"],
help="Backend to use for LLM inference",
)
parser.add_argument(
"--engine-dir", type=str, default=None,
help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
)
parser.add_argument(
"--kv-cache-free-gpu-memory-fraction", type=float, default=0.6,
help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
)
parser.add_argument(
"--openai-api-base", type=str,
default="http://localhost:8000/v1/chat/completions",
help="OpenAI API base URL (for trtllm-serve backend)",
)
parser.add_argument(
"--openai-model-name", type=str, default="trt_engines_bfloat16",
help="Model name to use with OpenAI API (for trtllm-serve backend)",
)
parser.add_argument(
"--epoch", type=int, default=1, help="Epoch to run",
)
return parser.parse_args()
def data_collator(batch, tokenizer, s3_tokenizer):
"""Data collator: extracts cosy3 tokens from prompt_audio using v3 s3 tokenizer."""
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
target_sample_rate = 16000
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
mels, prompt_audio_cosy3tokens_list, full_text_list = [], [], []
chat_list = []
for item in batch:
prompt_text, target_text = item["prompt_text"], item["target_text"]
prompt_text_list.append(prompt_text)
full_text = 'You are a helpful assistant.<|endofprompt|>' + prompt_text + target_text
full_text_list.append(full_text)
ref_audio = torch.from_numpy(item["prompt_audio"]["array"]).float().unsqueeze(0)
ref_sr = item["prompt_audio"]["sampling_rate"]
if ref_sr != target_sample_rate:
ref_audio = torchaudio.transforms.Resample(ref_sr, target_sample_rate)(ref_audio)
prompt_audio_list.append(ref_audio)
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
if len(mels) > 0:
mels_padded, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = s3_tokenizer.quantize(mels_padded.to(device), mels_lens.to(device))
for i in range(len(codes)):
prompt_audio_cosy3tokens_list.append(codes[i, :codes_lens[i].item()])
for i, prompt_audio_cosy3tokens in enumerate(prompt_audio_cosy3tokens_list):
prompt_audio_cosy3_id_str = convert_cosy3_tokens_to_speech_id_str(
prompt_audio_cosy3tokens)
chat = [
{"role": "user", "content": full_text_list[i]},
{"role": "assistant", "content": prompt_audio_cosy3_id_str}
]
chat_list.append(chat)
input_ids = tokenizer.apply_chat_template(
chat, tokenize=True, return_tensors='pt', continue_final_message=True)
input_ids_list.append(input_ids.squeeze(0))
ids = [item["id"] for item in batch]
return {
"input_ids": input_ids_list,
"ids": ids,
"prompt_text": prompt_text_list,
"prompt_audio_list": prompt_audio_list,
"chat_list": chat_list,
}
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
local_rank = 0
device = torch.device(f"cuda:{local_rank}")
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
if args.backend == "hf":
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
model.eval()
model.to(device)
runner = None
elif args.backend == "trtllm":
if args.engine_dir is None:
raise ValueError("--engine-dir is required when backend is 'trtllm'")
runtime_rank = tensorrt_llm.mpi_rank()
model = None
runner_kwargs = dict(
engine_dir=args.engine_dir,
rank=runtime_rank,
max_output_len=2048,
enable_context_fmha_fp32_acc=False,
max_batch_size=args.batch_size,
max_input_len=512,
kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
cuda_graph_mode=False,
gather_generation_logits=False,
)
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
elif args.backend == "vllm":
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
runner = None
elif args.backend == "trtllm-serve":
model = None
runner = None
else:
raise ValueError(f"Unsupported backend: {args.backend}")
token2wav_model = CosyVoice3_Token2Wav(
model_dir=args.token2wav_path, enable_trt=args.enable_trt, device_id=local_rank, streaming=args.streaming
)
s3_tokenizer = s3tokenizer.load_model(
f"{args.token2wav_path}/speech_tokenizer_v3.onnx"
).to(device).eval()
dataset = load_dataset(
"yuekai/seed_tts_cosy2",
split=args.split_name,
trust_remote_code=True,
)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
)
for epoch in range(args.epoch):
print(f"Running epoch {epoch}")
total_llm_time = 0
total_token2wav_time = 0
total_data_load_time = 0
total_llm_post_processing_time = 0
total_audio_save_time = 0
total_audio_samples = 0
start_time = time.time()
progress_bar = tqdm(total=len(dataset), desc="Processing", unit="wavs")
last_batch_end_time = time.time()
for batch in dataloader:
data_loaded_time = time.time()
total_data_load_time += data_loaded_time - last_batch_end_time
with torch.no_grad():
llm_start_time = time.time()
if args.backend == "hf":
input_ids_list = batch["input_ids"]
if len(input_ids_list) == 1:
input_ids = input_ids_list[0].unsqueeze(0)
attention_mask = torch.ones_like(input_ids)
else:
max_len = max([len(ids) for ids in input_ids_list])
input_ids_list_new = [
torch.cat([ids, torch.full((max_len - len(ids),), tokenizer.pad_token_id)])
for ids in input_ids_list
]
input_ids = torch.stack(input_ids_list_new)
attention_mask = torch.zeros_like(input_ids)
for i in range(len(input_ids_list)):
attention_mask[i, :len(input_ids_list[i])] = 1
outputs = model.generate(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
max_new_tokens=2048,
do_sample=True,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=1.1,
top_k=args.top_k,
)
torch.cuda.synchronize()
elif args.backend == "trtllm":
batch_input_ids = list(batch["input_ids"])
input_lengths = [x.size(0) for x in batch_input_ids]
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
outputs = runner.generate(
batch_input_ids=batch_input_ids,
max_new_tokens=2048,
end_id=end_id,
pad_id=end_id,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=1.1,
num_return_sequences=1,
streaming=False,
output_sequence_lengths=True,
output_generation_logits=False,
return_dict=True,
return_all_generated_tokens=False
)
torch.cuda.synchronize()
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
num_output_sents, num_beams, _ = output_ids.size()
assert num_beams == 1
batch_size = len(batch["input_ids"])
num_return_sequences = num_output_sents // batch_size
assert num_return_sequences == 1
outputs = []
for i in range(batch_size * num_return_sequences):
batch_idx = i // num_return_sequences
output_begin = input_lengths[batch_idx]
output_end = sequence_lengths[i][0]
outputs_i = output_ids[i][0][:output_end].tolist()
outputs.append(outputs_i)
elif args.backend == "vllm":
input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=1.1,
max_tokens=2048,
)
outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
for j, output in enumerate(outputs):
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
elif args.backend == "trtllm-serve":
if args.batch_size > 1:
outputs = asyncio.run(send_batch_requests_async(
args.openai_api_base,
args.openai_model_name,
batch["chat_list"],
args.temperature,
args.top_p,
args.top_k,
))
else:
outputs = []
for chat in batch["chat_list"]:
payload = {
"model": args.openai_model_name,
"messages": chat,
"max_tokens": 2048,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": False,
}
response = requests.post(args.openai_api_base, json=payload)
response.raise_for_status()
response_json = response.json()
generated_content = response_json['choices'][0]['message']['content']
outputs.append(generated_content)
llm_end_time = time.time()
total_llm_time += (llm_end_time - llm_start_time)
items_for_token_2wav = []
for i in range(len(batch["ids"])):
llm_post_processing_start_time = time.time()
if args.backend == "trtllm-serve":
speech_tokens_str = outputs[i].strip().split('><')
if len(speech_tokens_str) > 1:
speech_tokens_str = [
t if t.startswith('<') else '<' + t for t in speech_tokens_str
]
speech_tokens_str = [
t if t.endswith('>') else t + '>' for t in speech_tokens_str
]
speech_ids = extract_speech_ids(speech_tokens_str)
else:
input_length = len(batch["input_ids"][i])
generated_ids = outputs[i][input_length:]
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
speech_ids = extract_speech_ids(speech_tokens_str)
print(i, speech_ids[:10], "...", f"total={len(speech_ids)}")
if len(speech_ids) == 0:
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
llm_post_processing_end_time = time.time()
total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
continue
current_prompt_audio = batch["prompt_audio_list"][i]
llm_post_processing_end_time = time.time()
total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
items_for_token_2wav.append({
"speech_ids": speech_ids,
"prompt_audio": current_prompt_audio.squeeze(0),
"id": batch["ids"][i]
})
for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
if not t2w_batch:
continue
t2w_speech_tokens = [item["speech_ids"] for item in t2w_batch]
t2w_prompt_audios = [item["prompt_audio"] for item in t2w_batch]
t2w_sample_rates = [16000] * len(t2w_batch)
token2wav_start_time = time.time()
generated_wavs = token2wav_model(
t2w_speech_tokens, t2w_prompt_audios, t2w_sample_rates,
streaming=args.streaming,
)
token2wav_end_time = time.time()
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
audio_save_start_time = time.time()
for j, audio_hat in enumerate(generated_wavs):
wav = audio_hat.squeeze().cpu().numpy()
total_audio_samples += len(wav)
sf.write(f"{args.output_dir}/{t2w_batch[j]['id']}.wav", wav, 24000)
print(f"Generated audio for sample {t2w_batch[j]['id']} with {len(t2w_speech_tokens[j])} tokens")
audio_save_end_time = time.time()
total_audio_save_time += audio_save_end_time - audio_save_start_time
progress_bar.update(len(batch["ids"]))
last_batch_end_time = time.time()
progress_bar.close()
end_time = time.time()
total_audio_duration_seconds = total_audio_samples / 24000
log_file_path = os.path.join(args.output_dir, "log.txt")
with open(log_file_path, 'w') as f:
log_data = {
"args": vars(args),
"data_load_time_seconds": total_data_load_time,
"llm_time_seconds": total_llm_time,
"llm_post_processing_time_seconds": total_llm_post_processing_time,
"token2wav_time_seconds": total_token2wav_time,
"audio_save_time_seconds": total_audio_save_time,
"total_audio_duration_seconds": total_audio_duration_seconds,
"pipeline_time_seconds": end_time - start_time,
}
print(log_data)
f.write(json.dumps(log_data, indent=4))
print(f"Metrics logged to {log_file_path}")
if __name__ == "__main__":
args = get_args()
if args.backend == "vllm":
from vllm import LLM, SamplingParams
elif args.backend == "trtllm":
import tensorrt_llm
from tensorrt_llm.runtime import ModelRunnerCpp
elif args.backend == "hf":
from transformers import AutoModelForCausalLM
elif args.backend == "trtllm-serve":
pass
else:
raise ValueError(f"Unsupported backend: {args.backend}")
main(args)