--- zipformer.py	2024-03-08 15:18:44.384000000 +0800
+++ zipformer.py	2024-03-08 15:18:42.056000000 +0800
@@ -1415,7 +1415,7 @@
         self.length_factor = length_factor
         self.extend_pe(torch.tensor(0.0).expand(max_len))
 
-    def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
+    def extend_pe(self, x: Tensor, left_context_len: int = 0) -> Tensor:
         """Reset the positional encodings."""
         T = x.size(0) + left_context_len
 
@@ -1423,8 +1423,7 @@
             # self.pe contains both positive and negative parts
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(0) >= T * 2 - 1:
-                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
-                return
+                return self.pe.to(dtype=x.dtype, device=x.device)
 
         # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
         x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
@@ -1458,12 +1457,13 @@
         cosines = (x_atan * freqs).cos()
         sines = (x_atan * freqs).sin()
 
-        pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
-        pe[:, 0::2] = cosines
-        pe[:, 1::2] = sines
-        pe[:, -1] = 1.0  # for bias.
+        cos_shape0 = cosines.shape[0]
+        bias_one = torch.ones(cos_shape0, 1)
+        pe = torch.cat((cosines.unsqueeze(2), sines.unsqueeze(2)), dim=2)
+        pe = pe.reshape(cos_shape0, -1)
+        pe = torch.cat((pe[:, :-1], bias_one), dim=1)
 
-        self.pe = pe.to(dtype=x.dtype)
+        return pe.to(dtype=x.dtype)
 
     def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
         """Create positional encoding.
@@ -1475,14 +1475,14 @@
         Returns:
             positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
         """
-        self.extend_pe(x, left_context_len)
+        pe = self.extend_pe(x, left_context_len)
         x_size_left = x.size(0) + left_context_len
         # length of positive side: x.size(0) + left_context_len
         # length of negative side: x.size(0)
-        pos_emb = self.pe[
-            self.pe.size(0) // 2
+        pos_emb = pe[
+            pe.size(0) // 2
             - x_size_left
-            + 1 : self.pe.size(0) // 2  # noqa E203
+            + 1 : pe.size(0) // 2  # noqa E203
             + x.size(0),
             :,
         ]