4a731972创建于 2025年5月27日历史提交
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index f5cbe01f..7d8fff13 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -7,6 +7,7 @@ import json
 import time
 import copy
 import torch
+import torch_npu
 import random
 import string
 import logging
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 70cd02e3..7adca0de 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -4,6 +4,7 @@ import time
 import numpy as np
 import torch
 import torch.nn.functional as F
+import torch_npu
 from torch import Tensor
 from torch import nn
 from torch.cuda.amp import autocast
@@ -158,20 +159,10 @@ class MultiHeadedAttentionSANM(nn.Module):
             torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).

         """
-        b, t, d = x.size()
         q_k_v = self.linear_q_k_v(x)
         q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
-        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
-            1, 2
-        )  # (batch, head, time1, d_k)
-        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
-            1, 2
-        )  # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
-            1, 2
-        )  # (batch, head, time2, d_k)
-
-        return q_h, k_h, v_h, v
+
+        return q, k, v

     def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
         """Compute attention context vector.
@@ -225,13 +216,22 @@ class MultiHeadedAttentionSANM(nn.Module):
             torch.Tensor: Output tensor (#batch, time1, d_model).

         """
-        q_h, k_h, v_h, v = self.forward_qkv(x)
+        q, k, v = self.forward_qkv(x)
         fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+        att_outs = self.npu_flash_attention(q, k, v) # 使用npu PFA算子替换attention结构
         return att_outs + fsmn_memory

+    def npu_flash_attention(self, query, key, value):
+        x = torch_npu.npu_prompt_flash_attention(
+            query,
+            key,
+            value,
+            num_heads=self.h,
+            scale_value=self.d_k ** (-0.5)
+        )
+        x = self.linear_out(x)
+        return x
+
     def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
         """Compute scaled dot product attention.

@@ -278,10 +278,10 @@ class LayerNorm(nn.LayerNorm):

     def forward(self, input):
         output = F.layer_norm(
-            input.float(),
+            input,
             self.normalized_shape,
-            self.weight.float() if self.weight is not None else None,
-            self.bias.float() if self.bias is not None else None,
+            self.weight if self.weight is not None else None,
+            self.bias if self.bias is not None else None,
             self.eps,
         )
         return output.type_as(input)
@@ -342,7 +342,7 @@ class EncoderLayerSANM(nn.Module):
         skip_layer = False
         # with stochastic depth, residual connection `x + f(x)` becomes
         # `x <- x + 1 / (1 - p) * f(x)` at training time.
-        stoch_layer_coeff = 1.0
+        stoch_layer_coeff = 1.0 # 推理场景下,不需要stoch_layer_coeff参数
         if self.training and self.stochastic_depth_rate > 0:
             skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
             stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
@@ -370,12 +370,12 @@ class EncoderLayerSANM(nn.Module):
                 dim=-1,
             )
             if self.in_size == self.size:
-                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+                x = residual + self.concat_linear(x_concat)
             else:
-                x = stoch_layer_coeff * self.concat_linear(x_concat)
+                x = self.concat_linear(x_concat)
         else:
             if self.in_size == self.size:
-                x = residual + stoch_layer_coeff * self.dropout(
+                x = residual + self.dropout(
                     self.self_attn(
                         x,
                         mask,
@@ -384,7 +384,7 @@ class EncoderLayerSANM(nn.Module):
                     )
                 )
             else:
-                x = stoch_layer_coeff * self.dropout(
+                x = self.dropout(
                     self.self_attn(
                         x,
                         mask,
@@ -398,7 +398,7 @@ class EncoderLayerSANM(nn.Module):
         residual = x
         if self.normalize_before:
             x = self.norm2(x)
-        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+        x = residual + self.dropout(self.feed_forward(x))
         if not self.normalize_before:
             x = self.norm2(x)

@@ -557,7 +557,7 @@ class SenseVoiceEncoderSmall(nn.Module):
     ):
         """Embed positions in tensor."""
         maxlen = xs_pad.shape[1]
-        masks = sequence_mask(ilens, maxlen=maxlen, device=ilens.device)[:, None, :]
+        masks = None

         xs_pad *= self.output_size() ** 0.5

@@ -575,7 +575,7 @@ class SenseVoiceEncoderSmall(nn.Module):
         xs_pad = self.after_norm(xs_pad)

         # forward encoder2
-        olens = masks.squeeze(1).sum(1).int()
+        olens = ilens.int()

         for layer_idx, encoder_layer in enumerate(self.tp_encoders):
             encoder_outs = encoder_layer(xs_pad, masks)
@@ -769,7 +769,7 @@ class SenseVoiceSmall(nn.Module):
         speech = torch.cat((input_query, speech), dim=1)
         speech_lengths += 3

-        encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
+        encoder_out, encoder_out_lens = self.encoder(speech.to(torch.float16), speech_lengths.to(torch.float16))

         return encoder_out, encoder_out_lens

@@ -876,7 +876,7 @@ class SenseVoiceSmall(nn.Module):
         speech_lengths += 3

         # Encoder
-        encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
+        encoder_out, encoder_out_lens = self.encoder(speech.to(torch.float16), speech_lengths.to(torch.float16))
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]