From b70c9af573bf9cb4b0d8170bb6cc0658a3900a38 Mon Sep 17 00:00:00 2001
From: shikang <shikang12@huawei.com>
Date: Tue, 4 Mar 2025 14:47:07 +0800
Subject: [PATCH] add new code

---
 cosyvoice/cli/cosyvoice.py         | 12 +++++++-
 cosyvoice/cli/frontend.py          | 17 +++++++----
 cosyvoice/cli/model.py             | 37 ++++++++++++------------
 cosyvoice/flow/flow_matching.py    | 20 +++++++++----
 cosyvoice/llm/llm.py               |  2 +-
 cosyvoice/transformer/attention.py |  1 -
 cosyvoice/transformer/encoder.py   | 45 ++++++++++++++++++++++++++++++
 7 files changed, 101 insertions(+), 33 deletions(-)

diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py
index e2d62e2..d3ccbfc 100644
--- a/cosyvoice/cli/cosyvoice.py
+++ b/cosyvoice/cli/cosyvoice.py
@@ -13,11 +13,13 @@
 # limitations under the License.
 import os
 import time
+import platform
 from typing import Generator
 from tqdm import tqdm
 from hyperpyyaml import load_hyperpyyaml
 from modelscope import snapshot_download
 import torch
+from ais_bench.infer.interface import InferSession
 from cosyvoice.cli.frontend import CosyVoiceFrontEnd
 from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
 from cosyvoice.utils.file_utils import logging
@@ -26,7 +28,7 @@ from cosyvoice.utils.class_utils import get_model_type
 
 class CosyVoice:
 
-    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
+    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, load_om=False):
         self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         self.fp16 = fp16
@@ -57,6 +59,14 @@ class CosyVoice:
             self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                 '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
                                 self.fp16)
+        if load_om:
+            arch = platform.machine()
+            system = platform.system().lower()
+            flow_om = InferSession(0, '{}/flow_{}_{}.om'.format(model_dir, system ,arch))
+            speech_om = InferSession(0, '{}/speech_{}_{}.om'.format(model_dir, system ,arch))
+            self.frontend.speech_om = speech_om
+            self.frontend.flow_om = flow_om
+            self.model.flow.decoder.flow_om = flow_om
         del configs
 
     def list_available_spks(self):
diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py
index 6e10f00..964ebf3 100644
--- a/cosyvoice/cli/frontend.py
+++ b/cosyvoice/cli/frontend.py
@@ -71,6 +71,8 @@ class CosyVoiceFrontEnd:
             self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
             self.en_tn_model = EnNormalizer()
             self.inflect_parser = inflect.engine()
+        self.speech_om = None
+        self.flow_om = None
 
     def _extract_text_token(self, text):
         if isinstance(text, Generator):
@@ -92,11 +94,16 @@ class CosyVoiceFrontEnd:
     def _extract_speech_token(self, speech):
         assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
         feat = whisper.log_mel_spectrogram(speech, n_mels=128)
-        speech_token = self.speech_tokenizer_session.run(None,
-                                                         {self.speech_tokenizer_session.get_inputs()[0].name:
-                                                          feat.detach().cpu().numpy(),
-                                                          self.speech_tokenizer_session.get_inputs()[1].name:
-                                                          np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
+        if torch.npu.is_available():
+            feed = [feat.detach().cpu().numpy(), np.array([feat.shape[2]], dtype=np.int32)]
+            speech_token = self.speech_om.infer(feed, mode='dymshape', custom_sizes=[100000000])[0].flatten().tolist()
+            self.flow_om.set_context()
+        else:
+            speech_token = self.speech_tokenizer_session.run(None,
+                                                            {self.speech_tokenizer_session.get_inputs()[0].name:
+                                                            feat.detach().cpu().numpy(),
+                                                            self.speech_tokenizer_session.get_inputs()[1].name:
+                                                            np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
         speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
         speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
         return speech_token, speech_token_len
diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py
index 9ebf8cb..530a44d 100644
--- a/cosyvoice/cli/model.py
+++ b/cosyvoice/cli/model.py
@@ -99,25 +99,24 @@ class CosyVoiceModel:
         self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
 
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
-        with self.llm_context:
-            if isinstance(text, Generator):
-                assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
-                for i in self.llm.inference_bistream(text=text,
-                                                     prompt_text=prompt_text.to(self.device),
-                                                     prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
-                                                     prompt_speech_token=llm_prompt_speech_token.to(self.device),
-                                                     prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
-                                                     embedding=llm_embedding.to(self.device)):
-                    self.tts_speech_token_dict[uuid].append(i)
-            else:
-                for i in self.llm.inference(text=text.to(self.device),
-                                            text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
-                                            prompt_text=prompt_text.to(self.device),
-                                            prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
-                                            prompt_speech_token=llm_prompt_speech_token.to(self.device),
-                                            prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
-                                            embedding=llm_embedding.to(self.device)):
-                    self.tts_speech_token_dict[uuid].append(i)
+        if isinstance(text, Generator):
+            assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
+            for i in self.llm.inference_bistream(text=text,
+                                                    prompt_text=prompt_text.to(self.device),
+                                                    prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+                                                    prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                                    prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+                                                    embedding=llm_embedding.to(self.device)):
+                self.tts_speech_token_dict[uuid].append(i)
+        else:
+            for i in self.llm.inference(text=text.to(self.device),
+                                        text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
+                                        prompt_text=prompt_text.to(self.device),
+                                        prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+                                        prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                        prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+                                        embedding=llm_embedding.to(self.device)):
+                self.tts_speech_token_dict[uuid].append(i)
         self.llm_end_dict[uuid] = True
 
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py
index 6a60f6d..f59c790 100644
--- a/cosyvoice/flow/flow_matching.py
+++ b/cosyvoice/flow/flow_matching.py
@@ -14,6 +14,7 @@
 import threading
 import torch
 import torch.nn.functional as F
+import numpy as np
 from matcha.models.components.flow_matching import BASECFM
 
 
@@ -32,6 +33,7 @@ class ConditionalCFM(BASECFM):
         # Just change the architecture of the estimator here
         self.estimator = estimator
         self.lock = threading.Lock()
+        self.flow_om = None
 
     @torch.inference_mode()
     def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
@@ -105,12 +107,18 @@ class ConditionalCFM(BASECFM):
             t_in[:] = t.unsqueeze(0)
             spks_in[0] = spks
             cond_in[0] = cond
-            dphi_dt = self.forward_estimator(
-                x_in, mask_in,
-                mu_in, t_in,
-                spks_in,
-                cond_in
-            )
+            if torch.npu.is_available():
+                feed_list = [x_in, mask_in, mu_in, t_in, spks_in, cond_in]
+                feed = [i.cpu().detach().numpy().astype(np.float32) for i in feed_list]
+                dphi_dt = self.flow_om.infer(feed, mode="dymshape", custom_sizes=10000000)
+                dphi_dt = torch.from_numpy(dphi_dt[0]).npu()
+            else:
+                dphi_dt = self.forward_estimator(
+                    x_in, mask_in,
+                    mu_in, t_in,
+                    spks_in,
+                    cond_in
+                )
             dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
             dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
             x = x + dt * dphi_dt
diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py
index bbd3305..2a975d1 100644
--- a/cosyvoice/llm/llm.py
+++ b/cosyvoice/llm/llm.py
@@ -206,7 +206,7 @@ class TransformerLM(torch.nn.Module):
         offset = 0
         att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
         for i in range(max_len):
-            y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
+            y_pred, att_cache, cnn_cache = self.llm(lm_input, offset=offset, required_cache_size=-1,
                                                                   att_cache=att_cache, cnn_cache=cnn_cache,
                                                                   att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
                                                                                                  device=lm_input.device)).to(torch.bool))
diff --git a/cosyvoice/transformer/attention.py b/cosyvoice/transformer/attention.py
index 8c0c098..3ddebff 100644
--- a/cosyvoice/transformer/attention.py
+++ b/cosyvoice/transformer/attention.py
@@ -108,7 +108,6 @@ class MultiHeadedAttention(nn.Module):
         if mask.size(2) > 0:  # time2 > 0
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
             # For last chunk, time2 might be larger than scores.size(-1)
-            mask = mask[:, :, :, :scores.size(-1)]  # (batch, 1, *, time2)
             scores = scores.masked_fill(mask, -float('inf'))
             attn = torch.softmax(scores, dim=-1).masked_fill(
                 mask, 0.0)  # (batch, head, time1, time2)
diff --git a/cosyvoice/transformer/encoder.py b/cosyvoice/transformer/encoder.py
index c5709d0..f85a360 100644
--- a/cosyvoice/transformer/encoder.py
+++ b/cosyvoice/transformer/encoder.py
@@ -383,6 +383,51 @@ class TransformerEncoder(BaseEncoder):
                 dropout_rate, normalize_before) for _ in range(num_blocks)
         ])
 
+    def forward(
+        self,
+        xs: torch.Tensor,
+        offset: int,
+        required_cache_size: int,
+        att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+        cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+        att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+        seq = xs.shape[1]
+
+        tmp_masks = torch.ones(1,
+                               xs.size(1),
+                               device=xs.device,
+                               dtype=torch.bool)
+        tmp_masks = tmp_masks.unsqueeze(1)
+        if self.global_cmvn is not None:
+            xs = self.global_cmvn(xs)
+
+        xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
+
+        _, cache_t1 = att_cache.size(0), att_cache.size(2)
+        chunk_size = xs.size(1)
+        attention_key_size = cache_t1 + chunk_size
+        pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
+                                               size=attention_key_size)
+
+        r_att_cache = []
+
+        for i, layer in enumerate(self.encoders):
+            xs, _, new_att_cache, _ = layer(
+                xs,
+                att_mask,
+                pos_emb,
+                att_cache=att_cache[i:i+1] if seq==1 else att_cache,
+                cnn_cache=cnn_cache)
+            r_att_cache.append(new_att_cache)
+
+        if self.normalize_before:
+            xs = self.after_norm(xs)
+
+        r_att_cache = torch.cat(r_att_cache, dim=0)
+
+        return (xs, r_att_cache, None)
 
 class ConformerEncoder(BaseEncoder):
     """Conformer encoder module."""
-- 
2.21.0