import json
import os
import logging
import torch
from torch.utils.dlpack import to_dlpack
import triton_python_backend_utils as pb_utils
from hyperpyyaml import load_hyperpyyaml
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
torch.set_num_threads(1)
class TritonPythonModel:
"""Triton Python model for CosyVoice3 vocoder (CausalHiFTGenerator).
Stateless: converts mel spectrogram to waveform.
CausalHiFTGenerator manages its own internal cache.
"""
def initialize(self, args):
parameters = json.loads(args['model_config'])['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
model_dir = model_params["model_dir"]
self.device = torch.device("cuda")
with open(os.path.join(model_dir, 'cosyvoice3.yaml'), 'r') as f:
configs = load_hyperpyyaml(f, overrides={
'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')
})
self.hift = configs['hift']
hift_state_dict = {
k.replace('generator.', ''): v
for k, v in torch.load(
os.path.join(model_dir, 'hift.pt'),
map_location='cpu', weights_only=True
).items()
}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
logger.info("CausalHiFTGenerator initialized successfully")
def execute(self, requests):
responses = []
for req_idx, request in enumerate(requests):
mel = pb_utils.get_input_tensor_by_name(request, "mel")
mel = torch.utils.dlpack.from_dlpack(mel.to_dlpack()).to(self.device)
if mel.dim() == 2:
mel = mel.unsqueeze(0)
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
with torch.no_grad():
speech, _ = self.hift.inference(speech_feat=mel, finalize=finalize)
speech = speech.squeeze()
speech_tensor = pb_utils.Tensor.from_dlpack(
"tts_speech", to_dlpack(speech.unsqueeze(0)))
inference_response = pb_utils.InferenceResponse(
output_tensors=[speech_tensor])
responses.append(inference_response)
return responses