From b70c9af573bf9cb4b0d8170bb6cc0658a3900a38 Mon Sep 17 00:00:00 2001
From: shikang <shikang12@huawei.com>
Date: Tue, 4 Mar 2025 14:47:07 +0800
Subject: [PATCH] add new code
cosyvoice/cli/cosyvoice.py | 12 +++++++-
cosyvoice/cli/frontend.py | 17 +++++++----
cosyvoice/cli/model.py | 37 ++++++++++++------------
cosyvoice/flow/flow_matching.py | 20 +++++++++----
cosyvoice/llm/llm.py | 2 +-
cosyvoice/transformer/attention.py | 1 -
cosyvoice/transformer/encoder.py | 45 ++++++++++++++++++++++++++++++
7 files changed, 101 insertions(+), 33 deletions(-)
@@ -13,11 +13,13 @@
# limitations under the License.
import os
import time
+import platform
from typing import Generator
from tqdm import tqdm
from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download
import torch
+from ais_bench.infer.interface import InferSession
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.utils.file_utils import logging
@@ -26,7 +28,7 @@ from cosyvoice.utils.class_utils import get_model_type
class CosyVoice:
- def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, load_om=False):
self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
self.fp16 = fp16
@@ -57,6 +59,14 @@ class CosyVoice:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
self.fp16)
+ if load_om:
+ arch = platform.machine()
+ system = platform.system().lower()
+ flow_om = InferSession(0, '{}/flow_{}_{}.om'.format(model_dir, system ,arch))
+ speech_om = InferSession(0, '{}/speech_{}_{}.om'.format(model_dir, system ,arch))
+ self.frontend.speech_om = speech_om
+ self.frontend.flow_om = flow_om
+ self.model.flow.decoder.flow_om = flow_om
del configs
def list_available_spks(self):
@@ -71,6 +71,8 @@ class CosyVoiceFrontEnd:
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
self.en_tn_model = EnNormalizer()
self.inflect_parser = inflect.engine()
+ self.speech_om = None
+ self.flow_om = None
def _extract_text_token(self, text):
if isinstance(text, Generator):
@@ -92,11 +94,16 @@ class CosyVoiceFrontEnd:
def _extract_speech_token(self, speech):
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
- speech_token = self.speech_tokenizer_session.run(None,
- {self.speech_tokenizer_session.get_inputs()[0].name:
- feat.detach().cpu().numpy(),
- self.speech_tokenizer_session.get_inputs()[1].name:
- np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
+ if torch.npu.is_available():
+ feed = [feat.detach().cpu().numpy(), np.array([feat.shape[2]], dtype=np.int32)]
+ speech_token = self.speech_om.infer(feed, mode='dymshape', custom_sizes=[100000000])[0].flatten().tolist()
+ self.flow_om.set_context()
+ else:
+ speech_token = self.speech_tokenizer_session.run(None,
+ {self.speech_tokenizer_session.get_inputs()[0].name:
+ feat.detach().cpu().numpy(),
+ self.speech_tokenizer_session.get_inputs()[1].name:
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
return speech_token, speech_token_len
@@ -99,25 +99,24 @@ class CosyVoiceModel:
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
- with self.llm_context:
- if isinstance(text, Generator):
- assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
- for i in self.llm.inference_bistream(text=text,
- prompt_text=prompt_text.to(self.device),
- prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
- prompt_speech_token=llm_prompt_speech_token.to(self.device),
- prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
- embedding=llm_embedding.to(self.device)):
- self.tts_speech_token_dict[uuid].append(i)
- else:
- for i in self.llm.inference(text=text.to(self.device),
- text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
- prompt_text=prompt_text.to(self.device),
- prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
- prompt_speech_token=llm_prompt_speech_token.to(self.device),
- prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
- embedding=llm_embedding.to(self.device)):
- self.tts_speech_token_dict[uuid].append(i)
+ if isinstance(text, Generator):
+ assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
+ for i in self.llm.inference_bistream(text=text,
+ prompt_text=prompt_text.to(self.device),
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+ embedding=llm_embedding.to(self.device)):
+ self.tts_speech_token_dict[uuid].append(i)
+ else:
+ for i in self.llm.inference(text=text.to(self.device),
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
+ prompt_text=prompt_text.to(self.device),
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+ embedding=llm_embedding.to(self.device)):
+ self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
@@ -14,6 +14,7 @@
import threading
import torch
import torch.nn.functional as F
+import numpy as np
from matcha.models.components.flow_matching import BASECFM
@@ -32,6 +33,7 @@ class ConditionalCFM(BASECFM):
# Just change the architecture of the estimator here
self.estimator = estimator
self.lock = threading.Lock()
+ self.flow_om = None
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
@@ -105,12 +107,18 @@ class ConditionalCFM(BASECFM):
t_in[:] = t.unsqueeze(0)
spks_in[0] = spks
cond_in[0] = cond
- dphi_dt = self.forward_estimator(
- x_in, mask_in,
- mu_in, t_in,
- spks_in,
- cond_in
- )
+ if torch.npu.is_available():
+ feed_list = [x_in, mask_in, mu_in, t_in, spks_in, cond_in]
+ feed = [i.cpu().detach().numpy().astype(np.float32) for i in feed_list]
+ dphi_dt = self.flow_om.infer(feed, mode="dymshape", custom_sizes=10000000)
+ dphi_dt = torch.from_numpy(dphi_dt[0]).npu()
+ else:
+ dphi_dt = self.forward_estimator(
+ x_in, mask_in,
+ mu_in, t_in,
+ spks_in,
+ cond_in
+ )
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
@@ -206,7 +206,7 @@ class TransformerLM(torch.nn.Module):
offset = 0
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
for i in range(max_len):
- y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
+ y_pred, att_cache, cnn_cache = self.llm(lm_input, offset=offset, required_cache_size=-1,
att_cache=att_cache, cnn_cache=cnn_cache,
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
device=lm_input.device)).to(torch.bool))
@@ -108,7 +108,6 @@ class MultiHeadedAttention(nn.Module):
if mask.size(2) > 0: # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
- mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
@@ -383,6 +383,51 @@ class TransformerEncoder(BaseEncoder):
dropout_rate, normalize_before) for _ in range(num_blocks)
])
+ def forward(
+ self,
+ xs: torch.Tensor,
+ offset: int,
+ required_cache_size: int,
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ seq = xs.shape[1]
+
+ tmp_masks = torch.ones(1,
+ xs.size(1),
+ device=xs.device,
+ dtype=torch.bool)
+ tmp_masks = tmp_masks.unsqueeze(1)
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
+
+ _, cache_t1 = att_cache.size(0), att_cache.size(2)
+ chunk_size = xs.size(1)
+ attention_key_size = cache_t1 + chunk_size
+ pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
+ size=attention_key_size)
+
+ r_att_cache = []
+
+ for i, layer in enumerate(self.encoders):
+ xs, _, new_att_cache, _ = layer(
+ xs,
+ att_mask,
+ pos_emb,
+ att_cache=att_cache[i:i+1] if seq==1 else att_cache,
+ cnn_cache=cnn_cache)
+ r_att_cache.append(new_att_cache)
+
+ if self.normalize_before:
+ xs = self.after_norm(xs)
+
+ r_att_cache = torch.cat(r_att_cache, dim=0)
+
+ return (xs, r_att_cache, None)
class ConformerEncoder(BaseEncoder):
"""Conformer encoder module."""
--
2.21.0