8f0b28b8创建于 3月16日历史提交
import json
import os
import logging

import torch
from torch.utils.dlpack import to_dlpack
import triton_python_backend_utils as pb_utils
from hyperpyyaml import load_hyperpyyaml

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

torch.set_num_threads(1)


class TritonPythonModel:
    """Triton Python model for CosyVoice3 vocoder (CausalHiFTGenerator).

    Stateless: converts mel spectrogram to waveform.
    CausalHiFTGenerator manages its own internal cache.
    """

    def initialize(self, args):
        parameters = json.loads(args['model_config'])['parameters']
        model_params = {k: v["string_value"] for k, v in parameters.items()}
        model_dir = model_params["model_dir"]

        self.device = torch.device("cuda")

        # Load CausalHiFTGenerator from cosyvoice3.yaml
        with open(os.path.join(model_dir, 'cosyvoice3.yaml'), 'r') as f:
            configs = load_hyperpyyaml(f, overrides={
                'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')
            })
        self.hift = configs['hift']
        hift_state_dict = {
            k.replace('generator.', ''): v
            for k, v in torch.load(
                os.path.join(model_dir, 'hift.pt'),
                map_location='cpu', weights_only=True
            ).items()
        }
        self.hift.load_state_dict(hift_state_dict, strict=True)
        self.hift.to(self.device).eval()
        logger.info("CausalHiFTGenerator initialized successfully")

    def execute(self, requests):
        responses = []
        for req_idx, request in enumerate(requests):
            mel = pb_utils.get_input_tensor_by_name(request, "mel")
            mel = torch.utils.dlpack.from_dlpack(mel.to_dlpack()).to(self.device)
            if mel.dim() == 2:
                mel = mel.unsqueeze(0)  # [80, T] -> [1, 80, T]

            finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()

            with torch.no_grad():
                speech, _ = self.hift.inference(speech_feat=mel, finalize=finalize)

            # speech shape: [1, 1, S] or [1, S] depending on hift version
            speech = speech.squeeze()  # flatten to [S]

            speech_tensor = pb_utils.Tensor.from_dlpack(
                "tts_speech", to_dlpack(speech.unsqueeze(0)))  # [1, S] for batch dim
            inference_response = pb_utils.InferenceResponse(
                output_tensors=[speech_tensor])
            responses.append(inference_response)

        return responses