diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index a864dadd..0384ef91 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -181,13 +181,7 @@ class AutoModel:
 
         set_all_random_seed(kwargs.get("seed", 0))
 
-        device = kwargs.get("device", "cuda")
-        if ((device =="cuda" and not torch.cuda.is_available())
-            or (device == "xpu" and not torch.xpu.is_available())
-            or (device == "mps" and not torch.backends.mps.is_available())
-            or kwargs.get("ngpu", 1) == 0):
-            device = "cpu"
-            kwargs["batch_size"] = 1
+        device = kwargs.get("device", "cpu")
         kwargs["device"] = device
 
         torch.set_num_threads(kwargs.get("ncpu", 4))
@@ -283,11 +277,6 @@ class AutoModel:
             else:
                 print(f"error, init_param does not exist!: {init_param}")
 
-        # fp16
-        if kwargs.get("fp16", False):
-            model.to(torch.float16)
-        elif kwargs.get("bf16", False):
-            model.to(torch.bfloat16)
         model.to(device)
 
         if not kwargs.get("disable_log", True):
@@ -398,6 +387,8 @@ class AutoModel:
 
     def inference_with_vad(self, input, input_len=None, **cfg):
         kwargs = self.kwargs
+        time_stats = {"input_speech_time": 0.0, "end_to_end_time": 0.0, "vad_time": 0.0,
+                      "paraformer_time": 0.0, "punc_time": 0.0}
         # step.1: compute the vad model
         deep_update(self.vad_kwargs, cfg)
         beg_vad = time.time()
@@ -405,6 +396,8 @@ class AutoModel:
             input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg
         )
         end_vad = time.time()
+        time_stats["vad_time"] = end_vad - beg_vad
+        print("Finish segmenting audios within {:.3f} seconds.".format(time_stats["vad_time"]))
 
         #  FIX(gcf): concat the vad clips for sense vocie model for better aed
         if cfg.get("merge_vad", False):
@@ -418,7 +411,6 @@ class AutoModel:
         deep_update(kwargs, cfg)
         batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
         batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
-        kwargs["batch_size"] = batch_size
 
         key_list, data_list = prepare_data_iterator(
             input, input_len=input_len, data_type=kwargs.get("data_type", None)
@@ -463,26 +455,14 @@ class AutoModel:
             # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
 
             all_segments = []
-            max_len_in_batch = 0
-            end_idx = 1
-            for j, _ in enumerate(range(0, n)):
-                # pbar_sample.update(1)
-                sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
-                potential_batch_length = max(max_len_in_batch, sample_length) * (j + 1 - beg_idx)
-                # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
-                if (
-                    j < n - 1
-                    and sample_length < batch_size_threshold_ms
-                    and potential_batch_length < batch_size
-                ):
-                    max_len_in_batch = max(max_len_in_batch, sample_length)
-                    end_idx += 1
-                    continue
-
+            batch_segments = kwargs['batch_size']
+            loop_num = n // batch_segments if n % batch_segments == 0 else n // batch_segments + 1
+            end_idx = batch_segments
+            for j in range(loop_num):
                 speech_j, speech_lengths_j = slice_padding_audio_samples(
                     speech, speech_lengths, sorted_data[beg_idx:end_idx]
                 )
-                results = self.inference(
+                results, meta_data = self.inference_with_asr(
                     speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
                 )
                 if self.spk_model is not None:
@@ -503,8 +483,7 @@ class AutoModel:
                         )
                         results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
                 beg_idx = end_idx
-                end_idx += 1
-                max_len_in_batch = sample_length
+                end_idx += batch_segments
                 if len(results) < 1:
                     continue
                 results_sorted.extend(results)
@@ -556,6 +535,13 @@ class AutoModel:
             if not len(result["text"].strip()):
                 continue
             return_raw_text = kwargs.get("return_raw_text", False)
+
+            end_paraformer = time.time()
+            time_stats["paraformer_time"] = time_stats["paraformer_time"] + end_paraformer - beg_asr_total
+            print("\tFinish recognizing audio using Paraformer within {:.3f} seconds, "
+                  "which contains {} segments and batch_size is {}."
+                  .format(time_stats["paraformer_time"], n, batch_segments))
+
             # step.3 compute punc model
             raw_text = None
             if self.punc_model is not None:
@@ -567,6 +553,9 @@ class AutoModel:
                 if return_raw_text:
                     result["raw_text"] = raw_text
                 result["text"] = punc_res[0]["text"]
+                end_punc = time.time()
+                time_stats["punc_time"] = time_stats["punc_time"] + end_punc - end_paraformer
+                print("\tFinish adding punctuation using PUNC model within {:.3f} seconds.".format(time_stats["punc_time"]))
 
             # speaker embedding cluster after resorted
             if self.spk_model is not None and kwargs.get("return_spk_res", True):
@@ -653,12 +642,14 @@ class AutoModel:
                     f"time_escape: {time_escape_total_per_sample:0.3f}"
                 )
 
-        # end_total = time.time()
+        end_total = time.time()
+        time_stats["end_to_end_time"] = end_total - beg_vad
+        time_stats["input_speech_time"] = time_speech_total_all_samples
         # time_escape_total_all_samples = end_total - beg_total
         # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
         #                      f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
         #                      f"time_escape_all: {time_escape_total_all_samples:0.3f}")
-        return results_ret_list
+        return results_ret_list, time_stats
 
     def export(self, input=None, **cfg):
         """
diff --git a/funasr/models/bicif_paraformer/cif_predictor.py b/funasr/models/bicif_paraformer/cif_predictor.py
index ca98cdc2..eebbe7e2 100644
--- a/funasr/models/bicif_paraformer/cif_predictor.py
+++ b/funasr/models/bicif_paraformer/cif_predictor.py
@@ -29,43 +29,49 @@ def cif(hidden, alphas, threshold):
     batch_size, len_time, hidden_size = hidden.size()
 
     # loop varss
-    integrate = torch.zeros([batch_size], device=hidden.device)
-    frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
+    frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
+
     # intermediate vars along time
     list_fires = []
     list_frames = []
+    integrate = torch.zeros([batch_size], dtype=torch.float32, device=hidden.device)
+    threshold_tensor = torch.ones([batch_size], dtype=torch.float32, device=hidden.device) * threshold
 
     for t in range(len_time):
         alpha = alphas[:, t]
-        distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
+        distribution_completion = threshold_tensor - integrate
 
         integrate += alpha
         list_fires.append(integrate)
 
         fire_place = integrate >= threshold
         integrate = torch.where(
-            fire_place, integrate - torch.ones([batch_size], device=hidden.device), integrate
+            fire_place, integrate - threshold_tensor, integrate
         )
         cur = torch.where(fire_place, distribution_completion, alpha)
         remainds = alpha - cur
 
+        cur = cur.to(hidden.dtype) ## prevent bf16 error
         frame += cur[:, None] * hidden[:, t, :]
         list_frames.append(frame)
         frame = torch.where(
-            fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
+            fire_place[:, None].expand(-1, hidden_size), remainds[:, None] * hidden[:, t, :], frame
         )
 
     fires = torch.stack(list_fires, 1)
     frames = torch.stack(list_frames, 1)
-    list_ls = []
-    len_labels = torch.round(alphas.sum(-1)).int()
-    max_label_len = len_labels.max()
+    fire_idxs = fires >= threshold
+    frame_fires = torch.zeros_like(hidden)
+    max_label_len = frames[0, fire_idxs[0]].size(0)
     for b in range(batch_size):
-        fire = fires[b, :]
-        l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
-        pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
-        list_ls.append(torch.cat([l, pad_l], 0))
-    return torch.stack(list_ls, 0), fires
+        frame_fire = frames[b, fire_idxs[b]]
+        frame_len = frame_fire.size(0)
+        frame_fires[b, :frame_len, :] = frame_fire
+
+        if frame_len >= max_label_len:
+            max_label_len = frame_len
+    frame_fires = frame_fires[:, :max_label_len, :]
+    return frame_fires, fires
 
 
 def cif_wo_hidden(alphas, threshold):
@@ -75,6 +81,7 @@ def cif_wo_hidden(alphas, threshold):
     integrate = torch.zeros([batch_size], device=alphas.device)
     # intermediate vars along time
     list_fires = []
+    threshold_tensor = torch.ones([batch_size], dtype=alphas.dtype, device=alphas.device) * threshold
 
     for t in range(len_time):
         alpha = alphas[:, t]
@@ -85,7 +92,7 @@ def cif_wo_hidden(alphas, threshold):
         fire_place = integrate >= threshold
         integrate = torch.where(
             fire_place,
-            integrate - torch.ones([batch_size], device=alphas.device) * threshold,
+            integrate - threshold_tensor,
             integrate,
         )
 
@@ -117,6 +124,8 @@ class CifPredictorV3(torch.nn.Module):
         super(CifPredictorV3, self).__init__()
 
         self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
+        self.l_order = l_order
+        self.r_order = r_order
         self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
         self.cif_output = torch.nn.Linear(idim, 1)
         self.dropout = torch.nn.Dropout(p=dropout)
@@ -168,7 +177,7 @@ class CifPredictorV3(torch.nn.Module):
         self.smooth_factor2 = smooth_factor2
         self.noise_threshold2 = noise_threshold2
 
-    def forward(
+    def _forward(
         self,
         hidden,
         target_label=None,
@@ -244,11 +253,39 @@ class CifPredictorV3(torch.nn.Module):
             acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
         return acoustic_embeds, token_num, alphas, cif_peak, token_num2
 
+    def process_hidden(
+        self,
+        hidden,
+        mask
+    ):
+        h = hidden
+        context = h.transpose(1, 2)
+        queries = torch.nn.functional.pad(context, (self.l_order, self.r_order))
+        output = torch.relu(self.cif_conv1d(queries))
+        output = output.transpose(1, 2)
+
+        output = self.cif_output(output)
+        alphas = torch.sigmoid(output)
+        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+        mask = mask.transpose(-1, -2)
+        alphas = alphas * mask
+        alphas = alphas.squeeze(-1)
+        mask = mask.squeeze(-1)
+        hidden, alphas, token_num = self.tail_process_fn(hidden, alphas.float(), mask=mask)
+        return hidden, alphas, token_num
+
+    def forward(self, hidden, mask):
+        hidden, alphas, token_num = self.process_hidden(hidden, mask)
+        acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+        token_num_int = torch.max(token_num).type(torch.int32).item()
+        acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+        return acoustic_embeds, token_num, alphas, cif_peak
+
     def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
         h = hidden
         b = hidden.shape[0]
         context = h.transpose(1, 2)
-        queries = self.pad(context)
+        queries = torch.nn.functional.pad(context, (self.l_order, self.r_order))
         output = torch.relu(self.cif_conv1d(queries))
 
         # alphas2 is an extra head for timestamp prediction
@@ -272,23 +309,23 @@ class CifPredictorV3(torch.nn.Module):
         # repeat the mask in T demension to match the upsampled length
         if mask is not None:
             mask2 = (
-                mask.repeat(1, self.upsample_times, 1)
+                mask.expand(-1, self.upsample_times, -1)
                 .transpose(-1, -2)
                 .reshape(alphas2.shape[0], -1)
             )
             mask2 = mask2.unsqueeze(-1)
             alphas2 = alphas2 * mask2
         alphas2 = alphas2.squeeze(-1)
+
+        alphas2 = alphas2.float()
         _token_num = alphas2.sum(-1)
         if token_num is not None:
-            alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
-        # re-downsample
-        ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
-        ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
+            alphas2 *= (token_num / _token_num)[:, None].expand(-1, alphas2.size(1))
+
         # upsampled alphas and cif_peak
         us_alphas = alphas2
         us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
-        return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
+        return None, None, us_alphas, us_cif_peak
 
     def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
         b, t, d = hidden.size()
@@ -296,6 +333,8 @@ class CifPredictorV3(torch.nn.Module):
         if mask is not None:
             zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
             ones_t = torch.ones_like(zeros_t)
+            zeros_t = zeros_t.to(hidden.dtype)
+            ones_t = ones_t.to(hidden.dtype)
             mask_1 = torch.cat([mask, zeros_t], dim=1)
             mask_2 = torch.cat([ones_t, mask], dim=1)
             mask = mask_2 - mask_1
diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py
index 4db9c76c..209ca60a 100644
--- a/funasr/models/bicif_paraformer/model.py
+++ b/funasr/models/bicif_paraformer/model.py
@@ -127,8 +127,8 @@ class BiCifParaformer(Paraformer):
         encoder_out_mask = (
             ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
         ).to(encoder_out.device)
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = (
-            self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id)
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = (
+            self.predictor(encoder_out, encoder_out_mask)
         )
         return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
 
@@ -253,7 +253,6 @@ class BiCifParaformer(Paraformer):
             data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)
         )
         time2 = time.perf_counter()
-        meta_data["load_data"] = f"{time2 - time1:0.3f}"
         speech, speech_lengths = extract_fbank(
             audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
         )
@@ -263,11 +262,20 @@ class BiCifParaformer(Paraformer):
             speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
         )
 
+        if kwargs.get("fp16", False):
+            speech = speech.to(torch.float16)
+            # speech_lengths = speech_lengths.to(torch.float16)
+        elif kwargs.get("bf16", False):
+            speech = speech.to(torch.bfloat16)
+            # speech_lengths = speech_lengths.to(torch.bfloat16)
+
         speech = speech.to(device=kwargs["device"])
         speech_lengths = speech_lengths.to(device=kwargs["device"])
+        meta_data["load_data"] = f"{time.perf_counter() - time1:0.3f}"
 
         # Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        use_flash_attention = kwargs.get("use_flash_attention", False)
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, use_flash_attention=use_flash_attention)
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
 
@@ -283,7 +291,7 @@ class BiCifParaformer(Paraformer):
         if torch.max(pre_token_length) < 1:
             return []
         decoder_outs = self.cal_decoder_with_predictor(
-            encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length
+            encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length, use_flash_attention=use_flash_attention
         )
         decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
 
@@ -338,13 +346,13 @@ class BiCifParaformer(Paraformer):
                 if tokenizer is not None:
                     # Change integer-ids to tokens
                     token = tokenizer.ids2tokens(token_int)
-                    text = tokenizer.tokens2text(token)
 
                     _, timestamp = ts_prediction_lfr6_standard(
                         us_alphas[i][: encoder_out_lens[i] * 3],
                         us_peaks[i][: encoder_out_lens[i] * 3],
                         copy.copy(token),
                         vad_offset=kwargs.get("begin_time", 0),
+                        threshold=kwargs.get("cif_threshold", 1.0)
                     )
 
                     text_postprocessed, time_stamp_postprocessed, word_lists = (
diff --git a/funasr/models/contextual_paraformer/decoder.py b/funasr/models/contextual_paraformer/decoder.py
index ba2ce9ad..57033c1b 100644
--- a/funasr/models/contextual_paraformer/decoder.py
+++ b/funasr/models/contextual_paraformer/decoder.py
@@ -57,6 +57,7 @@ class ContextualDecoderLayer(torch.nn.Module):
         memory,
         memory_mask,
         cache=None,
+        use_flash_attention=False
     ):
         # tgt = self.dropout(tgt)
         if isinstance(tgt, Tuple):
@@ -78,7 +79,7 @@ class ContextualDecoderLayer(torch.nn.Module):
         residual = x
         if self.normalize_before:
             x = self.norm3(x)
-        x = self.src_attn(x, memory, memory_mask)
+        x = self.src_attn(x, memory, memory_mask, use_flash_attention=use_flash_attention)
         x_src_attn = x
 
         x = residual + self.dropout(x)
@@ -102,12 +103,12 @@ class ContextualBiasDecoder(torch.nn.Module):
         self.dropout = torch.nn.Dropout(dropout_rate)
         self.normalize_before = normalize_before
 
-    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None, use_flash_attention=False):
         x = tgt
         if self.src_attn is not None:
             if self.normalize_before:
                 x = self.norm3(x)
-            x = self.dropout(self.src_attn(x, memory, memory_mask))
+            x = self.dropout(self.src_attn(x, memory, memory_mask, use_flash_attention=use_flash_attention))
         return x, tgt_mask, memory, memory_mask, cache
 
 
@@ -254,12 +255,14 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder):
     def forward(
         self,
         hs_pad: torch.Tensor,
-        hlens: torch.Tensor,
+        memory_mask: torch.Tensor,
         ys_in_pad: torch.Tensor,
-        ys_in_lens: torch.Tensor,
+        tgt_mask: torch.Tensor,
         contextual_info: torch.Tensor,
+        contextual_mask: torch.Tensor,
         clas_scale: float = 1.0,
         return_hidden: bool = False,
+        use_flash_attention: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward decoder.
 
@@ -279,20 +282,16 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder):
             olens: (batch, )
         """
         tgt = ys_in_pad
-        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
 
         memory = hs_pad
-        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
 
         x = tgt
-        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
-        _, _, x_self_attn, x_src_attn = self.last_decoder(x, tgt_mask, memory, memory_mask)
+        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask, use_flash_attention)
+        _, _, x_self_attn, x_src_attn = self.last_decoder(x, tgt_mask, memory, memory_mask, use_flash_attention=use_flash_attention)
 
         # contextual paraformer related
-        contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
-        contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
         cx, tgt_mask, _, _, _ = self.bias_decoder(
-            x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask
+            x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask, use_flash_attention=use_flash_attention
         )
 
         if self.bias_output is not None:
diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index fd882202..a4ec1d7d 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -18,6 +18,7 @@ from distutils.version import LooseVersion
 
 from funasr.register import tables
 from funasr.utils import postprocess_utils
+from funasr.models.scama import utils as myutils
 from funasr.metrics.compute_acc import th_accuracy
 from funasr.models.paraformer.model import Paraformer
 from funasr.utils.datadir_writer import DatadirWriter
@@ -303,6 +304,7 @@ class ContextualParaformer(Paraformer):
         ys_pad_lens,
         hw_list=None,
         clas_scale=1.0,
+        use_flash_attention=False
     ):
         if hw_list is None:
             hw_list = [torch.Tensor([1]).long().to(encoder_out.device)]  # empty hotword list
@@ -328,13 +330,21 @@ class ContextualParaformer(Paraformer):
             _, (h_n, _) = self.bias_encoder(hw_embed)
             hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
 
+        # build mask before model inference
+        tgt_mask = myutils.sequence_mask(ys_pad_lens, dtype=encoder_out.dtype, device=encoder_out.device)[:, :, None]
+        memory_mask = myutils.sequence_mask(encoder_out_lens, dtype=encoder_out.dtype, device=encoder_out.device)[:, None, :]
+        contextual_length = torch.Tensor([hw_embed.shape[1]]).int().repeat(encoder_out.shape[0])
+        contextual_mask = myutils.sequence_mask(contextual_length, dtype=encoder_out.dtype, device=encoder_out.device)[:, None, :]
+
         decoder_outs = self.decoder(
             encoder_out,
-            encoder_out_lens,
+            memory_mask,
             sematic_embeds,
-            ys_pad_lens,
+            tgt_mask,
             contextual_info=hw_embed,
+            contextual_mask=contextual_mask,
             clas_scale=clas_scale,
+            use_flash_attention=use_flash_attention
         )
 
         decoder_out = decoder_outs[0]
@@ -385,13 +395,21 @@ class ContextualParaformer(Paraformer):
         speech = speech.to(device=kwargs["device"])
         speech_lengths = speech_lengths.to(device=kwargs["device"])
 
+        if kwargs.get("fp16", False):
+            speech = speech.to(torch.float16)
+            speech_lengths = speech_lengths.to(torch.float16)
+        elif kwargs.get("bf16", False):
+            speech = speech.to(torch.bfloat16)
+            speech_lengths = speech_lengths.to(torch.float16)
+
         # hotword
         self.hotword_list = self.generate_hotwords_list(
             kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend
         )
 
         # Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        use_flash_attention = kwargs.get("use_flash_attention", False)
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, use_flash_attention=use_flash_attention)
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
 
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index abc5dfd1..942e2496 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -88,7 +88,8 @@ class CTTransformer(torch.nn.Module):
         """
         x = self.embed(text)
         # mask = self._target_mask(input)
-        h, _, _ = self.encoder(x, text_lengths)
+        masks = (~make_pad_mask(text_lengths)[:, None, :]).to(x.device)
+        h, _, _ = self.encoder(x, masks)
         y = self.decoder(h)
         return y, None
 
diff --git a/funasr/models/e_paraformer/decoder.py b/funasr/models/e_paraformer/decoder.py
index 7edd91a2..4b58738b 100644
--- a/funasr/models/e_paraformer/decoder.py
+++ b/funasr/models/e_paraformer/decoder.py
@@ -75,7 +75,7 @@ class DecoderLayerSANM(torch.nn.Module):
         self.reserve_attn = False
         self.attn_mat = []
 
-    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+    def forward(self, tgt, tgt_mask, memory, memory_mask=None, use_flash_attention=False, cache=None):
         """Compute decoded features.
 
         Args:
@@ -114,7 +114,7 @@ class DecoderLayerSANM(torch.nn.Module):
                 x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
                 self.attn_mat.append(attn_mat)
             else:
-                x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False)
+                x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False, use_flash_attention=use_flash_attention)
             x = residual + self.dropout(x_src_attn)
             # x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
 
@@ -359,12 +359,13 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
     def forward(
         self,
         hs_pad: torch.Tensor,
-        hlens: torch.Tensor,
+        memory_mask: torch.Tensor,
         ys_in_pad: torch.Tensor,
-        ys_in_lens: torch.Tensor,
+        tgt_mask: torch.Tensor,
         chunk_mask: torch.Tensor = None,
         return_hidden: bool = False,
         return_both: bool = False,
+        use_flash_attention: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward decoder.
 
@@ -384,17 +385,15 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
             olens: (batch, )
         """
         tgt = ys_in_pad
-        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
 
         memory = hs_pad
-        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
         if chunk_mask is not None:
             memory_mask = memory_mask * chunk_mask
             if tgt_mask.size(1) != memory_mask.size(1):
                 memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
 
         x = tgt
-        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
+        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask, use_flash_attention)
         if self.decoders2 is not None:
             x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
         x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index d5970503..368a41a3 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -220,7 +220,7 @@ class CifPredictorV2(torch.nn.Module):
             alphas = torch.sigmoid(output)
             alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
             if mask is not None:
-                mask = mask.transpose(-1, -2).float()
+                mask = mask.transpose(-1, -2)
                 alphas = alphas * mask
             if mask_chunk_predictor is not None:
                 alphas = alphas * mask_chunk_predictor
@@ -347,7 +347,7 @@ class CifPredictorV2(torch.nn.Module):
         b, t, d = hidden.size()
         tail_threshold = self.tail_threshold
         if mask is not None:
-            zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
+            zeros_t = torch.zeros((b, 1), dtype=alphas.dtype, device=alphas.device)
             ones_t = torch.ones_like(zeros_t)
             mask_1 = torch.cat([mask, zeros_t], dim=1)
             mask_2 = torch.cat([ones_t, mask], dim=1)
@@ -689,7 +689,7 @@ def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
     fires[fire_idxs] = 1
     fires = fires + prefix_sum - prefix_sum_floor
     if return_fire_idxs:
-        return fires, fire_idxs
+        return fires.to(dtype), fire_idxs
     return fires
 
 
@@ -731,7 +731,8 @@ def cif_v1(hidden, alphas, threshold):
     frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
     indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
     frame_fires_idxs = indices < batch_len.unsqueeze(1)
-    frame_fires[frame_fires_idxs] = frames
+    # prevent inconsistent length betwen frames and frame_fires
+    frame_fires[frame_fires_idxs] = frames[:max_label_len]
     return frame_fires, fires
 
 
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index 7edd91a2..5c93ff52 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -75,7 +75,7 @@ class DecoderLayerSANM(torch.nn.Module):
         self.reserve_attn = False
         self.attn_mat = []
 
-    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+    def forward(self, tgt, tgt_mask, memory, memory_mask=None, use_flash_attention=False, cache=None):
         """Compute decoded features.
 
         Args:
@@ -114,7 +114,7 @@ class DecoderLayerSANM(torch.nn.Module):
                 x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
                 self.attn_mat.append(attn_mat)
             else:
-                x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False)
+                x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False, use_flash_attention=use_flash_attention)
             x = residual + self.dropout(x_src_attn)
             # x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
 
@@ -365,6 +365,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
         chunk_mask: torch.Tensor = None,
         return_hidden: bool = False,
         return_both: bool = False,
+        use_flash_attention: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward decoder.
 
@@ -394,7 +395,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
                 memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
 
         x = tgt
-        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
+        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask, use_flash_attention)
         if self.decoders2 is not None:
             x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
         x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 85967af3..53136845 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -11,6 +11,7 @@ from torch.cuda.amp import autocast
 from typing import Union, Dict, List, Tuple, Optional
 
 from funasr.register import tables
+from funasr.models.scama import utils as myutils
 from funasr.models.ctc.ctc import CTC
 from funasr.utils import postprocess_utils
 from funasr.metrics.compute_acc import th_accuracy
@@ -259,7 +260,9 @@ class Paraformer(torch.nn.Module):
                 speech, speech_lengths = self.normalize(speech, speech_lengths)
 
         # Forward encoder
-        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+        use_flash_attention = kwargs.get("use_flash_attention", False)
+        masks = (~make_pad_mask(speech_lengths)[:, None, :]).to(speech.device)
+        encoder_out, encoder_out_lens, _ = self.encoder(speech, masks, use_flash_attention=use_flash_attention)
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
 
@@ -276,10 +279,12 @@ class Paraformer(torch.nn.Module):
         return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
 
     def cal_decoder_with_predictor(
-        self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+        self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, use_flash_attention=False
     ):
 
-        decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
+        tgt_mask = myutils.sequence_mask(ys_pad_lens, dtype=encoder_out.dtype, device=encoder_out.device)[:, :, None]
+        memory_mask = myutils.sequence_mask(encoder_out_lens, dtype=encoder_out.dtype, device=encoder_out.device)[:, None, :]
+        decoder_outs = self.decoder(encoder_out, memory_mask, sematic_embeds, tgt_mask, use_flash_attention=use_flash_attention)
         decoder_out = decoder_outs[0]
         decoder_out = torch.log_softmax(decoder_out, dim=-1)
         return decoder_out, ys_pad_lens
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index 47d60cb6..cd5e737d 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -10,6 +10,7 @@ import math
 
 import numpy
 import torch
+import torch_npu
 from torch import nn
 from typing import Optional, Tuple
 
@@ -203,6 +204,8 @@ class MultiHeadedAttentionSANM(nn.Module):
             left_padding = left_padding + sanm_shfit
         right_padding = kernel_size - 1 - left_padding
         self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+        self.left_padding = left_padding
+        self.right_padding = right_padding
 
     def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
         b, t, d = inputs.size()
@@ -213,7 +216,7 @@ class MultiHeadedAttentionSANM(nn.Module):
             inputs = inputs * mask
 
         x = inputs.transpose(1, 2)
-        x = self.pad_fn(x)
+        x = F.pad(x, (self.left_padding, self.right_padding))
         x = self.fsmn_block(x)
         x = x.transpose(1, 2)
         x += inputs
@@ -289,7 +292,7 @@ class MultiHeadedAttentionSANM(nn.Module):
 
         return self.linear_out(x)  # (batch, time1, d_model)
 
-    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+    def _forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
         """Compute scaled dot product attention.
 
         Args:
@@ -310,6 +313,29 @@ class MultiHeadedAttentionSANM(nn.Module):
         att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
         return att_outs + fsmn_memory
 
+    def forward_with_flash_attention(self, x, mask, mask_shfit_chunk=None):
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
+        attn_out, _ = torch_npu.npu_fused_infer_attention_score(
+            q,
+            k,
+            v,
+            scale=self.d_k ** (-0.5),
+            atten_mask=mask.unsqueeze(1).eq(0).expand(-1, -1, q.size(1), -1),
+            input_layout="BSH",
+            num_heads=self.h,
+            sparse_mode=1
+        )
+        attn_outs = self.linear_out(attn_out)
+        return attn_outs + fsmn_memory
+
+    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None, use_flash_attention=False):
+        if use_flash_attention:
+            return self.forward_with_flash_attention(x, mask, mask_shfit_chunk)
+        else:
+            return self._forward(x, mask, mask_shfit_chunk, mask_att_chunk_encoder)
+
     def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
         """Compute scaled dot product attention.
 
@@ -494,6 +520,8 @@ class MultiHeadedAttentionSANMDecoder(nn.Module):
             left_padding = left_padding + sanm_shfit
         right_padding = kernel_size - 1 - left_padding
         self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+        self.left_padding = left_padding
+        self.right_padding = right_padding
         self.kernel_size = kernel_size
 
     def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
@@ -522,7 +550,7 @@ class MultiHeadedAttentionSANMDecoder(nn.Module):
         if cache is None:
             # print("in fsmn, cache is None, x", x.size())
 
-            x = self.pad_fn(x)
+            x = F.pad(x, (self.left_padding, self.right_padding))
             if not self.training:
                 cache = x
         else:
@@ -697,7 +725,7 @@ class MultiHeadedAttentionCrossAtt(nn.Module):
             return self.linear_out(x), attn  # (batch, time1, d_model)
         return self.linear_out(x)  # (batch, time1, d_model)
 
-    def forward(self, x, memory, memory_mask, ret_attn=False):
+    def _forward(self, x, memory, memory_mask, ret_attn=False):
         """Compute scaled dot product attention.
 
         Args:
@@ -716,6 +744,28 @@ class MultiHeadedAttentionCrossAtt(nn.Module):
         scores = torch.matmul(q_h, k_h.transpose(-2, -1))
         return self.forward_attention(v_h, scores, memory_mask, ret_attn=ret_attn)
 
+    def forward_with_flash_attention(self, x, memory, memory_mask):
+        n_batch, q_s = x.shape[0], x.shape[1]
+        q_h, k_h, v_h = self.forward_qkv(x, memory)
+        attn_out, _ = torch_npu.npu_fused_infer_attention_score(
+            q_h,
+            k_h,
+            v_h,
+            scale=self.d_k ** (-0.5),
+            atten_mask=memory_mask.unsqueeze(1).eq(0).expand(-1, -1, q_s, -1),
+            input_layout="BNSD",
+            num_heads=self.h,
+            sparse_mode=1
+        )
+        attn_out = attn_out.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        return self.linear_out(attn_out)
+
+    def forward(self, x, memory, memory_mask, ret_attn=False, use_flash_attention=False):
+        if use_flash_attention:
+            return self.forward_with_flash_attention(x, memory, memory_mask)
+        else:
+            return self._forward(x, memory, memory_mask, ret_attn)
+
     def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
         """Compute scaled dot product attention.
 
diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index 0d39ca74..b4301cff 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -69,7 +69,7 @@ class EncoderLayerSANM(nn.Module):
         self.stochastic_depth_rate = stochastic_depth_rate
         self.dropout_rate = dropout_rate
 
-    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+    def forward(self, x, mask, use_flash_attention=False, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
         """Compute encoded features.
 
         Args:
@@ -124,6 +124,7 @@ class EncoderLayerSANM(nn.Module):
                         mask,
                         mask_shfit_chunk=mask_shfit_chunk,
                         mask_att_chunk_encoder=mask_att_chunk_encoder,
+                        use_flash_attention=use_flash_attention
                     )
                 )
             else:
@@ -133,6 +134,7 @@ class EncoderLayerSANM(nn.Module):
                         mask,
                         mask_shfit_chunk=mask_shfit_chunk,
                         mask_att_chunk_encoder=mask_att_chunk_encoder,
+                        use_flash_attention=use_flash_attention
                     )
                 )
         if not self.normalize_before:
@@ -361,9 +363,10 @@ class SANMEncoder(nn.Module):
     def forward(
         self,
         xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
+        masks: torch.Tensor,
         prev_states: torch.Tensor = None,
         ctc: CTC = None,
+        use_flash_attention=False
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         """Embed positions in tensor.
 
@@ -374,7 +377,6 @@ class SANMEncoder(nn.Module):
         Returns:
             position embedded tensor and mask
         """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
         xs_pad = xs_pad * self.output_size() ** 0.5
         if self.embed is None:
             xs_pad = xs_pad
@@ -397,11 +399,11 @@ class SANMEncoder(nn.Module):
             xs_pad = self.embed(xs_pad)
 
         # xs_pad = self.dropout(xs_pad)
-        encoder_outs = self.encoders0(xs_pad, masks)
+        encoder_outs = self.encoders0(xs_pad, masks, use_flash_attention)
         xs_pad, masks = encoder_outs[0], encoder_outs[1]
         intermediate_outs = []
         if len(self.interctc_layer_idx) == 0:
-            encoder_outs = self.encoders(xs_pad, masks)
+            encoder_outs = self.encoders(xs_pad, masks, use_flash_attention)
             xs_pad, masks = encoder_outs[0], encoder_outs[1]
         else:
             for layer_idx, encoder_layer in enumerate(self.encoders):
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 37ce8868..14737b90 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -29,7 +29,7 @@ def cif_wo_hidden(alphas, threshold):
 
 
 def ts_prediction_lfr6_standard(
-    us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
+    us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3, threshold=1.0
 ):
     if not len(char_list):
         return "", []
@@ -43,7 +43,7 @@ def ts_prediction_lfr6_standard(
     if char_list[-1] == "</s>":
         char_list = char_list[:-1]
     fire_place = (
-        torch.where(peaks >= 1.0 - 1e-4)[0].cpu().numpy() + force_time_shift
+        torch.where(peaks >= threshold - 1e-4)[0].cpu().numpy() + force_time_shift
     )  # total offset
     if len(fire_place) != len(char_list) + 1:
         alphas /= alphas.sum() / (len(char_list) + 1)