From 23192263998105aa3c8c3f99106c0de9f5b8a808 Mon Sep 17 00:00:00 2001
From: shikang <shikang12@huawei.com>
Date: Tue, 4 Mar 2025 11:22:31 +0800
Subject: [PATCH] 300I DUO patch
cosyvoice/cli/cosyvoice.py | 12 +++++++-
cosyvoice/cli/frontend.py | 17 +++++++----
cosyvoice/cli/model.py | 37 ++++++++++++------------
cosyvoice/flow/flow_matching.py | 20 +++++++++----
cosyvoice/flow/length_regulator.py | 12 ++++----
cosyvoice/hifigan/generator.py | 10 +++----
cosyvoice/llm/llm.py | 2 +-
cosyvoice/transformer/attention.py | 1 -
cosyvoice/transformer/encoder.py | 45 ++++++++++++++++++++++++++++++
9 files changed, 112 insertions(+), 44 deletions(-)
@@ -13,11 +13,13 @@
# limitations under the License.
import os
import time
+import platform
from typing import Generator
from tqdm import tqdm
from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download
import torch
+from ais_bench.infer.interface import InferSession
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.utils.file_utils import logging
@@ -26,7 +28,7 @@ from cosyvoice.utils.class_utils import get_model_type
class CosyVoice:
- def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, load_om=False):
self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
self.fp16 = fp16
@@ -57,6 +59,14 @@ class CosyVoice:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
self.fp16)
+ if load_om:
+ arch = platform.machine()
+ system = platform.system().lower()
+ flow_om = InferSession(0, '{}/flow_{}_{}.om'.format(model_dir, system ,arch))
+ speech_om = InferSession(0, '{}/speech_{}_{}.om'.format(model_dir, system ,arch))
+ self.frontend.speech_om = speech_om
+ self.frontend.flow_om = flow_om
+ self.model.flow.decoder.flow_om = flow_om
del configs
def list_available_spks(self):
@@ -71,6 +71,8 @@ class CosyVoiceFrontEnd:
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
self.en_tn_model = EnNormalizer()
self.inflect_parser = inflect.engine()
+ self.speech_om = None
+ self.flow_om = None
def _extract_text_token(self, text):
if isinstance(text, Generator):
@@ -92,11 +94,16 @@ class CosyVoiceFrontEnd:
def _extract_speech_token(self, speech):
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
- speech_token = self.speech_tokenizer_session.run(None,
- {self.speech_tokenizer_session.get_inputs()[0].name:
- feat.detach().cpu().numpy(),
- self.speech_tokenizer_session.get_inputs()[1].name:
- np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
+ if self.speech_om:
+ feed = [feat.detach().cpu().numpy(), np.array([feat.shape[2]], dtype=np.int32)]
+ speech_token = self.speech_om.infer(feed, mode='dymshape', custom_sizes=[100000000])[0].flatten().tolist()
+ self.flow_om.set_context()
+ else:
+ speech_token = self.speech_tokenizer_session.run(None,
+ {self.speech_tokenizer_session.get_inputs()[0].name:
+ feat.detach().cpu().numpy(),
+ self.speech_tokenizer_session.get_inputs()[1].name:
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
return speech_token, speech_token_len
@@ -99,25 +99,24 @@ class CosyVoiceModel:
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
- with self.llm_context:
- if isinstance(text, Generator):
- assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
- for i in self.llm.inference_bistream(text=text,
- prompt_text=prompt_text.to(self.device),
- prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
- prompt_speech_token=llm_prompt_speech_token.to(self.device),
- prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
- embedding=llm_embedding.to(self.device)):
- self.tts_speech_token_dict[uuid].append(i)
- else:
- for i in self.llm.inference(text=text.to(self.device),
- text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
- prompt_text=prompt_text.to(self.device),
- prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
- prompt_speech_token=llm_prompt_speech_token.to(self.device),
- prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
- embedding=llm_embedding.to(self.device)):
- self.tts_speech_token_dict[uuid].append(i)
+ if isinstance(text, Generator):
+ assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
+ for i in self.llm.inference_bistream(text=text,
+ prompt_text=prompt_text.to(self.device),
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+ embedding=llm_embedding.to(self.device)):
+ self.tts_speech_token_dict[uuid].append(i)
+ else:
+ for i in self.llm.inference(text=text.to(self.device),
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
+ prompt_text=prompt_text.to(self.device),
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+ embedding=llm_embedding.to(self.device)):
+ self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
@@ -14,6 +14,7 @@
import threading
import torch
import torch.nn.functional as F
+import numpy as np
from matcha.models.components.flow_matching import BASECFM
@@ -32,6 +33,7 @@ class ConditionalCFM(BASECFM):
# Just change the architecture of the estimator here
self.estimator = estimator
self.lock = threading.Lock()
+ self.flow_om = None
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
@@ -105,12 +107,18 @@ class ConditionalCFM(BASECFM):
t_in[:] = t.unsqueeze(0)
spks_in[0] = spks
cond_in[0] = cond
- dphi_dt = self.forward_estimator(
- x_in, mask_in,
- mu_in, t_in,
- spks_in,
- cond_in
- )
+ if self.flow_om:
+ feed_list = [x_in, mask_in, mu_in, t_in, spks_in, cond_in]
+ feed = [i.cpu().detach().numpy().astype(np.float32) for i in feed_list]
+ dphi_dt = self.flow_om.infer(feed, mode="dymshape", custom_sizes=10000000)
+ dphi_dt = torch.from_numpy(dphi_dt[0]).npu()
+ else:
+ dphi_dt = self.forward_estimator(
+ x_in, mask_in,
+ mu_in, t_in,
+ spks_in,
+ cond_in
+ )
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
@@ -53,15 +53,15 @@ class InterpolateRegulator(nn.Module):
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
# x in (B, T, D)
if x2.shape[1] > 40:
- x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
- x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
- mode='linear')
- x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous().cpu(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear').npu()
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous().cpu(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
+ mode='linear').npu()
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous().cpu(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear').npu()
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
else:
- x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous().cpu(), size=mel_len2, mode='linear').npu()
if x1.shape[1] != 0:
- x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous().cpu(), size=mel_len1, mode='linear').npu() # interpolate函数在300I推理卡上暂不支持,需要指定CPU计算
x = torch.concat([x1, x2], dim=2)
else:
x = x2
@@ -332,9 +332,9 @@ class HiFTGenerator(nn.Module):
def _stft(self, x):
spec = torch.stft(
- x,
- self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
- return_complex=True)
+ x.cpu(),
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.cpu(),
+ return_complex=True).npu()
spec = torch.view_as_real(spec) # [B, F, TT, 2]
return spec[..., 0], spec[..., 1]
@@ -342,8 +342,8 @@ class HiFTGenerator(nn.Module):
magnitude = torch.clip(magnitude, max=1e2)
real = magnitude * torch.cos(phase)
img = magnitude * torch.sin(phase)
- inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
- self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
+ inverse_transform = torch.istft(torch.complex(real, img).cpu(), self.istft_params["n_fft"], self.istft_params["hop_len"],
+ self.istft_params["n_fft"], window=self.stft_window.cpu()).npu() # # stft函数在300I推理卡上暂不支持,需要指定CPU计算
return inverse_transform
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
@@ -206,7 +206,7 @@ class TransformerLM(torch.nn.Module):
offset = 0
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
for i in range(max_len):
- y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
+ y_pred, att_cache, cnn_cache = self.llm(lm_input, offset=offset, required_cache_size=-1,
att_cache=att_cache, cnn_cache=cnn_cache,
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
device=lm_input.device)).to(torch.bool))
@@ -108,7 +108,6 @@ class MultiHeadedAttention(nn.Module):
if mask.size(2) > 0: # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
- mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
@@ -383,6 +383,51 @@ class TransformerEncoder(BaseEncoder):
dropout_rate, normalize_before) for _ in range(num_blocks)
])
+ def forward(
+ self,
+ xs: torch.Tensor,
+ offset: int,
+ required_cache_size: int,
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ seq = xs.shape[1]
+
+ tmp_masks = torch.ones(1,
+ xs.size(1),
+ device=xs.device,
+ dtype=torch.bool)
+ tmp_masks = tmp_masks.unsqueeze(1)
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
+
+ _, cache_t1 = att_cache.size(0), att_cache.size(2)
+ chunk_size = xs.size(1)
+ attention_key_size = cache_t1 + chunk_size
+ pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
+ size=attention_key_size)
+
+ r_att_cache = []
+
+ for i, layer in enumerate(self.encoders):
+ xs, _, new_att_cache, _ = layer(
+ xs,
+ att_mask,
+ pos_emb,
+ att_cache=att_cache[i:i+1] if seq==1 else att_cache,
+ cnn_cache=cnn_cache)
+ r_att_cache.append(new_att_cache)
+
+ if self.normalize_before:
+ xs = self.after_norm(xs)
+
+ r_att_cache = torch.cat(r_att_cache, dim=0)
+
+ return (xs, r_att_cache, None)
class ConformerEncoder(BaseEncoder):
"""Conformer encoder module."""
--
2.21.0