From 23192263998105aa3c8c3f99106c0de9f5b8a808 Mon Sep 17 00:00:00 2001
From: shikang <shikang12@huawei.com>
Date: Tue, 4 Mar 2025 11:22:31 +0800
Subject: [PATCH] 300I DUO patch

---
 cosyvoice/cli/cosyvoice.py         | 12 +++++++-
 cosyvoice/cli/frontend.py          | 17 +++++++----
 cosyvoice/cli/model.py             | 37 ++++++++++++------------
 cosyvoice/flow/flow_matching.py    | 20 +++++++++----
 cosyvoice/flow/length_regulator.py | 12 ++++----
 cosyvoice/hifigan/generator.py     | 10 +++----
 cosyvoice/llm/llm.py               |  2 +-
 cosyvoice/transformer/attention.py |  1 -
 cosyvoice/transformer/encoder.py   | 45 ++++++++++++++++++++++++++++++
 9 files changed, 112 insertions(+), 44 deletions(-)

diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py
index e2d62e2..d3ccbfc 100644
--- a/cosyvoice/cli/cosyvoice.py
+++ b/cosyvoice/cli/cosyvoice.py
@@ -13,11 +13,13 @@
 # limitations under the License.
 import os
 import time
+import platform
 from typing import Generator
 from tqdm import tqdm
 from hyperpyyaml import load_hyperpyyaml
 from modelscope import snapshot_download
 import torch
+from ais_bench.infer.interface import InferSession
 from cosyvoice.cli.frontend import CosyVoiceFrontEnd
 from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
 from cosyvoice.utils.file_utils import logging
@@ -26,7 +28,7 @@ from cosyvoice.utils.class_utils import get_model_type
 
 class CosyVoice:
 
-    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
+    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, load_om=False):
         self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         self.fp16 = fp16
@@ -57,6 +59,14 @@ class CosyVoice:
             self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                 '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
                                 self.fp16)
+        if load_om:
+            arch = platform.machine()
+            system = platform.system().lower()
+            flow_om = InferSession(0, '{}/flow_{}_{}.om'.format(model_dir, system ,arch))
+            speech_om = InferSession(0, '{}/speech_{}_{}.om'.format(model_dir, system ,arch))
+            self.frontend.speech_om = speech_om
+            self.frontend.flow_om = flow_om
+            self.model.flow.decoder.flow_om = flow_om
         del configs
 
     def list_available_spks(self):
diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py
index 6e10f00..6a0c19d 100644
--- a/cosyvoice/cli/frontend.py
+++ b/cosyvoice/cli/frontend.py
@@ -71,6 +71,8 @@ class CosyVoiceFrontEnd:
             self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
             self.en_tn_model = EnNormalizer()
             self.inflect_parser = inflect.engine()
+        self.speech_om = None
+        self.flow_om = None
 
     def _extract_text_token(self, text):
         if isinstance(text, Generator):
@@ -92,11 +94,16 @@ class CosyVoiceFrontEnd:
     def _extract_speech_token(self, speech):
         assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
         feat = whisper.log_mel_spectrogram(speech, n_mels=128)
-        speech_token = self.speech_tokenizer_session.run(None,
-                                                         {self.speech_tokenizer_session.get_inputs()[0].name:
-                                                          feat.detach().cpu().numpy(),
-                                                          self.speech_tokenizer_session.get_inputs()[1].name:
-                                                          np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
+        if self.speech_om:
+            feed = [feat.detach().cpu().numpy(), np.array([feat.shape[2]], dtype=np.int32)]
+            speech_token = self.speech_om.infer(feed, mode='dymshape', custom_sizes=[100000000])[0].flatten().tolist()
+            self.flow_om.set_context()
+        else:
+            speech_token = self.speech_tokenizer_session.run(None,
+                                                            {self.speech_tokenizer_session.get_inputs()[0].name:
+                                                            feat.detach().cpu().numpy(),
+                                                            self.speech_tokenizer_session.get_inputs()[1].name:
+                                                            np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
         speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
         speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
         return speech_token, speech_token_len
diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py
index 9ebf8cb..530a44d 100644
--- a/cosyvoice/cli/model.py
+++ b/cosyvoice/cli/model.py
@@ -99,25 +99,24 @@ class CosyVoiceModel:
         self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
 
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
-        with self.llm_context:
-            if isinstance(text, Generator):
-                assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
-                for i in self.llm.inference_bistream(text=text,
-                                                     prompt_text=prompt_text.to(self.device),
-                                                     prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
-                                                     prompt_speech_token=llm_prompt_speech_token.to(self.device),
-                                                     prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
-                                                     embedding=llm_embedding.to(self.device)):
-                    self.tts_speech_token_dict[uuid].append(i)
-            else:
-                for i in self.llm.inference(text=text.to(self.device),
-                                            text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
-                                            prompt_text=prompt_text.to(self.device),
-                                            prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
-                                            prompt_speech_token=llm_prompt_speech_token.to(self.device),
-                                            prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
-                                            embedding=llm_embedding.to(self.device)):
-                    self.tts_speech_token_dict[uuid].append(i)
+        if isinstance(text, Generator):
+            assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
+            for i in self.llm.inference_bistream(text=text,
+                                                    prompt_text=prompt_text.to(self.device),
+                                                    prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+                                                    prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                                    prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+                                                    embedding=llm_embedding.to(self.device)):
+                self.tts_speech_token_dict[uuid].append(i)
+        else:
+            for i in self.llm.inference(text=text.to(self.device),
+                                        text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
+                                        prompt_text=prompt_text.to(self.device),
+                                        prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+                                        prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                        prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+                                        embedding=llm_embedding.to(self.device)):
+                self.tts_speech_token_dict[uuid].append(i)
         self.llm_end_dict[uuid] = True
 
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py
index 6a60f6d..985f194 100644
--- a/cosyvoice/flow/flow_matching.py
+++ b/cosyvoice/flow/flow_matching.py
@@ -14,6 +14,7 @@
 import threading
 import torch
 import torch.nn.functional as F
+import numpy as np
 from matcha.models.components.flow_matching import BASECFM
 
 
@@ -32,6 +33,7 @@ class ConditionalCFM(BASECFM):
         # Just change the architecture of the estimator here
         self.estimator = estimator
         self.lock = threading.Lock()
+        self.flow_om = None
 
     @torch.inference_mode()
     def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
@@ -105,12 +107,18 @@ class ConditionalCFM(BASECFM):
             t_in[:] = t.unsqueeze(0)
             spks_in[0] = spks
             cond_in[0] = cond
-            dphi_dt = self.forward_estimator(
-                x_in, mask_in,
-                mu_in, t_in,
-                spks_in,
-                cond_in
-            )
+            if self.flow_om:
+                feed_list = [x_in, mask_in, mu_in, t_in, spks_in, cond_in]
+                feed = [i.cpu().detach().numpy().astype(np.float32) for i in feed_list]
+                dphi_dt = self.flow_om.infer(feed, mode="dymshape", custom_sizes=10000000)
+                dphi_dt = torch.from_numpy(dphi_dt[0]).npu()
+            else:
+                dphi_dt = self.forward_estimator(
+                    x_in, mask_in,
+                    mu_in, t_in,
+                    spks_in,
+                    cond_in
+                )
             dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
             dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
             x = x + dt * dphi_dt
diff --git a/cosyvoice/flow/length_regulator.py b/cosyvoice/flow/length_regulator.py
index 2cae42f..aae8e92 100644
--- a/cosyvoice/flow/length_regulator.py
+++ b/cosyvoice/flow/length_regulator.py
@@ -53,15 +53,15 @@ class InterpolateRegulator(nn.Module):
         # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
         # x in (B, T, D)
         if x2.shape[1] > 40:
-            x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
-            x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
-                                   mode='linear')
-            x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
+            x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous().cpu(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear').npu()
+            x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous().cpu(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
+                                   mode='linear').npu()
+            x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous().cpu(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear').npu()
             x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
         else:
-            x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
+            x2 = F.interpolate(x2.transpose(1, 2).contiguous().cpu(), size=mel_len2, mode='linear').npu()
         if x1.shape[1] != 0:
-            x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
+            x1 = F.interpolate(x1.transpose(1, 2).contiguous().cpu(), size=mel_len1, mode='linear').npu() # interpolate函数在300I推理卡上暂不支持,需要指定CPU计算
             x = torch.concat([x1, x2], dim=2)
         else:
             x = x2
diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py
index c47bf05..bcb4fb5 100644
--- a/cosyvoice/hifigan/generator.py
+++ b/cosyvoice/hifigan/generator.py
@@ -332,9 +332,9 @@ class HiFTGenerator(nn.Module):
 
     def _stft(self, x):
         spec = torch.stft(
-            x,
-            self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
-            return_complex=True)
+            x.cpu(),
+            self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.cpu(),
+            return_complex=True).npu()
         spec = torch.view_as_real(spec)  # [B, F, TT, 2]
         return spec[..., 0], spec[..., 1]
 
@@ -342,8 +342,8 @@ class HiFTGenerator(nn.Module):
         magnitude = torch.clip(magnitude, max=1e2)
         real = magnitude * torch.cos(phase)
         img = magnitude * torch.sin(phase)
-        inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
-                                        self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
+        inverse_transform = torch.istft(torch.complex(real, img).cpu(), self.istft_params["n_fft"], self.istft_params["hop_len"],
+                                        self.istft_params["n_fft"], window=self.stft_window.cpu()).npu() #  # stft函数在300I推理卡上暂不支持,需要指定CPU计算
         return inverse_transform
 
     def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py
index bbd3305..2a975d1 100644
--- a/cosyvoice/llm/llm.py
+++ b/cosyvoice/llm/llm.py
@@ -206,7 +206,7 @@ class TransformerLM(torch.nn.Module):
         offset = 0
         att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
         for i in range(max_len):
-            y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
+            y_pred, att_cache, cnn_cache = self.llm(lm_input, offset=offset, required_cache_size=-1,
                                                                   att_cache=att_cache, cnn_cache=cnn_cache,
                                                                   att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
                                                                                                  device=lm_input.device)).to(torch.bool))
diff --git a/cosyvoice/transformer/attention.py b/cosyvoice/transformer/attention.py
index 8c0c098..3ddebff 100644
--- a/cosyvoice/transformer/attention.py
+++ b/cosyvoice/transformer/attention.py
@@ -108,7 +108,6 @@ class MultiHeadedAttention(nn.Module):
         if mask.size(2) > 0:  # time2 > 0
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
             # For last chunk, time2 might be larger than scores.size(-1)
-            mask = mask[:, :, :, :scores.size(-1)]  # (batch, 1, *, time2)
             scores = scores.masked_fill(mask, -float('inf'))
             attn = torch.softmax(scores, dim=-1).masked_fill(
                 mask, 0.0)  # (batch, head, time1, time2)
diff --git a/cosyvoice/transformer/encoder.py b/cosyvoice/transformer/encoder.py
index c5709d0..f85a360 100644
--- a/cosyvoice/transformer/encoder.py
+++ b/cosyvoice/transformer/encoder.py
@@ -383,6 +383,51 @@ class TransformerEncoder(BaseEncoder):
                 dropout_rate, normalize_before) for _ in range(num_blocks)
         ])
 
+    def forward(
+        self,
+        xs: torch.Tensor,
+        offset: int,
+        required_cache_size: int,
+        att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+        cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+        att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+        seq = xs.shape[1]
+
+        tmp_masks = torch.ones(1,
+                               xs.size(1),
+                               device=xs.device,
+                               dtype=torch.bool)
+        tmp_masks = tmp_masks.unsqueeze(1)
+        if self.global_cmvn is not None:
+            xs = self.global_cmvn(xs)
+
+        xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
+
+        _, cache_t1 = att_cache.size(0), att_cache.size(2)
+        chunk_size = xs.size(1)
+        attention_key_size = cache_t1 + chunk_size
+        pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
+                                               size=attention_key_size)
+
+        r_att_cache = []
+
+        for i, layer in enumerate(self.encoders):
+            xs, _, new_att_cache, _ = layer(
+                xs,
+                att_mask,
+                pos_emb,
+                att_cache=att_cache[i:i+1] if seq==1 else att_cache,
+                cnn_cache=cnn_cache)
+            r_att_cache.append(new_att_cache)
+
+        if self.normalize_before:
+            xs = self.after_norm(xs)
+
+        r_att_cache = torch.cat(r_att_cache, dim=0)
+
+        return (xs, r_att_cache, None)
 
 class ConformerEncoder(BaseEncoder):
     """Conformer encoder module."""
-- 
2.21.0