@@ -181,13 +181,7 @@ class AutoModel:
set_all_random_seed(kwargs.get("seed", 0))
- device = kwargs.get("device", "cuda")
- if ((device =="cuda" and not torch.cuda.is_available())
- or (device == "xpu" and not torch.xpu.is_available())
- or (device == "mps" and not torch.backends.mps.is_available())
- or kwargs.get("ngpu", 1) == 0):
- device = "cpu"
- kwargs["batch_size"] = 1
+ device = kwargs.get("device", "cpu")
kwargs["device"] = device
torch.set_num_threads(kwargs.get("ncpu", 4))
@@ -283,11 +277,6 @@ class AutoModel:
else:
print(f"error, init_param does not exist!: {init_param}")
- # fp16
- if kwargs.get("fp16", False):
- model.to(torch.float16)
- elif kwargs.get("bf16", False):
- model.to(torch.bfloat16)
model.to(device)
if not kwargs.get("disable_log", True):
@@ -398,6 +387,8 @@ class AutoModel:
def inference_with_vad(self, input, input_len=None, **cfg):
kwargs = self.kwargs
+ time_stats = {"input_speech_time": 0.0, "end_to_end_time": 0.0, "vad_time": 0.0,
+ "paraformer_time": 0.0, "punc_time": 0.0}
# step.1: compute the vad model
deep_update(self.vad_kwargs, cfg)
beg_vad = time.time()
@@ -405,6 +396,8 @@ class AutoModel:
input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg
)
end_vad = time.time()
+ time_stats["vad_time"] = end_vad - beg_vad
+ print("Finish segmenting audios within {:.3f} seconds.".format(time_stats["vad_time"]))
# FIX(gcf): concat the vad clips for sense vocie model for better aed
if cfg.get("merge_vad", False):
@@ -418,7 +411,6 @@ class AutoModel:
deep_update(kwargs, cfg)
batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
- kwargs["batch_size"] = batch_size
key_list, data_list = prepare_data_iterator(
input, input_len=input_len, data_type=kwargs.get("data_type", None)
@@ -463,26 +455,14 @@ class AutoModel:
# pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
all_segments = []
- max_len_in_batch = 0
- end_idx = 1
- for j, _ in enumerate(range(0, n)):
- # pbar_sample.update(1)
- sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
- potential_batch_length = max(max_len_in_batch, sample_length) * (j + 1 - beg_idx)
- # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
- if (
- j < n - 1
- and sample_length < batch_size_threshold_ms
- and potential_batch_length < batch_size
- ):
- max_len_in_batch = max(max_len_in_batch, sample_length)
- end_idx += 1
- continue
-
+ batch_segments = kwargs['batch_size']
+ loop_num = n // batch_segments if n % batch_segments == 0 else n // batch_segments + 1
+ end_idx = batch_segments
+ for j in range(loop_num):
speech_j, speech_lengths_j = slice_padding_audio_samples(
speech, speech_lengths, sorted_data[beg_idx:end_idx]
)
- results = self.inference(
+ results, meta_data = self.inference_with_asr(
speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
)
if self.spk_model is not None:
@@ -503,8 +483,7 @@ class AutoModel:
)
results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
beg_idx = end_idx
- end_idx += 1
- max_len_in_batch = sample_length
+ end_idx += batch_segments
if len(results) < 1:
continue
results_sorted.extend(results)
@@ -556,6 +535,13 @@ class AutoModel:
if not len(result["text"].strip()):
continue
return_raw_text = kwargs.get("return_raw_text", False)
+
+ end_paraformer = time.time()
+ time_stats["paraformer_time"] = time_stats["paraformer_time"] + end_paraformer - beg_asr_total
+ print("\tFinish recognizing audio using Paraformer within {:.3f} seconds, "
+ "which contains {} segments and batch_size is {}."
+ .format(time_stats["paraformer_time"], n, batch_segments))
+
# step.3 compute punc model
raw_text = None
if self.punc_model is not None:
@@ -567,6 +553,9 @@ class AutoModel:
if return_raw_text:
result["raw_text"] = raw_text
result["text"] = punc_res[0]["text"]
+ end_punc = time.time()
+ time_stats["punc_time"] = time_stats["punc_time"] + end_punc - end_paraformer
+ print("\tFinish adding punctuation using PUNC model within {:.3f} seconds.".format(time_stats["punc_time"]))
# speaker embedding cluster after resorted
if self.spk_model is not None and kwargs.get("return_spk_res", True):
@@ -653,12 +642,14 @@ class AutoModel:
f"time_escape: {time_escape_total_per_sample:0.3f}"
)
- # end_total = time.time()
+ end_total = time.time()
+ time_stats["end_to_end_time"] = end_total - beg_vad
+ time_stats["input_speech_time"] = time_speech_total_all_samples
# time_escape_total_all_samples = end_total - beg_total
# print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
# f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
# f"time_escape_all: {time_escape_total_all_samples:0.3f}")
- return results_ret_list
+ return results_ret_list, time_stats
def export(self, input=None, **cfg):
"""
@@ -29,43 +29,49 @@ def cif(hidden, alphas, threshold):
batch_size, len_time, hidden_size = hidden.size()
# loop varss
- integrate = torch.zeros([batch_size], device=hidden.device)
- frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
+ frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
+
# intermediate vars along time
list_fires = []
list_frames = []
+ integrate = torch.zeros([batch_size], dtype=torch.float32, device=hidden.device)
+ threshold_tensor = torch.ones([batch_size], dtype=torch.float32, device=hidden.device) * threshold
for t in range(len_time):
alpha = alphas[:, t]
- distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
+ distribution_completion = threshold_tensor - integrate
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
- fire_place, integrate - torch.ones([batch_size], device=hidden.device), integrate
+ fire_place, integrate - threshold_tensor, integrate
)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur
+ cur = cur.to(hidden.dtype) ## prevent bf16 error
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(
- fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
+ fire_place[:, None].expand(-1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
- list_ls = []
- len_labels = torch.round(alphas.sum(-1)).int()
- max_label_len = len_labels.max()
+ fire_idxs = fires >= threshold
+ frame_fires = torch.zeros_like(hidden)
+ max_label_len = frames[0, fire_idxs[0]].size(0)
for b in range(batch_size):
- fire = fires[b, :]
- l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
- pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
- list_ls.append(torch.cat([l, pad_l], 0))
- return torch.stack(list_ls, 0), fires
+ frame_fire = frames[b, fire_idxs[b]]
+ frame_len = frame_fire.size(0)
+ frame_fires[b, :frame_len, :] = frame_fire
+
+ if frame_len >= max_label_len:
+ max_label_len = frame_len
+ frame_fires = frame_fires[:, :max_label_len, :]
+ return frame_fires, fires
def cif_wo_hidden(alphas, threshold):
@@ -75,6 +81,7 @@ def cif_wo_hidden(alphas, threshold):
integrate = torch.zeros([batch_size], device=alphas.device)
# intermediate vars along time
list_fires = []
+ threshold_tensor = torch.ones([batch_size], dtype=alphas.dtype, device=alphas.device) * threshold
for t in range(len_time):
alpha = alphas[:, t]
@@ -85,7 +92,7 @@ def cif_wo_hidden(alphas, threshold):
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
- integrate - torch.ones([batch_size], device=alphas.device) * threshold,
+ integrate - threshold_tensor,
integrate,
)
@@ -117,6 +124,8 @@ class CifPredictorV3(torch.nn.Module):
super(CifPredictorV3, self).__init__()
self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
+ self.l_order = l_order
+ self.r_order = r_order
self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
self.cif_output = torch.nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
@@ -168,7 +177,7 @@ class CifPredictorV3(torch.nn.Module):
self.smooth_factor2 = smooth_factor2
self.noise_threshold2 = noise_threshold2
- def forward(
+ def _forward(
self,
hidden,
target_label=None,
@@ -244,11 +253,39 @@ class CifPredictorV3(torch.nn.Module):
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak, token_num2
+ def process_hidden(
+ self,
+ hidden,
+ mask
+ ):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = torch.nn.functional.pad(context, (self.l_order, self.r_order))
+ output = torch.relu(self.cif_conv1d(queries))
+ output = output.transpose(1, 2)
+
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+ mask = mask.transpose(-1, -2)
+ alphas = alphas * mask
+ alphas = alphas.squeeze(-1)
+ mask = mask.squeeze(-1)
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas.float(), mask=mask)
+ return hidden, alphas, token_num
+
+ def forward(self, hidden, mask):
+ hidden, alphas, token_num = self.process_hidden(hidden, mask)
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+ token_num_int = torch.max(token_num).type(torch.int32).item()
+ acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+ return acoustic_embeds, token_num, alphas, cif_peak
+
def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
h = hidden
b = hidden.shape[0]
context = h.transpose(1, 2)
- queries = self.pad(context)
+ queries = torch.nn.functional.pad(context, (self.l_order, self.r_order))
output = torch.relu(self.cif_conv1d(queries))
# alphas2 is an extra head for timestamp prediction
@@ -272,23 +309,23 @@ class CifPredictorV3(torch.nn.Module):
# repeat the mask in T demension to match the upsampled length
if mask is not None:
mask2 = (
- mask.repeat(1, self.upsample_times, 1)
+ mask.expand(-1, self.upsample_times, -1)
.transpose(-1, -2)
.reshape(alphas2.shape[0], -1)
)
mask2 = mask2.unsqueeze(-1)
alphas2 = alphas2 * mask2
alphas2 = alphas2.squeeze(-1)
+
+ alphas2 = alphas2.float()
_token_num = alphas2.sum(-1)
if token_num is not None:
- alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
- # re-downsample
- ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
- ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
+ alphas2 *= (token_num / _token_num)[:, None].expand(-1, alphas2.size(1))
+
# upsampled alphas and cif_peak
us_alphas = alphas2
us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
- return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
+ return None, None, us_alphas, us_cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
@@ -296,6 +333,8 @@ class CifPredictorV3(torch.nn.Module):
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
+ zeros_t = zeros_t.to(hidden.dtype)
+ ones_t = ones_t.to(hidden.dtype)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
@@ -127,8 +127,8 @@ class BiCifParaformer(Paraformer):
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = (
- self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = (
+ self.predictor(encoder_out, encoder_out_mask)
)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
@@ -253,7 +253,6 @@ class BiCifParaformer(Paraformer):
data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)
)
time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
)
@@ -263,11 +262,20 @@ class BiCifParaformer(Paraformer):
speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
)
+ if kwargs.get("fp16", False):
+ speech = speech.to(torch.float16)
+ # speech_lengths = speech_lengths.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ speech = speech.to(torch.bfloat16)
+ # speech_lengths = speech_lengths.to(torch.bfloat16)
+
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
+ meta_data["load_data"] = f"{time.perf_counter() - time1:0.3f}"
# Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ use_flash_attention = kwargs.get("use_flash_attention", False)
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, use_flash_attention=use_flash_attention)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
@@ -283,7 +291,7 @@ class BiCifParaformer(Paraformer):
if torch.max(pre_token_length) < 1:
return []
decoder_outs = self.cal_decoder_with_predictor(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length, use_flash_attention=use_flash_attention
)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
@@ -338,13 +346,13 @@ class BiCifParaformer(Paraformer):
if tokenizer is not None:
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
_, timestamp = ts_prediction_lfr6_standard(
us_alphas[i][: encoder_out_lens[i] * 3],
us_peaks[i][: encoder_out_lens[i] * 3],
copy.copy(token),
vad_offset=kwargs.get("begin_time", 0),
+ threshold=kwargs.get("cif_threshold", 1.0)
)
text_postprocessed, time_stamp_postprocessed, word_lists = (
@@ -57,6 +57,7 @@ class ContextualDecoderLayer(torch.nn.Module):
memory,
memory_mask,
cache=None,
+ use_flash_attention=False
):
# tgt = self.dropout(tgt)
if isinstance(tgt, Tuple):
@@ -78,7 +79,7 @@ class ContextualDecoderLayer(torch.nn.Module):
residual = x
if self.normalize_before:
x = self.norm3(x)
- x = self.src_attn(x, memory, memory_mask)
+ x = self.src_attn(x, memory, memory_mask, use_flash_attention=use_flash_attention)
x_src_attn = x
x = residual + self.dropout(x)
@@ -102,12 +103,12 @@ class ContextualBiasDecoder(torch.nn.Module):
self.dropout = torch.nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
- def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None, use_flash_attention=False):
x = tgt
if self.src_attn is not None:
if self.normalize_before:
x = self.norm3(x)
- x = self.dropout(self.src_attn(x, memory, memory_mask))
+ x = self.dropout(self.src_attn(x, memory, memory_mask, use_flash_attention=use_flash_attention))
return x, tgt_mask, memory, memory_mask, cache
@@ -254,12 +255,14 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder):
def forward(
self,
hs_pad: torch.Tensor,
- hlens: torch.Tensor,
+ memory_mask: torch.Tensor,
ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
+ tgt_mask: torch.Tensor,
contextual_info: torch.Tensor,
+ contextual_mask: torch.Tensor,
clas_scale: float = 1.0,
return_hidden: bool = False,
+ use_flash_attention: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -279,20 +282,16 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder):
olens: (batch, )
"""
tgt = ys_in_pad
- tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
memory = hs_pad
- memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
x = tgt
- x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
- _, _, x_self_attn, x_src_attn = self.last_decoder(x, tgt_mask, memory, memory_mask)
+ x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask, use_flash_attention)
+ _, _, x_self_attn, x_src_attn = self.last_decoder(x, tgt_mask, memory, memory_mask, use_flash_attention=use_flash_attention)
# contextual paraformer related
- contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
- contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
cx, tgt_mask, _, _, _ = self.bias_decoder(
- x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask
+ x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask, use_flash_attention=use_flash_attention
)
if self.bias_output is not None:
@@ -18,6 +18,7 @@ from distutils.version import LooseVersion
from funasr.register import tables
from funasr.utils import postprocess_utils
+from funasr.models.scama import utils as myutils
from funasr.metrics.compute_acc import th_accuracy
from funasr.models.paraformer.model import Paraformer
from funasr.utils.datadir_writer import DatadirWriter
@@ -303,6 +304,7 @@ class ContextualParaformer(Paraformer):
ys_pad_lens,
hw_list=None,
clas_scale=1.0,
+ use_flash_attention=False
):
if hw_list is None:
hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
@@ -328,13 +330,21 @@ class ContextualParaformer(Paraformer):
_, (h_n, _) = self.bias_encoder(hw_embed)
hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
+ # build mask before model inference
+ tgt_mask = myutils.sequence_mask(ys_pad_lens, dtype=encoder_out.dtype, device=encoder_out.device)[:, :, None]
+ memory_mask = myutils.sequence_mask(encoder_out_lens, dtype=encoder_out.dtype, device=encoder_out.device)[:, None, :]
+ contextual_length = torch.Tensor([hw_embed.shape[1]]).int().repeat(encoder_out.shape[0])
+ contextual_mask = myutils.sequence_mask(contextual_length, dtype=encoder_out.dtype, device=encoder_out.device)[:, None, :]
+
decoder_outs = self.decoder(
encoder_out,
- encoder_out_lens,
+ memory_mask,
sematic_embeds,
- ys_pad_lens,
+ tgt_mask,
contextual_info=hw_embed,
+ contextual_mask=contextual_mask,
clas_scale=clas_scale,
+ use_flash_attention=use_flash_attention
)
decoder_out = decoder_outs[0]
@@ -385,13 +395,21 @@ class ContextualParaformer(Paraformer):
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
+ if kwargs.get("fp16", False):
+ speech = speech.to(torch.float16)
+ speech_lengths = speech_lengths.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ speech = speech.to(torch.bfloat16)
+ speech_lengths = speech_lengths.to(torch.float16)
+
# hotword
self.hotword_list = self.generate_hotwords_list(
kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend
)
# Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ use_flash_attention = kwargs.get("use_flash_attention", False)
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, use_flash_attention=use_flash_attention)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
@@ -88,7 +88,8 @@ class CTTransformer(torch.nn.Module):
"""
x = self.embed(text)
# mask = self._target_mask(input)
- h, _, _ = self.encoder(x, text_lengths)
+ masks = (~make_pad_mask(text_lengths)[:, None, :]).to(x.device)
+ h, _, _ = self.encoder(x, masks)
y = self.decoder(h)
return y, None
@@ -75,7 +75,7 @@ class DecoderLayerSANM(torch.nn.Module):
self.reserve_attn = False
self.attn_mat = []
- def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, use_flash_attention=False, cache=None):
"""Compute decoded features.
Args:
@@ -114,7 +114,7 @@ class DecoderLayerSANM(torch.nn.Module):
x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
self.attn_mat.append(attn_mat)
else:
- x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False)
+ x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False, use_flash_attention=use_flash_attention)
x = residual + self.dropout(x_src_attn)
# x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
@@ -359,12 +359,13 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
def forward(
self,
hs_pad: torch.Tensor,
- hlens: torch.Tensor,
+ memory_mask: torch.Tensor,
ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
+ tgt_mask: torch.Tensor,
chunk_mask: torch.Tensor = None,
return_hidden: bool = False,
return_both: bool = False,
+ use_flash_attention: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -384,17 +385,15 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
olens: (batch, )
"""
tgt = ys_in_pad
- tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
memory = hs_pad
- memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
if chunk_mask is not None:
memory_mask = memory_mask * chunk_mask
if tgt_mask.size(1) != memory_mask.size(1):
memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
x = tgt
- x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
+ x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask, use_flash_attention)
if self.decoders2 is not None:
x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
@@ -220,7 +220,7 @@ class CifPredictorV2(torch.nn.Module):
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
if mask is not None:
- mask = mask.transpose(-1, -2).float()
+ mask = mask.transpose(-1, -2)
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
@@ -347,7 +347,7 @@ class CifPredictorV2(torch.nn.Module):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
if mask is not None:
- zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
+ zeros_t = torch.zeros((b, 1), dtype=alphas.dtype, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
@@ -689,7 +689,7 @@ def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor
if return_fire_idxs:
- return fires, fire_idxs
+ return fires.to(dtype), fire_idxs
return fires
@@ -731,7 +731,8 @@ def cif_v1(hidden, alphas, threshold):
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
- frame_fires[frame_fires_idxs] = frames
+ # prevent inconsistent length betwen frames and frame_fires
+ frame_fires[frame_fires_idxs] = frames[:max_label_len]
return frame_fires, fires
@@ -75,7 +75,7 @@ class DecoderLayerSANM(torch.nn.Module):
self.reserve_attn = False
self.attn_mat = []
- def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, use_flash_attention=False, cache=None):
"""Compute decoded features.
Args:
@@ -114,7 +114,7 @@ class DecoderLayerSANM(torch.nn.Module):
x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
self.attn_mat.append(attn_mat)
else:
- x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False)
+ x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False, use_flash_attention=use_flash_attention)
x = residual + self.dropout(x_src_attn)
# x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
@@ -365,6 +365,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
chunk_mask: torch.Tensor = None,
return_hidden: bool = False,
return_both: bool = False,
+ use_flash_attention: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -394,7 +395,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
x = tgt
- x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
+ x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask, use_flash_attention)
if self.decoders2 is not None:
x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
@@ -11,6 +11,7 @@ from torch.cuda.amp import autocast
from typing import Union, Dict, List, Tuple, Optional
from funasr.register import tables
+from funasr.models.scama import utils as myutils
from funasr.models.ctc.ctc import CTC
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
@@ -259,7 +260,9 @@ class Paraformer(torch.nn.Module):
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
- encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ use_flash_attention = kwargs.get("use_flash_attention", False)
+ masks = (~make_pad_mask(speech_lengths)[:, None, :]).to(speech.device)
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, masks, use_flash_attention=use_flash_attention)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
@@ -276,10 +279,12 @@ class Paraformer(torch.nn.Module):
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def cal_decoder_with_predictor(
- self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, use_flash_attention=False
):
- decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
+ tgt_mask = myutils.sequence_mask(ys_pad_lens, dtype=encoder_out.dtype, device=encoder_out.device)[:, :, None]
+ memory_mask = myutils.sequence_mask(encoder_out_lens, dtype=encoder_out.dtype, device=encoder_out.device)[:, None, :]
+ decoder_outs = self.decoder(encoder_out, memory_mask, sematic_embeds, tgt_mask, use_flash_attention=use_flash_attention)
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
@@ -10,6 +10,7 @@ import math
import numpy
import torch
+import torch_npu
from torch import nn
from typing import Optional, Tuple
@@ -203,6 +204,8 @@ class MultiHeadedAttentionSANM(nn.Module):
left_padding = left_padding + sanm_shfit
right_padding = kernel_size - 1 - left_padding
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+ self.left_padding = left_padding
+ self.right_padding = right_padding
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
b, t, d = inputs.size()
@@ -213,7 +216,7 @@ class MultiHeadedAttentionSANM(nn.Module):
inputs = inputs * mask
x = inputs.transpose(1, 2)
- x = self.pad_fn(x)
+ x = F.pad(x, (self.left_padding, self.right_padding))
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x += inputs
@@ -289,7 +292,7 @@ class MultiHeadedAttentionSANM(nn.Module):
return self.linear_out(x) # (batch, time1, d_model)
- def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+ def _forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
"""Compute scaled dot product attention.
Args:
@@ -310,6 +313,29 @@ class MultiHeadedAttentionSANM(nn.Module):
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
return att_outs + fsmn_memory
+ def forward_with_flash_attention(self, x, mask, mask_shfit_chunk=None):
+ q_k_v = self.linear_q_k_v(x)
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+ fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
+ attn_out, _ = torch_npu.npu_fused_infer_attention_score(
+ q,
+ k,
+ v,
+ scale=self.d_k ** (-0.5),
+ atten_mask=mask.unsqueeze(1).eq(0).expand(-1, -1, q.size(1), -1),
+ input_layout="BSH",
+ num_heads=self.h,
+ sparse_mode=1
+ )
+ attn_outs = self.linear_out(attn_out)
+ return attn_outs + fsmn_memory
+
+ def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None, use_flash_attention=False):
+ if use_flash_attention:
+ return self.forward_with_flash_attention(x, mask, mask_shfit_chunk)
+ else:
+ return self._forward(x, mask, mask_shfit_chunk, mask_att_chunk_encoder)
+
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
"""Compute scaled dot product attention.
@@ -494,6 +520,8 @@ class MultiHeadedAttentionSANMDecoder(nn.Module):
left_padding = left_padding + sanm_shfit
right_padding = kernel_size - 1 - left_padding
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+ self.left_padding = left_padding
+ self.right_padding = right_padding
self.kernel_size = kernel_size
def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
@@ -522,7 +550,7 @@ class MultiHeadedAttentionSANMDecoder(nn.Module):
if cache is None:
# print("in fsmn, cache is None, x", x.size())
- x = self.pad_fn(x)
+ x = F.pad(x, (self.left_padding, self.right_padding))
if not self.training:
cache = x
else:
@@ -697,7 +725,7 @@ class MultiHeadedAttentionCrossAtt(nn.Module):
return self.linear_out(x), attn # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
- def forward(self, x, memory, memory_mask, ret_attn=False):
+ def _forward(self, x, memory, memory_mask, ret_attn=False):
"""Compute scaled dot product attention.
Args:
@@ -716,6 +744,28 @@ class MultiHeadedAttentionCrossAtt(nn.Module):
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
return self.forward_attention(v_h, scores, memory_mask, ret_attn=ret_attn)
+ def forward_with_flash_attention(self, x, memory, memory_mask):
+ n_batch, q_s = x.shape[0], x.shape[1]
+ q_h, k_h, v_h = self.forward_qkv(x, memory)
+ attn_out, _ = torch_npu.npu_fused_infer_attention_score(
+ q_h,
+ k_h,
+ v_h,
+ scale=self.d_k ** (-0.5),
+ atten_mask=memory_mask.unsqueeze(1).eq(0).expand(-1, -1, q_s, -1),
+ input_layout="BNSD",
+ num_heads=self.h,
+ sparse_mode=1
+ )
+ attn_out = attn_out.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ return self.linear_out(attn_out)
+
+ def forward(self, x, memory, memory_mask, ret_attn=False, use_flash_attention=False):
+ if use_flash_attention:
+ return self.forward_with_flash_attention(x, memory, memory_mask)
+ else:
+ return self._forward(x, memory, memory_mask, ret_attn)
+
def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
"""Compute scaled dot product attention.
@@ -69,7 +69,7 @@ class EncoderLayerSANM(nn.Module):
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
- def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+ def forward(self, x, mask, use_flash_attention=False, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
@@ -124,6 +124,7 @@ class EncoderLayerSANM(nn.Module):
mask,
mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
+ use_flash_attention=use_flash_attention
)
)
else:
@@ -133,6 +134,7 @@ class EncoderLayerSANM(nn.Module):
mask,
mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
+ use_flash_attention=use_flash_attention
)
)
if not self.normalize_before:
@@ -361,9 +363,10 @@ class SANMEncoder(nn.Module):
def forward(
self,
xs_pad: torch.Tensor,
- ilens: torch.Tensor,
+ masks: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
+ use_flash_attention=False
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
@@ -374,7 +377,6 @@ class SANMEncoder(nn.Module):
Returns:
position embedded tensor and mask
"""
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad = xs_pad * self.output_size() ** 0.5
if self.embed is None:
xs_pad = xs_pad
@@ -397,11 +399,11 @@ class SANMEncoder(nn.Module):
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
- encoder_outs = self.encoders0(xs_pad, masks)
+ encoder_outs = self.encoders0(xs_pad, masks, use_flash_attention)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
- encoder_outs = self.encoders(xs_pad, masks)
+ encoder_outs = self.encoders(xs_pad, masks, use_flash_attention)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
@@ -29,7 +29,7 @@ def cif_wo_hidden(alphas, threshold):
def ts_prediction_lfr6_standard(
- us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
+ us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3, threshold=1.0
):
if not len(char_list):
return "", []
@@ -43,7 +43,7 @@ def ts_prediction_lfr6_standard(
if char_list[-1] == "</s>":
char_list = char_list[:-1]
fire_place = (
- torch.where(peaks >= 1.0 - 1e-4)[0].cpu().numpy() + force_time_shift
+ torch.where(peaks >= threshold - 1e-4)[0].cpu().numpy() + force_time_shift
) # total offset
if len(fire_place) != len(char_list) + 1:
alphas /= alphas.sum() / (len(char_list) + 1)