diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py

index 7ab04a7..dcb743f 100644

--- a/cosyvoice/cli/cosyvoice.py

+++ b/cosyvoice/cli/cosyvoice.py

@@ -22,6 +22,7 @@ from cosyvoice.cli.frontend import CosyVoiceFrontEnd

 from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model

 from cosyvoice.utils.file_utils import logging

 from cosyvoice.utils.class_utils import get_model_type

+from cosyvoice.utils.file_utils import load_wav

 

 

 class CosyVoice:

@@ -74,6 +75,11 @@ class CosyVoice:

         self.frontend.spk2info[zero_shot_spk_id] = model_input

         return True

 

+    def add_sft_spk(self, sft_spk_id):

+        model_input = self.frontend.frontend_from_spk2info(sft_spk_id)

+        self.frontend.spk2info_sft[sft_spk_id] = model_input

+        return True

+

     def save_spkinfo(self):

         torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))

 

@@ -188,7 +194,8 @@ class CosyVoice2(CosyVoice):

 

 class CosyVoice3(CosyVoice2):

 

-    def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):

+    def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1,

+                 speaker_info_dir=None):

         self.model_dir = model_dir

         self.fp16 = fp16

         if not os.path.exists(model_dir):

@@ -224,6 +231,66 @@ class CosyVoice3(CosyVoice2):

                                 self.fp16)

         del configs

 

+        self.speaker_info_dir = speaker_info_dir or os.path.join(model_dir, 'speaker_info')

+        if not os.path.exists(self.speaker_info_dir):

+            logging.warning(f'speaker_info_dir {self.speaker_info_dir} does not exist')

+

+        self.promote_wave_info = {}

+        self._preload_all_wave_info()

+

+    def _preload_all_wave_info(self):

+        if not os.path.exists(self.speaker_info_dir):

+            return

+

+        import json

+

+        speaker_info_file = os.path.join(self.speaker_info_dir, 'speaker_info.json')

+        if not os.path.exists(speaker_info_file):

+            logging.warning(f'speaker_info.json not found in {self.speaker_info_dir}')

+            return

+

+        try:

+            with open(speaker_info_file, 'r', encoding='utf-8') as f:

+                all_speakers = json.load(f)

+

+            for spk_id, wave_info in all_speakers.items():

+                try:

+                    prompt_wav_path = os.path.join(self.speaker_info_dir, wave_info['prompt_wav'])

+                    if os.path.exists(prompt_wav_path):

+                        prompt_wav = load_wav(prompt_wav_path, 16000)

+                        wave_info['prompt_wav'] = prompt_wav

+                        self.promote_wave_info[spk_id] = wave_info

+                        logging.info(f"Loaded speaker {spk_id} from {prompt_wav_path}, promote={wave_info['prompt_text']}")

+                    else:

+                        logging.error(f'Audio file not found for speaker {spk_id}: {prompt_wav_path}')

+                except Exception as e:

+                    logging.error(f'Failed to load speaker {spk_id}: {e}')

+

+            logging.info(f'Preloaded {len(self.promote_wave_info)} speakers from {speaker_info_file}')

+

+        except Exception as e:

+            logging.error(f'Failed to load {speaker_info_file}: {e}')

+

+    def inference_zero_shot_by_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):

+        if spk_id not in self.promote_wave_info:

+            raise ValueError(f'Speaker ID {spk_id} not found. Available IDs: {list(self.promote_wave_info.keys())}')

+        wave_info = self.promote_wave_info[spk_id]

+

+        if spk_id not in self.frontend.spk2info:

+            self.add_zero_shot_spk(wave_info['prompt_text'], wave_info['prompt_wav'], spk_id)

+

+        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):

+            model_input = self.frontend.frontend_zero_shot(i, wave_info['prompt_text'], wave_info['prompt_wav'], self.sample_rate, spk_id)

+            start_time = time.time()

+            logging.info('synthesis text {}'.format(i))

+            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):

+                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate

+                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))

+                # 2.save audio

+                #res = torch.cat((res, model_output['tts_speech']), dim=1)

+                yield model_output

+                start_time = time.time()

+

 

 def AutoModel(**kwargs):

     if not os.path.exists(kwargs['model_dir']):

diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py

index 6d397cc..10b0a87 100644

--- a/cosyvoice/cli/frontend.py

+++ b/cosyvoice/cli/frontend.py

@@ -26,6 +26,7 @@ import inflect

 from cosyvoice.utils.file_utils import logging, load_wav

 from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation

 

+import threading

 

 class CosyVoiceFrontEnd:

 

@@ -50,6 +51,8 @@ class CosyVoiceFrontEnd:

             self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)

         else:

             self.spk2info = {}

+        logging.info(f"spk2info file = {spk2info}, CosyVoiceFrontEnd.spk2info = {self.spk2info.keys()}")

+        self.spk2info_sft = {}

         self.allowed_special = allowed_special

         self.inflect_parser = inflect.engine()

         # NOTE compatible when no text frontend tool is avaliable

@@ -74,6 +77,12 @@ class CosyVoiceFrontEnd:

                 self.text_frontend = ''

                 logging.info('no frontend is avaliable')

 

+        self.lock = threading.Lock()

+

+    def tokenizer_encode_with_lock(self, *args, **kwargs):

+        with self.lock:

+            output = self.tokenizer.encode(*args, **kwargs)

+        return output

 

     def _extract_text_token(self, text):

         if isinstance(text, Generator):

@@ -81,7 +90,7 @@ class CosyVoiceFrontEnd:

             # NOTE add a dummy text_token_len for compatibility

             return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)

         else:

-            text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)

+            text_token = self.tokenizer_encode_with_lock(text, allowed_special=self.allowed_special)

             text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)

             text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)

             return text_token, text_token_len

@@ -148,13 +157,13 @@ class CosyVoiceFrontEnd:

                 text = text.replace(" - ", ",")

                 text = remove_bracket(text)

                 text = re.sub(r'[,,、]+$', '。', text)

-                texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,

+                texts = list(split_paragraph(text, partial(self.tokenizer_encode_with_lock, allowed_special=self.allowed_special), "zh", token_max_n=80,

                                              token_min_n=60, merge_len=20, comma_split=False))

             else:

                 if self.text_frontend == 'wetext':

                     text = self.en_tn_model.normalize(text)

                 text = spell_out_number(text, self.inflect_parser)

-                texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,

+                texts = list(split_paragraph(text, partial(self.tokenizer_encode_with_lock, allowed_special=self.allowed_special), "en", token_max_n=80,

                                              token_min_n=60, merge_len=20, comma_split=False))

         texts = [i for i in texts if not is_only_punctuation(i)]

         return texts if split is True else text

@@ -222,3 +231,13 @@ class CosyVoiceFrontEnd:

                        'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,

                        'flow_embedding': embedding}

         return model_input

+

+    def frontend_from_spk2info(self, spk_id):

+        spk_info = self.spk2info[spk_id]

+        embedding = spk_info['embedding'].to(self.device)

+        model_input = {

+            'llm_embedding': embedding,

+            'flow_embedding': embedding,

+        }

+        return model_input

+

diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py

index 92a15d9..f5362e1 100644

--- a/cosyvoice/cli/model.py

+++ b/cosyvoice/cli/model.py

@@ -100,7 +100,7 @@ class CosyVoiceModel:

 

     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):

         cur_silent_token_num, max_silent_token_num = 0, 5

-        with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):

+        with torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):

             if isinstance(text, Generator):

                 assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'

                 token_generator = self.llm.inference_bistream(text=text,

@@ -282,9 +282,13 @@ class CosyVoice2Model(CosyVoiceModel):

         export_cosyvoice2_vllm(self.llm, model_dir, self.device)

         from vllm import EngineArgs, LLMEngine

         engine_args = EngineArgs(model=model_dir,

+                                 dtype='float16',

                                  skip_tokenizer_init=True,

                                  enable_prompt_embeds=True,

-                                 gpu_memory_utilization=0.2)

+                                 gpu_memory_utilization=0.5,

+                                 additional_config={"torchair_graph_config":{"enabled":True}},

+                                 compilation_config={'cudagraph_capture_sizes':[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,512,1024],"cudagraph_mode": "FULL"},

+                                 enable_prefix_caching=True)

         self.llm.vllm = LLMEngine.from_engine_args(engine_args)

         self.llm.lock = threading.Lock()

         del self.llm.llm.model.model.layers

diff --git a/cosyvoice/flow/DiT/modules.py b/cosyvoice/flow/DiT/modules.py

index be8caec..d188914 100644

--- a/cosyvoice/flow/DiT/modules.py

+++ b/cosyvoice/flow/DiT/modules.py

@@ -18,7 +18,7 @@ import torch.nn.functional as F

 import torchaudio

 

 from x_transformers.x_transformers import apply_rotary_pos_emb

-

+import torch_npu

 

 # raw wav to mel spec

 class MelSpec(nn.Module):

@@ -84,7 +84,12 @@ class SinusPositionEmbedding(nn.Module):

 

 

 # convolutional position embedding

+class Mish(nn.Module):

+    def __init__(self):

+        super().__init__()

 

+    def forward(self, x):

+        return x * torch.tanh(torch.log(1 + torch.exp(x)))

 

 class ConvPositionEmbedding(nn.Module):

     def __init__(self, dim, kernel_size=31, groups=16):

@@ -92,9 +97,9 @@ class ConvPositionEmbedding(nn.Module):

         assert kernel_size % 2 != 0

         self.conv1d = nn.Sequential(

             nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),

-            nn.Mish(),

+            Mish(),

             nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),

-            nn.Mish(),

+            Mish(),

         )

 

     def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):  # noqa: F722

@@ -119,11 +124,11 @@ class CausalConvPositionEmbedding(nn.Module):

         self.kernel_size = kernel_size

         self.conv1 = nn.Sequential(

             nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),

-            nn.Mish(),

+            Mish(),

         )

         self.conv2 = nn.Sequential(

             nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),

-            nn.Mish(),

+            Mish(),

         )

 

     def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):  # noqa: F722

@@ -388,7 +393,17 @@ class AttnProcessor:

         else:

             attn_mask = None

 

-        x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)

+        atten_mask_npu = torch.logical_not(attn_mask)

+        head_num = query.shape[1]

+        x = torch_npu.npu_fusion_attention(

+                       query, key, value, head_num, input_layout="BNSD",

+                       pse=None,

+                       atten_mask=atten_mask_npu,

+                       scale=1.0 / math.sqrt(query.shape[-1]),

+                       pre_tockens=2147483647,

+                       next_tockens=2147483647,

+                       keep_prob=1

+                   )[0]

         x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

         x = x.to(query.dtype)

 

diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py

index c255186..9a20c34 100644

--- a/cosyvoice/flow/flow.py

+++ b/cosyvoice/flow/flow.py

@@ -376,7 +376,8 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):

                   prompt_feat_len,

                   embedding,

                   streaming,

-                  finalize):

+                  finalize,

+                  n_timesteps=None):

         assert token.shape[0] == 1

         # xvec projection

         embedding = F.normalize(embedding, dim=1)

@@ -406,7 +407,7 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):

             mask=mask.unsqueeze(1),

             spks=embedding,

             cond=conds,

-            n_timesteps=10,

+            n_timesteps=10 if n_timesteps is None else n_timesteps,

             streaming=streaming

         )

         feat = feat[:, :, mel_len1:]

diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py

index bbc2a21..0cedae6 100644

--- a/cosyvoice/hifigan/generator.py

+++ b/cosyvoice/hifigan/generator.py

@@ -498,10 +498,24 @@ class HiFTGenerator(nn.Module):

 

     def _istft(self, magnitude, phase):

         magnitude = torch.clip(magnitude, max=1e2)

+        device = magnitude.device

+        is_npu = device.type == 'npu'

         real = magnitude * torch.cos(phase)

         img = magnitude * torch.sin(phase)

-        inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],

-                                        self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))

+        if is_npu:

+            complex_tensor = torch.complex(real.cpu(), img.cpu())

+            window_xpu = self.stft_window.cpu()

+        else:

+            complex_tensor = torch.complex(real, img)

+            window_xpu = self.stft_window.to(device)

+

+        inverse_transform = torch.istft(complex_tensor,

+                                        self.istft_params["n_fft"],

+                                        self.istft_params["hop_len"],

+                                        self.istft_params["n_fft"],

+                                        window=window_xpu)

+        if is_npu:

+            inverse_transform = inverse_transform.to(device)

         return inverse_transform

 

     def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:

diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py

index b173ef2..fefb363 100644

--- a/cosyvoice/utils/file_utils.py

+++ b/cosyvoice/utils/file_utils.py

@@ -20,7 +20,7 @@ import torch

 import torchaudio

 import logging

 logging.getLogger('matplotlib').setLevel(logging.WARNING)

-logging.basicConfig(level=logging.DEBUG,

+logging.basicConfig(level=logging.INFO,

                     format='%(asctime)s %(levelname)s %(message)s')

 

 

@@ -42,8 +42,22 @@ def read_json_lists(list_file):

 

 

 def load_wav(wav, target_sr, min_sr=16000):

-    speech, sample_rate = torchaudio.load(wav, backend='soundfile')

-    speech = speech.mean(dim=0, keepdim=True)

+    if isinstance(wav, torch.Tensor):

+        speech = wav

+        sample_rate = 16000 # assumed

+    elif not isinstance(wav, str):

+        try:

+            speech = torch.as_tensor(wav)

+            sample_rate = 16000 # assumed

+        except:

+            raise TypeError(f"Expect path or Tensor, but got {type(wav)}")

+    else:

+        speech, sample_rate = torchaudio.load(wav, backend='soundfile')

+

+    if speech.ndim > 1:

+        speech = speech.mean(dim=0, keepdim=True)

+    elif speech.ndim == 1:

+        speech = speech.unsqueeze(0)

     if sample_rate != target_sr:

         assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)

         speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)

diff --git a/cosyvoice/utils/mask.py b/cosyvoice/utils/mask.py

index 5d3dfd6..179dc73 100644

--- a/cosyvoice/utils/mask.py

+++ b/cosyvoice/utils/mask.py

@@ -230,9 +230,8 @@ def add_optional_chunk_mask(xs: torch.Tensor,

     else:

         chunk_masks = masks

     assert chunk_masks.dtype == torch.bool

-    if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:

-        print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')

-        chunk_masks[chunk_masks.sum(dim=-1) == 0] = True

+    new_mask = (chunk_masks.sum(dim=-1, keepdim=True) == 0)

+    chunk_masks = torch.logical_or(chunk_masks, new_mask)

     return chunk_masks

 

 

diff --git a/diff.patch b/diff.patch

new file mode 100644

index 0000000..c4e5bfa

--- /dev/null

+++ b/diff.patch

@@ -0,0 +1,353 @@

+diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py

+index 7ab04a7..dcb743f 100644

+--- a/cosyvoice/cli/cosyvoice.py

++++ b/cosyvoice/cli/cosyvoice.py

+@@ -22,6 +22,7 @@ from cosyvoice.cli.frontend import CosyVoiceFrontEnd

+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model

+ from cosyvoice.utils.file_utils import logging

+ from cosyvoice.utils.class_utils import get_model_type

++from cosyvoice.utils.file_utils import load_wav

+ 

+ 

+ class CosyVoice:

+@@ -74,6 +75,11 @@ class CosyVoice:

+         self.frontend.spk2info[zero_shot_spk_id] = model_input

+         return True

+ 

++    def add_sft_spk(self, sft_spk_id):

++        model_input = self.frontend.frontend_from_spk2info(sft_spk_id)

++        self.frontend.spk2info_sft[sft_spk_id] = model_input

++        return True

++

+     def save_spkinfo(self):

+         torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))

+ 

+@@ -188,7 +194,8 @@ class CosyVoice2(CosyVoice):

+ 

+ class CosyVoice3(CosyVoice2):

+ 

+-    def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):

++    def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1,

++                 speaker_info_dir=None):

+         self.model_dir = model_dir

+         self.fp16 = fp16

+         if not os.path.exists(model_dir):

+@@ -224,6 +231,66 @@ class CosyVoice3(CosyVoice2):

+                                 self.fp16)

+         del configs

+ 

++        self.speaker_info_dir = speaker_info_dir or os.path.join(model_dir, 'speaker_info')

++        if not os.path.exists(self.speaker_info_dir):

++            logging.warning(f'speaker_info_dir {self.speaker_info_dir} does not exist')

++

++        self.promote_wave_info = {}

++        self._preload_all_wave_info()

++

++    def _preload_all_wave_info(self):

++        if not os.path.exists(self.speaker_info_dir):

++            return

++

++        import json

++

++        speaker_info_file = os.path.join(self.speaker_info_dir, 'speaker_info.json')

++        if not os.path.exists(speaker_info_file):

++            logging.warning(f'speaker_info.json not found in {self.speaker_info_dir}')

++            return

++

++        try:

++            with open(speaker_info_file, 'r', encoding='utf-8') as f:

++                all_speakers = json.load(f)

++

++            for spk_id, wave_info in all_speakers.items():

++                try:

++                    prompt_wav_path = os.path.join(self.speaker_info_dir, wave_info['prompt_wav'])

++                    if os.path.exists(prompt_wav_path):

++                        prompt_wav = load_wav(prompt_wav_path, 16000)

++                        wave_info['prompt_wav'] = prompt_wav

++                        self.promote_wave_info[spk_id] = wave_info

++                        logging.info(f"Loaded speaker {spk_id} from {prompt_wav_path}, promote={wave_info['prompt_text']}")

++                    else:

++                        logging.error(f'Audio file not found for speaker {spk_id}: {prompt_wav_path}')

++                except Exception as e:

++                    logging.error(f'Failed to load speaker {spk_id}: {e}')

++

++            logging.info(f'Preloaded {len(self.promote_wave_info)} speakers from {speaker_info_file}')

++

++        except Exception as e:

++            logging.error(f'Failed to load {speaker_info_file}: {e}')

++

++    def inference_zero_shot_by_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):

++        if spk_id not in self.promote_wave_info:

++            raise ValueError(f'Speaker ID {spk_id} not found. Available IDs: {list(self.promote_wave_info.keys())}')

++        wave_info = self.promote_wave_info[spk_id]

++

++        if spk_id not in self.frontend.spk2info:

++            self.add_zero_shot_spk(wave_info['prompt_text'], wave_info['prompt_wav'], spk_id)

++

++        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):

++            model_input = self.frontend.frontend_zero_shot(i, wave_info['prompt_text'], wave_info['prompt_wav'], self.sample_rate, spk_id)

++            start_time = time.time()

++            logging.info('synthesis text {}'.format(i))

++            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):

++                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate

++                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))

++                # 2.save audio

++                #res = torch.cat((res, model_output['tts_speech']), dim=1)

++                yield model_output

++                start_time = time.time()

++

+ 

+ def AutoModel(**kwargs):

+     if not os.path.exists(kwargs['model_dir']):

+diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py

+index 6d397cc..10b0a87 100644

+--- a/cosyvoice/cli/frontend.py

++++ b/cosyvoice/cli/frontend.py

+@@ -26,6 +26,7 @@ import inflect

+ from cosyvoice.utils.file_utils import logging, load_wav

+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation

+ 

++import threading

+ 

+ class CosyVoiceFrontEnd:

+ 

+@@ -50,6 +51,8 @@ class CosyVoiceFrontEnd:

+             self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)

+         else:

+             self.spk2info = {}

++        logging.info(f"spk2info file = {spk2info}, CosyVoiceFrontEnd.spk2info = {self.spk2info.keys()}")

++        self.spk2info_sft = {}

+         self.allowed_special = allowed_special

+         self.inflect_parser = inflect.engine()

+         # NOTE compatible when no text frontend tool is avaliable

+@@ -74,6 +77,12 @@ class CosyVoiceFrontEnd:

+                 self.text_frontend = ''

+                 logging.info('no frontend is avaliable')

+ 

++        self.lock = threading.Lock()

++

++    def tokenizer_encode_with_lock(self, *args, **kwargs):

++        with self.lock:

++            output = self.tokenizer.encode(*args, **kwargs)

++        return output

+ 

+     def _extract_text_token(self, text):

+         if isinstance(text, Generator):

+@@ -81,7 +90,7 @@ class CosyVoiceFrontEnd:

+             # NOTE add a dummy text_token_len for compatibility

+             return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)

+         else:

+-            text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)

++            text_token = self.tokenizer_encode_with_lock(text, allowed_special=self.allowed_special)

+             text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)

+             text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)

+             return text_token, text_token_len

+@@ -148,13 +157,13 @@ class CosyVoiceFrontEnd:

+                 text = text.replace(" - ", ",")

+                 text = remove_bracket(text)

+                 text = re.sub(r'[,,、]+$', '。', text)

+-                texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,

++                texts = list(split_paragraph(text, partial(self.tokenizer_encode_with_lock, allowed_special=self.allowed_special), "zh", token_max_n=80,

+                                              token_min_n=60, merge_len=20, comma_split=False))

+             else:

+                 if self.text_frontend == 'wetext':

+                     text = self.en_tn_model.normalize(text)

+                 text = spell_out_number(text, self.inflect_parser)

+-                texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,

++                texts = list(split_paragraph(text, partial(self.tokenizer_encode_with_lock, allowed_special=self.allowed_special), "en", token_max_n=80,

+                                              token_min_n=60, merge_len=20, comma_split=False))

+         texts = [i for i in texts if not is_only_punctuation(i)]

+         return texts if split is True else text

+@@ -222,3 +231,13 @@ class CosyVoiceFrontEnd:

+                        'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,

+                        'flow_embedding': embedding}

+         return model_input

++

++    def frontend_from_spk2info(self, spk_id):

++        spk_info = self.spk2info[spk_id]

++        embedding = spk_info['embedding'].to(self.device)

++        model_input = {

++            'llm_embedding': embedding,

++            'flow_embedding': embedding,

++        }

++        return model_input

++

+diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py

+index 92a15d9..f5362e1 100644

+--- a/cosyvoice/cli/model.py

++++ b/cosyvoice/cli/model.py

+@@ -100,7 +100,7 @@ class CosyVoiceModel:

+ 

+     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):

+         cur_silent_token_num, max_silent_token_num = 0, 5

+-        with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):

++        with torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):

+             if isinstance(text, Generator):

+                 assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'

+                 token_generator = self.llm.inference_bistream(text=text,

+@@ -282,9 +282,13 @@ class CosyVoice2Model(CosyVoiceModel):

+         export_cosyvoice2_vllm(self.llm, model_dir, self.device)

+         from vllm import EngineArgs, LLMEngine

+         engine_args = EngineArgs(model=model_dir,

++                                 dtype='float16',

+                                  skip_tokenizer_init=True,

+                                  enable_prompt_embeds=True,

+-                                 gpu_memory_utilization=0.2)

++                                 gpu_memory_utilization=0.5,

++                                 additional_config={"torchair_graph_config":{"enabled":True}},

++                                 compilation_config={'cudagraph_capture_sizes':[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,512,1024],"cudagraph_mode": "FULL"},

++                                 enable_prefix_caching=True)

+         self.llm.vllm = LLMEngine.from_engine_args(engine_args)

+         self.llm.lock = threading.Lock()

+         del self.llm.llm.model.model.layers

+diff --git a/cosyvoice/flow/DiT/modules.py b/cosyvoice/flow/DiT/modules.py

+index be8caec..d188914 100644

+--- a/cosyvoice/flow/DiT/modules.py

++++ b/cosyvoice/flow/DiT/modules.py

+@@ -18,7 +18,7 @@ import torch.nn.functional as F

+ import torchaudio

+ 

+ from x_transformers.x_transformers import apply_rotary_pos_emb

+-

++import torch_npu

+ 

+ # raw wav to mel spec

+ class MelSpec(nn.Module):

+@@ -84,7 +84,12 @@ class SinusPositionEmbedding(nn.Module):

+ 

+ 

+ # convolutional position embedding

++class Mish(nn.Module):

++    def __init__(self):

++        super().__init__()

+ 

++    def forward(self, x):

++        return x * torch.tanh(torch.log(1 + torch.exp(x)))

+ 

+ class ConvPositionEmbedding(nn.Module):

+     def __init__(self, dim, kernel_size=31, groups=16):

+@@ -92,9 +97,9 @@ class ConvPositionEmbedding(nn.Module):

+         assert kernel_size % 2 != 0

+         self.conv1d = nn.Sequential(

+             nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),

+-            nn.Mish(),

++            Mish(),

+             nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),

+-            nn.Mish(),

++            Mish(),

+         )

+ 

+     def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):  # noqa: F722

+@@ -119,11 +124,11 @@ class CausalConvPositionEmbedding(nn.Module):

+         self.kernel_size = kernel_size

+         self.conv1 = nn.Sequential(

+             nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),

+-            nn.Mish(),

++            Mish(),

+         )

+         self.conv2 = nn.Sequential(

+             nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),

+-            nn.Mish(),

++            Mish(),

+         )

+ 

+     def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):  # noqa: F722

+@@ -388,7 +393,17 @@ class AttnProcessor:

+         else:

+             attn_mask = None

+ 

+-        x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)

++        atten_mask_npu = torch.logical_not(attn_mask)

++        head_num = query.shape[1]

++        x = torch_npu.npu_fusion_attention(

++                       query, key, value, head_num, input_layout="BNSD",

++                       pse=None,

++                       atten_mask=atten_mask_npu,

++                       scale=1.0 / math.sqrt(query.shape[-1]),

++                       pre_tockens=2147483647,

++                       next_tockens=2147483647,

++                       keep_prob=1

++                   )[0]

+         x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

+         x = x.to(query.dtype)

+ 

+diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py

+index c255186..9a20c34 100644

+--- a/cosyvoice/flow/flow.py

++++ b/cosyvoice/flow/flow.py

+@@ -376,7 +376,8 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):

+                   prompt_feat_len,

+                   embedding,

+                   streaming,

+-                  finalize):

++                  finalize,

++                  n_timesteps=None):

+         assert token.shape[0] == 1

+         # xvec projection

+         embedding = F.normalize(embedding, dim=1)

+@@ -406,7 +407,7 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):

+             mask=mask.unsqueeze(1),

+             spks=embedding,

+             cond=conds,

+-            n_timesteps=10,

++            n_timesteps=10 if n_timesteps is None else n_timesteps,

+             streaming=streaming

+         )

+         feat = feat[:, :, mel_len1:]

+diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py

+index bbc2a21..0cedae6 100644

+--- a/cosyvoice/hifigan/generator.py

++++ b/cosyvoice/hifigan/generator.py

+@@ -498,10 +498,24 @@ class HiFTGenerator(nn.Module):

+ 

+     def _istft(self, magnitude, phase):

+         magnitude = torch.clip(magnitude, max=1e2)

++        device = magnitude.device

++        is_npu = device.type == 'npu'

+         real = magnitude * torch.cos(phase)

+         img = magnitude * torch.sin(phase)

+-        inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],

+-                                        self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))

++        if is_npu:

++            complex_tensor = torch.complex(real.cpu(), img.cpu())

++            window_xpu = self.stft_window.cpu()

++        else:

++            complex_tensor = torch.complex(real, img)

++            window_xpu = self.stft_window.to(device)

++

++        inverse_transform = torch.istft(complex_tensor,

++                                        self.istft_params["n_fft"],

++                                        self.istft_params["hop_len"],

++                                        self.istft_params["n_fft"],

++                                        window=window_xpu)

++        if is_npu:

++            inverse_transform = inverse_transform.to(device)

+         return inverse_transform

+ 

+     def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:

+diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py

+index b173ef2..fefb363 100644

+--- a/cosyvoice/utils/file_utils.py

++++ b/cosyvoice/utils/file_utils.py

+@@ -20,7 +20,7 @@ import torch

+ import torchaudio

+ import logging

+ logging.getLogger('matplotlib').setLevel(logging.WARNING)

+-logging.basicConfig(level=logging.DEBUG,

++logging.basicConfig(level=logging.INFO,

+                     format='%(asctime)s %(levelname)s %(message)s')

+ 

+ 

+@@ -42,8 +42,22 @@ def read_json_lists(list_file):

+ 

+ 

+ def load_wav(wav, target_sr, min_sr=16000):

+-    speech, sample_rate = torchaudio.load(wav, backend='soundfile')

+-    speech = speech.mean(dim=0, keepdim=True)

++    if isinstance(wav, torch.Tensor):

++        speech = wav

++        sample_rate = 16000 # assumed

++    elif not isinstance(wav, str):

++        try:

++            speech = torch.as_tensor(wav)

++            sample

\ No newline at end of file

diff --git a/runtime/python/grpc/batch_test_client.py b/runtime/python/grpc/batch_test_client.py

new file mode 100644

index 0000000..1bb9709

--- /dev/null

+++ b/runtime/python/grpc/batch_test_client.py

@@ -0,0 +1,249 @@

+import os

+import sys

+import logging

+import argparse

+import soundfile

+import cosyvoice_pb2

+import cosyvoice_pb2_grpc

+import grpc

+import torch

+import numpy as np

+import time

+import datetime

+sys.path.insert(0, '../../../')

+from cosyvoice.utils.file_utils import load_wav

+from concurrent.futures import ThreadPoolExecutor, as_completed

+from threading import Barrier, Lock

+

+logging.basicConfig(level=logging.INFO,

+                    format='%(asctime)s %(levelname)s %(message)s')

+

+TARGET_SR = 24000

+

+

+def percentile(data, p):

+    if len(data) == 0:

+        return 0

+    return np.percentile(data, p)

+

+

+def build_request(args, prompt_audio_bytes, tts_text):

+    request = cosyvoice_pb2.Request()

+

+    if args.mode == 'zero_shot':

+        zero_shot_request = cosyvoice_pb2.zeroshotRequest()

+        zero_shot_request.tts_text = tts_text

+        zero_shot_request.prompt_text = args.prompt_text

+        zero_shot_request.prompt_audio = prompt_audio_bytes

+        zero_shot_request.stream = args.stream

+        request.zero_shot_request.CopyFrom(zero_shot_request)

+    elif args.mode == 'zero_shot_by_id':

+        req = cosyvoice_pb2.zeroshotByIdRequest()

+        req.tts_text = tts_text

+        req.spk_id = args.spk_id

+        req.stream = args.stream

+        request.zero_shot_by_id_request.CopyFrom(req)

+    else:

+        sft_request = cosyvoice_pb2.sftRequest(

+            spk_id=args.spk_id,

+            tts_text=tts_text,

+            stream=args.stream

+        )

+        request.sft_request.CopyFrom(sft_request)

+

+    return request

+

+

+def single_request(task_id, args, stub, barrier, prompt_audio_bytes, tts_text, use_barrier):

+    try:

+        request = build_request(args, prompt_audio_bytes, tts_text)

+

+        if use_barrier and barrier is not None:

+            barrier.wait()

+

+        start_t = time.time()

+        response = stub.Inference(request)

+        first_packet_time = None

+        chunks = []

+        for r in response:

+            if first_packet_time is None:

+                first_packet_time = time.time()

+            chunks.append(np.frombuffer(r.tts_audio, dtype=np.int16))

+        end_t = time.time()

+

+        if len(chunks) == 0:

+            return {"success": False}

+

+        full_audio = np.concatenate(chunks)

+        audio_dur = len(full_audio) / TARGET_SR

+        total_time = end_t - start_t

+        first_fix = (first_packet_time - start_t) * 1000 if first_packet_time else 0

+        rtf = total_time / audio_dur if audio_dur > 0 else 0

+        # logging.info(f"start = {datetime.datetime.fromtimestamp(start_t):%H:%M:%S.%f}, "

+        #       f"first = {datetime.datetime.fromtimestamp(first_packet_time):%H:%M:%S.%f}, "

+        #       f"end = {datetime.datetime.fromtimestamp(end_t):%H:%M:%S.%f}, "

+        #       f"total_time = {total_time:.3f}, "

+        #       f"audio_dur = {audio_dur:.3f}, "

+        #       f"first = {first_fix:.3f}, "

+        #       f"rtf = {rtf:.3f}")

+

+        if args.save_task_id is not None and (task_id == args.save_task_id or args.save_task_id == -1):

+            os.makedirs(args.save_audio_dir, exist_ok=True)

+            save_path = os.path.join(

+                args.save_audio_dir,

+                f"task_{task_id}.wav"

+            )

+            soundfile.write(save_path, full_audio, TARGET_SR)

+            logging.info(f"Audio saved: {save_path}")

+        return {

+            "success": True,

+            "first_fix": first_fix,

+            "rtf": rtf,

+            "audio_dur": audio_dur

+        }

+

+    except Exception as e:

+        logging.error(f"Task {task_id} failed: {str(e)}")

+        return {"success": False}

+

+

+def main():

+    parser = argparse.ArgumentParser()

+    parser.add_argument('--host', type=str, default='0.0.0.0')

+    parser.add_argument('--port', type=int, default=50000)

+    parser.add_argument('--mode',

+                        default='sft',

+                        choices=['sft', 'zero_shot', 'zero_shot_by_id'])

+    parser.add_argument('--tts_text', type=str,

+                        default='旅游是一种很好的放松方式,可以让我们暂时远离繁忙的工作,享受大自然的美好与宁静。')

+    parser.add_argument('--spk_id', type=str, default='spk1')

+    parser.add_argument('--prompt_text', type=str,

+                        default='You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。')

+    parser.add_argument('--prompt_wav', type=str,

+                        default='/data/code/CosyVoice/asset/zero_shot_prompt.wav')

+    parser.add_argument('--stream', action='store_true')

+    parser.add_argument('--total_req', type=int, default=100)

+    parser.add_argument('--concurrency', type=int, default=8)

+    parser.add_argument('--save_task_id', type=int, default=None)

+    parser.add_argument('--save_audio_dir', type=str, default='./saved_audio')

+    parser.add_argument('--text_file', type=str, default=None,

+                        help='文本文件路径,每行一个句子,用于多样化测试')

+    parser.add_argument('--sync_start', action='store_true',

+                        help='启用同步启动模式:等待所有并发线程就绪后同时发起请求(会降低吞吐量)')

+    args = parser.parse_args()

+

+    logging.info(f"Start benchmark total = {args.total_req} concurrency = {args.concurrency}")

+

+    text_list = []

+    if args.text_file and os.path.exists(args.text_file):

+        with open(args.text_file, 'r', encoding='utf-8') as f:

+            text_list = [line.strip() for line in f if line.strip()]

+        logging.info(f"Loaded {len(text_list)} texts from {args.text_file}")

+

+    if not text_list:

+        text_list = [args.tts_text]

+        logging.info("Using default text with task_id suffix for diversity")

+

+    prompt_audio_bytes = None

+

+    if args.mode == "zero_shot":

+        prompt_speech = load_wav(args.prompt_wav, 16000)

+        prompt_audio_bytes = (

+            (prompt_speech.numpy() * (2**15))

+            .astype(np.int16)

+            .tobytes()

+        )

+

+    options = [

+        ("grpc.keepalive_time_ms", 120000),

+        ("grpc.keepalive_timeout_ms", 60000),

+        ("grpc.http2.max_pings_without_data", 0),

+        ("grpc.keepalive_permit_without_calls", 1),

+        ("grpc.enable_http_proxy", 0),

+    ]

+

+    channel = grpc.insecure_channel(

+        f"{args.host}:{args.port}",

+        options=options

+    )

+

+    stub = cosyvoice_pb2_grpc.CosyVoiceStub(channel)

+

+    barrier = Barrier(args.concurrency) if args.sync_start else None

+

+    if args.sync_start:

+        logging.info("sync mode")

+    else:

+        logging.info("async mode")

+

+    results = []

+    progress = 0

+    progress_lock = Lock()

+    start_bench = time.time()

+

+    with ThreadPoolExecutor(max_workers=args.concurrency) as executor:

+        futures = []

+        for i in range(args.total_req):

+            if len(text_list) == 1:

+                tts_text = f"{text_list[0]}第{i+1}句。"

+            else:

+                tts_text = text_list[i % len(text_list)]

+

+            future = executor.submit(

+                single_request,

+                i,

+                args,

+                stub,

+                barrier,

+                prompt_audio_bytes,

+                tts_text,

+                args.sync_start

+            )

+            futures.append(future)

+

+        for future in as_completed(futures):

+            r = future.result()

+            results.append(r)

+            with progress_lock:

+                progress += 1

+                percent = progress / args.total_req * 100

+                logging.info(f"Progress: {progress}/{args.total_req} ({percent:.1f}%)")

+

+    end_bench = time.time()

+    success_results = [r for r in results if r["success"]]

+

+    if not success_results:

+        logging.info("All requests failed.")

+        return

+

+    first_fix_list = [r["first_fix"] for r in success_results]

+    rtf_list = [r["rtf"] for r in success_results]

+    audio_dur_list = [r["audio_dur"] for r in success_results]

+    total_audio_dur = sum(audio_dur_list)

+    throughput = total_audio_dur / (end_bench - start_bench)

+    avg_first_fix = sum(r['first_fix'] for r in success_results) / len(success_results)

+    avg_rtf = sum(r['rtf'] for r in success_results) / len(success_results)

+

+    logging.info("=" * 64)

+    logging.info("SUMMARY:")

+    logging.info(f"success ratio: {len(success_results)}/{args.total_req}")

+    logging.info(f"duration: {end_bench-start_bench:.2f}s")

+    logging.info(f"throughput: {throughput:.2f} seconds of audio / s")

+    logging.info("=" * 64)

+    logging.info(f"AVG First Latency: {avg_first_fix:.0f}")

+    logging.info(f"P50: {percentile(first_fix_list,50):.0f}")

+    logging.info(f"P70: {percentile(first_fix_list,70):.0f}")

+    logging.info(f"P80: {percentile(first_fix_list,80):.0f}")

+    logging.info(f"P90: {percentile(first_fix_list,90):.0f}")

+    logging.info(f"P99: {percentile(first_fix_list,99):.0f}")

+    logging.info("=" * 64)

+    logging.info(f"AVG RTF: {avg_rtf:.2f}")

+    logging.info(f"P50: {percentile(rtf_list,50):.2f}")

+    logging.info(f"P70: {percentile(rtf_list,70):.2f}")

+    logging.info(f"P80: {percentile(rtf_list,80):.2f}")

+    logging.info(f"P90: {percentile(rtf_list,90):.2f}")

+    logging.info(f"P99: {percentile(rtf_list,99):.2f}")

+    logging.info("="*20)

+

+if __name__ == "__main__":

+    main()

\ No newline at end of file

diff --git a/runtime/python/grpc/client.py b/runtime/python/grpc/client.py

index 9885130..b825434 100644

--- a/runtime/python/grpc/client.py

+++ b/runtime/python/grpc/client.py

@@ -44,6 +44,7 @@ def main():

             zero_shot_request.prompt_text = args.prompt_text

             prompt_speech = load_wav(args.prompt_wav, 16000)

             zero_shot_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()

+            zero_shot_request.stream = args.stream

             request.zero_shot_request.CopyFrom(zero_shot_request)

         elif args.mode == 'cross_lingual':

             logging.info('send cross_lingual request')

@@ -52,6 +53,13 @@ def main():

             prompt_speech = load_wav(args.prompt_wav, 16000)

             cross_lingual_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()

             request.cross_lingual_request.CopyFrom(cross_lingual_request)

+        elif args.mode == 'zero_shot_by_id':

+            logging.info('send zero_shot_by_id request')

+            zero_shot_by_id_request = cosyvoice_pb2.zeroshotByIdRequest()

+            zero_shot_by_id_request.tts_text = args.tts_text

+            zero_shot_by_id_request.spk_id = args.spk_id

+            zero_shot_by_id_request.stream = args.stream

+            request.zero_shot_by_id_request.CopyFrom(zero_shot_by_id_request)

         else:

             logging.info('send instruct request')

             instruct_request = cosyvoice_pb2.instructRequest()

@@ -80,7 +88,7 @@ if __name__ == "__main__":

                         default='50000')

     parser.add_argument('--mode',

                         default='sft',

-                        choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],

+                        choices=['sft', 'zero_shot', 'cross_lingual', 'instruct', 'zero_shot_by_id'],

                         help='request mode')

     parser.add_argument('--tts_text',

                         type=str,

@@ -101,6 +109,9 @@ if __name__ == "__main__":

     parser.add_argument('--tts_wav',

                         type=str,

                         default='demo.wav')

+    parser.add_argument('--stream', 

+                        action='store_true', 

+                        help='whether to use streaming inference')

     args = parser.parse_args()

-    prompt_sr, target_sr = 16000, 22050

+    prompt_sr, target_sr = 16000, 24000

     main()

diff --git a/runtime/python/grpc/cosyvoice.proto b/runtime/python/grpc/cosyvoice.proto

index fe0c3ad..3feeea4 100644

--- a/runtime/python/grpc/cosyvoice.proto

+++ b/runtime/python/grpc/cosyvoice.proto

@@ -13,18 +13,21 @@ message Request{

     zeroshotRequest zero_shot_request = 2;

     crosslingualRequest cross_lingual_request = 3;

     instructRequest instruct_request = 4;

+    zeroshotByIdRequest zero_shot_by_id_request = 5;

   }

 }

 

 message sftRequest{

   string spk_id = 1;

   string tts_text = 2;

+  bool stream = 3;

 }

 

 message zeroshotRequest{

   string tts_text = 1;

   string prompt_text = 2;

   bytes prompt_audio = 3;

+  bool stream = 4; 

 }

 

 message crosslingualRequest{

@@ -38,6 +41,12 @@ message instructRequest{

   string instruct_text = 3;

 }

 

+message zeroshotByIdRequest{

+  string tts_text = 1;

+  string spk_id = 2;

+  bool stream = 3;

+}

+

 message Response{

   bytes tts_audio = 1;

 }

\ No newline at end of file

diff --git a/runtime/python/grpc/run_server.sh b/runtime/python/grpc/run_server.sh

new file mode 100644

index 0000000..3c66d30

--- /dev/null

+++ b/runtime/python/grpc/run_server.sh

@@ -0,0 +1,15 @@

+#!/bin/bash

+

+export VLLM_WORKER_MULTIPROC_METHOD=spawn

+export ASCEND_RT_VISIBLE_DEVICES=0

+export TASK_QUEUE_ENABLE=1

+export CPU_AFFINITY_CONF=1

+export VLLM_ASCEND_ENABLE_MLP_OPTIMIZE=1

+export VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION=1

+export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so.2:$LD_PRELOAD

+

+# 配置参数

+MODEL_DIR="/home/Fun-CosyVoice3-0.5B-2512/"

+MAX_CONC=8

+

+python server.py --port 50099 --model_dir "$MODEL_DIR" --max_conc "$MAX_CONC" --graph_mode

diff --git a/runtime/python/grpc/server.py b/runtime/python/grpc/server.py

index 28ecc19..c0f3252 100644

--- a/runtime/python/grpc/server.py

+++ b/runtime/python/grpc/server.py

@@ -27,49 +27,952 @@ sys.path.append('{}/../../..'.format(ROOT_DIR))

 sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))

 from cosyvoice.cli.cosyvoice import AutoModel

 

-logging.basicConfig(level=logging.DEBUG,

+logging.basicConfig(level=logging.INFO,

                     format='%(asctime)s %(levelname)s %(message)s')

 

+logging.info("Initializing NPU environment...")

+try:

+    import torch_npu

+    from torch_npu.contrib import transfer_to_npu

+    torch_npu.npu.set_compile_mode(jit_compile=False)

+    torch.npu.config.allow_internal_format = False

+    logging.info("NPU initialized successfully.")

+except ImportError:

+    logging.warning("NPU libraries not found. Using CPU.")

+

+import threading

+import time

+import uuid

+import vllm

+import queue

+import torchair as tng

+from torchair.configs.compiler_config import CompilerConfig

+import datetime

+import json

+import gc

+from typing import Optional

+from enum import Enum

+from dataclasses import dataclass, field

+

+from hyperpyyaml import load_hyperpyyaml

+from cosyvoice.utils.file_utils import load_wav

+from cosyvoice.cli.frontend import CosyVoiceFrontEnd

+

+

+class ThcSectionStatus(str, Enum):

+    LLM_DOING = 'llm_doing'

+    T2W_WAITING = 't2w_waiting'

+    T2W_DOING = 't2w_doing'

+    DONE = 'done'

+

+@dataclass

+class ThcSection:

+    idx: tuple

+    mode: str

+    spk_id: str

+    first: bool

+    finalize: bool

+    tokens: list

+    token_begin_pos: int

+    token_end_pos: int

+    token_begin_pos_in_batch: Optional[int] = None

+    token_end_pos_in_batch: Optional[int] = None

+    flow_begin_pos_in_batch: Optional[int] = None

+    flow_end_pos_in_batch: Optional[int] = None

+    flow_cache: Optional[torch.Tensor] = None

+    speech_begin_pos_in_batch: Optional[int] = None

+    speech_end_pos_in_batch: Optional[int] = None

+    speech: Optional[torch.Tensor] = None

+    speech_cpu: Optional[np.array] = None

+    speech_dt: Optional[float] = None

+    status: ThcSectionStatus = ThcSectionStatus.LLM_DOING

+

+class ThcMetric:

+    def __init__(self, begin=None):

+        self.begin: Optional[datetime.datetime] = begin

+        self.end: Optional[datetime.datetime] = None

+        self.dt: Optional[float] = None

+

+    def start(self):

+        self.begin = self.begin or datetime.datetime.now()

+

+    def stop(self):

+        self.end = self.end or datetime.datetime.now()

+        self.dt = (self.end - self.begin).total_seconds()

+

+@dataclass

+class ThcAllMetric:

+    e2e: ThcMetric = field(default_factory=ThcMetric)

+    llm: ThcMetric = field(default_factory=ThcMetric)

+    t2w: list = field(default_factory=list)

+

+@dataclass

+class ThcRequestInfo:

+    idx: int

+    request_id: str

+    mode: str

+    spk_id: str

+    first: bool = False

+    finalize: bool = False

+    done: bool = False

+    llm_output: list = field(default_factory=list)

+    llm_done: bool = False

+    section_queue: queue.Queue = field(default_factory=queue.Queue)

+    flow_cache: Optional[torch.Tensor] = None

+    speech: Optional[torch.Tensor] = None

+    speech_cum_dt: float = 0.0

+    dynamic_rtf: Optional[float] = None

+    dynamic_slack: Optional[float] = None

+    release_dt: Optional[float] = None

+    sections: list = field(default_factory=list)

+    release: queue.Queue = field(default_factory=queue.Queue)

+    metric: ThcAllMetric = field(default_factory=ThcAllMetric)

+

+@dataclass

+class ThcT2WBatchInput:

+    batch_requests: list = field(default_factory=ThcSection)

+    tokens: Optional[torch.Tensor] = None

+    token_len: Optional[torch.Tensor] = None

+    flow_prompt_speech_token: Optional[torch.Tensor] = None

+    flow_prompt_speech_token_len: Optional[torch.Tensor] = None

+    prompt_speech_feat: Optional[torch.Tensor] = None

+    prompt_speech_feat_len: Optional[torch.Tensor] = None

+    flow_embedding: Optional[torch.Tensor] = None

+    stream: bool = False

+    finalize: bool = False

+    n_timesteps: int = 10

+

+

+class Token2Wave():

+    def __init__(self,

+                 device,

+                 fp16: bool = False,

+                 model_dir: str = None,

+                 speaker_info_dir: str = None,

+                 graph_mode: bool = False):

+        hyper_yaml_path = os.path.join(model_dir, 'cosyvoice3.yaml')

+        with open(hyper_yaml_path, 'r') as f:

+            configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})

+

+        self.device = device

+        self.flow_len_per_token = 2

+        self.speech_len_per_token = 960

+        self.speech_sample_rate = configs['sample_rate']

+        self.token_sample_rate = configs['sample_rate'] / self.speech_len_per_token

+

+        self.model_dir = model_dir

+        self.graph_mode = graph_mode

+

+        frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],

+                                     configs['feat_extractor'],

+                                     os.path.join(model_dir, 'campplus.onnx'),

+                                     os.path.join(model_dir, 'speech_tokenizer_v3.onnx'),

+                                     os.path.join(model_dir, 'spk2info.pt'),

+                                     configs['allowed_special'])

+        sample_rate = configs['sample_rate']

+

+        flow_ckpt_file = os.path.join(model_dir, 'flow.pt')

+        hift_ckpt_file = os.path.join(model_dir, 'hift.pt')

+

+        self.flow = configs['flow'].to(self.device)

+        self.flow.load_state_dict(torch.load(flow_ckpt_file, map_location=self.device, weights_only=True), strict=True)

+        self.flow.eval()

+        # in case hift_model is a hifigan model

+        self.hift = configs['hift'].to(self.device)

+        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_ckpt_file, map_location=self.device, weights_only=True).items()}

+        self.hift.load_state_dict(hift_state_dict, strict=True)

+        self.hift.eval()

+

+        self.prompt_wave_info = None

+

+        self.lookahead_token_len = self.flow.pre_lookahead_len

+        self.pre_gap_token_len = self.flow.pre_lookahead_len * 4

+        self.post_gap_token_len = self.flow.pre_lookahead_len * 1

+

+        self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]

+

+        self.perfs = []

+

+        self.batches_requests = queue.Queue()

+        self.inference_event = threading.Event()

+        self.lock = threading.Lock()

+

+        del configs

+

+        if graph_mode:

+            config = CompilerConfig()

+            npu_backend = tng.get_npu_backend(compiler_config=config)

+            self.flow.decoder.estimator.forward = torch.compile(

+                self.flow.decoder.estimator.forward,

+                dynamic=True,

+                fullgraph=True,

+                backend=npu_backend

+            )

+            logging.info("Token2Wave: compile graph success")

+

+    def add_requests(self, batch_requests):

+        self.batches_requests.put(batch_requests)

+        self.inference_event.set()

+

+    def prepare_batch_data(self, batch_requests):

+        merged_tokens = []

+        mode = None

+        spk_id = None

+        pos = 0

+        for request in batch_requests:

+            mode = mode or request.mode

+            assert request.mode == mode

+            spk_id = spk_id or request.spk_id

+            assert request.spk_id == spk_id

+

+            merged_tokens.extend(request.tokens)

+            request.token_begin_pos_in_batch = pos + request.token_begin_pos

+            request.token_end_pos_in_batch = pos + request.token_end_pos

+            pos += request.token_end_pos

+            pos += self.post_gap_token_len

+            request.flow_begin_pos_in_batch = int(request.token_begin_pos_in_batch * self.flow_len_per_token)

+            request.flow_end_pos_in_batch = int(request.token_end_pos_in_batch * self.flow_len_per_token)

+            request.status = ThcSectionStatus.T2W_DOING

+            logging.debug(f"Token2Wave.prepare_batch_data: {request.idx}, "

+                          f"merged_tokens = {len(merged_tokens)}, "

+                          f"start_pos_in_batch = {request.token_begin_pos_in_batch}, end_pos_in_batch = {request.token_end_pos_in_batch}, "

+                          f"pos = {pos}")

+        assert pos == len(merged_tokens), f"pos = {pos}, len(merged_tokens) = {len(merged_tokens)}"

+

+        tokens = torch.tensor(merged_tokens, dtype=torch.int32, device=self.device).unsqueeze(dim=0)

+        token_len = torch.tensor([tokens.shape[1]], dtype=torch.int32, device=self.device)

+

+        if mode == 'zero_shot_by_id':

+            flow_prompt_speech_token = self.prompt_wave_info[spk_id]['flow_prompt_speech_token']

+            flow_prompt_speech_token_len = self.prompt_wave_info[spk_id]['flow_prompt_speech_token_len']

+            prompt_speech_feat = self.prompt_wave_info[spk_id]['prompt_speech_feat']

+            prompt_speech_feat_len = self.prompt_wave_info[spk_id]['prompt_speech_feat_len']

+        elif mode == 'sft':

+            flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32, device=self.device)

+            flow_prompt_speech_token_len = torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32, device=self.device)

+            prompt_speech_feat = torch.zeros(1, 0, 80, device=self.device)

+            prompt_speech_feat_len = torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32, device=self.device)

+

+        flow_embedding = self.prompt_wave_info[spk_id]['flow_embedding']

+        stream = True

+        finalize = False

+        n_timesteps = 5 if any([request.first for request in batch_requests]) else 10

+

+        batch_data = ThcT2WBatchInput(batch_requests=batch_requests,

+                                      tokens=tokens,

+                                      token_len=token_len,

+                                      flow_prompt_speech_token=flow_prompt_speech_token,

+                                      flow_prompt_speech_token_len=flow_prompt_speech_token_len,

+                                      prompt_speech_feat=prompt_speech_feat,

+                                      prompt_speech_feat_len=prompt_speech_feat_len,

+                                      flow_embedding=flow_embedding,

+                                      stream=stream,

+                                      finalize=finalize,

+                                      n_timesteps=n_timesteps)

+

+        return batch_data

+

+    def step(self):

+        # modify if enable multi_threads

+        batch_requests = self.batches_requests.get_nowait()

+

+        metric = ThcMetric()

+        metric.start()

+        logging.debug(f"Token2Wave.infer: start")

+

+        batch_data = self.prepare_batch_data(batch_requests)

+

+        tts_mel, _ = self.flow.inference(token=batch_data.tokens,

+                                         token_len=batch_data.token_len,

+                                         prompt_token=batch_data.flow_prompt_speech_token,

+                                         prompt_token_len=batch_data.flow_prompt_speech_token_len,

+                                         prompt_feat=batch_data.prompt_speech_feat,

+                                         prompt_feat_len=batch_data.prompt_speech_feat_len,

+                                         embedding=batch_data.flow_embedding,

+                                         streaming=batch_data.stream,

+                                         finalize=batch_data.finalize,

+                                         n_timesteps=batch_data.n_timesteps)

+

+        logging.debug(f"tts_mel = {tts_mel.shape}")

+        assert tts_mel.shape[-1] == (batch_data.tokens.shape[-1] - self.lookahead_token_len) * 2

+

+        logging.debug(f"self.hift.f0_predictor.condnet[0].causal_padding = {self.hift.f0_predictor.condnet[0].causal_padding}")

+        prefix_flow_len = 16 * self.flow_len_per_token

+        token_offset = 4

+

+        speech_pos = 0

+        merged_tts_mel = []

+        for request in batch_data.batch_requests:

+            request_tts_mel = tts_mel[..., request.flow_begin_pos_in_batch : request.flow_end_pos_in_batch]

+            speech_len = (request.token_end_pos - request.token_begin_pos) * self.speech_len_per_token

+            assert request.first == (request.flow_cache is None)

+            if request.first:

+                flow_prefix = request_tts_mel.new_zeros([*request_tts_mel.shape[: -1], prefix_flow_len])

+                request.flow_cache = request_tts_mel

+                request.speech_begin_pos_in_batch = speech_pos + flow_prefix.shape[-1] // self.flow_len_per_token * self.speech_len_per_token

+                speech_len -= token_offset * self.speech_len_per_token

+            else:

+                flow_prefix = request.flow_cache[..., -prefix_flow_len :]

+                request.flow_cache = torch.cat([request.flow_cache, request_tts_mel], dim=-1)

+                assert flow_prefix.shape[-1] >= token_offset * self.flow_len_per_token, f"request.flow_cache = {request.flow_cache.shape}, flow_prefix = {flow_prefix.shape}"

+                request.speech_begin_pos_in_batch = speech_pos + (flow_prefix.shape[-1] // self.flow_len_per_token - token_offset) * self.speech_len_per_token

+

+            if request.finalize:

+                flow_postfix = request_tts_mel.new_zeros([*request_tts_mel.shape[: -1], prefix_flow_len])

+                request_tts_mel = torch.cat([flow_prefix, request_tts_mel, flow_postfix], dim=-1)

+                speech_len += token_offset * self.speech_len_per_token

+            else:

+                request_tts_mel = torch.cat([flow_prefix, request_tts_mel], dim=-1)

+            request.speech_end_pos_in_batch = request.speech_begin_pos_in_batch + speech_len

+            merged_tts_mel.append(request_tts_mel)

+            speech_pos += request_tts_mel.shape[-1] // self.flow_len_per_token * self.speech_len_per_token

+            logging.debug(f"Token2Wave.step: {request.idx}, finalize = {request.finalize}, "

+                          f"speech_begin_pos_in_batch = {request.speech_begin_pos_in_batch}, speech_end_pos_in_batch = {request.speech_end_pos_in_batch}, "

+                          f"speech_pos = {speech_pos}")

+        merged_tts_mel = torch.cat(merged_tts_mel, dim=-1)

+

+        tts_speech, _ = self.hift.inference(speech_feat=merged_tts_mel, finalize=True)

+        tts_speech_cpu = (tts_speech * (2 ** 15)).to(torch.int16).view(-1).cpu().numpy()

+        logging.debug(f"merged_tts_mel = {merged_tts_mel.shape}, tts_speech_cpu = {tts_speech_cpu.shape}")

+

+        for request in batch_data.batch_requests:

+            request.speech = tts_speech[:, request.speech_begin_pos_in_batch : request.speech_end_pos_in_batch]

+            request.speech_cpu = tts_speech_cpu[request.speech_begin_pos_in_batch : request.speech_end_pos_in_batch]

+            assert request.speech_cpu.shape[0] == request.speech_end_pos_in_batch - request.speech_begin_pos_in_batch, f"{request.speech_cpu.shape[0]} != {request.speech_end_pos_in_batch} - {request.speech_begin_pos_in_batch}"

+            request.speech_dt = request.speech_cpu.shape[0] / self.speech_sample_rate

+            request.status = ThcSectionStatus.DONE

+

+        metric.stop()

+

+        perf = {'token_len': batch_data.token_len[0].item(), 'tts_speech_shape': list(tts_speech.shape), 'n_timesteps': batch_data.n_timesteps, 'metric': metric}

+        self.perfs.append(perf)

+        logging.debug(f"Token2Wave.infer: end: token_len = {perf['token_len']}, tts_speech_shape = {perf['tts_speech_shape']}, dt = {metric.dt:.3f}")

+

+        return batch_data.batch_requests

 

 class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):

     def __init__(self, args):

-        self.cosyvoice = AutoModel(model_dir=args.model_dir)

-        logging.info('grpc service initialized')

+        self.args = args

+

+        self.cosyvoice = AutoModel(

+            model_dir=self.args.model_dir,

+            load_vllm=True,

+            speaker_info_dir=self.args.speaker_info_dir,

+        )

+        logging.info(f'grpc service initialized, graph_mode={self.args.graph_mode}')

+

+        # taohouchao

+        self.prompt_emb = {}

+        for spk_id, spk_info in self.cosyvoice.frontend.spk2info.items():

+            if spk_id not in self.cosyvoice.frontend.spk2info_sft:

+                self.cosyvoice.add_sft_spk(spk_id)

+            model_input = self.cosyvoice.frontend.spk2info_sft[spk_id]

+            assert self.prompt_emb.get(spk_id) is None

+            self.prompt_emb[spk_id] = {

+                'sos_emb': self.cosyvoice.model.llm.speech_embedding.weight[self.cosyvoice.model.llm.sos].reshape(1, 1, -1),

+                'task_id_emb': self.cosyvoice.model.llm.speech_embedding.weight[self.cosyvoice.model.llm.task_id].reshape(1, 1, -1),

+            }

+

+        for spk_id, wave_info in self.cosyvoice.promote_wave_info.items():

+            if spk_id not in self.cosyvoice.frontend.spk2info:

+                self.cosyvoice.add_zero_shot_spk(wave_info['prompt_text'], wave_info['prompt_wav'], spk_id)

+            model_input = self.cosyvoice.frontend.spk2info[spk_id]

+            prompt_text = model_input['prompt_text']

+            llm_prompt_speech_token = model_input['llm_prompt_speech_token']

+            assert self.prompt_emb.get(spk_id) is None

+            self.prompt_emb[spk_id] = {

+                'sos_emb': self.cosyvoice.model.llm.speech_embedding.weight[self.cosyvoice.model.llm.sos].reshape(1, 1, -1),

+                'task_id_emb': self.cosyvoice.model.llm.speech_embedding.weight[self.cosyvoice.model.llm.task_id].reshape(1, 1, -1),

+                'prompt_text_emb': self.cosyvoice.model.llm.llm.model.model.embed_tokens(prompt_text),

+                'prompt_speech_token_emb': self.cosyvoice.model.llm.speech_embedding(llm_prompt_speech_token),

+            }

+

+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

+        self.requests_info = {}

+        self.requests_id = [None for _ in range(self.args.max_conc)]

+

+        self.lock = threading.Lock()

+        self.t2w = Token2Wave(device=self.device,

+                              fp16=False,

+                              model_dir=self.args.model_dir,

+                              speaker_info_dir=self.args.speaker_info_dir,

+                              graph_mode=self.args.graph_mode)

+        self.t2w.prompt_wave_info = {**self.cosyvoice.frontend.spk2info, **self.cosyvoice.frontend.spk2info_sft}

+        self.t2w_first_token_len = 20 + self.t2w.post_gap_token_len

+        self.t2w_min_token_len = 25

+        self.t2w_total_token_len = {True: 800, False: 300}

+        self.t2w_delay = 0.2

+        self.release_timeout = 1

+        self.prev_release_time = None

+

+        # modify if enable multi_threads

+        # self.t2w_inference_work_thread = threading.Thread(target=self.t2w.inference, daemon=True)

+        # self.t2w_inference_work_thread.start()

+

+        self.inference_event = threading.Event()

+        self.inference_work_thread = threading.Thread(target=self.infer, daemon=True)

+        self.inference_work_thread.start()

+

+        if self.args.enable_vllm_profile:

+            assert False

+            try:

+                self.cosyvoice.model.llm.vllm.start_profile()

+                logging.info('vLLM profile started')

+            except Exception as e:

+                logging.warning(f'Failed to start vLLM profile: {e}')

+

+    def postprocess_vllm_outputs(self, llm_step_outputs, now):

+        with self.lock:

+            abort_requests_id = []

+            for output in llm_step_outputs:

+                request_id = output.request_id

+                token_ids = output.outputs[0].token_ids

+                request_info = self.requests_info.get(request_id)

+                if request_info is None:

+                    abort_requests_id.append(request_id)

+                    continue

+                request_info.llm_output = token_ids[: -1] if token_ids[-1] in self.cosyvoice.model.llm.stop_token_ids else token_ids

+                request_info.llm_done = output.finished

+                if request_info.llm_done:

+                    request_info.metric.llm.stop()

+

+            if len(abort_requests_id) > 0:

+                logging.info(f"postprocess_vllm_outputs: abort_requests_id = {abort_requests_id}")

+                self.cosyvoice.model.llm.vllm.abort_request(abort_requests_id)

+

+    def schedule_nonfirst_requests_t2w(self, now):

+        requests_history_token_len = {}

+        requests_remain_token_len = {}

+        requests_slack = {}

+

+        for request_id, request_info in self.requests_info.items():

+            if len(request_info.sections) == 0 or any([section.status != ThcSectionStatus.DONE for section in request_info.sections]):

+                continue

+            history_token_len = sum([section.token_end_pos - section.token_begin_pos for section in request_info.sections])

+            requests_history_token_len[request_id] = history_token_len

+            new_token_len = len(request_info.llm_output) - history_token_len

+            if new_token_len == 0:

+                continue

+            requests_remain_token_len[request_id] = new_token_len

+            calc_cum_dt = (now - request_info.metric.t2w[0].end).total_seconds()

+            requests_slack[request_id] = request_info.speech_cum_dt - calc_cum_dt

+

+        flag = [len(request_info.sections) >= 1 for request_info in self.requests_info.values()]

+        available_token_len = self.t2w_total_token_len[len(flag) == self.args.max_conc and all(flag)]

+        t2w_min_token_len = max(self.t2w_min_token_len, available_token_len // len(requests_remain_token_len) - (self.t2w.pre_gap_token_len + self.t2w.post_gap_token_len))

+

+        flag = [token_len >= t2w_min_token_len or self.requests_info[request_id].llm_done for request_id, token_len in requests_remain_token_len.items()]

+        if len(flag) == 0 or all(flag) == False:

+            return {}

+

+        requests_priority = {}

+        requests_t2w_sched = {}

+        for request_id in requests_slack.keys():

+            requests_priority[request_id] = int((max(requests_slack.values()) - requests_slack[request_id]) * self.t2w.token_sample_rate)

+            requests_t2w_sched[request_id] = {'history_len': requests_history_token_len[request_id], 'alloc_len': 0, 'full': None}

+

+        estimate_token_len_flag = False

+

+        if estimate_token_len_flag:

+            estimate_t2w_calc_dt = lambda v: int(212 * (v ** 2) + 646.7 * v - 153.5)

+            expected_t2w_calc_dt = min(1.0, min(requests_slack.values()))

+            available_token_len = max(100, estimate_t2w_calc_dt(expected_t2w_calc_dt))

+

+        brief = {self.requests_info[request_id].idx: {'slack': round(requests_slack[request_id], 3), 'remain': requests_remain_token_len[request_id], 'sections': len(self.requests_info[request_id].sections)} for request_id in requests_slack.keys()}

+        logging.debug(f"available_token_len = {available_token_len}, {brief}")

+        logging.info(f"schedule_nonfirst_requests_t2w: available_token_len = {available_token_len}")

+

+        for request_id in requests_remain_token_len.keys():

+            request_info = self.requests_info[request_id]

+            assert requests_t2w_sched[request_id]['alloc_len'] == 0

+            alloc_token_len = min(requests_remain_token_len[request_id], t2w_min_token_len)

+            alloc_token_len_with_overhead = alloc_token_len + self.t2w.pre_gap_token_len + self.t2w.post_gap_token_len

+            available_token_len -= alloc_token_len_with_overhead

+            requests_remain_token_len[request_id] -= alloc_token_len

+            requests_priority[request_id] -= alloc_token_len

+            requests_t2w_sched[request_id]['alloc_len'] += alloc_token_len

+            requests_t2w_sched[request_id]['full'] = requests_remain_token_len[request_id] == 0

+            assert requests_remain_token_len[request_id] >= 0

+

+        while available_token_len > 0 and max(requests_remain_token_len.values()) > 0:

+            requests_remain_token_len = {request_id: token_len for request_id, token_len in requests_remain_token_len.items() if token_len > 0}

+            requests_priority = {request_id: requests_priority[request_id] for request_id in requests_remain_token_len.keys()}

+            request_id = max(requests_priority.items(), key=lambda v: v[1])[0]

+            assert requests_t2w_sched[request_id]['alloc_len'] > 0

+            if requests_t2w_sched[request_id]['alloc_len'] == 0:

+                alloc_token_len = min(requests_remain_token_len[request_id], t2w_min_token_len)

+                alloc_token_len_with_overhead = alloc_token_len + self.t2w.pre_gap_token_len + self.t2w.post_gap_token_len

+            else:

+                alloc_token_len = 1

+                alloc_token_len_with_overhead = 1

+            available_token_len -= alloc_token_len_with_overhead

+            requests_remain_token_len[request_id] -= alloc_token_len

+            requests_priority[request_id] -= alloc_token_len

+            requests_t2w_sched[request_id]['alloc_len'] += alloc_token_len

+            requests_t2w_sched[request_id]['full'] = requests_remain_token_len[request_id] == 0

+            assert requests_remain_token_len[request_id] >= 0

+

+        requests_t2w_sched = {request_id: sched for request_id, sched in requests_t2w_sched.items() if sched['alloc_len'] > 0}

+

+        return requests_t2w_sched

+

+    def maybe_t2w_infer(self, now):

+        with self.lock:

+            first_flag = []

+            for request_id, request_info in self.requests_info.items():

+                if len(request_info.sections) == 0:

+                    request_info.first = (len(request_info.llm_output) >= self.t2w_first_token_len) or request_info.llm_done

+                    first_flag.append(request_info.first)

+                else:

+                    request_info.first = False

+

+            logging.debug(f"maybe_t2w_infer: first_flag = {first_flag}")

+

+            if len(first_flag) > 0:

+                if all(first_flag):

+                    requests_t2w_sched = {request_id: {'history_len': 0,

+                                                       'alloc_len': min(self.t2w_first_token_len, len(request_info.llm_output)),

+                                                       'full': len(request_info.llm_output) <= self.t2w_first_token_len} \

+                                          for request_id, request_info in self.requests_info.items() if request_info.first}

+                else:

+                    requests_t2w_sched = {}

+            else:

+                requests_t2w_sched = self.schedule_nonfirst_requests_t2w(now)

+            if len(requests_t2w_sched) == 0:

+                return None

+

+            t2w_requests = []

+            for request_id, sched in requests_t2w_sched.items():

+                request_info = self.requests_info[request_id]

+                t2w_finalize = request_info.llm_done and sched['full']

+

+                valid_tokens = request_info.llm_output[sched['history_len'] : sched['history_len'] + sched['alloc_len']]

+

+                if self.t2w.pre_gap_token_len > 0:

+                    pre_gap = request_info.llm_output[: sched['history_len']][-self.t2w.pre_gap_token_len :]

+                    pre_gap = [self.t2w.silent_tokens[0]] * (self.t2w.pre_gap_token_len - len(pre_gap)) + pre_gap

+                else:

+                    pre_gap = []

+

+                if t2w_finalize:

+                    post_gap = [self.t2w.silent_tokens[0]] * self.t2w.post_gap_token_len

+                else:

+                    post_gap = []

+

+                tokens = pre_gap + valid_tokens + post_gap

+                rel_begin_pos = len(pre_gap)

+                rel_end_pos = rel_begin_pos + len(valid_tokens) - (0 if t2w_finalize else self.t2w.post_gap_token_len)

+

+                t2w_request = ThcSection(idx=(request_info.idx, request_id, len(request_info.sections)),

+                                         mode=request_info.mode,

+                                         spk_id=request_info.spk_id,

+                                         first=request_info.first,

+                                         finalize=t2w_finalize,

+                                         tokens=tokens,

+                                         token_begin_pos=rel_begin_pos,

+                                         token_end_pos=rel_end_pos,

+                                         flow_cache=request_info.flow_cache,

+                                         status=ThcSectionStatus.T2W_DOING)

+                request_info.sections.append(t2w_request)

+                t2w_requests.append(t2w_request)

+                request_info.metric.t2w.append(ThcMetric(begin=now))

+                logging.debug(f"t2w_infer: {request_info.idx}: t2w_request: {request_info.metric.t2w[-1].begin:%H:%M:%S.%f}, "

+                              f"rel_begin_pos = {rel_begin_pos}, rel_end_pos = {rel_end_pos}, "

+                              f"first = {t2w_request.first}, finalize = {t2w_request.finalize}")

+

+        assert len(t2w_requests) > 0

+        sections_brief = {f'{t2w_request.idx[0]}-{t2w_request.idx[2]}': len(t2w_request.tokens) for t2w_request in t2w_requests}

+        logging.info(f"t2w_infer: sections: {sections_brief}")

+        self.t2w.add_requests(t2w_requests)

+        outputs = self.t2w.step()

+

+        return outputs

+

+    def postprocess_t2w_outputs(self, t2w_step_outputs):

+        with self.lock:

+            for output in t2w_step_outputs:

+                request_id = output.idx[1]

+                section_idx = output.idx[2]

+                assert isinstance(request_id, str)

+                request_info = self.requests_info.get(request_id, None)

+                if request_info is None:

+                    continue

+                request_info.flow_cache = output.flow_cache

+                if request_info.speech is None:

+                    request_info.speech = output.speech

+                else:

+                    request_info.speech = torch.cat([request_info.speech, output.speech], dim=-1)

+                request_info.speech_cum_dt += output.speech_dt

+                request_info.metric.t2w[section_idx].stop()

+                request_info.done = any([section.finalize for section in request_info.sections])

+                request_info.section_queue.put(output)

+

+            now = datetime.datetime.now()

+            for request_info in self.requests_info.values():

+                if request_info.done:

+                    calc_cum_dt = (now - request_info.metric.e2e.begin).total_seconds()

+                    request_info.dynamic_rtf = calc_cum_dt / request_info.speech_cum_dt

+                    request_info.dynamic_slack = request_info.speech_cum_dt - calc_cum_dt

+                    request_info.release_dt = (now - request_info.metric.t2w[-1].end).total_seconds()

+                    logging.debug(f"infer: {request_info.idx}: "

+                                 f"begin = {request_info.metric.e2e.begin:%H:%M:%S.%f}, "

+                                 f"now = {now:%H:%M:%S.%f}, "

+                                 f"done = {request_info.done}, "

+                                 f"calc_cum_dt = {calc_cum_dt:.3f}, "

+                                 f"speech_cum_dt = {request_info.speech_cum_dt:.3f}, "

+                                 f"dynamic_rtf = {request_info.dynamic_rtf:.3f}, "

+                                 f"dynamic_slack = {request_info.dynamic_slack:.3f}")

+

+            batch_release_flag = True

+

+            if batch_release_flag:

+                can_release_flag = []

+                must_release_flag = []

+                for request_info in self.requests_info.values():

+                    can_release_flag.append(request_info.done)

+                    if request_info.done:

+                        if len(request_info.sections) == 1:

+                            flag = True

+                        else:

+                            dt = (now - request_info.metric.t2w[-2].end).total_seconds()

+                            flag = dt > self.release_timeout

+                        must_release_flag.append(flag)

+                    else:

+                        must_release_flag.append(False)

+

+                if all(can_release_flag) or any(must_release_flag):

+                    self.prev_release_time = now

+                    requests_id = list(self.requests_info.keys())

+                    for request_id in requests_id:

+                        request_info = self.requests_info[request_id]

+                        if request_info.done:

+                            self.requests_info.pop(request_id)

+                            self.requests_id[request_info.idx] = None

+                            request_info.release.put(True)

+                            logging.debug(f"infer: remove: request_idx: {request_info.idx}, dynamic_rtf = {request_info.dynamic_rtf:.3f}, dynamic_slack = {request_info.dynamic_slack:.3f}")

+                else:

+                    self.prev_release_time = None

+            else:

+                requests_id = list(self.requests_info.keys())

+                for request_id in requests_id:

+                    request_info = self.requests_info[request_id]

+                    if request_info.done:

+                        self.prev_release_time = now

+                        self.requests_info.pop(request_id)

+                        self.requests_id[request_info.idx] = None

+                        request_info.release.put(True)

+                        logging.debug(f"infer: remove: request_idx: {request_info.idx}, dynamic_rtf = {request_info.dynamic_rtf:.3f}, dynamic_slack = {request_info.dynamic_slack:.3f}")

+

+

+    def infer(self):

+        llm = self.cosyvoice.model.llm

+        print_wait_flag = True

+        empty_cache_flag = True

+

+        logging.info("infer: start")

+

+        while True:

+            if print_wait_flag:

+                logging.info("infer: wait")

+                print_wait_flag = False

+

+            if not self.inference_event.wait(timeout=30):

+                if empty_cache_flag:

+                    logging.info("infer: empty_cache")

+                    gc.collect()

+                    if self.device.type == 'npu':

+                        torch_npu.npu.empty_cache()

+                    empty_cache_flag = False

+                continue

+

+            if llm.vllm.get_num_unfinished_requests() == 0:

+                continue

+

+            print_wait_flag = True

+            empty_cache_flag = True

+

+            llm_section_begin = datetime.datetime.now()

+            self.prev_release_time = None

+            n_step = 0

+            logging.debug(f"infer: llm.vllm.get_num_unfinished_requests() = {llm.vllm.get_num_unfinished_requests()}")

+

+            while llm.vllm.has_unfinished_requests() or len(self.requests_info) > 0:

+                self.inference_event.clear()

+

+                if llm.vllm.has_unfinished_requests():

+                    llm_step_outputs = llm.vllm.step()

+

+                    now = datetime.datetime.now()

+                    llm_section_end = now

+                    logging.debug(f"infer: n_step = {n_step}, len(llm_step_outputs) = {len(llm_step_outputs)}, "

+                                  f"{[len(request_output.outputs) for request_output in llm_step_outputs]}, "

+                                  f"{[len(request_output.outputs[0].token_ids) for request_output in llm_step_outputs]}, "

+                                  f"{[request_output.outputs[0].token_ids[-1] for request_output in llm_step_outputs]}, "

+                                  f"{[request_output.finished for request_output in llm_step_outputs]}, "

+                                  f"{[request_output.outputs[0].finish_reason for request_output in llm_step_outputs]}")

+                    n_step += 1

+

+                    self.postprocess_vllm_outputs(llm_step_outputs, now)

+

+                if self.prev_release_time is None or \

+                   (datetime.datetime.now() - self.prev_release_time).total_seconds() >= self.t2w_delay or \

+                   len(self.requests_info) == self.args.max_conc:

+                    t2w_step_outputs = self.maybe_t2w_infer(now)

+                    if t2w_step_outputs is not None:

+                        llm_section_dt = (llm_section_end - llm_section_begin).total_seconds()

+                        t2w_perf = self.t2w.perfs[-1]

+                        logging.debug(f"infer: llm_dt = {llm_section_dt:.3f}, t2w_token_len = {t2w_perf['token_len']}, t2w_dt = {t2w_perf['metric'].dt:.3f}")

+                        llm_section_begin = datetime.datetime.now()

+                        self.postprocess_t2w_outputs(t2w_step_outputs)

+                elif llm.vllm.has_unfinished_requests():

+                    pass

+                else:

+                    time.sleep(0.01)

+

+    def add_request(self, mode, tts_text, spk_id, stream_flag):

+        if stream_flag != True:

+            raise NotImplementedError("only support stream mode")

+

+        cosyvoice = self.cosyvoice

+        frontend = cosyvoice.frontend

+        model = cosyvoice.model

+        llm = model.llm

+

+        request_id = str(uuid.uuid1())

+

+        spk_prompt_emb = self.prompt_emb[spk_id]

+        sos_emb = spk_prompt_emb['sos_emb']

+        task_id_emb = spk_prompt_emb['task_id_emb']

+

+        # modify if iter text

+        tts_texts = frontend.text_normalize(tts_text, split=True, text_frontend=frontend.text_frontend)

+        tts_text = ''.join(tts_texts)

+        if mode == 'zero_shot_by_id':

+            model_input = frontend.frontend_zero_shot(tts_text, None, None, None, spk_id)

+            text = model_input['text']

+            text_len = model_input['text_len']

+            text_emb = llm.llm.model.model.embed_tokens(text)

+            prompt_text_emb = spk_prompt_emb['prompt_text_emb']

+            prompt_speech_token_emb = spk_prompt_emb['prompt_speech_token_emb']

+            lm_input = torch.concat([sos_emb, prompt_text_emb, text_emb, task_id_emb, prompt_speech_token_emb], dim=1)

+        elif mode == 'sft':

+            tts_text = '<|endofprompt|>' + tts_text

+            model_input = frontend.frontend_sft(tts_text, spk_id)

+            text = model_input['text']

+            text_len = model_input['text_len']

+            text_emb = llm.llm.model.model.embed_tokens(text)

+            lm_input = torch.concat([sos_emb, text_emb, task_id_emb], dim=1)

+

+        logging.debug(f"add_request: lm_input: [{lm_input.dtype}, lm_input.shape: {lm_input.shape}]")

+

+        min_token_text_ratio = 2

+        max_token_text_ratio = 20

+        min_len = int(text_len * min_token_text_ratio)

+        max_len = int(text_len * max_token_text_ratio)

+

+        sampling = 25

+        sampling_params = vllm.SamplingParams(top_k=sampling,

+                                              stop_token_ids=llm.stop_token_ids,

+                                              min_tokens=min_len,

+                                              max_tokens=max_len)

+        logging.debug(f"add_request: sampling_params: {sampling_params}")

+        now = datetime.datetime.now()

+

+        with self.lock:

+            for n in range(len(self.requests_id)):

+                other_request_id = self.requests_id[n]

+                other_requests_info = self.requests_info.get(other_request_id, None)

+                if other_requests_info is not None and (now - other_requests_info.metric.e2e.begin).total_seconds() > 60:

+                    logging.info(f"add_request: request_idx: {n}, {other_request_id}, timeout")

+                    self.requests_info.pop(other_request_id)

+                    self.requests_id[n] = None

+

+            request_idx = [n for n in range(self.args.max_conc) if self.requests_id[n] is None]

+            assert len(request_idx) > 0, f"requests_id = {self.requests_id}"

+            request_idx = request_idx[0]

+            request_info = ThcRequestInfo(

+                idx=request_idx,

+                request_id=request_id,

+                mode=mode,

+                spk_id=spk_id,

+            )

+            request_info.metric.e2e.start()

+            request_info.metric.llm.start()

+

+            self.requests_id[request_idx] = request_id

+            self.requests_info[request_id] = request_info

+

+            llm.vllm.add_request(request_id, {'prompt_embeds': lm_input.squeeze(0).to(torch.bfloat16)}, sampling_params)

+

+        self.inference_event.set()

+

+        logging.info(f"add_request: request_idx: {request_idx}, {request_id}, mode = {mode}, tts_text = {tts_text}, spk_id = {spk_id}")

+

+        return request_info

+

+    def check_context_and_get_queue(self, request_info, context, q, timeout, q_name):

+        request_idx = request_info.idx

+        request_id = request_info.request_id

+        unit = 0.2

+

+        for n in range(max(1, int(timeout / unit))):

+            if (context is not None) and (not context.is_active()):

+                logging.info(f"check_context: request_idx: {request_idx}, {request_id}, client disconnected")

+                return None

+            try:

+                data = q.get(timeout=unit)

+                return data

+            except queue.Empty:

+                continue

+        else:

+            logging.info(f"check_queue: request_idx: {request_idx}, {request_id}, get {q_name} timeout")

+            return None

+

+    def postprocess_request(self, request_info):

+        request_idx = request_info.idx

+        request_id = request_info.request_id

+

+        worst_slack = 999

+        for n in range(1, len(request_info.sections)):

+            speech_cum_dt = sum([section.speech_dt for section in request_info.sections[: n]])

+            calc_cum_dt = (request_info.metric.t2w[n].end - request_info.metric.t2w[0].end).total_seconds()

+            worst_slack = min(worst_slack, speech_cum_dt - calc_cum_dt)

+

+        logging.debug(f"request_info.llm_output = {request_info.llm_output}")

+

+        request_info.metric.e2e.stop()

+        first_latency = (request_info.metric.t2w[0].end - request_info.metric.e2e.begin).total_seconds()

+        rtf = request_info.metric.e2e.dt / request_info.speech_cum_dt

+        logging.info(f"complete_request: request_idx: {request_idx}, {request_id}, "

+                    f"submit_time = {request_info.metric.e2e.begin:%H:%M:%S.%f}, "

+                    f"e2e = {request_info.metric.e2e.dt:.3f}, "

+                    f"slack = {worst_slack:.3f}, "

+                    f"first_latency = {first_latency:.3f}, "

+                    f"rtf = {rtf:.3f}, "

+                    f"speech_dt = {request_info.speech_cum_dt:.3f}")

+

+

+    def process_request(self, mode, tts_text, spk_id, stream_flag, context):

+        logging.debug(f"process_request: tts_text = {tts_text}")

+        request_info = self.add_request(mode, tts_text, spk_id, stream_flag)

+        request_idx = request_info.idx

+        request_id = request_info.request_id

+

+        n_section = 0

+        while True:

+            section = self.check_context_and_get_queue(request_info, context, request_info.section_queue, timeout=10, q_name='section')

+            if section is None:

+                with self.lock:

+                    self.requests_info.pop(request_id)

+                    self.requests_id[request_idx] = None

+                return

+            assert section.idx[2] == n_section, f"section.idx[2] = {section.idx[2]}, n_section = {n_section}"

+            n_section += 1

+

+            logging.debug(f"{request_id}, section.speech_cpu = {section.speech_cpu.shape}, speech_dt = {section.speech_dt}")

+

+            if section.finalize and request_info.done:

+                release = self.check_context_and_get_queue(request_info, context, request_info.release, timeout=30, q_name='release')

+                if release is None:

+                    with self.lock:

+                        self.requests_info.pop(request_id)

+                        self.requests_id[request_idx] = None

+                    return

+

+                self.postprocess_request(request_info)

+

+            yield section.speech_cpu, section.finalize and request_info.done

+

+            if request_info.done:

+                break

+

+    def warmup(self):

+        tts_text = '1,2,3,4,5,6,7,8,9,10'

+        spk_id = 'spk1'

+        stream_flag = True

+

+        for n_req in range(self.args.max_conc):

+            request_info = self.add_request('zero_shot_by_id', tts_text, spk_id, stream_flag)

+

+        while True:

+            time.sleep(1)

+

+            with self.lock:

+                request_num = len(self.requests_info)

+

+            if request_num == 0:

+                break

 

     def Inference(self, request, context):

-        if request.HasField('sft_request'):

-            logging.info('get sft inference request')

-            model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id)

-        elif request.HasField('zero_shot_request'):

-            logging.info('get zero_shot inference request')

-            prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)

-            prompt_speech_16k = prompt_speech_16k.float() / (2**15)

-            model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text,

-                                                              request.zero_shot_request.prompt_text,

-                                                              prompt_speech_16k)

-        elif request.HasField('cross_lingual_request'):

-            logging.info('get cross_lingual inference request')

-            prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)

-            prompt_speech_16k = prompt_speech_16k.float() / (2**15)

-            model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)

+        logging.debug(f"Inference: context = {context}, PID = {os.getpid()}, TID = {threading.get_ident()}")

+        logging.debug(f"Inference: tts_text = {request.zero_shot_by_id_request.tts_text}, spk_id = {request.zero_shot_by_id_request.spk_id}")

+        high_performace_flag = True

+

+        if high_performace_flag:

+            if request.HasField('sft_request'):

+                mode = 'sft'

+                tts_text = request.sft_request.tts_text

+                spk_id = request.sft_request.spk_id

+                stream_flag = request.sft_request.stream

+            elif request.HasField('zero_shot_by_id_request'):

+                mode = 'zero_shot_by_id'

+                tts_text = request.zero_shot_by_id_request.tts_text

+                spk_id = request.zero_shot_by_id_request.spk_id

+                stream_flag = request.zero_shot_by_id_request.stream

+            else:

+                raise NotImplementedError('Mode not implemented!')

+

+            for tts_speech, done_flag in self.process_request(mode, tts_text, spk_id, stream_flag, context):

+                response = cosyvoice_pb2.Response()

+                response.tts_audio = tts_speech.tobytes()

+                yield response

+                if done_flag:

+                    break

+

         else:

-            logging.info('get instruct inference request')

-            model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,

-                                                             request.instruct_request.spk_id,

-                                                             request.instruct_request.instruct_text)

+            if request.HasField('zero_shot_by_id_request'):

+                logging.info('get zero_shot_by_id inference request')

+                model_output = self.cosyvoice.inference_zero_shot_by_id(

+                    request.zero_shot_by_id_request.tts_text,

+                    request.zero_shot_by_id_request.spk_id,

+                    stream=request.zero_shot_by_id_request.stream

+                )

+            else:

+                assert False

+                logging.info('get instruct inference request')

+                model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,

+                                                                request.instruct_request.spk_id,

+                                                                request.instruct_request.instruct_text)

 

-        logging.info('send inference response')

-        for i in model_output:

-            response = cosyvoice_pb2.Response()

-            response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()

-            yield response

+            logging.info('send inference response')

+

+            count=0

+            for i in model_output:

+                response = cosyvoice_pb2.Response()

+                response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()

+                yield response

+                count+=1

+                if count==4 and self.args.enable_vllm_profile:

+                    assert False

+                    try:

+                        self.cosyvoice.model.llm.vllm.stop_profile()

+                        logging.info('vLLM profile stopped')

+                    except Exception as e:

+                        logging.warning(f'Failed to stop vLLM profile: {e}')

 

 

 def main():

+    service = CosyVoiceServiceImpl(args)

     grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)

-    cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)

+    cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(service, grpcServer)

     grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port))

     grpcServer.start()

+

+    if args.graph_mode:

+        logging.info(f"warmup: start")

+        service.warmup()

+        logging.info(f"warmup: end")

+

     logging.info("server listening on 0.0.0.0:{}".format(args.port))

     grpcServer.wait_for_termination()

 

@@ -86,5 +989,16 @@ if __name__ == '__main__':

                         type=str,

                         default='iic/CosyVoice2-0.5B',

                         help='local path or modelscope repo id')

+    parser.add_argument('--graph_mode',

+                        action='store_true',

+                        help='whether to use graph mode')

+    parser.add_argument('--speaker_info_dir',

+                        type=str,

+                        default='.',

+                        help='directory containing speaker info files (default: model_dir/speaker_info)')

+    parser.add_argument('--enable_vllm_profile',

+                        action='store_true',

+                        help='whether to enable vLLM profiling (default: False)')

     args = parser.parse_args()

+    logging.info(f"args: {args}")

     main()

diff --git a/runtime/python/grpc/speaker_info.json b/runtime/python/grpc/speaker_info.json

new file mode 100644

index 0000000..814a9c3

--- /dev/null

+++ b/runtime/python/grpc/speaker_info.json

@@ -0,0 +1,6 @@

+{

+    "spk1":{

+        "prompt_wav":"../../../asset/zero_shot_prompt.wav",

+        "prompt_text":"You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"

+    }

+}

\ No newline at end of file

diff --git a/runtime/python/grpc/update_grpc_file.sh b/runtime/python/grpc/update_grpc_file.sh

new file mode 100644

index 0000000..1c0a8da

--- /dev/null

+++ b/runtime/python/grpc/update_grpc_file.sh

@@ -0,0 +1 @@

+python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto

\ No newline at end of file