@@ -22,6 +22,7 @@ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
+from cosyvoice.utils.file_utils import load_wav
class CosyVoice:
@@ -74,6 +75,11 @@ class CosyVoice:
self.frontend.spk2info[zero_shot_spk_id] = model_input
return True
+ def add_sft_spk(self, sft_spk_id):
+ model_input = self.frontend.frontend_from_spk2info(sft_spk_id)
+ self.frontend.spk2info_sft[sft_spk_id] = model_input
+ return True
+
def save_spkinfo(self):
torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
@@ -188,7 +194,8 @@ class CosyVoice2(CosyVoice):
class CosyVoice3(CosyVoice2):
- def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
+ def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1,
+ speaker_info_dir=None):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
@@ -224,6 +231,66 @@ class CosyVoice3(CosyVoice2):
self.fp16)
del configs
+ self.speaker_info_dir = speaker_info_dir or os.path.join(model_dir, 'speaker_info')
+ if not os.path.exists(self.speaker_info_dir):
+ logging.warning(f'speaker_info_dir {self.speaker_info_dir} does not exist')
+
+ self.promote_wave_info = {}
+ self._preload_all_wave_info()
+
+ def _preload_all_wave_info(self):
+ if not os.path.exists(self.speaker_info_dir):
+ return
+
+ import json
+
+ speaker_info_file = os.path.join(self.speaker_info_dir, 'speaker_info.json')
+ if not os.path.exists(speaker_info_file):
+ logging.warning(f'speaker_info.json not found in {self.speaker_info_dir}')
+ return
+
+ try:
+ with open(speaker_info_file, 'r', encoding='utf-8') as f:
+ all_speakers = json.load(f)
+
+ for spk_id, wave_info in all_speakers.items():
+ try:
+ prompt_wav_path = os.path.join(self.speaker_info_dir, wave_info['prompt_wav'])
+ if os.path.exists(prompt_wav_path):
+ prompt_wav = load_wav(prompt_wav_path, 16000)
+ wave_info['prompt_wav'] = prompt_wav
+ self.promote_wave_info[spk_id] = wave_info
+ logging.info(f"Loaded speaker {spk_id} from {prompt_wav_path}, promote={wave_info['prompt_text']}")
+ else:
+ logging.error(f'Audio file not found for speaker {spk_id}: {prompt_wav_path}')
+ except Exception as e:
+ logging.error(f'Failed to load speaker {spk_id}: {e}')
+
+ logging.info(f'Preloaded {len(self.promote_wave_info)} speakers from {speaker_info_file}')
+
+ except Exception as e:
+ logging.error(f'Failed to load {speaker_info_file}: {e}')
+
+ def inference_zero_shot_by_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
+ if spk_id not in self.promote_wave_info:
+ raise ValueError(f'Speaker ID {spk_id} not found. Available IDs: {list(self.promote_wave_info.keys())}')
+ wave_info = self.promote_wave_info[spk_id]
+
+ if spk_id not in self.frontend.spk2info:
+ self.add_zero_shot_spk(wave_info['prompt_text'], wave_info['prompt_wav'], spk_id)
+
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
+ model_input = self.frontend.frontend_zero_shot(i, wave_info['prompt_text'], wave_info['prompt_wav'], self.sample_rate, spk_id)
+ start_time = time.time()
+ logging.info('synthesis text {}'.format(i))
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+ # 2.save audio
+ #res = torch.cat((res, model_output['tts_speech']), dim=1)
+ yield model_output
+ start_time = time.time()
+
def AutoModel(**kwargs):
if not os.path.exists(kwargs['model_dir']):
@@ -26,6 +26,7 @@ import inflect
from cosyvoice.utils.file_utils import logging, load_wav
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
+import threading
class CosyVoiceFrontEnd:
@@ -50,6 +51,8 @@ class CosyVoiceFrontEnd:
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
else:
self.spk2info = {}
+ logging.info(f"spk2info file = {spk2info}, CosyVoiceFrontEnd.spk2info = {self.spk2info.keys()}")
+ self.spk2info_sft = {}
self.allowed_special = allowed_special
self.inflect_parser = inflect.engine()
# NOTE compatible when no text frontend tool is avaliable
@@ -74,6 +77,12 @@ class CosyVoiceFrontEnd:
self.text_frontend = ''
logging.info('no frontend is avaliable')
+ self.lock = threading.Lock()
+
+ def tokenizer_encode_with_lock(self, *args, **kwargs):
+ with self.lock:
+ output = self.tokenizer.encode(*args, **kwargs)
+ return output
def _extract_text_token(self, text):
if isinstance(text, Generator):
@@ -81,7 +90,7 @@ class CosyVoiceFrontEnd:
# NOTE add a dummy text_token_len for compatibility
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
else:
- text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
+ text_token = self.tokenizer_encode_with_lock(text, allowed_special=self.allowed_special)
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
return text_token, text_token_len
@@ -148,13 +157,13 @@ class CosyVoiceFrontEnd:
text = text.replace(" - ", ",")
text = remove_bracket(text)
text = re.sub(r'[,,、]+$', '。', text)
- texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
+ texts = list(split_paragraph(text, partial(self.tokenizer_encode_with_lock, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
else:
if self.text_frontend == 'wetext':
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
- texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
+ texts = list(split_paragraph(text, partial(self.tokenizer_encode_with_lock, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
texts = [i for i in texts if not is_only_punctuation(i)]
return texts if split is True else text
@@ -222,3 +231,13 @@ class CosyVoiceFrontEnd:
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
'flow_embedding': embedding}
return model_input
+
+ def frontend_from_spk2info(self, spk_id):
+ spk_info = self.spk2info[spk_id]
+ embedding = spk_info['embedding'].to(self.device)
+ model_input = {
+ 'llm_embedding': embedding,
+ 'flow_embedding': embedding,
+ }
+ return model_input
+
@@ -100,7 +100,7 @@ class CosyVoiceModel:
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
cur_silent_token_num, max_silent_token_num = 0, 5
- with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
+ with torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
if isinstance(text, Generator):
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
token_generator = self.llm.inference_bistream(text=text,
@@ -282,9 +282,13 @@ class CosyVoice2Model(CosyVoiceModel):
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
from vllm import EngineArgs, LLMEngine
engine_args = EngineArgs(model=model_dir,
+ dtype='float16',
skip_tokenizer_init=True,
enable_prompt_embeds=True,
- gpu_memory_utilization=0.2)
+ gpu_memory_utilization=0.5,
+ additional_config={"torchair_graph_config":{"enabled":True}},
+ compilation_config={'cudagraph_capture_sizes':[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,512,1024],"cudagraph_mode": "FULL"},
+ enable_prefix_caching=True)
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
self.llm.lock = threading.Lock()
del self.llm.llm.model.model.layers
@@ -18,7 +18,7 @@ import torch.nn.functional as F
import torchaudio
from x_transformers.x_transformers import apply_rotary_pos_emb
-
+import torch_npu
# raw wav to mel spec
class MelSpec(nn.Module):
@@ -84,7 +84,12 @@ class SinusPositionEmbedding(nn.Module):
# convolutional position embedding
+class Mish(nn.Module):
+ def __init__(self):
+ super().__init__()
+ def forward(self, x):
+ return x * torch.tanh(torch.log(1 + torch.exp(x)))
class ConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
@@ -92,9 +97,9 @@ class ConvPositionEmbedding(nn.Module):
assert kernel_size % 2 != 0
self.conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
- nn.Mish(),
+ Mish(),
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
- nn.Mish(),
+ Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
@@ -119,11 +124,11 @@ class CausalConvPositionEmbedding(nn.Module):
self.kernel_size = kernel_size
self.conv1 = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
- nn.Mish(),
+ Mish(),
)
self.conv2 = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
- nn.Mish(),
+ Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
@@ -388,7 +393,17 @@ class AttnProcessor:
else:
attn_mask = None
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
+ atten_mask_npu = torch.logical_not(attn_mask)
+ head_num = query.shape[1]
+ x = torch_npu.npu_fusion_attention(
+ query, key, value, head_num, input_layout="BNSD",
+ pse=None,
+ atten_mask=atten_mask_npu,
+ scale=1.0 / math.sqrt(query.shape[-1]),
+ pre_tockens=2147483647,
+ next_tockens=2147483647,
+ keep_prob=1
+ )[0]
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
@@ -376,7 +376,8 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
prompt_feat_len,
embedding,
streaming,
- finalize):
+ finalize,
+ n_timesteps=None):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
@@ -406,7 +407,7 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
- n_timesteps=10,
+ n_timesteps=10 if n_timesteps is None else n_timesteps,
streaming=streaming
)
feat = feat[:, :, mel_len1:]
@@ -498,10 +498,24 @@ class HiFTGenerator(nn.Module):
def _istft(self, magnitude, phase):
magnitude = torch.clip(magnitude, max=1e2)
+ device = magnitude.device
+ is_npu = device.type == 'npu'
real = magnitude * torch.cos(phase)
img = magnitude * torch.sin(phase)
- inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
- self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
+ if is_npu:
+ complex_tensor = torch.complex(real.cpu(), img.cpu())
+ window_xpu = self.stft_window.cpu()
+ else:
+ complex_tensor = torch.complex(real, img)
+ window_xpu = self.stft_window.to(device)
+
+ inverse_transform = torch.istft(complex_tensor,
+ self.istft_params["n_fft"],
+ self.istft_params["hop_len"],
+ self.istft_params["n_fft"],
+ window=window_xpu)
+ if is_npu:
+ inverse_transform = inverse_transform.to(device)
return inverse_transform
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
@@ -20,7 +20,7 @@ import torch
import torchaudio
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
-logging.basicConfig(level=logging.DEBUG,
+logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s')
@@ -42,8 +42,22 @@ def read_json_lists(list_file):
def load_wav(wav, target_sr, min_sr=16000):
- speech, sample_rate = torchaudio.load(wav, backend='soundfile')
- speech = speech.mean(dim=0, keepdim=True)
+ if isinstance(wav, torch.Tensor):
+ speech = wav
+ sample_rate = 16000 # assumed
+ elif not isinstance(wav, str):
+ try:
+ speech = torch.as_tensor(wav)
+ sample_rate = 16000 # assumed
+ except:
+ raise TypeError(f"Expect path or Tensor, but got {type(wav)}")
+ else:
+ speech, sample_rate = torchaudio.load(wav, backend='soundfile')
+
+ if speech.ndim > 1:
+ speech = speech.mean(dim=0, keepdim=True)
+ elif speech.ndim == 1:
+ speech = speech.unsqueeze(0)
if sample_rate != target_sr:
assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
@@ -230,9 +230,8 @@ def add_optional_chunk_mask(xs: torch.Tensor,
else:
chunk_masks = masks
assert chunk_masks.dtype == torch.bool
- if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
- print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
- chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
+ new_mask = (chunk_masks.sum(dim=-1, keepdim=True) == 0)
+ chunk_masks = torch.logical_or(chunk_masks, new_mask)
return chunk_masks
new file mode 100644
@@ -0,0 +1,353 @@
+diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py
+index 7ab04a7..dcb743f 100644
+--- a/cosyvoice/cli/cosyvoice.py
+@@ -22,6 +22,7 @@ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
+ from cosyvoice.utils.file_utils import logging
+ from cosyvoice.utils.class_utils import get_model_type
++from cosyvoice.utils.file_utils import load_wav
+
+
+ class CosyVoice:
+@@ -74,6 +75,11 @@ class CosyVoice:
+ self.frontend.spk2info[zero_shot_spk_id] = model_input
+ return True
+
++ def add_sft_spk(self, sft_spk_id):
++ model_input = self.frontend.frontend_from_spk2info(sft_spk_id)
++ self.frontend.spk2info_sft[sft_spk_id] = model_input
++ return True
++
+ def save_spkinfo(self):
+ torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
+
+@@ -188,7 +194,8 @@ class CosyVoice2(CosyVoice):
+
+ class CosyVoice3(CosyVoice2):
+
+- def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
++ def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1,
++ speaker_info_dir=None):
+ self.model_dir = model_dir
+ self.fp16 = fp16
+ if not os.path.exists(model_dir):
+@@ -224,6 +231,66 @@ class CosyVoice3(CosyVoice2):
+ self.fp16)
+ del configs
+
++ self.speaker_info_dir = speaker_info_dir or os.path.join(model_dir, 'speaker_info')
++ if not os.path.exists(self.speaker_info_dir):
++ logging.warning(f'speaker_info_dir {self.speaker_info_dir} does not exist')
++
++ self.promote_wave_info = {}
++ self._preload_all_wave_info()
++
++ def _preload_all_wave_info(self):
++ if not os.path.exists(self.speaker_info_dir):
++ return
++
++ import json
++
++ speaker_info_file = os.path.join(self.speaker_info_dir, 'speaker_info.json')
++ if not os.path.exists(speaker_info_file):
++ logging.warning(f'speaker_info.json not found in {self.speaker_info_dir}')
++ return
++
++ try:
++ with open(speaker_info_file, 'r', encoding='utf-8') as f:
++ all_speakers = json.load(f)
++
++ for spk_id, wave_info in all_speakers.items():
++ try:
++ prompt_wav_path = os.path.join(self.speaker_info_dir, wave_info['prompt_wav'])
++ if os.path.exists(prompt_wav_path):
++ prompt_wav = load_wav(prompt_wav_path, 16000)
++ wave_info['prompt_wav'] = prompt_wav
++ self.promote_wave_info[spk_id] = wave_info
++ logging.info(f"Loaded speaker {spk_id} from {prompt_wav_path}, promote={wave_info['prompt_text']}")
++ else:
++ logging.error(f'Audio file not found for speaker {spk_id}: {prompt_wav_path}')
++ except Exception as e:
++ logging.error(f'Failed to load speaker {spk_id}: {e}')
++
++ logging.info(f'Preloaded {len(self.promote_wave_info)} speakers from {speaker_info_file}')
++
++ except Exception as e:
++ logging.error(f'Failed to load {speaker_info_file}: {e}')
++
++ def inference_zero_shot_by_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
++ if spk_id not in self.promote_wave_info:
++ raise ValueError(f'Speaker ID {spk_id} not found. Available IDs: {list(self.promote_wave_info.keys())}')
++ wave_info = self.promote_wave_info[spk_id]
++
++ if spk_id not in self.frontend.spk2info:
++ self.add_zero_shot_spk(wave_info['prompt_text'], wave_info['prompt_wav'], spk_id)
++
++ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
++ model_input = self.frontend.frontend_zero_shot(i, wave_info['prompt_text'], wave_info['prompt_wav'], self.sample_rate, spk_id)
++ start_time = time.time()
++ logging.info('synthesis text {}'.format(i))
++ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
++ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
++ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
++ # 2.save audio
++ #res = torch.cat((res, model_output['tts_speech']), dim=1)
++ yield model_output
++ start_time = time.time()
++
+
+ def AutoModel(**kwargs):
+ if not os.path.exists(kwargs['model_dir']):
+diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py
+index 6d397cc..10b0a87 100644
+--- a/cosyvoice/cli/frontend.py
+@@ -26,6 +26,7 @@ import inflect
+ from cosyvoice.utils.file_utils import logging, load_wav
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
+
++import threading
+
+ class CosyVoiceFrontEnd:
+
+@@ -50,6 +51,8 @@ class CosyVoiceFrontEnd:
+ self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
+ else:
+ self.spk2info = {}
++ logging.info(f"spk2info file = {spk2info}, CosyVoiceFrontEnd.spk2info = {self.spk2info.keys()}")
++ self.spk2info_sft = {}
+ self.allowed_special = allowed_special
+ self.inflect_parser = inflect.engine()
+ # NOTE compatible when no text frontend tool is avaliable
+@@ -74,6 +77,12 @@ class CosyVoiceFrontEnd:
+ self.text_frontend = ''
+ logging.info('no frontend is avaliable')
+
++ self.lock = threading.Lock()
++
++ def tokenizer_encode_with_lock(self, *args, **kwargs):
++ with self.lock:
++ output = self.tokenizer.encode(*args, **kwargs)
++ return output
+
+ def _extract_text_token(self, text):
+ if isinstance(text, Generator):
+@@ -81,7 +90,7 @@ class CosyVoiceFrontEnd:
+ # NOTE add a dummy text_token_len for compatibility
+ return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
+ else:
+- text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
++ text_token = self.tokenizer_encode_with_lock(text, allowed_special=self.allowed_special)
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
+ return text_token, text_token_len
+@@ -148,13 +157,13 @@ class CosyVoiceFrontEnd:
+ text = text.replace(" - ", ",")
+ text = remove_bracket(text)
+ text = re.sub(r'[,,、]+$', '。', text)
+- texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
++ texts = list(split_paragraph(text, partial(self.tokenizer_encode_with_lock, allowed_special=self.allowed_special), "zh", token_max_n=80,
+ token_min_n=60, merge_len=20, comma_split=False))
+ else:
+ if self.text_frontend == 'wetext':
+ text = self.en_tn_model.normalize(text)
+ text = spell_out_number(text, self.inflect_parser)
+- texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
++ texts = list(split_paragraph(text, partial(self.tokenizer_encode_with_lock, allowed_special=self.allowed_special), "en", token_max_n=80,
+ token_min_n=60, merge_len=20, comma_split=False))
+ texts = [i for i in texts if not is_only_punctuation(i)]
+ return texts if split is True else text
+@@ -222,3 +231,13 @@ class CosyVoiceFrontEnd:
+ 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
+ 'flow_embedding': embedding}
+ return model_input
++
++ def frontend_from_spk2info(self, spk_id):
++ spk_info = self.spk2info[spk_id]
++ embedding = spk_info['embedding'].to(self.device)
++ model_input = {
++ 'llm_embedding': embedding,
++ 'flow_embedding': embedding,
++ }
++ return model_input
++
+diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py
+index 92a15d9..f5362e1 100644
+--- a/cosyvoice/cli/model.py
+@@ -100,7 +100,7 @@ class CosyVoiceModel:
+
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
+ cur_silent_token_num, max_silent_token_num = 0, 5
+- with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
++ with torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
+ if isinstance(text, Generator):
+ assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
+ token_generator = self.llm.inference_bistream(text=text,
+@@ -282,9 +282,13 @@ class CosyVoice2Model(CosyVoiceModel):
+ export_cosyvoice2_vllm(self.llm, model_dir, self.device)
+ from vllm import EngineArgs, LLMEngine
+ engine_args = EngineArgs(model=model_dir,
++ dtype='float16',
+ skip_tokenizer_init=True,
+ enable_prompt_embeds=True,
+- gpu_memory_utilization=0.2)
++ gpu_memory_utilization=0.5,
++ additional_config={"torchair_graph_config":{"enabled":True}},
++ compilation_config={'cudagraph_capture_sizes':[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,512,1024],"cudagraph_mode": "FULL"},
++ enable_prefix_caching=True)
+ self.llm.vllm = LLMEngine.from_engine_args(engine_args)
+ self.llm.lock = threading.Lock()
+ del self.llm.llm.model.model.layers
+diff --git a/cosyvoice/flow/DiT/modules.py b/cosyvoice/flow/DiT/modules.py
+index be8caec..d188914 100644
+--- a/cosyvoice/flow/DiT/modules.py
+@@ -18,7 +18,7 @@ import torch.nn.functional as F
+ import torchaudio
+
+ from x_transformers.x_transformers import apply_rotary_pos_emb
+-
++import torch_npu
+
+ # raw wav to mel spec
+ class MelSpec(nn.Module):
+@@ -84,7 +84,12 @@ class SinusPositionEmbedding(nn.Module):
+
+
+ # convolutional position embedding
++class Mish(nn.Module):
++ def __init__(self):
++ super().__init__()
+
++ def forward(self, x):
++ return x * torch.tanh(torch.log(1 + torch.exp(x)))
+
+ class ConvPositionEmbedding(nn.Module):
+ def __init__(self, dim, kernel_size=31, groups=16):
+@@ -92,9 +97,9 @@ class ConvPositionEmbedding(nn.Module):
+ assert kernel_size % 2 != 0
+ self.conv1d = nn.Sequential(
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
+- nn.Mish(),
++ Mish(),
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
+- nn.Mish(),
++ Mish(),
+ )
+
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
+@@ -119,11 +124,11 @@ class CausalConvPositionEmbedding(nn.Module):
+ self.kernel_size = kernel_size
+ self.conv1 = nn.Sequential(
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
+- nn.Mish(),
++ Mish(),
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
+- nn.Mish(),
++ Mish(),
+ )
+
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
+@@ -388,7 +393,17 @@ class AttnProcessor:
+ else:
+ attn_mask = None
+
+- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
++ atten_mask_npu = torch.logical_not(attn_mask)
++ head_num = query.shape[1]
++ x = torch_npu.npu_fusion_attention(
++ query, key, value, head_num, input_layout="BNSD",
++ pse=None,
++ atten_mask=atten_mask_npu,
++ scale=1.0 / math.sqrt(query.shape[-1]),
++ pre_tockens=2147483647,
++ next_tockens=2147483647,
++ keep_prob=1
++ )[0]
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ x = x.to(query.dtype)
+
+diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py
+index c255186..9a20c34 100644
+--- a/cosyvoice/flow/flow.py
+@@ -376,7 +376,8 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
+ prompt_feat_len,
+ embedding,
+ streaming,
+- finalize):
++ finalize,
++ n_timesteps=None):
+ assert token.shape[0] == 1
+ # xvec projection
+ embedding = F.normalize(embedding, dim=1)
+@@ -406,7 +407,7 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
+ mask=mask.unsqueeze(1),
+ spks=embedding,
+ cond=conds,
+- n_timesteps=10,
++ n_timesteps=10 if n_timesteps is None else n_timesteps,
+ streaming=streaming
+ )
+ feat = feat[:, :, mel_len1:]
+diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py
+index bbc2a21..0cedae6 100644
+--- a/cosyvoice/hifigan/generator.py
+@@ -498,10 +498,24 @@ class HiFTGenerator(nn.Module):
+
+ def _istft(self, magnitude, phase):
+ magnitude = torch.clip(magnitude, max=1e2)
++ device = magnitude.device
++ is_npu = device.type == 'npu'
+ real = magnitude * torch.cos(phase)
+ img = magnitude * torch.sin(phase)
+- inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
+- self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
++ if is_npu:
++ complex_tensor = torch.complex(real.cpu(), img.cpu())
++ window_xpu = self.stft_window.cpu()
++ else:
++ complex_tensor = torch.complex(real, img)
++ window_xpu = self.stft_window.to(device)
++
++ inverse_transform = torch.istft(complex_tensor,
++ self.istft_params["n_fft"],
++ self.istft_params["hop_len"],
++ self.istft_params["n_fft"],
++ window=window_xpu)
++ if is_npu:
++ inverse_transform = inverse_transform.to(device)
+ return inverse_transform
+
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
+diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py
+index b173ef2..fefb363 100644
+--- a/cosyvoice/utils/file_utils.py
+@@ -20,7 +20,7 @@ import torch
+ import torchaudio
+ import logging
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
+-logging.basicConfig(level=logging.DEBUG,
++logging.basicConfig(level=logging.INFO,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+
+@@ -42,8 +42,22 @@ def read_json_lists(list_file):
+
+
+ def load_wav(wav, target_sr, min_sr=16000):
+- speech, sample_rate = torchaudio.load(wav, backend='soundfile')
+- speech = speech.mean(dim=0, keepdim=True)
++ if isinstance(wav, torch.Tensor):
++ speech = wav
++ sample_rate = 16000 # assumed
++ elif not isinstance(wav, str):
++ try:
++ speech = torch.as_tensor(wav)
++ sample
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,249 @@
+import os
+import sys
+import logging
+import argparse
+import soundfile
+import cosyvoice_pb2
+import cosyvoice_pb2_grpc
+import grpc
+import torch
+import numpy as np
+import time
+import datetime
+sys.path.insert(0, '../../../')
+from cosyvoice.utils.file_utils import load_wav
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from threading import Barrier, Lock
+
+logging.basicConfig(level=logging.INFO,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+TARGET_SR = 24000
+
+
+def percentile(data, p):
+ if len(data) == 0:
+ return 0
+ return np.percentile(data, p)
+
+
+def build_request(args, prompt_audio_bytes, tts_text):
+ request = cosyvoice_pb2.Request()
+
+ if args.mode == 'zero_shot':
+ zero_shot_request = cosyvoice_pb2.zeroshotRequest()
+ zero_shot_request.tts_text = tts_text
+ zero_shot_request.prompt_text = args.prompt_text
+ zero_shot_request.prompt_audio = prompt_audio_bytes
+ zero_shot_request.stream = args.stream
+ request.zero_shot_request.CopyFrom(zero_shot_request)
+ elif args.mode == 'zero_shot_by_id':
+ req = cosyvoice_pb2.zeroshotByIdRequest()
+ req.tts_text = tts_text
+ req.spk_id = args.spk_id
+ req.stream = args.stream
+ request.zero_shot_by_id_request.CopyFrom(req)
+ else:
+ sft_request = cosyvoice_pb2.sftRequest(
+ spk_id=args.spk_id,
+ tts_text=tts_text,
+ stream=args.stream
+ )
+ request.sft_request.CopyFrom(sft_request)
+
+ return request
+
+
+def single_request(task_id, args, stub, barrier, prompt_audio_bytes, tts_text, use_barrier):
+ try:
+ request = build_request(args, prompt_audio_bytes, tts_text)
+
+ if use_barrier and barrier is not None:
+ barrier.wait()
+
+ start_t = time.time()
+ response = stub.Inference(request)
+ first_packet_time = None
+ chunks = []
+ for r in response:
+ if first_packet_time is None:
+ first_packet_time = time.time()
+ chunks.append(np.frombuffer(r.tts_audio, dtype=np.int16))
+ end_t = time.time()
+
+ if len(chunks) == 0:
+ return {"success": False}
+
+ full_audio = np.concatenate(chunks)
+ audio_dur = len(full_audio) / TARGET_SR
+ total_time = end_t - start_t
+ first_fix = (first_packet_time - start_t) * 1000 if first_packet_time else 0
+ rtf = total_time / audio_dur if audio_dur > 0 else 0
+ # logging.info(f"start = {datetime.datetime.fromtimestamp(start_t):%H:%M:%S.%f}, "
+ # f"first = {datetime.datetime.fromtimestamp(first_packet_time):%H:%M:%S.%f}, "
+ # f"end = {datetime.datetime.fromtimestamp(end_t):%H:%M:%S.%f}, "
+ # f"total_time = {total_time:.3f}, "
+ # f"audio_dur = {audio_dur:.3f}, "
+ # f"first = {first_fix:.3f}, "
+ # f"rtf = {rtf:.3f}")
+
+ if args.save_task_id is not None and (task_id == args.save_task_id or args.save_task_id == -1):
+ os.makedirs(args.save_audio_dir, exist_ok=True)
+ save_path = os.path.join(
+ args.save_audio_dir,
+ f"task_{task_id}.wav"
+ )
+ soundfile.write(save_path, full_audio, TARGET_SR)
+ logging.info(f"Audio saved: {save_path}")
+ return {
+ "success": True,
+ "first_fix": first_fix,
+ "rtf": rtf,
+ "audio_dur": audio_dur
+ }
+
+ except Exception as e:
+ logging.error(f"Task {task_id} failed: {str(e)}")
+ return {"success": False}
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--host', type=str, default='0.0.0.0')
+ parser.add_argument('--port', type=int, default=50000)
+ parser.add_argument('--mode',
+ default='sft',
+ choices=['sft', 'zero_shot', 'zero_shot_by_id'])
+ parser.add_argument('--tts_text', type=str,
+ default='旅游是一种很好的放松方式,可以让我们暂时远离繁忙的工作,享受大自然的美好与宁静。')
+ parser.add_argument('--spk_id', type=str, default='spk1')
+ parser.add_argument('--prompt_text', type=str,
+ default='You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。')
+ parser.add_argument('--prompt_wav', type=str,
+ default='/data/code/CosyVoice/asset/zero_shot_prompt.wav')
+ parser.add_argument('--stream', action='store_true')
+ parser.add_argument('--total_req', type=int, default=100)
+ parser.add_argument('--concurrency', type=int, default=8)
+ parser.add_argument('--save_task_id', type=int, default=None)
+ parser.add_argument('--save_audio_dir', type=str, default='./saved_audio')
+ parser.add_argument('--text_file', type=str, default=None,
+ help='文本文件路径,每行一个句子,用于多样化测试')
+ parser.add_argument('--sync_start', action='store_true',
+ help='启用同步启动模式:等待所有并发线程就绪后同时发起请求(会降低吞吐量)')
+ args = parser.parse_args()
+
+ logging.info(f"Start benchmark total = {args.total_req} concurrency = {args.concurrency}")
+
+ text_list = []
+ if args.text_file and os.path.exists(args.text_file):
+ with open(args.text_file, 'r', encoding='utf-8') as f:
+ text_list = [line.strip() for line in f if line.strip()]
+ logging.info(f"Loaded {len(text_list)} texts from {args.text_file}")
+
+ if not text_list:
+ text_list = [args.tts_text]
+ logging.info("Using default text with task_id suffix for diversity")
+
+ prompt_audio_bytes = None
+
+ if args.mode == "zero_shot":
+ prompt_speech = load_wav(args.prompt_wav, 16000)
+ prompt_audio_bytes = (
+ (prompt_speech.numpy() * (2**15))
+ .astype(np.int16)
+ .tobytes()
+ )
+
+ options = [
+ ("grpc.keepalive_time_ms", 120000),
+ ("grpc.keepalive_timeout_ms", 60000),
+ ("grpc.http2.max_pings_without_data", 0),
+ ("grpc.keepalive_permit_without_calls", 1),
+ ("grpc.enable_http_proxy", 0),
+ ]
+
+ channel = grpc.insecure_channel(
+ f"{args.host}:{args.port}",
+ options=options
+ )
+
+ stub = cosyvoice_pb2_grpc.CosyVoiceStub(channel)
+
+ barrier = Barrier(args.concurrency) if args.sync_start else None
+
+ if args.sync_start:
+ logging.info("sync mode")
+ else:
+ logging.info("async mode")
+
+ results = []
+ progress = 0
+ progress_lock = Lock()
+ start_bench = time.time()
+
+ with ThreadPoolExecutor(max_workers=args.concurrency) as executor:
+ futures = []
+ for i in range(args.total_req):
+ if len(text_list) == 1:
+ tts_text = f"{text_list[0]}第{i+1}句。"
+ else:
+ tts_text = text_list[i % len(text_list)]
+
+ future = executor.submit(
+ single_request,
+ i,
+ args,
+ stub,
+ barrier,
+ prompt_audio_bytes,
+ tts_text,
+ args.sync_start
+ )
+ futures.append(future)
+
+ for future in as_completed(futures):
+ r = future.result()
+ results.append(r)
+ with progress_lock:
+ progress += 1
+ percent = progress / args.total_req * 100
+ logging.info(f"Progress: {progress}/{args.total_req} ({percent:.1f}%)")
+
+ end_bench = time.time()
+ success_results = [r for r in results if r["success"]]
+
+ if not success_results:
+ logging.info("All requests failed.")
+ return
+
+ first_fix_list = [r["first_fix"] for r in success_results]
+ rtf_list = [r["rtf"] for r in success_results]
+ audio_dur_list = [r["audio_dur"] for r in success_results]
+ total_audio_dur = sum(audio_dur_list)
+ throughput = total_audio_dur / (end_bench - start_bench)
+ avg_first_fix = sum(r['first_fix'] for r in success_results) / len(success_results)
+ avg_rtf = sum(r['rtf'] for r in success_results) / len(success_results)
+
+ logging.info("=" * 64)
+ logging.info("SUMMARY:")
+ logging.info(f"success ratio: {len(success_results)}/{args.total_req}")
+ logging.info(f"duration: {end_bench-start_bench:.2f}s")
+ logging.info(f"throughput: {throughput:.2f} seconds of audio / s")
+ logging.info("=" * 64)
+ logging.info(f"AVG First Latency: {avg_first_fix:.0f}")
+ logging.info(f"P50: {percentile(first_fix_list,50):.0f}")
+ logging.info(f"P70: {percentile(first_fix_list,70):.0f}")
+ logging.info(f"P80: {percentile(first_fix_list,80):.0f}")
+ logging.info(f"P90: {percentile(first_fix_list,90):.0f}")
+ logging.info(f"P99: {percentile(first_fix_list,99):.0f}")
+ logging.info("=" * 64)
+ logging.info(f"AVG RTF: {avg_rtf:.2f}")
+ logging.info(f"P50: {percentile(rtf_list,50):.2f}")
+ logging.info(f"P70: {percentile(rtf_list,70):.2f}")
+ logging.info(f"P80: {percentile(rtf_list,80):.2f}")
+ logging.info(f"P90: {percentile(rtf_list,90):.2f}")
+ logging.info(f"P99: {percentile(rtf_list,99):.2f}")
+ logging.info("="*20)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
@@ -44,6 +44,7 @@ def main():
zero_shot_request.prompt_text = args.prompt_text
prompt_speech = load_wav(args.prompt_wav, 16000)
zero_shot_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
+ zero_shot_request.stream = args.stream
request.zero_shot_request.CopyFrom(zero_shot_request)
elif args.mode == 'cross_lingual':
logging.info('send cross_lingual request')
@@ -52,6 +53,13 @@ def main():
prompt_speech = load_wav(args.prompt_wav, 16000)
cross_lingual_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
request.cross_lingual_request.CopyFrom(cross_lingual_request)
+ elif args.mode == 'zero_shot_by_id':
+ logging.info('send zero_shot_by_id request')
+ zero_shot_by_id_request = cosyvoice_pb2.zeroshotByIdRequest()
+ zero_shot_by_id_request.tts_text = args.tts_text
+ zero_shot_by_id_request.spk_id = args.spk_id
+ zero_shot_by_id_request.stream = args.stream
+ request.zero_shot_by_id_request.CopyFrom(zero_shot_by_id_request)
else:
logging.info('send instruct request')
instruct_request = cosyvoice_pb2.instructRequest()
@@ -80,7 +88,7 @@ if __name__ == "__main__":
default='50000')
parser.add_argument('--mode',
default='sft',
- choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
+ choices=['sft', 'zero_shot', 'cross_lingual', 'instruct', 'zero_shot_by_id'],
help='request mode')
parser.add_argument('--tts_text',
type=str,
@@ -101,6 +109,9 @@ if __name__ == "__main__":
parser.add_argument('--tts_wav',
type=str,
default='demo.wav')
+ parser.add_argument('--stream',
+ action='store_true',
+ help='whether to use streaming inference')
args = parser.parse_args()
- prompt_sr, target_sr = 16000, 22050
+ prompt_sr, target_sr = 16000, 24000
main()
@@ -13,18 +13,21 @@ message Request{
zeroshotRequest zero_shot_request = 2;
crosslingualRequest cross_lingual_request = 3;
instructRequest instruct_request = 4;
+ zeroshotByIdRequest zero_shot_by_id_request = 5;
}
}
message sftRequest{
string spk_id = 1;
string tts_text = 2;
+ bool stream = 3;
}
message zeroshotRequest{
string tts_text = 1;
string prompt_text = 2;
bytes prompt_audio = 3;
+ bool stream = 4;
}
message crosslingualRequest{
@@ -38,6 +41,12 @@ message instructRequest{
string instruct_text = 3;
}
+message zeroshotByIdRequest{
+ string tts_text = 1;
+ string spk_id = 2;
+ bool stream = 3;
+}
+
message Response{
bytes tts_audio = 1;
}
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+export VLLM_WORKER_MULTIPROC_METHOD=spawn
+export ASCEND_RT_VISIBLE_DEVICES=0
+export TASK_QUEUE_ENABLE=1
+export CPU_AFFINITY_CONF=1
+export VLLM_ASCEND_ENABLE_MLP_OPTIMIZE=1
+export VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION=1
+export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so.2:$LD_PRELOAD
+
+# 配置参数
+MODEL_DIR="/home/Fun-CosyVoice3-0.5B-2512/"
+MAX_CONC=8
+
+python server.py --port 50099 --model_dir "$MODEL_DIR" --max_conc "$MAX_CONC" --graph_mode
@@ -27,49 +27,952 @@ sys.path.append('{}/../../..'.format(ROOT_DIR))
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import AutoModel
-logging.basicConfig(level=logging.DEBUG,
+logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s')
+logging.info("Initializing NPU environment...")
+try:
+ import torch_npu
+ from torch_npu.contrib import transfer_to_npu
+ torch_npu.npu.set_compile_mode(jit_compile=False)
+ torch.npu.config.allow_internal_format = False
+ logging.info("NPU initialized successfully.")
+except ImportError:
+ logging.warning("NPU libraries not found. Using CPU.")
+
+import threading
+import time
+import uuid
+import vllm
+import queue
+import torchair as tng
+from torchair.configs.compiler_config import CompilerConfig
+import datetime
+import json
+import gc
+from typing import Optional
+from enum import Enum
+from dataclasses import dataclass, field
+
+from hyperpyyaml import load_hyperpyyaml
+from cosyvoice.utils.file_utils import load_wav
+from cosyvoice.cli.frontend import CosyVoiceFrontEnd
+
+
+class ThcSectionStatus(str, Enum):
+ LLM_DOING = 'llm_doing'
+ T2W_WAITING = 't2w_waiting'
+ T2W_DOING = 't2w_doing'
+ DONE = 'done'
+
+@dataclass
+class ThcSection:
+ idx: tuple
+ mode: str
+ spk_id: str
+ first: bool
+ finalize: bool
+ tokens: list
+ token_begin_pos: int
+ token_end_pos: int
+ token_begin_pos_in_batch: Optional[int] = None
+ token_end_pos_in_batch: Optional[int] = None
+ flow_begin_pos_in_batch: Optional[int] = None
+ flow_end_pos_in_batch: Optional[int] = None
+ flow_cache: Optional[torch.Tensor] = None
+ speech_begin_pos_in_batch: Optional[int] = None
+ speech_end_pos_in_batch: Optional[int] = None
+ speech: Optional[torch.Tensor] = None
+ speech_cpu: Optional[np.array] = None
+ speech_dt: Optional[float] = None
+ status: ThcSectionStatus = ThcSectionStatus.LLM_DOING
+
+class ThcMetric:
+ def __init__(self, begin=None):
+ self.begin: Optional[datetime.datetime] = begin
+ self.end: Optional[datetime.datetime] = None
+ self.dt: Optional[float] = None
+
+ def start(self):
+ self.begin = self.begin or datetime.datetime.now()
+
+ def stop(self):
+ self.end = self.end or datetime.datetime.now()
+ self.dt = (self.end - self.begin).total_seconds()
+
+@dataclass
+class ThcAllMetric:
+ e2e: ThcMetric = field(default_factory=ThcMetric)
+ llm: ThcMetric = field(default_factory=ThcMetric)
+ t2w: list = field(default_factory=list)
+
+@dataclass
+class ThcRequestInfo:
+ idx: int
+ request_id: str
+ mode: str
+ spk_id: str
+ first: bool = False
+ finalize: bool = False
+ done: bool = False
+ llm_output: list = field(default_factory=list)
+ llm_done: bool = False
+ section_queue: queue.Queue = field(default_factory=queue.Queue)
+ flow_cache: Optional[torch.Tensor] = None
+ speech: Optional[torch.Tensor] = None
+ speech_cum_dt: float = 0.0
+ dynamic_rtf: Optional[float] = None
+ dynamic_slack: Optional[float] = None
+ release_dt: Optional[float] = None
+ sections: list = field(default_factory=list)
+ release: queue.Queue = field(default_factory=queue.Queue)
+ metric: ThcAllMetric = field(default_factory=ThcAllMetric)
+
+@dataclass
+class ThcT2WBatchInput:
+ batch_requests: list = field(default_factory=ThcSection)
+ tokens: Optional[torch.Tensor] = None
+ token_len: Optional[torch.Tensor] = None
+ flow_prompt_speech_token: Optional[torch.Tensor] = None
+ flow_prompt_speech_token_len: Optional[torch.Tensor] = None
+ prompt_speech_feat: Optional[torch.Tensor] = None
+ prompt_speech_feat_len: Optional[torch.Tensor] = None
+ flow_embedding: Optional[torch.Tensor] = None
+ stream: bool = False
+ finalize: bool = False
+ n_timesteps: int = 10
+
+
+class Token2Wave():
+ def __init__(self,
+ device,
+ fp16: bool = False,
+ model_dir: str = None,
+ speaker_info_dir: str = None,
+ graph_mode: bool = False):
+ hyper_yaml_path = os.path.join(model_dir, 'cosyvoice3.yaml')
+ with open(hyper_yaml_path, 'r') as f:
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
+
+ self.device = device
+ self.flow_len_per_token = 2
+ self.speech_len_per_token = 960
+ self.speech_sample_rate = configs['sample_rate']
+ self.token_sample_rate = configs['sample_rate'] / self.speech_len_per_token
+
+ self.model_dir = model_dir
+ self.graph_mode = graph_mode
+
+ frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
+ configs['feat_extractor'],
+ os.path.join(model_dir, 'campplus.onnx'),
+ os.path.join(model_dir, 'speech_tokenizer_v3.onnx'),
+ os.path.join(model_dir, 'spk2info.pt'),
+ configs['allowed_special'])
+ sample_rate = configs['sample_rate']
+
+ flow_ckpt_file = os.path.join(model_dir, 'flow.pt')
+ hift_ckpt_file = os.path.join(model_dir, 'hift.pt')
+
+ self.flow = configs['flow'].to(self.device)
+ self.flow.load_state_dict(torch.load(flow_ckpt_file, map_location=self.device, weights_only=True), strict=True)
+ self.flow.eval()
+ # in case hift_model is a hifigan model
+ self.hift = configs['hift'].to(self.device)
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_ckpt_file, map_location=self.device, weights_only=True).items()}
+ self.hift.load_state_dict(hift_state_dict, strict=True)
+ self.hift.eval()
+
+ self.prompt_wave_info = None
+
+ self.lookahead_token_len = self.flow.pre_lookahead_len
+ self.pre_gap_token_len = self.flow.pre_lookahead_len * 4
+ self.post_gap_token_len = self.flow.pre_lookahead_len * 1
+
+ self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]
+
+ self.perfs = []
+
+ self.batches_requests = queue.Queue()
+ self.inference_event = threading.Event()
+ self.lock = threading.Lock()
+
+ del configs
+
+ if graph_mode:
+ config = CompilerConfig()
+ npu_backend = tng.get_npu_backend(compiler_config=config)
+ self.flow.decoder.estimator.forward = torch.compile(
+ self.flow.decoder.estimator.forward,
+ dynamic=True,
+ fullgraph=True,
+ backend=npu_backend
+ )
+ logging.info("Token2Wave: compile graph success")
+
+ def add_requests(self, batch_requests):
+ self.batches_requests.put(batch_requests)
+ self.inference_event.set()
+
+ def prepare_batch_data(self, batch_requests):
+ merged_tokens = []
+ mode = None
+ spk_id = None
+ pos = 0
+ for request in batch_requests:
+ mode = mode or request.mode
+ assert request.mode == mode
+ spk_id = spk_id or request.spk_id
+ assert request.spk_id == spk_id
+
+ merged_tokens.extend(request.tokens)
+ request.token_begin_pos_in_batch = pos + request.token_begin_pos
+ request.token_end_pos_in_batch = pos + request.token_end_pos
+ pos += request.token_end_pos
+ pos += self.post_gap_token_len
+ request.flow_begin_pos_in_batch = int(request.token_begin_pos_in_batch * self.flow_len_per_token)
+ request.flow_end_pos_in_batch = int(request.token_end_pos_in_batch * self.flow_len_per_token)
+ request.status = ThcSectionStatus.T2W_DOING
+ logging.debug(f"Token2Wave.prepare_batch_data: {request.idx}, "
+ f"merged_tokens = {len(merged_tokens)}, "
+ f"start_pos_in_batch = {request.token_begin_pos_in_batch}, end_pos_in_batch = {request.token_end_pos_in_batch}, "
+ f"pos = {pos}")
+ assert pos == len(merged_tokens), f"pos = {pos}, len(merged_tokens) = {len(merged_tokens)}"
+
+ tokens = torch.tensor(merged_tokens, dtype=torch.int32, device=self.device).unsqueeze(dim=0)
+ token_len = torch.tensor([tokens.shape[1]], dtype=torch.int32, device=self.device)
+
+ if mode == 'zero_shot_by_id':
+ flow_prompt_speech_token = self.prompt_wave_info[spk_id]['flow_prompt_speech_token']
+ flow_prompt_speech_token_len = self.prompt_wave_info[spk_id]['flow_prompt_speech_token_len']
+ prompt_speech_feat = self.prompt_wave_info[spk_id]['prompt_speech_feat']
+ prompt_speech_feat_len = self.prompt_wave_info[spk_id]['prompt_speech_feat_len']
+ elif mode == 'sft':
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32, device=self.device)
+ flow_prompt_speech_token_len = torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32, device=self.device)
+ prompt_speech_feat = torch.zeros(1, 0, 80, device=self.device)
+ prompt_speech_feat_len = torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32, device=self.device)
+
+ flow_embedding = self.prompt_wave_info[spk_id]['flow_embedding']
+ stream = True
+ finalize = False
+ n_timesteps = 5 if any([request.first for request in batch_requests]) else 10
+
+ batch_data = ThcT2WBatchInput(batch_requests=batch_requests,
+ tokens=tokens,
+ token_len=token_len,
+ flow_prompt_speech_token=flow_prompt_speech_token,
+ flow_prompt_speech_token_len=flow_prompt_speech_token_len,
+ prompt_speech_feat=prompt_speech_feat,
+ prompt_speech_feat_len=prompt_speech_feat_len,
+ flow_embedding=flow_embedding,
+ stream=stream,
+ finalize=finalize,
+ n_timesteps=n_timesteps)
+
+ return batch_data
+
+ def step(self):
+ # modify if enable multi_threads
+ batch_requests = self.batches_requests.get_nowait()
+
+ metric = ThcMetric()
+ metric.start()
+ logging.debug(f"Token2Wave.infer: start")
+
+ batch_data = self.prepare_batch_data(batch_requests)
+
+ tts_mel, _ = self.flow.inference(token=batch_data.tokens,
+ token_len=batch_data.token_len,
+ prompt_token=batch_data.flow_prompt_speech_token,
+ prompt_token_len=batch_data.flow_prompt_speech_token_len,
+ prompt_feat=batch_data.prompt_speech_feat,
+ prompt_feat_len=batch_data.prompt_speech_feat_len,
+ embedding=batch_data.flow_embedding,
+ streaming=batch_data.stream,
+ finalize=batch_data.finalize,
+ n_timesteps=batch_data.n_timesteps)
+
+ logging.debug(f"tts_mel = {tts_mel.shape}")
+ assert tts_mel.shape[-1] == (batch_data.tokens.shape[-1] - self.lookahead_token_len) * 2
+
+ logging.debug(f"self.hift.f0_predictor.condnet[0].causal_padding = {self.hift.f0_predictor.condnet[0].causal_padding}")
+ prefix_flow_len = 16 * self.flow_len_per_token
+ token_offset = 4
+
+ speech_pos = 0
+ merged_tts_mel = []
+ for request in batch_data.batch_requests:
+ request_tts_mel = tts_mel[..., request.flow_begin_pos_in_batch : request.flow_end_pos_in_batch]
+ speech_len = (request.token_end_pos - request.token_begin_pos) * self.speech_len_per_token
+ assert request.first == (request.flow_cache is None)
+ if request.first:
+ flow_prefix = request_tts_mel.new_zeros([*request_tts_mel.shape[: -1], prefix_flow_len])
+ request.flow_cache = request_tts_mel
+ request.speech_begin_pos_in_batch = speech_pos + flow_prefix.shape[-1] // self.flow_len_per_token * self.speech_len_per_token
+ speech_len -= token_offset * self.speech_len_per_token
+ else:
+ flow_prefix = request.flow_cache[..., -prefix_flow_len :]
+ request.flow_cache = torch.cat([request.flow_cache, request_tts_mel], dim=-1)
+ assert flow_prefix.shape[-1] >= token_offset * self.flow_len_per_token, f"request.flow_cache = {request.flow_cache.shape}, flow_prefix = {flow_prefix.shape}"
+ request.speech_begin_pos_in_batch = speech_pos + (flow_prefix.shape[-1] // self.flow_len_per_token - token_offset) * self.speech_len_per_token
+
+ if request.finalize:
+ flow_postfix = request_tts_mel.new_zeros([*request_tts_mel.shape[: -1], prefix_flow_len])
+ request_tts_mel = torch.cat([flow_prefix, request_tts_mel, flow_postfix], dim=-1)
+ speech_len += token_offset * self.speech_len_per_token
+ else:
+ request_tts_mel = torch.cat([flow_prefix, request_tts_mel], dim=-1)
+ request.speech_end_pos_in_batch = request.speech_begin_pos_in_batch + speech_len
+ merged_tts_mel.append(request_tts_mel)
+ speech_pos += request_tts_mel.shape[-1] // self.flow_len_per_token * self.speech_len_per_token
+ logging.debug(f"Token2Wave.step: {request.idx}, finalize = {request.finalize}, "
+ f"speech_begin_pos_in_batch = {request.speech_begin_pos_in_batch}, speech_end_pos_in_batch = {request.speech_end_pos_in_batch}, "
+ f"speech_pos = {speech_pos}")
+ merged_tts_mel = torch.cat(merged_tts_mel, dim=-1)
+
+ tts_speech, _ = self.hift.inference(speech_feat=merged_tts_mel, finalize=True)
+ tts_speech_cpu = (tts_speech * (2 ** 15)).to(torch.int16).view(-1).cpu().numpy()
+ logging.debug(f"merged_tts_mel = {merged_tts_mel.shape}, tts_speech_cpu = {tts_speech_cpu.shape}")
+
+ for request in batch_data.batch_requests:
+ request.speech = tts_speech[:, request.speech_begin_pos_in_batch : request.speech_end_pos_in_batch]
+ request.speech_cpu = tts_speech_cpu[request.speech_begin_pos_in_batch : request.speech_end_pos_in_batch]
+ assert request.speech_cpu.shape[0] == request.speech_end_pos_in_batch - request.speech_begin_pos_in_batch, f"{request.speech_cpu.shape[0]} != {request.speech_end_pos_in_batch} - {request.speech_begin_pos_in_batch}"
+ request.speech_dt = request.speech_cpu.shape[0] / self.speech_sample_rate
+ request.status = ThcSectionStatus.DONE
+
+ metric.stop()
+
+ perf = {'token_len': batch_data.token_len[0].item(), 'tts_speech_shape': list(tts_speech.shape), 'n_timesteps': batch_data.n_timesteps, 'metric': metric}
+ self.perfs.append(perf)
+ logging.debug(f"Token2Wave.infer: end: token_len = {perf['token_len']}, tts_speech_shape = {perf['tts_speech_shape']}, dt = {metric.dt:.3f}")
+
+ return batch_data.batch_requests
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
def __init__(self, args):
- self.cosyvoice = AutoModel(model_dir=args.model_dir)
- logging.info('grpc service initialized')
+ self.args = args
+
+ self.cosyvoice = AutoModel(
+ model_dir=self.args.model_dir,
+ load_vllm=True,
+ speaker_info_dir=self.args.speaker_info_dir,
+ )
+ logging.info(f'grpc service initialized, graph_mode={self.args.graph_mode}')
+
+ # taohouchao
+ self.prompt_emb = {}
+ for spk_id, spk_info in self.cosyvoice.frontend.spk2info.items():
+ if spk_id not in self.cosyvoice.frontend.spk2info_sft:
+ self.cosyvoice.add_sft_spk(spk_id)
+ model_input = self.cosyvoice.frontend.spk2info_sft[spk_id]
+ assert self.prompt_emb.get(spk_id) is None
+ self.prompt_emb[spk_id] = {
+ 'sos_emb': self.cosyvoice.model.llm.speech_embedding.weight[self.cosyvoice.model.llm.sos].reshape(1, 1, -1),
+ 'task_id_emb': self.cosyvoice.model.llm.speech_embedding.weight[self.cosyvoice.model.llm.task_id].reshape(1, 1, -1),
+ }
+
+ for spk_id, wave_info in self.cosyvoice.promote_wave_info.items():
+ if spk_id not in self.cosyvoice.frontend.spk2info:
+ self.cosyvoice.add_zero_shot_spk(wave_info['prompt_text'], wave_info['prompt_wav'], spk_id)
+ model_input = self.cosyvoice.frontend.spk2info[spk_id]
+ prompt_text = model_input['prompt_text']
+ llm_prompt_speech_token = model_input['llm_prompt_speech_token']
+ assert self.prompt_emb.get(spk_id) is None
+ self.prompt_emb[spk_id] = {
+ 'sos_emb': self.cosyvoice.model.llm.speech_embedding.weight[self.cosyvoice.model.llm.sos].reshape(1, 1, -1),
+ 'task_id_emb': self.cosyvoice.model.llm.speech_embedding.weight[self.cosyvoice.model.llm.task_id].reshape(1, 1, -1),
+ 'prompt_text_emb': self.cosyvoice.model.llm.llm.model.model.embed_tokens(prompt_text),
+ 'prompt_speech_token_emb': self.cosyvoice.model.llm.speech_embedding(llm_prompt_speech_token),
+ }
+
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.requests_info = {}
+ self.requests_id = [None for _ in range(self.args.max_conc)]
+
+ self.lock = threading.Lock()
+ self.t2w = Token2Wave(device=self.device,
+ fp16=False,
+ model_dir=self.args.model_dir,
+ speaker_info_dir=self.args.speaker_info_dir,
+ graph_mode=self.args.graph_mode)
+ self.t2w.prompt_wave_info = {**self.cosyvoice.frontend.spk2info, **self.cosyvoice.frontend.spk2info_sft}
+ self.t2w_first_token_len = 20 + self.t2w.post_gap_token_len
+ self.t2w_min_token_len = 25
+ self.t2w_total_token_len = {True: 800, False: 300}
+ self.t2w_delay = 0.2
+ self.release_timeout = 1
+ self.prev_release_time = None
+
+ # modify if enable multi_threads
+ # self.t2w_inference_work_thread = threading.Thread(target=self.t2w.inference, daemon=True)
+ # self.t2w_inference_work_thread.start()
+
+ self.inference_event = threading.Event()
+ self.inference_work_thread = threading.Thread(target=self.infer, daemon=True)
+ self.inference_work_thread.start()
+
+ if self.args.enable_vllm_profile:
+ assert False
+ try:
+ self.cosyvoice.model.llm.vllm.start_profile()
+ logging.info('vLLM profile started')
+ except Exception as e:
+ logging.warning(f'Failed to start vLLM profile: {e}')
+
+ def postprocess_vllm_outputs(self, llm_step_outputs, now):
+ with self.lock:
+ abort_requests_id = []
+ for output in llm_step_outputs:
+ request_id = output.request_id
+ token_ids = output.outputs[0].token_ids
+ request_info = self.requests_info.get(request_id)
+ if request_info is None:
+ abort_requests_id.append(request_id)
+ continue
+ request_info.llm_output = token_ids[: -1] if token_ids[-1] in self.cosyvoice.model.llm.stop_token_ids else token_ids
+ request_info.llm_done = output.finished
+ if request_info.llm_done:
+ request_info.metric.llm.stop()
+
+ if len(abort_requests_id) > 0:
+ logging.info(f"postprocess_vllm_outputs: abort_requests_id = {abort_requests_id}")
+ self.cosyvoice.model.llm.vllm.abort_request(abort_requests_id)
+
+ def schedule_nonfirst_requests_t2w(self, now):
+ requests_history_token_len = {}
+ requests_remain_token_len = {}
+ requests_slack = {}
+
+ for request_id, request_info in self.requests_info.items():
+ if len(request_info.sections) == 0 or any([section.status != ThcSectionStatus.DONE for section in request_info.sections]):
+ continue
+ history_token_len = sum([section.token_end_pos - section.token_begin_pos for section in request_info.sections])
+ requests_history_token_len[request_id] = history_token_len
+ new_token_len = len(request_info.llm_output) - history_token_len
+ if new_token_len == 0:
+ continue
+ requests_remain_token_len[request_id] = new_token_len
+ calc_cum_dt = (now - request_info.metric.t2w[0].end).total_seconds()
+ requests_slack[request_id] = request_info.speech_cum_dt - calc_cum_dt
+
+ flag = [len(request_info.sections) >= 1 for request_info in self.requests_info.values()]
+ available_token_len = self.t2w_total_token_len[len(flag) == self.args.max_conc and all(flag)]
+ t2w_min_token_len = max(self.t2w_min_token_len, available_token_len // len(requests_remain_token_len) - (self.t2w.pre_gap_token_len + self.t2w.post_gap_token_len))
+
+ flag = [token_len >= t2w_min_token_len or self.requests_info[request_id].llm_done for request_id, token_len in requests_remain_token_len.items()]
+ if len(flag) == 0 or all(flag) == False:
+ return {}
+
+ requests_priority = {}
+ requests_t2w_sched = {}
+ for request_id in requests_slack.keys():
+ requests_priority[request_id] = int((max(requests_slack.values()) - requests_slack[request_id]) * self.t2w.token_sample_rate)
+ requests_t2w_sched[request_id] = {'history_len': requests_history_token_len[request_id], 'alloc_len': 0, 'full': None}
+
+ estimate_token_len_flag = False
+
+ if estimate_token_len_flag:
+ estimate_t2w_calc_dt = lambda v: int(212 * (v ** 2) + 646.7 * v - 153.5)
+ expected_t2w_calc_dt = min(1.0, min(requests_slack.values()))
+ available_token_len = max(100, estimate_t2w_calc_dt(expected_t2w_calc_dt))
+
+ brief = {self.requests_info[request_id].idx: {'slack': round(requests_slack[request_id], 3), 'remain': requests_remain_token_len[request_id], 'sections': len(self.requests_info[request_id].sections)} for request_id in requests_slack.keys()}
+ logging.debug(f"available_token_len = {available_token_len}, {brief}")
+ logging.info(f"schedule_nonfirst_requests_t2w: available_token_len = {available_token_len}")
+
+ for request_id in requests_remain_token_len.keys():
+ request_info = self.requests_info[request_id]
+ assert requests_t2w_sched[request_id]['alloc_len'] == 0
+ alloc_token_len = min(requests_remain_token_len[request_id], t2w_min_token_len)
+ alloc_token_len_with_overhead = alloc_token_len + self.t2w.pre_gap_token_len + self.t2w.post_gap_token_len
+ available_token_len -= alloc_token_len_with_overhead
+ requests_remain_token_len[request_id] -= alloc_token_len
+ requests_priority[request_id] -= alloc_token_len
+ requests_t2w_sched[request_id]['alloc_len'] += alloc_token_len
+ requests_t2w_sched[request_id]['full'] = requests_remain_token_len[request_id] == 0
+ assert requests_remain_token_len[request_id] >= 0
+
+ while available_token_len > 0 and max(requests_remain_token_len.values()) > 0:
+ requests_remain_token_len = {request_id: token_len for request_id, token_len in requests_remain_token_len.items() if token_len > 0}
+ requests_priority = {request_id: requests_priority[request_id] for request_id in requests_remain_token_len.keys()}
+ request_id = max(requests_priority.items(), key=lambda v: v[1])[0]
+ assert requests_t2w_sched[request_id]['alloc_len'] > 0
+ if requests_t2w_sched[request_id]['alloc_len'] == 0:
+ alloc_token_len = min(requests_remain_token_len[request_id], t2w_min_token_len)
+ alloc_token_len_with_overhead = alloc_token_len + self.t2w.pre_gap_token_len + self.t2w.post_gap_token_len
+ else:
+ alloc_token_len = 1
+ alloc_token_len_with_overhead = 1
+ available_token_len -= alloc_token_len_with_overhead
+ requests_remain_token_len[request_id] -= alloc_token_len
+ requests_priority[request_id] -= alloc_token_len
+ requests_t2w_sched[request_id]['alloc_len'] += alloc_token_len
+ requests_t2w_sched[request_id]['full'] = requests_remain_token_len[request_id] == 0
+ assert requests_remain_token_len[request_id] >= 0
+
+ requests_t2w_sched = {request_id: sched for request_id, sched in requests_t2w_sched.items() if sched['alloc_len'] > 0}
+
+ return requests_t2w_sched
+
+ def maybe_t2w_infer(self, now):
+ with self.lock:
+ first_flag = []
+ for request_id, request_info in self.requests_info.items():
+ if len(request_info.sections) == 0:
+ request_info.first = (len(request_info.llm_output) >= self.t2w_first_token_len) or request_info.llm_done
+ first_flag.append(request_info.first)
+ else:
+ request_info.first = False
+
+ logging.debug(f"maybe_t2w_infer: first_flag = {first_flag}")
+
+ if len(first_flag) > 0:
+ if all(first_flag):
+ requests_t2w_sched = {request_id: {'history_len': 0,
+ 'alloc_len': min(self.t2w_first_token_len, len(request_info.llm_output)),
+ 'full': len(request_info.llm_output) <= self.t2w_first_token_len} \
+ for request_id, request_info in self.requests_info.items() if request_info.first}
+ else:
+ requests_t2w_sched = {}
+ else:
+ requests_t2w_sched = self.schedule_nonfirst_requests_t2w(now)
+ if len(requests_t2w_sched) == 0:
+ return None
+
+ t2w_requests = []
+ for request_id, sched in requests_t2w_sched.items():
+ request_info = self.requests_info[request_id]
+ t2w_finalize = request_info.llm_done and sched['full']
+
+ valid_tokens = request_info.llm_output[sched['history_len'] : sched['history_len'] + sched['alloc_len']]
+
+ if self.t2w.pre_gap_token_len > 0:
+ pre_gap = request_info.llm_output[: sched['history_len']][-self.t2w.pre_gap_token_len :]
+ pre_gap = [self.t2w.silent_tokens[0]] * (self.t2w.pre_gap_token_len - len(pre_gap)) + pre_gap
+ else:
+ pre_gap = []
+
+ if t2w_finalize:
+ post_gap = [self.t2w.silent_tokens[0]] * self.t2w.post_gap_token_len
+ else:
+ post_gap = []
+
+ tokens = pre_gap + valid_tokens + post_gap
+ rel_begin_pos = len(pre_gap)
+ rel_end_pos = rel_begin_pos + len(valid_tokens) - (0 if t2w_finalize else self.t2w.post_gap_token_len)
+
+ t2w_request = ThcSection(idx=(request_info.idx, request_id, len(request_info.sections)),
+ mode=request_info.mode,
+ spk_id=request_info.spk_id,
+ first=request_info.first,
+ finalize=t2w_finalize,
+ tokens=tokens,
+ token_begin_pos=rel_begin_pos,
+ token_end_pos=rel_end_pos,
+ flow_cache=request_info.flow_cache,
+ status=ThcSectionStatus.T2W_DOING)
+ request_info.sections.append(t2w_request)
+ t2w_requests.append(t2w_request)
+ request_info.metric.t2w.append(ThcMetric(begin=now))
+ logging.debug(f"t2w_infer: {request_info.idx}: t2w_request: {request_info.metric.t2w[-1].begin:%H:%M:%S.%f}, "
+ f"rel_begin_pos = {rel_begin_pos}, rel_end_pos = {rel_end_pos}, "
+ f"first = {t2w_request.first}, finalize = {t2w_request.finalize}")
+
+ assert len(t2w_requests) > 0
+ sections_brief = {f'{t2w_request.idx[0]}-{t2w_request.idx[2]}': len(t2w_request.tokens) for t2w_request in t2w_requests}
+ logging.info(f"t2w_infer: sections: {sections_brief}")
+ self.t2w.add_requests(t2w_requests)
+ outputs = self.t2w.step()
+
+ return outputs
+
+ def postprocess_t2w_outputs(self, t2w_step_outputs):
+ with self.lock:
+ for output in t2w_step_outputs:
+ request_id = output.idx[1]
+ section_idx = output.idx[2]
+ assert isinstance(request_id, str)
+ request_info = self.requests_info.get(request_id, None)
+ if request_info is None:
+ continue
+ request_info.flow_cache = output.flow_cache
+ if request_info.speech is None:
+ request_info.speech = output.speech
+ else:
+ request_info.speech = torch.cat([request_info.speech, output.speech], dim=-1)
+ request_info.speech_cum_dt += output.speech_dt
+ request_info.metric.t2w[section_idx].stop()
+ request_info.done = any([section.finalize for section in request_info.sections])
+ request_info.section_queue.put(output)
+
+ now = datetime.datetime.now()
+ for request_info in self.requests_info.values():
+ if request_info.done:
+ calc_cum_dt = (now - request_info.metric.e2e.begin).total_seconds()
+ request_info.dynamic_rtf = calc_cum_dt / request_info.speech_cum_dt
+ request_info.dynamic_slack = request_info.speech_cum_dt - calc_cum_dt
+ request_info.release_dt = (now - request_info.metric.t2w[-1].end).total_seconds()
+ logging.debug(f"infer: {request_info.idx}: "
+ f"begin = {request_info.metric.e2e.begin:%H:%M:%S.%f}, "
+ f"now = {now:%H:%M:%S.%f}, "
+ f"done = {request_info.done}, "
+ f"calc_cum_dt = {calc_cum_dt:.3f}, "
+ f"speech_cum_dt = {request_info.speech_cum_dt:.3f}, "
+ f"dynamic_rtf = {request_info.dynamic_rtf:.3f}, "
+ f"dynamic_slack = {request_info.dynamic_slack:.3f}")
+
+ batch_release_flag = True
+
+ if batch_release_flag:
+ can_release_flag = []
+ must_release_flag = []
+ for request_info in self.requests_info.values():
+ can_release_flag.append(request_info.done)
+ if request_info.done:
+ if len(request_info.sections) == 1:
+ flag = True
+ else:
+ dt = (now - request_info.metric.t2w[-2].end).total_seconds()
+ flag = dt > self.release_timeout
+ must_release_flag.append(flag)
+ else:
+ must_release_flag.append(False)
+
+ if all(can_release_flag) or any(must_release_flag):
+ self.prev_release_time = now
+ requests_id = list(self.requests_info.keys())
+ for request_id in requests_id:
+ request_info = self.requests_info[request_id]
+ if request_info.done:
+ self.requests_info.pop(request_id)
+ self.requests_id[request_info.idx] = None
+ request_info.release.put(True)
+ logging.debug(f"infer: remove: request_idx: {request_info.idx}, dynamic_rtf = {request_info.dynamic_rtf:.3f}, dynamic_slack = {request_info.dynamic_slack:.3f}")
+ else:
+ self.prev_release_time = None
+ else:
+ requests_id = list(self.requests_info.keys())
+ for request_id in requests_id:
+ request_info = self.requests_info[request_id]
+ if request_info.done:
+ self.prev_release_time = now
+ self.requests_info.pop(request_id)
+ self.requests_id[request_info.idx] = None
+ request_info.release.put(True)
+ logging.debug(f"infer: remove: request_idx: {request_info.idx}, dynamic_rtf = {request_info.dynamic_rtf:.3f}, dynamic_slack = {request_info.dynamic_slack:.3f}")
+
+
+ def infer(self):
+ llm = self.cosyvoice.model.llm
+ print_wait_flag = True
+ empty_cache_flag = True
+
+ logging.info("infer: start")
+
+ while True:
+ if print_wait_flag:
+ logging.info("infer: wait")
+ print_wait_flag = False
+
+ if not self.inference_event.wait(timeout=30):
+ if empty_cache_flag:
+ logging.info("infer: empty_cache")
+ gc.collect()
+ if self.device.type == 'npu':
+ torch_npu.npu.empty_cache()
+ empty_cache_flag = False
+ continue
+
+ if llm.vllm.get_num_unfinished_requests() == 0:
+ continue
+
+ print_wait_flag = True
+ empty_cache_flag = True
+
+ llm_section_begin = datetime.datetime.now()
+ self.prev_release_time = None
+ n_step = 0
+ logging.debug(f"infer: llm.vllm.get_num_unfinished_requests() = {llm.vllm.get_num_unfinished_requests()}")
+
+ while llm.vllm.has_unfinished_requests() or len(self.requests_info) > 0:
+ self.inference_event.clear()
+
+ if llm.vllm.has_unfinished_requests():
+ llm_step_outputs = llm.vllm.step()
+
+ now = datetime.datetime.now()
+ llm_section_end = now
+ logging.debug(f"infer: n_step = {n_step}, len(llm_step_outputs) = {len(llm_step_outputs)}, "
+ f"{[len(request_output.outputs) for request_output in llm_step_outputs]}, "
+ f"{[len(request_output.outputs[0].token_ids) for request_output in llm_step_outputs]}, "
+ f"{[request_output.outputs[0].token_ids[-1] for request_output in llm_step_outputs]}, "
+ f"{[request_output.finished for request_output in llm_step_outputs]}, "
+ f"{[request_output.outputs[0].finish_reason for request_output in llm_step_outputs]}")
+ n_step += 1
+
+ self.postprocess_vllm_outputs(llm_step_outputs, now)
+
+ if self.prev_release_time is None or \
+ (datetime.datetime.now() - self.prev_release_time).total_seconds() >= self.t2w_delay or \
+ len(self.requests_info) == self.args.max_conc:
+ t2w_step_outputs = self.maybe_t2w_infer(now)
+ if t2w_step_outputs is not None:
+ llm_section_dt = (llm_section_end - llm_section_begin).total_seconds()
+ t2w_perf = self.t2w.perfs[-1]
+ logging.debug(f"infer: llm_dt = {llm_section_dt:.3f}, t2w_token_len = {t2w_perf['token_len']}, t2w_dt = {t2w_perf['metric'].dt:.3f}")
+ llm_section_begin = datetime.datetime.now()
+ self.postprocess_t2w_outputs(t2w_step_outputs)
+ elif llm.vllm.has_unfinished_requests():
+ pass
+ else:
+ time.sleep(0.01)
+
+ def add_request(self, mode, tts_text, spk_id, stream_flag):
+ if stream_flag != True:
+ raise NotImplementedError("only support stream mode")
+
+ cosyvoice = self.cosyvoice
+ frontend = cosyvoice.frontend
+ model = cosyvoice.model
+ llm = model.llm
+
+ request_id = str(uuid.uuid1())
+
+ spk_prompt_emb = self.prompt_emb[spk_id]
+ sos_emb = spk_prompt_emb['sos_emb']
+ task_id_emb = spk_prompt_emb['task_id_emb']
+
+ # modify if iter text
+ tts_texts = frontend.text_normalize(tts_text, split=True, text_frontend=frontend.text_frontend)
+ tts_text = ''.join(tts_texts)
+ if mode == 'zero_shot_by_id':
+ model_input = frontend.frontend_zero_shot(tts_text, None, None, None, spk_id)
+ text = model_input['text']
+ text_len = model_input['text_len']
+ text_emb = llm.llm.model.model.embed_tokens(text)
+ prompt_text_emb = spk_prompt_emb['prompt_text_emb']
+ prompt_speech_token_emb = spk_prompt_emb['prompt_speech_token_emb']
+ lm_input = torch.concat([sos_emb, prompt_text_emb, text_emb, task_id_emb, prompt_speech_token_emb], dim=1)
+ elif mode == 'sft':
+ tts_text = '<|endofprompt|>' + tts_text
+ model_input = frontend.frontend_sft(tts_text, spk_id)
+ text = model_input['text']
+ text_len = model_input['text_len']
+ text_emb = llm.llm.model.model.embed_tokens(text)
+ lm_input = torch.concat([sos_emb, text_emb, task_id_emb], dim=1)
+
+ logging.debug(f"add_request: lm_input: [{lm_input.dtype}, lm_input.shape: {lm_input.shape}]")
+
+ min_token_text_ratio = 2
+ max_token_text_ratio = 20
+ min_len = int(text_len * min_token_text_ratio)
+ max_len = int(text_len * max_token_text_ratio)
+
+ sampling = 25
+ sampling_params = vllm.SamplingParams(top_k=sampling,
+ stop_token_ids=llm.stop_token_ids,
+ min_tokens=min_len,
+ max_tokens=max_len)
+ logging.debug(f"add_request: sampling_params: {sampling_params}")
+ now = datetime.datetime.now()
+
+ with self.lock:
+ for n in range(len(self.requests_id)):
+ other_request_id = self.requests_id[n]
+ other_requests_info = self.requests_info.get(other_request_id, None)
+ if other_requests_info is not None and (now - other_requests_info.metric.e2e.begin).total_seconds() > 60:
+ logging.info(f"add_request: request_idx: {n}, {other_request_id}, timeout")
+ self.requests_info.pop(other_request_id)
+ self.requests_id[n] = None
+
+ request_idx = [n for n in range(self.args.max_conc) if self.requests_id[n] is None]
+ assert len(request_idx) > 0, f"requests_id = {self.requests_id}"
+ request_idx = request_idx[0]
+ request_info = ThcRequestInfo(
+ idx=request_idx,
+ request_id=request_id,
+ mode=mode,
+ spk_id=spk_id,
+ )
+ request_info.metric.e2e.start()
+ request_info.metric.llm.start()
+
+ self.requests_id[request_idx] = request_id
+ self.requests_info[request_id] = request_info
+
+ llm.vllm.add_request(request_id, {'prompt_embeds': lm_input.squeeze(0).to(torch.bfloat16)}, sampling_params)
+
+ self.inference_event.set()
+
+ logging.info(f"add_request: request_idx: {request_idx}, {request_id}, mode = {mode}, tts_text = {tts_text}, spk_id = {spk_id}")
+
+ return request_info
+
+ def check_context_and_get_queue(self, request_info, context, q, timeout, q_name):
+ request_idx = request_info.idx
+ request_id = request_info.request_id
+ unit = 0.2
+
+ for n in range(max(1, int(timeout / unit))):
+ if (context is not None) and (not context.is_active()):
+ logging.info(f"check_context: request_idx: {request_idx}, {request_id}, client disconnected")
+ return None
+ try:
+ data = q.get(timeout=unit)
+ return data
+ except queue.Empty:
+ continue
+ else:
+ logging.info(f"check_queue: request_idx: {request_idx}, {request_id}, get {q_name} timeout")
+ return None
+
+ def postprocess_request(self, request_info):
+ request_idx = request_info.idx
+ request_id = request_info.request_id
+
+ worst_slack = 999
+ for n in range(1, len(request_info.sections)):
+ speech_cum_dt = sum([section.speech_dt for section in request_info.sections[: n]])
+ calc_cum_dt = (request_info.metric.t2w[n].end - request_info.metric.t2w[0].end).total_seconds()
+ worst_slack = min(worst_slack, speech_cum_dt - calc_cum_dt)
+
+ logging.debug(f"request_info.llm_output = {request_info.llm_output}")
+
+ request_info.metric.e2e.stop()
+ first_latency = (request_info.metric.t2w[0].end - request_info.metric.e2e.begin).total_seconds()
+ rtf = request_info.metric.e2e.dt / request_info.speech_cum_dt
+ logging.info(f"complete_request: request_idx: {request_idx}, {request_id}, "
+ f"submit_time = {request_info.metric.e2e.begin:%H:%M:%S.%f}, "
+ f"e2e = {request_info.metric.e2e.dt:.3f}, "
+ f"slack = {worst_slack:.3f}, "
+ f"first_latency = {first_latency:.3f}, "
+ f"rtf = {rtf:.3f}, "
+ f"speech_dt = {request_info.speech_cum_dt:.3f}")
+
+
+ def process_request(self, mode, tts_text, spk_id, stream_flag, context):
+ logging.debug(f"process_request: tts_text = {tts_text}")
+ request_info = self.add_request(mode, tts_text, spk_id, stream_flag)
+ request_idx = request_info.idx
+ request_id = request_info.request_id
+
+ n_section = 0
+ while True:
+ section = self.check_context_and_get_queue(request_info, context, request_info.section_queue, timeout=10, q_name='section')
+ if section is None:
+ with self.lock:
+ self.requests_info.pop(request_id)
+ self.requests_id[request_idx] = None
+ return
+ assert section.idx[2] == n_section, f"section.idx[2] = {section.idx[2]}, n_section = {n_section}"
+ n_section += 1
+
+ logging.debug(f"{request_id}, section.speech_cpu = {section.speech_cpu.shape}, speech_dt = {section.speech_dt}")
+
+ if section.finalize and request_info.done:
+ release = self.check_context_and_get_queue(request_info, context, request_info.release, timeout=30, q_name='release')
+ if release is None:
+ with self.lock:
+ self.requests_info.pop(request_id)
+ self.requests_id[request_idx] = None
+ return
+
+ self.postprocess_request(request_info)
+
+ yield section.speech_cpu, section.finalize and request_info.done
+
+ if request_info.done:
+ break
+
+ def warmup(self):
+ tts_text = '1,2,3,4,5,6,7,8,9,10'
+ spk_id = 'spk1'
+ stream_flag = True
+
+ for n_req in range(self.args.max_conc):
+ request_info = self.add_request('zero_shot_by_id', tts_text, spk_id, stream_flag)
+
+ while True:
+ time.sleep(1)
+
+ with self.lock:
+ request_num = len(self.requests_info)
+
+ if request_num == 0:
+ break
def Inference(self, request, context):
- if request.HasField('sft_request'):
- logging.info('get sft inference request')
- model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id)
- elif request.HasField('zero_shot_request'):
- logging.info('get zero_shot inference request')
- prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
- prompt_speech_16k = prompt_speech_16k.float() / (2**15)
- model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text,
- request.zero_shot_request.prompt_text,
- prompt_speech_16k)
- elif request.HasField('cross_lingual_request'):
- logging.info('get cross_lingual inference request')
- prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
- prompt_speech_16k = prompt_speech_16k.float() / (2**15)
- model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
+ logging.debug(f"Inference: context = {context}, PID = {os.getpid()}, TID = {threading.get_ident()}")
+ logging.debug(f"Inference: tts_text = {request.zero_shot_by_id_request.tts_text}, spk_id = {request.zero_shot_by_id_request.spk_id}")
+ high_performace_flag = True
+
+ if high_performace_flag:
+ if request.HasField('sft_request'):
+ mode = 'sft'
+ tts_text = request.sft_request.tts_text
+ spk_id = request.sft_request.spk_id
+ stream_flag = request.sft_request.stream
+ elif request.HasField('zero_shot_by_id_request'):
+ mode = 'zero_shot_by_id'
+ tts_text = request.zero_shot_by_id_request.tts_text
+ spk_id = request.zero_shot_by_id_request.spk_id
+ stream_flag = request.zero_shot_by_id_request.stream
+ else:
+ raise NotImplementedError('Mode not implemented!')
+
+ for tts_speech, done_flag in self.process_request(mode, tts_text, spk_id, stream_flag, context):
+ response = cosyvoice_pb2.Response()
+ response.tts_audio = tts_speech.tobytes()
+ yield response
+ if done_flag:
+ break
+
else:
- logging.info('get instruct inference request')
- model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,
- request.instruct_request.spk_id,
- request.instruct_request.instruct_text)
+ if request.HasField('zero_shot_by_id_request'):
+ logging.info('get zero_shot_by_id inference request')
+ model_output = self.cosyvoice.inference_zero_shot_by_id(
+ request.zero_shot_by_id_request.tts_text,
+ request.zero_shot_by_id_request.spk_id,
+ stream=request.zero_shot_by_id_request.stream
+ )
+ else:
+ assert False
+ logging.info('get instruct inference request')
+ model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,
+ request.instruct_request.spk_id,
+ request.instruct_request.instruct_text)
- logging.info('send inference response')
- for i in model_output:
- response = cosyvoice_pb2.Response()
- response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
- yield response
+ logging.info('send inference response')
+
+ count=0
+ for i in model_output:
+ response = cosyvoice_pb2.Response()
+ response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
+ yield response
+ count+=1
+ if count==4 and self.args.enable_vllm_profile:
+ assert False
+ try:
+ self.cosyvoice.model.llm.vllm.stop_profile()
+ logging.info('vLLM profile stopped')
+ except Exception as e:
+ logging.warning(f'Failed to stop vLLM profile: {e}')
def main():
+ service = CosyVoiceServiceImpl(args)
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
- cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
+ cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(service, grpcServer)
grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port))
grpcServer.start()
+
+ if args.graph_mode:
+ logging.info(f"warmup: start")
+ service.warmup()
+ logging.info(f"warmup: end")
+
logging.info("server listening on 0.0.0.0:{}".format(args.port))
grpcServer.wait_for_termination()
@@ -86,5 +989,16 @@ if __name__ == '__main__':
type=str,
default='iic/CosyVoice2-0.5B',
help='local path or modelscope repo id')
+ parser.add_argument('--graph_mode',
+ action='store_true',
+ help='whether to use graph mode')
+ parser.add_argument('--speaker_info_dir',
+ type=str,
+ default='.',
+ help='directory containing speaker info files (default: model_dir/speaker_info)')
+ parser.add_argument('--enable_vllm_profile',
+ action='store_true',
+ help='whether to enable vLLM profiling (default: False)')
args = parser.parse_args()
+ logging.info(f"args: {args}")
main()
new file mode 100644
@@ -0,0 +1,6 @@
+{
+ "spk1":{
+ "prompt_wav":"../../../asset/zero_shot_prompt.wav",
+ "prompt_text":"You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"
+ }
+}
\ No newline at end of file
new file mode 100644
@@ -0,0 +1 @@
+python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto
\ No newline at end of file