import json
import os
import logging
import queue
import torch
import numpy as np
from torch.utils.dlpack import to_dlpack
import triton_python_backend_utils as pb_utils
from hyperpyyaml import load_hyperpyyaml
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
self.trt_engine = trt_engine
self.device = device
for _ in range(trt_concurrent):
trt_context = trt_engine.create_execution_context()
trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
assert trt_context is not None
self.trt_context_pool.put([trt_context, trt_stream])
def acquire_estimator(self):
return self.trt_context_pool.get(), self.trt_engine
def release_estimator(self, context, stream):
self.trt_context_pool.put([context, stream])
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16, autocast_mode=False):
import tensorrt as trt
logging.info("Converting onnx to trt...")
if autocast_mode:
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
else:
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
trt_logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(trt_logger)
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, trt_logger)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)
if not autocast_mode and fp16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
with open(onnx_model, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise ValueError(f'failed to parse {onnx_model}')
for i in range(len(trt_kwargs['input_names'])):
profile.set_shape(trt_kwargs['input_names'][i],
trt_kwargs['min_shape'][i],
trt_kwargs['opt_shape'][i],
trt_kwargs['max_shape'][i])
if not autocast_mode:
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
for i in range(network.num_inputs):
network.get_input(i).dtype = tensor_dtype
for i in range(network.num_outputs):
network.get_output(i).dtype = tensor_dtype
config.add_optimization_profile(profile)
engine_bytes = builder.build_serialized_network(network, config)
with open(trt_model, "wb") as f:
f.write(engine_bytes)
logging.info("Successfully converted onnx to trt")
torch.set_num_threads(1)
class TritonPythonModel:
"""Triton Python model for CosyVoice3 token2wav (flow-only, stateless).
Converts speech tokens to mel spectrogram using the CausalMaskedDiffWithDiT flow model.
"""
def initialize(self, args):
parameters = json.loads(args['model_config'])['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
model_dir = model_params["model_dir"]
self.device = torch.device("cuda")
with open(os.path.join(model_dir, 'cosyvoice3.yaml'), 'r') as f:
configs = load_hyperpyyaml(f, overrides={
'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')
})
self.flow = configs['flow']
self.fp16 = True
self.flow.half()
self.flow.load_state_dict(
torch.load(os.path.join(model_dir, 'flow.pt'),
map_location='cpu', weights_only=True),
strict=True
)
self.flow.to(self.device).eval()
self.load_trt(model_dir)
self.token_mel_ratio = self.flow.token_mel_ratio
logger.info(f"Token2wav (flow-only) initialized, token_mel_ratio={self.token_mel_ratio}")
def load_trt(self, model_dir, trt_concurrent=1):
device_id = torch.cuda.current_device()
onnx_path = os.path.join(model_dir, 'flow.decoder.estimator.autocast_fp16.onnx')
trt_path = os.path.join(model_dir, f'flow.decoder.estimator.autocast_fp16.{device_id}.plan')
if not os.path.exists(trt_path) or os.path.getsize(trt_path) == 0:
trt_kwargs = self.get_trt_kwargs()
convert_onnx_to_trt(trt_path, trt_kwargs, onnx_path,
fp16=True, autocast_mode=True)
del self.flow.decoder.estimator
import tensorrt as trt
with open(trt_path, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, f'failed to load trt {trt_path}'
self.flow.decoder.estimator = TrtContextWrapper(
estimator_engine, trt_concurrent=trt_concurrent, device=str(self.device))
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape,
'max_shape': max_shape, 'input_names': input_names}
def execute(self, requests):
responses = []
for req_idx, request in enumerate(requests):
target_speech_tokens = pb_utils.get_input_tensor_by_name(
request, "target_speech_tokens")
target_speech_tokens = torch.utils.dlpack.from_dlpack(
target_speech_tokens.to_dlpack()).to(self.device)
if target_speech_tokens.dim() == 1:
target_speech_tokens = target_speech_tokens.unsqueeze(0)
prompt_speech_tokens_pb = pb_utils.get_input_tensor_by_name(
request, "prompt_speech_tokens")
if prompt_speech_tokens_pb is not None:
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(
prompt_speech_tokens_pb.to_dlpack()).to(self.device)
if prompt_speech_tokens.dim() == 1:
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
prompt_speech_feat = pb_utils.get_input_tensor_by_name(
request, "prompt_speech_feat")
prompt_speech_feat = torch.utils.dlpack.from_dlpack(
prompt_speech_feat.to_dlpack()).to(self.device)
if prompt_speech_feat.dim() == 2:
prompt_speech_feat = prompt_speech_feat.unsqueeze(0)
prompt_spk_embedding = pb_utils.get_input_tensor_by_name(
request, "prompt_spk_embedding")
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(
prompt_spk_embedding.to_dlpack()).to(self.device)
if prompt_spk_embedding.dim() == 1:
prompt_spk_embedding = prompt_spk_embedding.unsqueeze(0)
else:
raise ValueError("prompt_speech_tokens is required for CosyVoice3 token2wav")
token_offset_pb = pb_utils.get_input_tensor_by_name(request, "token_offset")
finalize_pb = pb_utils.get_input_tensor_by_name(request, "finalize")
token_offset = token_offset_pb.as_numpy().item() if token_offset_pb is not None else None
finalize = finalize_pb.as_numpy().item() if finalize_pb is not None else True
streaming = not finalize
with torch.no_grad(), torch.cuda.amp.autocast(self.fp16):
mel, _ = self.flow.inference(
token=target_speech_tokens,
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_speech_tokens,
prompt_token_len=torch.tensor([prompt_speech_tokens.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_speech_feat,
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=prompt_spk_embedding,
streaming=streaming,
finalize=finalize,
)
if token_offset is not None:
mel = mel[:, :, token_offset * self.token_mel_ratio:]
mel_out = mel.squeeze(0).float()
mel_out = mel_out.cpu()
mel_tensor = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel_out))
inference_response = pb_utils.InferenceResponse(output_tensors=[mel_tensor])
responses.append(inference_response)
return responses