--- export-onnx.py	2024-03-08 15:33:12.340000000 +0800
+++ export-onnx.py	2024-03-08 15:33:12.340000000 +0800
@@ -27,10 +27,10 @@
 
 2. Export the model to ONNX
 
-./zipformer/export-onnx.py \
-  --tokens $repo/data/lang_bpe_500/tokens.txt \
+python3 ./export-onnx.py \
+  --tokens $repo/data/lang_char/tokens.txt \
   --use-averaged-model 0 \
-  --epoch 99 \
+  --epoch 12 \
   --avg 1 \
   --exp-dir $repo/exp \
   --num-encoder-layers "2,2,3,4,3,2" \
@@ -92,7 +92,7 @@
     parser.add_argument(
         "--epoch",
         type=int,
-        default=28,
+        default=12,
         help="""It specifies the checkpoint to use for averaging.
         Note: Epoch counts from 0.
         You can specify --avg to use more checkpoints for model averaging.""",
@@ -111,7 +111,7 @@
     parser.add_argument(
         "--avg",
         type=int,
-        default=15,
+        default=1,
         help="Number of checkpoints to average. Automatically select "
         "consecutive checkpoints before the checkpoint specified by "
         "'--epoch' and '--iter'",
@@ -120,7 +120,7 @@
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
-        default=True,
+        default=False,
         help="Whether to load averaged model. Currently it only supports "
         "using --epoch. If True, it would decode with the averaged model "
         "over the epoch range from `epoch-avg` (excluded) to `epoch`."
@@ -131,7 +131,7 @@
     parser.add_argument(
         "--exp-dir",
         type=str,
-        default="zipformer/exp",
+        default="/path/to/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/",
         help="""It specifies the directory where all training related
         files, e.g., checkpoints, log, etc, are saved
         """,
@@ -140,7 +140,7 @@
     parser.add_argument(
         "--tokens",
         type=str,
-        default="data/lang_bpe_500/tokens.txt",
+        default="/path/to/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/data/lang_char/tokens.txt",
         help="Path to the tokens.txt",
     )
 
@@ -298,7 +298,7 @@
     x_lens = torch.tensor([100], dtype=torch.int64)
 
     encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
-
+    encoder_model.save(str(encoder_filename).replace("onnx", "pt"))
     torch.onnx.export(
         encoder_model,
         (x, x_lens),
@@ -352,7 +352,9 @@
     context_size = decoder_model.decoder.context_size
     vocab_size = decoder_model.decoder.vocab_size
 
-    y = torch.zeros(10, context_size, dtype=torch.int64)
+    y = torch.zeros(1, context_size, dtype=torch.int64)
+    ts_decoder_model = torch.jit.trace(decoder_model, y)
+    ts_decoder_model.save(str(decoder_filename).replace("onnx", "pt"))
     decoder_model = torch.jit.script(decoder_model)
     torch.onnx.export(
         decoder_model,
@@ -393,8 +395,10 @@
     joiner_dim = joiner_model.output_linear.weight.shape[1]
     logging.info(f"joiner dim: {joiner_dim}")
 
-    projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
-    projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
+    projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
+    projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
+    ts_joiner_model = torch.jit.trace(joiner_model, (projected_encoder_out, projected_decoder_out))
+    ts_joiner_model.save(str(joiner_filename).replace("onnx", "pt"))
 
     torch.onnx.export(
         joiner_model,