@@ -7,6 +7,7 @@ import json
import time
import copy
import torch
+import torch_npu
import random
import string
import logging
@@ -4,6 +4,7 @@ import time
import numpy as np
import torch
import torch.nn.functional as F
+import torch_npu
from torch import Tensor
from torch import nn
from torch.cuda.amp import autocast
@@ -158,20 +159,10 @@ class MultiHeadedAttentionSANM(nn.Module):
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
"""
- b, t, d = x.size()
q_k_v = self.linear_q_k_v(x)
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
- q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
- 1, 2
- ) # (batch, head, time1, d_k)
- k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
- 1, 2
- ) # (batch, head, time2, d_k)
- v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
- 1, 2
- ) # (batch, head, time2, d_k)
-
- return q_h, k_h, v_h, v
+
+ return q, k, v
def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
"""Compute attention context vector.
@@ -225,13 +216,22 @@ class MultiHeadedAttentionSANM(nn.Module):
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
- q_h, k_h, v_h, v = self.forward_qkv(x)
+ q, k, v = self.forward_qkv(x)
fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+ att_outs = self.npu_flash_attention(q, k, v) # 使用npu PFA算子替换attention结构
return att_outs + fsmn_memory
+ def npu_flash_attention(self, query, key, value):
+ x = torch_npu.npu_prompt_flash_attention(
+ query,
+ key,
+ value,
+ num_heads=self.h,
+ scale_value=self.d_k ** (-0.5)
+ )
+ x = self.linear_out(x)
+ return x
+
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
"""Compute scaled dot product attention.
@@ -278,10 +278,10 @@ class LayerNorm(nn.LayerNorm):
def forward(self, input):
output = F.layer_norm(
- input.float(),
+ input,
self.normalized_shape,
- self.weight.float() if self.weight is not None else None,
- self.bias.float() if self.bias is not None else None,
+ self.weight if self.weight is not None else None,
+ self.bias if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
@@ -342,7 +342,7 @@ class EncoderLayerSANM(nn.Module):
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
- stoch_layer_coeff = 1.0
+ stoch_layer_coeff = 1.0 # 推理场景下,不需要stoch_layer_coeff参数
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
@@ -370,12 +370,12 @@ class EncoderLayerSANM(nn.Module):
dim=-1,
)
if self.in_size == self.size:
- x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+ x = residual + self.concat_linear(x_concat)
else:
- x = stoch_layer_coeff * self.concat_linear(x_concat)
+ x = self.concat_linear(x_concat)
else:
if self.in_size == self.size:
- x = residual + stoch_layer_coeff * self.dropout(
+ x = residual + self.dropout(
self.self_attn(
x,
mask,
@@ -384,7 +384,7 @@ class EncoderLayerSANM(nn.Module):
)
)
else:
- x = stoch_layer_coeff * self.dropout(
+ x = self.dropout(
self.self_attn(
x,
mask,
@@ -398,7 +398,7 @@ class EncoderLayerSANM(nn.Module):
residual = x
if self.normalize_before:
x = self.norm2(x)
- x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+ x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
@@ -557,7 +557,7 @@ class SenseVoiceEncoderSmall(nn.Module):
):
"""Embed positions in tensor."""
maxlen = xs_pad.shape[1]
- masks = sequence_mask(ilens, maxlen=maxlen, device=ilens.device)[:, None, :]
+ masks = None
xs_pad *= self.output_size() ** 0.5
@@ -575,7 +575,7 @@ class SenseVoiceEncoderSmall(nn.Module):
xs_pad = self.after_norm(xs_pad)
# forward encoder2
- olens = masks.squeeze(1).sum(1).int()
+ olens = ilens.int()
for layer_idx, encoder_layer in enumerate(self.tp_encoders):
encoder_outs = encoder_layer(xs_pad, masks)
@@ -769,7 +769,7 @@ class SenseVoiceSmall(nn.Module):
speech = torch.cat((input_query, speech), dim=1)
speech_lengths += 3
- encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
+ encoder_out, encoder_out_lens = self.encoder(speech.to(torch.float16), speech_lengths.to(torch.float16))
return encoder_out, encoder_out_lens
@@ -876,7 +876,7 @@ class SenseVoiceSmall(nn.Module):
speech_lengths += 3
# Encoder
- encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
+ encoder_out, encoder_out_lens = self.encoder(speech.to(torch.float16), speech_lengths.to(torch.float16))
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]