@@ -1415,7 +1415,7 @@
self.length_factor = length_factor
self.extend_pe(torch.tensor(0.0).expand(max_len))
- def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
+ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> Tensor:
"""Reset the positional encodings."""
T = x.size(0) + left_context_len
@@ -1423,8 +1423,7 @@
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(0) >= T * 2 - 1:
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
- return
+ return self.pe.to(dtype=x.dtype, device=x.device)
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
@@ -1458,12 +1457,13 @@
cosines = (x_atan * freqs).cos()
sines = (x_atan * freqs).sin()
- pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
- pe[:, 0::2] = cosines
- pe[:, 1::2] = sines
- pe[:, -1] = 1.0 # for bias.
+ cos_shape0 = cosines.shape[0]
+ bias_one = torch.ones(cos_shape0, 1)
+ pe = torch.cat((cosines.unsqueeze(2), sines.unsqueeze(2)), dim=2)
+ pe = pe.reshape(cos_shape0, -1)
+ pe = torch.cat((pe[:, :-1], bias_one), dim=1)
- self.pe = pe.to(dtype=x.dtype)
+ return pe.to(dtype=x.dtype)
def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
"""Create positional encoding.
@@ -1475,14 +1475,14 @@
Returns:
positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
"""
- self.extend_pe(x, left_context_len)
+ pe = self.extend_pe(x, left_context_len)
x_size_left = x.size(0) + left_context_len
# length of positive side: x.size(0) + left_context_len
# length of negative side: x.size(0)
- pos_emb = self.pe[
- self.pe.size(0) // 2
+ pos_emb = pe[
+ pe.size(0) // 2
- x_size_left
- + 1 : self.pe.size(0) // 2 # noqa E203
+ + 1 : pe.size(0) // 2 # noqa E203
+ x.size(0),
:,
]