diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py
index 73990fa..68c8299 100644
--- a/wenet/transformer/asr_model.py
+++ b/wenet/transformer/asr_model.py
@@ -245,7 +245,7 @@ class ASRModel(torch.nn.Module):
             top_k_logp, top_k_index = logp.topk(beam_size)  # (B*N, N)
             top_k_logp = mask_finished_scores(top_k_logp, end_flag)
             top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
-            # 2.3 Second beam prune: select topk score with history
+            # 2.3 Seconde beam prune: select topk score with history
             scores = scores + top_k_logp  # (B*N, N), broadcast add
             scores = scores.view(batch_size, beam_size * beam_size)  # (B, N*N)
             scores, offset_k_index = scores.topk(k=beam_size)  # (B, N)
@@ -570,13 +570,12 @@ class ASRModel(torch.nn.Module):
     def forward_encoder_chunk(
         self,
         xs: torch.Tensor,
-        offset: int,
-        required_cache_size: int,
+        offset: torch.Tensor,
+        required_cache_size: torch.Tensor,
         subsampling_cache: Optional[torch.Tensor] = None,
-        elayers_output_cache: Optional[List[torch.Tensor]] = None,
-        conformer_cnn_cache: Optional[List[torch.Tensor]] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor],
-               List[torch.Tensor]]:
+        elayers_output_cache: Optional[torch.Tensor] = None,
+        conformer_cnn_cache: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         """ Export interface for c++ call, give input chunk xs, and return
             output from time 0 to current chunk.
 
@@ -675,6 +674,10 @@ class ASRModel(torch.nn.Module):
         r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
         return decoder_out, r_decoder_out
 
+    @torch.jit.export
+    def test(self,) -> str:
+        return "test"
+
 
 def init_asr_model(configs):
     if configs['cmvn_file'] is not None:
diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py
index f41f7e4..40c1a57 100644
--- a/wenet/transformer/decoder.py
+++ b/wenet/transformer/decoder.py
@@ -57,8 +57,7 @@ class TransformerDecoder(torch.nn.Module):
         if input_layer == "embed":
             self.embed = torch.nn.Sequential(
                 torch.nn.Embedding(vocab_size, attention_dim),
-                PositionalEncoding(attention_dim, positional_dropout_rate),
-            )
+                PositionalEncoding(attention_dim, positional_dropout_rate))
         else:
             raise ValueError(f"only 'embed' is supported: {input_layer}")
 
@@ -81,6 +80,10 @@ class TransformerDecoder(torch.nn.Module):
                 concat_after,
             ) for _ in range(self.num_blocks)
         ])
+        self.onnx_mode = False
+
+    def set_onnx_mode(self, onnx_mode=False):
+        self.onnx_mode = onnx_mode
 
     def forward(
         self,
@@ -111,13 +114,15 @@ class TransformerDecoder(torch.nn.Module):
         tgt = ys_in_pad
 
         # tgt_mask: (B, 1, L)
-        tgt_mask = (~make_pad_mask(ys_in_lens).unsqueeze(1)).to(tgt.device)
+        tgt_mask = (~make_pad_mask(ys_in_lens, ys_in_pad).unsqueeze(1)).to(tgt.device)
         # m: (1, L, L)
         m = subsequent_mask(tgt_mask.size(-1),
                             device=tgt_mask.device).unsqueeze(0)
         # tgt_mask: (B, L, L)
-        tgt_mask = tgt_mask & m
-        x, _ = self.embed(tgt)
+        # tgt_mask = tgt_mask & m
+        tgt_mask = torch.mul(tgt_mask, m)
+        x = self.embed[0](tgt)
+        x, _ = self.embed[1](x, onnx_mode=self.onnx_mode)
         for layer in self.decoders:
             x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
                                                      memory_mask)
@@ -225,6 +230,13 @@ class BiTransformerDecoder(torch.nn.Module):
             self_attention_dropout_rate, src_attention_dropout_rate,
             input_layer, use_output_layer, normalize_before, concat_after)
 
+        self.onnx_mode = False
+
+    def set_onnx_mode(self, onnx_mode=False):
+        self.onnx_mode = onnx_mode
+        self.left_decoder.set_onnx_mode(onnx_mode)
+        self.right_decoder.set_onnx_mode(onnx_mode)
+
     def forward(
         self,
         memory: torch.Tensor,
@@ -252,6 +264,7 @@ class BiTransformerDecoder(torch.nn.Module):
                     if use_output_layer is True,
                 olens: (batch, )
         """
+        reverse_weight = 0.3
         l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
                                           ys_in_lens)
         r_x = torch.tensor(0.0)
diff --git a/wenet/transformer/decoder_layer.py b/wenet/transformer/decoder_layer.py
index 25bb281..59dd174 100644
--- a/wenet/transformer/decoder_layer.py
+++ b/wenet/transformer/decoder_layer.py
@@ -17,7 +17,7 @@ class DecoderLayer(nn.Module):
         size (int): Input dimension.
         self_attn (torch.nn.Module): Self-attention module instance.
             `MultiHeadedAttention` instance can be used as the argument.
-        src_attn (torch.nn.Module): Inter-attention module instance.
+        src_attn (torch.nn.Module): Self-attention module instance.
             `MultiHeadedAttention` instance can be used as the argument.
         feed_forward (torch.nn.Module): Feed-forward module instance.
             `PositionwiseFeedForward` instance can be used as the argument.
@@ -61,7 +61,8 @@ class DecoderLayer(nn.Module):
         tgt_mask: torch.Tensor,
         memory: torch.Tensor,
         memory_mask: torch.Tensor,
-        cache: Optional[torch.Tensor] = None
+        cache: Optional[torch.Tensor] = None,
+        onnx_mode: Optional[bool] = False
     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         """Compute decoded features.
 
diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py
index a47afd9..0a6794c 100644
--- a/wenet/transformer/embedding.py
+++ b/wenet/transformer/embedding.py
@@ -9,6 +9,7 @@ import math
 from typing import Tuple
 
 import torch
+from wenet.transformer.slice_helper import slice_helper2
 
 
 class PositionalEncoding(torch.nn.Module):
@@ -45,7 +46,8 @@ class PositionalEncoding(torch.nn.Module):
 
     def forward(self,
                 x: torch.Tensor,
-                offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
+                offset: torch.Tensor = torch.tensor(0),
+                onnx_mode: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
         """Add positional encoding.
 
         Args:
@@ -56,13 +58,21 @@ class PositionalEncoding(torch.nn.Module):
             torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
             torch.Tensor: for compatibility to RelPositionalEncoding
         """
-        assert offset + x.size(1) < self.max_len
+        # assert offset + x.size(1) < self.max_len
         self.pe = self.pe.to(x.device)
-        pos_emb = self.pe[:, offset:offset + x.size(1)]
+        # pos_emb = self.pe[:, offset:offset + x.size(1)]
+        if onnx_mode:
+            pos_emb = slice_helper2(self.pe, offset, offset + x.size(1))
+        else:
+            pos_emb = self.pe[:, offset:offset + x.size(1)]
         x = x * self.xscale + pos_emb
         return self.dropout(x), self.dropout(pos_emb)
 
-    def position_encoding(self, offset: int, size: int) -> torch.Tensor:
+    def position_encoding(self, 
+                            offset: torch.Tensor, 
+                            size: torch.Tensor,
+                            onnx_mode: bool = False,
+                            ) -> torch.Tensor:
         """ For getting encoding in a streaming fashion
 
         Attention!!!!!
@@ -79,7 +89,12 @@ class PositionalEncoding(torch.nn.Module):
             torch.Tensor: Corresponding encoding
         """
         assert offset + size < self.max_len
-        return self.dropout(self.pe[:, offset:offset + size])
+        if onnx_mode:
+            # pe = torch.cat([self.pe[:, [0]], slice_helper2(self.pe, offset, offset + size - 1)], dim=1)
+            pe = slice_helper2(self.pe, offset, offset + size)
+        else:
+            pe = self.pe[:, offset:offset + size]
+        return self.dropout(pe)
 
 
 class RelPositionalEncoding(PositionalEncoding):
@@ -96,7 +111,8 @@ class RelPositionalEncoding(PositionalEncoding):
 
     def forward(self,
                 x: torch.Tensor,
-                offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
+                offset: torch.Tensor,
+                onnx_mode: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
         """Compute positional encoding.
         Args:
             x (torch.Tensor): Input tensor (batch, time, `*`).
@@ -104,10 +120,16 @@ class RelPositionalEncoding(PositionalEncoding):
             torch.Tensor: Encoded tensor (batch, time, `*`).
             torch.Tensor: Positional embedding tensor (1, time, `*`).
         """
-        assert offset + x.size(1) < self.max_len
+        # assert offset + x.size(1) < self.max_len
         self.pe = self.pe.to(x.device)
         x = x * self.xscale
-        pos_emb = self.pe[:, offset:offset + x.size(1)]
+        if onnx_mode:
+            # end = offset.item() + x.size(1)
+            # pos_emb = torch.index_select(self.pe, 1, torch.tensor(range(x.size(1))))
+            pos_emb = slice_helper2(self.pe, offset, offset + x.size(1))
+            # pos_emb = slice_helper3(pos_emb, x.size(1))
+        else:
+            pos_emb = self.pe[:, offset:offset + x.size(1)]
         return self.dropout(x), self.dropout(pos_emb)
 
 
diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py
index e342ed4..9b4f968 100644
--- a/wenet/transformer/encoder.py
+++ b/wenet/transformer/encoder.py
@@ -6,6 +6,8 @@
 """Encoder definition."""
 from typing import Tuple, List, Optional
 
+import numpy as np
+import onnxruntime
 import torch
 from typeguard import check_argument_types
 
@@ -18,6 +20,7 @@ from wenet.transformer.embedding import NoPositionalEncoding
 from wenet.transformer.encoder_layer import TransformerEncoderLayer
 from wenet.transformer.encoder_layer import ConformerEncoderLayer
 from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from wenet.transformer.slice_helper import slice_helper3, get_next_cache_start
 from wenet.transformer.subsampling import Conv2dSubsampling4
 from wenet.transformer.subsampling import Conv2dSubsampling6
 from wenet.transformer.subsampling import Conv2dSubsampling8
@@ -26,6 +29,8 @@ from wenet.utils.common import get_activation
 from wenet.utils.mask import make_pad_mask
 from wenet.utils.mask import add_optional_chunk_mask
 
+def to_numpy(x):
+    return x.detach().numpy()
 
 class BaseEncoder(torch.nn.Module):
     def __init__(
@@ -116,10 +121,14 @@ class BaseEncoder(torch.nn.Module):
         self.static_chunk_size = static_chunk_size
         self.use_dynamic_chunk = use_dynamic_chunk
         self.use_dynamic_left_chunk = use_dynamic_left_chunk
+        self.onnx_mode = False
 
     def output_size(self) -> int:
         return self._output_size
 
+    def set_onnx_mode(self, onnx_mode=False):
+        self.onnx_mode = onnx_mode
+
     def forward(
         self,
         xs: torch.Tensor,
@@ -130,7 +139,7 @@ class BaseEncoder(torch.nn.Module):
         """Embed positions in tensor.
 
         Args:
-            xs: padded input tensor (B, T, D)
+            xs: padded input tensor (B, L, D)
             xs_lens: input length (B)
             decoding_chunk_size: decoding chunk size for dynamic chunk
                 0: default for training, use random dynamic chunk.
@@ -141,16 +150,18 @@ class BaseEncoder(torch.nn.Module):
                 >=0: use num_decoding_left_chunks
                 <0: use all left chunks
         Returns:
-            encoder output tensor xs, and subsampled masks
-            xs: padded output tensor (B, T' ~= T/subsample_rate, D)
-            masks: torch.Tensor batch padding mask after subsample
-                (B, 1, T' ~= T/subsample_rate)
+            encoder output tensor, lens and mask
         """
-        masks = ~make_pad_mask(xs_lens).unsqueeze(1)  # (B, 1, T)
+        decoding_chunk_size = 1
+        num_decoding_left_chunks = 1
+        self.use_dynamic_chunk = False
+        self.use_dynamic_left_chunk = False
+        self.static_chunk_size = 0
+        masks = ~make_pad_mask(xs_lens, xs).unsqueeze(1)  # (B, 1, L)
         if self.global_cmvn is not None:
             xs = self.global_cmvn(xs)
         xs, pos_emb, masks = self.embed(xs, masks)
-        mask_pad = masks  # (B, 1, T/subsample_rate)
+        mask_pad = masks
         chunk_masks = add_optional_chunk_mask(xs, masks,
                                               self.use_dynamic_chunk,
                                               self.use_dynamic_left_chunk,
@@ -169,13 +180,12 @@ class BaseEncoder(torch.nn.Module):
     def forward_chunk(
         self,
         xs: torch.Tensor,
-        offset: int,
-        required_cache_size: int,
+        offset_tensor: torch.Tensor = torch.tensor(0),
+        required_cache_size_tensor: torch.Tensor = torch.tensor(0),
         subsampling_cache: Optional[torch.Tensor] = None,
-        elayers_output_cache: Optional[List[torch.Tensor]] = None,
-        conformer_cnn_cache: Optional[List[torch.Tensor]] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor],
-               List[torch.Tensor]]:
+        elayers_output_cache: Optional[torch.Tensor] = None,
+        conformer_cnn_cache: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         """ Forward just one chunk
 
         Args:
@@ -199,6 +209,7 @@ class BaseEncoder(torch.nn.Module):
             List[torch.Tensor]: conformer cnn cache
 
         """
+        required_cache_size_tensor = torch.tensor(-1)
         assert xs.size(0) == 1
         # tmp_masks is just for interface compatibility
         tmp_masks = torch.ones(1,
@@ -208,30 +219,53 @@ class BaseEncoder(torch.nn.Module):
         tmp_masks = tmp_masks.unsqueeze(1)
         if self.global_cmvn is not None:
             xs = self.global_cmvn(xs)
-        xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
+        # if self.onnx_mode:
+        #     offset_tensor = offset_tensor - torch.tensor(1)
+        xs, pos_emb, _ = self.embed(xs, tmp_masks, offset_tensor, self.onnx_mode)
         if subsampling_cache is not None:
             cache_size = subsampling_cache.size(1)
             xs = torch.cat((subsampling_cache, xs), dim=1)
         else:
             cache_size = 0
-        pos_emb = self.embed.position_encoding(offset - cache_size, xs.size(1))
-        if required_cache_size < 0:
-            next_cache_start = 0
-        elif required_cache_size == 0:
-            next_cache_start = xs.size(1)
+        # if self.onnx_mode:
+        #     cache_size = cache_size - 1
+        # if self.onnx_mode:
+        #     # subsampling_cache append dummy var, remove it here
+        #     xs = xs[:, 1:, :]
+        #     cache_size = cache_size - 1
+        if isinstance(xs.size(1), int):
+            xs_size_1 = torch.tensor(xs.size(1))
         else:
-            next_cache_start = max(xs.size(1) - required_cache_size, 0)
-        r_subsampling_cache = xs[:, next_cache_start:, :]
+            xs_size_1 = xs.size(1).clone().detach()
+        pos_emb = self.embed.position_encoding(offset_tensor - cache_size, 
+                                            xs_size_1,
+                                            self.onnx_mode)
+        next_cache_start = get_next_cache_start(required_cache_size_tensor, xs)
+        r_subsampling_cache = slice_helper3(xs, next_cache_start)
+        # if self.onnx_mode:
+        #     next_cache_start_1 = get_next_cache_start(required_cache_size_tensor, xs)
+        #     r_subsampling_cache = slice_helper3(xs, next_cache_start_1)
+        # else:
+        #     required_cache_size = required_cache_size_tensor.detach().item()
+        #     if required_cache_size < 0:
+        #         next_cache_start = 0
+        #     elif required_cache_size == 0:
+        #         next_cache_start = xs.size(1)
+        #     else:
+        #         next_cache_start = max(xs.size(1) - required_cache_size, 0)
+        #     r_subsampling_cache = xs[:, next_cache_start:, :]
         # Real mask for transformer/conformer layers
         masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
         masks = masks.unsqueeze(1)
-        r_elayers_output_cache = []
-        r_conformer_cnn_cache = []
+        r_elayers_output_cache = None
+        r_conformer_cnn_cache = None
         for i, layer in enumerate(self.encoders):
             if elayers_output_cache is None:
                 attn_cache = None
             else:
                 attn_cache = elayers_output_cache[i]
+            # if self.onnx_mode and attn_cache is not None:
+            #     attn_cache = attn_cache[:, 1:, :]
             if conformer_cnn_cache is None:
                 cnn_cache = None
             else:
@@ -240,13 +274,32 @@ class BaseEncoder(torch.nn.Module):
                                          masks,
                                          pos_emb,
                                          output_cache=attn_cache,
-                                         cnn_cache=cnn_cache)
-            r_elayers_output_cache.append(xs[:, next_cache_start:, :])
-            r_conformer_cnn_cache.append(new_cnn_cache)
+                                         cnn_cache=cnn_cache,
+                                         onnx_mode=self.onnx_mode)
+            if self.onnx_mode:
+                layer_output_cache = slice_helper3(xs, next_cache_start)
+            else:
+                layer_output_cache = xs[:, next_cache_start:, :]
+            if i == 0:
+                r_elayers_output_cache = layer_output_cache.unsqueeze(0)
+                r_conformer_cnn_cache = new_cnn_cache.unsqueeze(0)
+            else:
+                # r_elayers_output_cache.append(xs[:, next_cache_start:, :])
+                r_elayers_output_cache = torch.cat((r_elayers_output_cache, layer_output_cache.unsqueeze(0)), 0)
+                # r_conformer_cnn_cache.append(new_cnn_cache)
+                r_conformer_cnn_cache = torch.cat((r_conformer_cnn_cache, new_cnn_cache.unsqueeze(0)), 0)
         if self.normalize_before:
             xs = self.after_norm(xs)
-
-        return (xs[:, cache_size:, :], r_subsampling_cache,
+        if self.onnx_mode:
+            cache_size = cache_size - 1
+            if isinstance(cache_size, int):
+                cache_size_1 = torch.tensor(cache_size)
+            else:
+                cache_size_1 = cache_size.clone().detach()
+            output = slice_helper3(xs, cache_size_1)
+        else:
+            output = xs[:, cache_size:, :]
+        return (output, r_subsampling_cache,
                 r_elayers_output_cache, r_conformer_cnn_cache)
 
     def forward_chunk_by_chunk(
@@ -290,24 +343,54 @@ class BaseEncoder(torch.nn.Module):
         decoding_window = (decoding_chunk_size - 1) * subsampling + context
         num_frames = xs.size(1)
         subsampling_cache: Optional[torch.Tensor] = None
-        elayers_output_cache: Optional[List[torch.Tensor]] = None
-        conformer_cnn_cache: Optional[List[torch.Tensor]] = None
+        elayers_output_cache: Optional[torch.Tensor] = None
+        conformer_cnn_cache: Optional[torch.Tensor] = None
         outputs = []
         offset = 0
         required_cache_size = decoding_chunk_size * num_decoding_left_chunks
+        print("required_cache_size:", required_cache_size)
+        encoder_session = onnxruntime.InferenceSession("onnx/encoder.onnx")
+
+        subsampling_cache_onnx = torch.zeros(1, 1, 256, requires_grad=False)
+        elayers_output_cache_onnx = torch.zeros(12, 1, 1, 256, requires_grad=False)
+        conformer_cnn_cache_onnx = torch.zeros(12, 1, 256, 7, requires_grad=False)
 
         # Feed forward overlap input step by step
         for cur in range(0, num_frames - context + 1, stride):
             end = min(cur + decoding_window, num_frames)
             chunk_xs = xs[:, cur:end, :]
+            
+            if offset > 0:
+                offset = offset - 1
             (y, subsampling_cache, elayers_output_cache,
-             conformer_cnn_cache) = self.forward_chunk(chunk_xs, offset,
-                                                       required_cache_size,
+             conformer_cnn_cache) = self.forward_chunk(chunk_xs, torch.tensor(offset),
+                                                       torch.tensor(required_cache_size),
                                                        subsampling_cache,
                                                        elayers_output_cache,
                                                        conformer_cnn_cache)
-            outputs.append(y)
+            
+            offset = offset + 1
+            encoder_inputs = {
+                encoder_session.get_inputs()[0].name: chunk_xs.numpy(),
+                encoder_session.get_inputs()[1].name: np.array(offset),
+                encoder_session.get_inputs()[2].name: subsampling_cache_onnx.numpy(),
+                encoder_session.get_inputs()[3].name: elayers_output_cache_onnx.numpy(),
+                encoder_session.get_inputs()[4].name: conformer_cnn_cache_onnx.numpy(),
+            }
+            ort_outs = encoder_session.run(None, encoder_inputs)
+            y_onnx, subsampling_cache_onnx, elayers_output_cache_onnx, conformer_cnn_cache_onnx = \
+                torch.from_numpy(ort_outs[0][:, 1:, :]), torch.from_numpy(ort_outs[1]), \
+                torch.from_numpy(ort_outs[2]), torch.from_numpy(ort_outs[3])
+
+            np.testing.assert_allclose(to_numpy(y), ort_outs[0][:, 1:, :], rtol=1e-03, atol=1e-03)
+            np.testing.assert_allclose(to_numpy(subsampling_cache), ort_outs[1][:, 1:, :], rtol=1e-03, atol=1e-03)
+            np.testing.assert_allclose(to_numpy(elayers_output_cache), ort_outs[2][:, :, 1:, :], rtol=1e-03, atol=1e-03)
+            np.testing.assert_allclose(to_numpy(conformer_cnn_cache), ort_outs[3], rtol=1e-03, atol=1e-03)
+
+            outputs.append(y_onnx)
+            # outputs.append(y)
             offset += y.size(1)
+            # break
         ys = torch.cat(outputs, 1)
         masks = torch.ones(1, ys.size(1), device=ys.device, dtype=torch.bool)
         masks = masks.unsqueeze(1)
diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py
index db8696d..0be079c 100644
--- a/wenet/transformer/encoder_layer.py
+++ b/wenet/transformer/encoder_layer.py
@@ -9,6 +9,7 @@ from typing import Optional, Tuple
 
 import torch
 from torch import nn
+from wenet.transformer.slice_helper import slice_helper
 
 
 class TransformerEncoderLayer(nn.Module):
@@ -53,6 +54,9 @@ class TransformerEncoderLayer(nn.Module):
         # concat_linear may be not used in forward fuction,
         # but will be saved in the *.pt
         self.concat_linear = nn.Linear(size + size, size)
+    
+    def set_onnx_mode(self, onnx_mode=False):
+        self.onnx_mode = onnx_mode
 
     def forward(
         self,
@@ -92,9 +96,14 @@ class TransformerEncoderLayer(nn.Module):
             assert output_cache.size(2) == self.size
             assert output_cache.size(1) < x.size(1)
             chunk = x.size(1) - output_cache.size(1)
-            x_q = x[:, -chunk:, :]
-            residual = residual[:, -chunk:, :]
-            mask = mask[:, -chunk:, :]
+            if self.onnx_mode:
+                x_q = slice_helper(x, chunk)
+                residual = slice_helper(residual, chunk)
+                mask = slice_helper(mask, chunk)
+            else:
+                x_q = x[:, -chunk:, :]
+                residual = residual[:, -chunk:, :]
+                mask = mask[:, -chunk:, :]
 
         if self.concat_after:
             x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
@@ -184,6 +193,7 @@ class ConformerEncoderLayer(nn.Module):
         mask_pad: Optional[torch.Tensor] = None,
         output_cache: Optional[torch.Tensor] = None,
         cnn_cache: Optional[torch.Tensor] = None,
+        onnx_mode: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         """Compute encoded features.
 
@@ -193,7 +203,6 @@ class ConformerEncoderLayer(nn.Module):
             pos_emb (torch.Tensor): positional encoding, must not be None
                 for ConformerEncoderLayer.
             mask_pad (torch.Tensor): batch padding mask used for conv module.
-                (#batch, 1,time)
             output_cache (torch.Tensor): Cache tensor of the output
                 (#batch, time2, size), time2 < time in x.
             cnn_cache (torch.Tensor): Convolution cache in conformer layer
@@ -202,6 +211,14 @@ class ConformerEncoderLayer(nn.Module):
             torch.Tensor: Mask tensor (#batch, time).
         """
 
+        if onnx_mode:
+            x = x[:, 1:, :]
+            mask = mask[:, :, 1:]
+            # pos_emb_ = pos_emb[:, 1:, :]
+            pos_emb_ = pos_emb[:, :-1, :]
+        else:
+            pos_emb_ = pos_emb
+
         # whether to use macaron style
         if self.feed_forward_macaron is not None:
             residual = x
@@ -223,12 +240,26 @@ class ConformerEncoderLayer(nn.Module):
             assert output_cache.size(0) == x.size(0)
             assert output_cache.size(2) == self.size
             assert output_cache.size(1) < x.size(1)
-            chunk = x.size(1) - output_cache.size(1)
-            x_q = x[:, -chunk:, :]
-            residual = residual[:, -chunk:, :]
-            mask = mask[:, -chunk:, :]
 
-        x_att = self.self_attn(x_q, x, x, mask, pos_emb)
+            # chunk = x.size(1) - output_cache.size(1)
+            if onnx_mode:
+                chunk = x.size(1) - output_cache.size(1) + 1
+                if isinstance(chunk, int):
+                    chunk_1 = torch.tensor(chunk)
+                else:
+                    chunk_1 = chunk.clone().detach()
+                # chunk = torch.tensor(chunk)
+                # print(type(chunk))
+                x_q = slice_helper(x, chunk_1)
+                residual = slice_helper(residual, chunk_1)
+                mask = slice_helper(mask, chunk_1)
+            else:
+                chunk = x.size(1) - output_cache.size(1)
+                x_q = x[:, -chunk:, :]
+                residual = residual[:, -chunk:, :]
+                mask = mask[:, -chunk:, :]
+
+        x_att = self.self_attn(x_q, x, x, mask, pos_emb_)
         if self.concat_after:
             x_concat = torch.cat((x, x_att), dim=-1)
             x = residual + self.concat_linear(x_concat)
diff --git a/wenet/transformer/subsampling.py b/wenet/transformer/subsampling.py
index b890f70..a978424 100644
--- a/wenet/transformer/subsampling.py
+++ b/wenet/transformer/subsampling.py
@@ -16,8 +16,11 @@ class BaseSubsampling(torch.nn.Module):
         self.right_context = 0
         self.subsampling_rate = 1
 
-    def position_encoding(self, offset: int, size: int) -> torch.Tensor:
-        return self.pos_enc.position_encoding(offset, size)
+    def position_encoding(self, 
+                        offset: torch.Tensor, 
+                        size: torch.Tensor,
+                        onnx_mode: bool = False) -> torch.Tensor:
+        return self.pos_enc.position_encoding(offset, size, onnx_mode)
 
 
 class LinearNoSubsampling(BaseSubsampling):
@@ -89,16 +92,17 @@ class Conv2dSubsampling4(BaseSubsampling):
             torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
         self.pos_enc = pos_enc_class
         # The right context for every conv layer is computed by:
-        # (kernel_size - 1) * frame_rate_of_this_layer
+        # (kernel_size - 1) / 2 * stride  * frame_rate_of_this_layer
         self.subsampling_rate = 4
-        # 6 = (3 - 1) * 1 + (3 - 1) * 2
+        # 6 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2
         self.right_context = 6
 
     def forward(
             self,
             x: torch.Tensor,
             x_mask: torch.Tensor,
-            offset: int = 0
+            offset: torch.Tensor = torch.tensor(0),
+            onnx_mode: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         """Subsample x.
 
@@ -118,7 +122,7 @@ class Conv2dSubsampling4(BaseSubsampling):
         x = self.conv(x)
         b, c, t, f = x.size()
         x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
-        x, pos_emb = self.pos_enc(x, offset)
+        x, pos_emb = self.pos_enc(x, offset, onnx_mode)
         return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
 
 
@@ -143,9 +147,9 @@ class Conv2dSubsampling6(BaseSubsampling):
         self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
                                       odim)
         self.pos_enc = pos_enc_class
-        # 10 = (3 - 1) * 1 + (5 - 1) * 2
+        # 14 = (3 - 1) / 2 * 2 * 1 + (5 - 1) / 2 * 3 * 2
         self.subsampling_rate = 6
-        self.right_context = 10
+        self.right_context = 14
 
     def forward(
             self,
@@ -198,7 +202,7 @@ class Conv2dSubsampling8(BaseSubsampling):
             odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
         self.pos_enc = pos_enc_class
         self.subsampling_rate = 8
-        # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
+        # 14 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2 + (3 - 1) / 2 * 2 * 4
         self.right_context = 14
 
     def forward(
diff --git a/wenet/utils/mask.py b/wenet/utils/mask.py
index c2bb50f..d23bd95 100644
--- a/wenet/utils/mask.py
+++ b/wenet/utils/mask.py
@@ -5,6 +5,15 @@
 
 import torch
 
+def tril_onnx(x, diagonal: torch.Tensor = torch.tensor(0)):
+     m,n = x.shape[0], x.shape[1]
+     arange = torch.arange(n, device = x.device)
+     mask = arange.expand(m, n)
+     mask_maker = torch.arange(m, device = x.device).unsqueeze(-1)
+     if diagonal:
+         mask_maker = mask_maker + diagonal
+     mask = mask <= mask_maker
+     return mask * x
 
 def subsequent_mask(
         size: int,
@@ -35,13 +44,17 @@ def subsequent_mask(
          [1, 1, 0],
          [1, 1, 1]]
     """
-    ret = torch.ones(size, size, device=device, dtype=torch.bool)
-    return torch.tril(ret, out=ret)
+    # ret = torch.ones(size, size, device=device, dtype=torch.bool)
+    # return torch.tril(ret, out=ret)
+    # to export onnx, we change the code as follows
+    ret = torch.ones(size, size, device=device)
+    #return torch.tril(ret, out=ret)
+    return tril_onnx(ret)
 
 
 def subsequent_chunk_mask(
-        size: int,
-        chunk_size: int,
+        size: torch.tensor(0),
+        chunk_size: torch.tensor(0),
         num_left_chunks: int = -1,
         device: torch.device = torch.device("cpu"),
 ) -> torch.Tensor:
@@ -67,6 +80,18 @@ def subsequent_chunk_mask(
          [1, 1, 1, 1]]
     """
     ret = torch.zeros(size, size, device=device, dtype=torch.bool)
+    row_index = torch.arange(size, device = device)
+    index  = row_index.expand(size, size)
+    expand_size = torch.ones((size), device = device)*size
+    #expand_size = expand_size.long()
+    if num_left_chunks < 0:
+        start1 = torch.tensor(0)
+    else:
+        start1 = torch.max((torch.floor_divide(row_index, chunk_size)-num_left_chunks).float()*chunk_size, torch.tensor(0.0)).long().view(size,1)
+    ending = torch.min((torch.floor_divide(row_index, chunk_size)+1).float()*chunk_size, expand_size.float()).long().view(size,1)
+    ret[torch.where(index < ending)] = True
+    ret[torch.where(index < start1)] = False
+    '''
     for i in range(size):
         if num_left_chunks < 0:
             start = 0
@@ -74,6 +99,8 @@ def subsequent_chunk_mask(
             start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
         ending = min((i // chunk_size + 1) * chunk_size, size)
         ret[i, start:ending] = True
+    print("ret:", ret)
+    '''
     return ret
 
 
@@ -107,18 +134,18 @@ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
     """
     # Whether to use chunk mask or not
     if use_dynamic_chunk:
-        max_len = xs.size(1)
+        max_len = xs.shape[1]
         if decoding_chunk_size < 0:
             chunk_size = max_len
             num_left_chunks = -1
         elif decoding_chunk_size > 0:
-            chunk_size = decoding_chunk_size
+            chunk_size = torch.tensor(decoding_chunk_size)
             num_left_chunks = num_decoding_left_chunks
         else:
             # chunk size is either [1, 25] or full context(max_len).
             # Since we use 4 times subsampling and allow up to 1s(100 frames)
             # delay, the maximum frame is 100 / 4 = 25.
-            chunk_size = torch.randint(1, max_len, (1, )).item()
+            chunk_size = torch.randint(1, max_len, (1, ))
             num_left_chunks = -1
             if chunk_size > max_len // 2:
                 chunk_size = max_len
@@ -128,14 +155,14 @@ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
                     max_left_chunks = (max_len - 1) // chunk_size
                     num_left_chunks = torch.randint(0, max_left_chunks,
                                                     (1, )).item()
-        chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
+        chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size,
                                             num_left_chunks,
                                             xs.device)  # (L, L)
         chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
         chunk_masks = masks & chunk_masks  # (B, L, L)
     elif static_chunk_size > 0:
         num_left_chunks = num_decoding_left_chunks
-        chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
+        chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size,
                                             num_left_chunks,
                                             xs.device)  # (L, L)
         chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
@@ -145,7 +172,7 @@ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
     return chunk_masks
 
 
-def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
+def make_pad_mask(lengths: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
     """Make mask tensor containing indices of padded part.
 
     See description of make_non_pad_mask.
@@ -162,8 +189,11 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
                  [0, 0, 0, 1, 1],
                  [0, 0, 1, 1, 1]]
     """
-    batch_size = int(lengths.size(0))
-    max_len = int(lengths.max().item())
+    # batch_size = int(lengths.size(0))
+    # max_len = int(lengths.max().item())
+    # to export the decoder onnx and avoid the constant fold
+    batch_size = xs.shape[0]
+    max_len = xs.shape[1]
     seq_range = torch.arange(0,
                              max_len,
                              dtype=torch.int64,