@@ -245,7 +245,7 @@ class ASRModel(torch.nn.Module):
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
- # 2.3 Second beam prune: select topk score with history
+ # 2.3 Seconde beam prune: select topk score with history
scores = scores + top_k_logp # (B*N, N), broadcast add
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
@@ -570,13 +570,12 @@ class ASRModel(torch.nn.Module):
def forward_encoder_chunk(
self,
xs: torch.Tensor,
- offset: int,
- required_cache_size: int,
+ offset: torch.Tensor,
+ required_cache_size: torch.Tensor,
subsampling_cache: Optional[torch.Tensor] = None,
- elayers_output_cache: Optional[List[torch.Tensor]] = None,
- conformer_cnn_cache: Optional[List[torch.Tensor]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor],
- List[torch.Tensor]]:
+ elayers_output_cache: Optional[torch.Tensor] = None,
+ conformer_cnn_cache: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
@@ -675,6 +674,10 @@ class ASRModel(torch.nn.Module):
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
return decoder_out, r_decoder_out
+ @torch.jit.export
+ def test(self,) -> str:
+ return "test"
+
def init_asr_model(configs):
if configs['cmvn_file'] is not None:
@@ -57,8 +57,7 @@ class TransformerDecoder(torch.nn.Module):
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
- PositionalEncoding(attention_dim, positional_dropout_rate),
- )
+ PositionalEncoding(attention_dim, positional_dropout_rate))
else:
raise ValueError(f"only 'embed' is supported: {input_layer}")
@@ -81,6 +80,10 @@ class TransformerDecoder(torch.nn.Module):
concat_after,
) for _ in range(self.num_blocks)
])
+ self.onnx_mode = False
+
+ def set_onnx_mode(self, onnx_mode=False):
+ self.onnx_mode = onnx_mode
def forward(
self,
@@ -111,13 +114,15 @@ class TransformerDecoder(torch.nn.Module):
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
- tgt_mask = (~make_pad_mask(ys_in_lens).unsqueeze(1)).to(tgt.device)
+ tgt_mask = (~make_pad_mask(ys_in_lens, ys_in_pad).unsqueeze(1)).to(tgt.device)
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1),
device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
- tgt_mask = tgt_mask & m
- x, _ = self.embed(tgt)
+ # tgt_mask = tgt_mask & m
+ tgt_mask = torch.mul(tgt_mask, m)
+ x = self.embed[0](tgt)
+ x, _ = self.embed[1](x, onnx_mode=self.onnx_mode)
for layer in self.decoders:
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
memory_mask)
@@ -225,6 +230,13 @@ class BiTransformerDecoder(torch.nn.Module):
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before, concat_after)
+ self.onnx_mode = False
+
+ def set_onnx_mode(self, onnx_mode=False):
+ self.onnx_mode = onnx_mode
+ self.left_decoder.set_onnx_mode(onnx_mode)
+ self.right_decoder.set_onnx_mode(onnx_mode)
+
def forward(
self,
memory: torch.Tensor,
@@ -252,6 +264,7 @@ class BiTransformerDecoder(torch.nn.Module):
if use_output_layer is True,
olens: (batch, )
"""
+ reverse_weight = 0.3
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
ys_in_lens)
r_x = torch.tensor(0.0)
@@ -17,7 +17,7 @@ class DecoderLayer(nn.Module):
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
- src_attn (torch.nn.Module): Inter-attention module instance.
+ src_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
@@ -61,7 +61,8 @@ class DecoderLayer(nn.Module):
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
- cache: Optional[torch.Tensor] = None
+ cache: Optional[torch.Tensor] = None,
+ onnx_mode: Optional[bool] = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute decoded features.
@@ -9,6 +9,7 @@ import math
from typing import Tuple
import torch
+from wenet.transformer.slice_helper import slice_helper2
class PositionalEncoding(torch.nn.Module):
@@ -45,7 +46,8 @@ class PositionalEncoding(torch.nn.Module):
def forward(self,
x: torch.Tensor,
- offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
+ offset: torch.Tensor = torch.tensor(0),
+ onnx_mode: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add positional encoding.
Args:
@@ -56,13 +58,21 @@ class PositionalEncoding(torch.nn.Module):
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
torch.Tensor: for compatibility to RelPositionalEncoding
"""
- assert offset + x.size(1) < self.max_len
+ # assert offset + x.size(1) < self.max_len
self.pe = self.pe.to(x.device)
- pos_emb = self.pe[:, offset:offset + x.size(1)]
+ # pos_emb = self.pe[:, offset:offset + x.size(1)]
+ if onnx_mode:
+ pos_emb = slice_helper2(self.pe, offset, offset + x.size(1))
+ else:
+ pos_emb = self.pe[:, offset:offset + x.size(1)]
x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb)
- def position_encoding(self, offset: int, size: int) -> torch.Tensor:
+ def position_encoding(self,
+ offset: torch.Tensor,
+ size: torch.Tensor,
+ onnx_mode: bool = False,
+ ) -> torch.Tensor:
""" For getting encoding in a streaming fashion
Attention!!!!!
@@ -79,7 +89,12 @@ class PositionalEncoding(torch.nn.Module):
torch.Tensor: Corresponding encoding
"""
assert offset + size < self.max_len
- return self.dropout(self.pe[:, offset:offset + size])
+ if onnx_mode:
+ # pe = torch.cat([self.pe[:, [0]], slice_helper2(self.pe, offset, offset + size - 1)], dim=1)
+ pe = slice_helper2(self.pe, offset, offset + size)
+ else:
+ pe = self.pe[:, offset:offset + size]
+ return self.dropout(pe)
class RelPositionalEncoding(PositionalEncoding):
@@ -96,7 +111,8 @@ class RelPositionalEncoding(PositionalEncoding):
def forward(self,
x: torch.Tensor,
- offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
+ offset: torch.Tensor,
+ onnx_mode: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
@@ -104,10 +120,16 @@ class RelPositionalEncoding(PositionalEncoding):
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
- assert offset + x.size(1) < self.max_len
+ # assert offset + x.size(1) < self.max_len
self.pe = self.pe.to(x.device)
x = x * self.xscale
- pos_emb = self.pe[:, offset:offset + x.size(1)]
+ if onnx_mode:
+ # end = offset.item() + x.size(1)
+ # pos_emb = torch.index_select(self.pe, 1, torch.tensor(range(x.size(1))))
+ pos_emb = slice_helper2(self.pe, offset, offset + x.size(1))
+ # pos_emb = slice_helper3(pos_emb, x.size(1))
+ else:
+ pos_emb = self.pe[:, offset:offset + x.size(1)]
return self.dropout(x), self.dropout(pos_emb)
@@ -6,6 +6,8 @@
"""Encoder definition."""
from typing import Tuple, List, Optional
+import numpy as np
+import onnxruntime
import torch
from typeguard import check_argument_types
@@ -18,6 +20,7 @@ from wenet.transformer.embedding import NoPositionalEncoding
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.encoder_layer import ConformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from wenet.transformer.slice_helper import slice_helper3, get_next_cache_start
from wenet.transformer.subsampling import Conv2dSubsampling4
from wenet.transformer.subsampling import Conv2dSubsampling6
from wenet.transformer.subsampling import Conv2dSubsampling8
@@ -26,6 +29,8 @@ from wenet.utils.common import get_activation
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
+def to_numpy(x):
+ return x.detach().numpy()
class BaseEncoder(torch.nn.Module):
def __init__(
@@ -116,10 +121,14 @@ class BaseEncoder(torch.nn.Module):
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
+ self.onnx_mode = False
def output_size(self) -> int:
return self._output_size
+ def set_onnx_mode(self, onnx_mode=False):
+ self.onnx_mode = onnx_mode
+
def forward(
self,
xs: torch.Tensor,
@@ -130,7 +139,7 @@ class BaseEncoder(torch.nn.Module):
"""Embed positions in tensor.
Args:
- xs: padded input tensor (B, T, D)
+ xs: padded input tensor (B, L, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
@@ -141,16 +150,18 @@ class BaseEncoder(torch.nn.Module):
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
- encoder output tensor xs, and subsampled masks
- xs: padded output tensor (B, T' ~= T/subsample_rate, D)
- masks: torch.Tensor batch padding mask after subsample
- (B, 1, T' ~= T/subsample_rate)
+ encoder output tensor, lens and mask
"""
- masks = ~make_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T)
+ decoding_chunk_size = 1
+ num_decoding_left_chunks = 1
+ self.use_dynamic_chunk = False
+ self.use_dynamic_left_chunk = False
+ self.static_chunk_size = 0
+ masks = ~make_pad_mask(xs_lens, xs).unsqueeze(1) # (B, 1, L)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
- mask_pad = masks # (B, 1, T/subsample_rate)
+ mask_pad = masks
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
@@ -169,13 +180,12 @@ class BaseEncoder(torch.nn.Module):
def forward_chunk(
self,
xs: torch.Tensor,
- offset: int,
- required_cache_size: int,
+ offset_tensor: torch.Tensor = torch.tensor(0),
+ required_cache_size_tensor: torch.Tensor = torch.tensor(0),
subsampling_cache: Optional[torch.Tensor] = None,
- elayers_output_cache: Optional[List[torch.Tensor]] = None,
- conformer_cnn_cache: Optional[List[torch.Tensor]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor],
- List[torch.Tensor]]:
+ elayers_output_cache: Optional[torch.Tensor] = None,
+ conformer_cnn_cache: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" Forward just one chunk
Args:
@@ -199,6 +209,7 @@ class BaseEncoder(torch.nn.Module):
List[torch.Tensor]: conformer cnn cache
"""
+ required_cache_size_tensor = torch.tensor(-1)
assert xs.size(0) == 1
# tmp_masks is just for interface compatibility
tmp_masks = torch.ones(1,
@@ -208,30 +219,53 @@ class BaseEncoder(torch.nn.Module):
tmp_masks = tmp_masks.unsqueeze(1)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
- xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
+ # if self.onnx_mode:
+ # offset_tensor = offset_tensor - torch.tensor(1)
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset_tensor, self.onnx_mode)
if subsampling_cache is not None:
cache_size = subsampling_cache.size(1)
xs = torch.cat((subsampling_cache, xs), dim=1)
else:
cache_size = 0
- pos_emb = self.embed.position_encoding(offset - cache_size, xs.size(1))
- if required_cache_size < 0:
- next_cache_start = 0
- elif required_cache_size == 0:
- next_cache_start = xs.size(1)
+ # if self.onnx_mode:
+ # cache_size = cache_size - 1
+ # if self.onnx_mode:
+ # # subsampling_cache append dummy var, remove it here
+ # xs = xs[:, 1:, :]
+ # cache_size = cache_size - 1
+ if isinstance(xs.size(1), int):
+ xs_size_1 = torch.tensor(xs.size(1))
else:
- next_cache_start = max(xs.size(1) - required_cache_size, 0)
- r_subsampling_cache = xs[:, next_cache_start:, :]
+ xs_size_1 = xs.size(1).clone().detach()
+ pos_emb = self.embed.position_encoding(offset_tensor - cache_size,
+ xs_size_1,
+ self.onnx_mode)
+ next_cache_start = get_next_cache_start(required_cache_size_tensor, xs)
+ r_subsampling_cache = slice_helper3(xs, next_cache_start)
+ # if self.onnx_mode:
+ # next_cache_start_1 = get_next_cache_start(required_cache_size_tensor, xs)
+ # r_subsampling_cache = slice_helper3(xs, next_cache_start_1)
+ # else:
+ # required_cache_size = required_cache_size_tensor.detach().item()
+ # if required_cache_size < 0:
+ # next_cache_start = 0
+ # elif required_cache_size == 0:
+ # next_cache_start = xs.size(1)
+ # else:
+ # next_cache_start = max(xs.size(1) - required_cache_size, 0)
+ # r_subsampling_cache = xs[:, next_cache_start:, :]
# Real mask for transformer/conformer layers
masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
masks = masks.unsqueeze(1)
- r_elayers_output_cache = []
- r_conformer_cnn_cache = []
+ r_elayers_output_cache = None
+ r_conformer_cnn_cache = None
for i, layer in enumerate(self.encoders):
if elayers_output_cache is None:
attn_cache = None
else:
attn_cache = elayers_output_cache[i]
+ # if self.onnx_mode and attn_cache is not None:
+ # attn_cache = attn_cache[:, 1:, :]
if conformer_cnn_cache is None:
cnn_cache = None
else:
@@ -240,13 +274,32 @@ class BaseEncoder(torch.nn.Module):
masks,
pos_emb,
output_cache=attn_cache,
- cnn_cache=cnn_cache)
- r_elayers_output_cache.append(xs[:, next_cache_start:, :])
- r_conformer_cnn_cache.append(new_cnn_cache)
+ cnn_cache=cnn_cache,
+ onnx_mode=self.onnx_mode)
+ if self.onnx_mode:
+ layer_output_cache = slice_helper3(xs, next_cache_start)
+ else:
+ layer_output_cache = xs[:, next_cache_start:, :]
+ if i == 0:
+ r_elayers_output_cache = layer_output_cache.unsqueeze(0)
+ r_conformer_cnn_cache = new_cnn_cache.unsqueeze(0)
+ else:
+ # r_elayers_output_cache.append(xs[:, next_cache_start:, :])
+ r_elayers_output_cache = torch.cat((r_elayers_output_cache, layer_output_cache.unsqueeze(0)), 0)
+ # r_conformer_cnn_cache.append(new_cnn_cache)
+ r_conformer_cnn_cache = torch.cat((r_conformer_cnn_cache, new_cnn_cache.unsqueeze(0)), 0)
if self.normalize_before:
xs = self.after_norm(xs)
-
- return (xs[:, cache_size:, :], r_subsampling_cache,
+ if self.onnx_mode:
+ cache_size = cache_size - 1
+ if isinstance(cache_size, int):
+ cache_size_1 = torch.tensor(cache_size)
+ else:
+ cache_size_1 = cache_size.clone().detach()
+ output = slice_helper3(xs, cache_size_1)
+ else:
+ output = xs[:, cache_size:, :]
+ return (output, r_subsampling_cache,
r_elayers_output_cache, r_conformer_cnn_cache)
def forward_chunk_by_chunk(
@@ -290,24 +343,54 @@ class BaseEncoder(torch.nn.Module):
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.size(1)
subsampling_cache: Optional[torch.Tensor] = None
- elayers_output_cache: Optional[List[torch.Tensor]] = None
- conformer_cnn_cache: Optional[List[torch.Tensor]] = None
+ elayers_output_cache: Optional[torch.Tensor] = None
+ conformer_cnn_cache: Optional[torch.Tensor] = None
outputs = []
offset = 0
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
+ print("required_cache_size:", required_cache_size)
+ encoder_session = onnxruntime.InferenceSession("onnx/encoder.onnx")
+
+ subsampling_cache_onnx = torch.zeros(1, 1, 256, requires_grad=False)
+ elayers_output_cache_onnx = torch.zeros(12, 1, 1, 256, requires_grad=False)
+ conformer_cnn_cache_onnx = torch.zeros(12, 1, 256, 7, requires_grad=False)
# Feed forward overlap input step by step
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
+
+ if offset > 0:
+ offset = offset - 1
(y, subsampling_cache, elayers_output_cache,
- conformer_cnn_cache) = self.forward_chunk(chunk_xs, offset,
- required_cache_size,
+ conformer_cnn_cache) = self.forward_chunk(chunk_xs, torch.tensor(offset),
+ torch.tensor(required_cache_size),
subsampling_cache,
elayers_output_cache,
conformer_cnn_cache)
- outputs.append(y)
+
+ offset = offset + 1
+ encoder_inputs = {
+ encoder_session.get_inputs()[0].name: chunk_xs.numpy(),
+ encoder_session.get_inputs()[1].name: np.array(offset),
+ encoder_session.get_inputs()[2].name: subsampling_cache_onnx.numpy(),
+ encoder_session.get_inputs()[3].name: elayers_output_cache_onnx.numpy(),
+ encoder_session.get_inputs()[4].name: conformer_cnn_cache_onnx.numpy(),
+ }
+ ort_outs = encoder_session.run(None, encoder_inputs)
+ y_onnx, subsampling_cache_onnx, elayers_output_cache_onnx, conformer_cnn_cache_onnx = \
+ torch.from_numpy(ort_outs[0][:, 1:, :]), torch.from_numpy(ort_outs[1]), \
+ torch.from_numpy(ort_outs[2]), torch.from_numpy(ort_outs[3])
+
+ np.testing.assert_allclose(to_numpy(y), ort_outs[0][:, 1:, :], rtol=1e-03, atol=1e-03)
+ np.testing.assert_allclose(to_numpy(subsampling_cache), ort_outs[1][:, 1:, :], rtol=1e-03, atol=1e-03)
+ np.testing.assert_allclose(to_numpy(elayers_output_cache), ort_outs[2][:, :, 1:, :], rtol=1e-03, atol=1e-03)
+ np.testing.assert_allclose(to_numpy(conformer_cnn_cache), ort_outs[3], rtol=1e-03, atol=1e-03)
+
+ outputs.append(y_onnx)
+ # outputs.append(y)
offset += y.size(1)
+ # break
ys = torch.cat(outputs, 1)
masks = torch.ones(1, ys.size(1), device=ys.device, dtype=torch.bool)
masks = masks.unsqueeze(1)
@@ -9,6 +9,7 @@ from typing import Optional, Tuple
import torch
from torch import nn
+from wenet.transformer.slice_helper import slice_helper
class TransformerEncoderLayer(nn.Module):
@@ -53,6 +54,9 @@ class TransformerEncoderLayer(nn.Module):
# concat_linear may be not used in forward fuction,
# but will be saved in the *.pt
self.concat_linear = nn.Linear(size + size, size)
+
+ def set_onnx_mode(self, onnx_mode=False):
+ self.onnx_mode = onnx_mode
def forward(
self,
@@ -92,9 +96,14 @@ class TransformerEncoderLayer(nn.Module):
assert output_cache.size(2) == self.size
assert output_cache.size(1) < x.size(1)
chunk = x.size(1) - output_cache.size(1)
- x_q = x[:, -chunk:, :]
- residual = residual[:, -chunk:, :]
- mask = mask[:, -chunk:, :]
+ if self.onnx_mode:
+ x_q = slice_helper(x, chunk)
+ residual = slice_helper(residual, chunk)
+ mask = slice_helper(mask, chunk)
+ else:
+ x_q = x[:, -chunk:, :]
+ residual = residual[:, -chunk:, :]
+ mask = mask[:, -chunk:, :]
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
@@ -184,6 +193,7 @@ class ConformerEncoderLayer(nn.Module):
mask_pad: Optional[torch.Tensor] = None,
output_cache: Optional[torch.Tensor] = None,
cnn_cache: Optional[torch.Tensor] = None,
+ onnx_mode: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
@@ -193,7 +203,6 @@ class ConformerEncoderLayer(nn.Module):
pos_emb (torch.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
mask_pad (torch.Tensor): batch padding mask used for conv module.
- (#batch, 1,time)
output_cache (torch.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
@@ -202,6 +211,14 @@ class ConformerEncoderLayer(nn.Module):
torch.Tensor: Mask tensor (#batch, time).
"""
+ if onnx_mode:
+ x = x[:, 1:, :]
+ mask = mask[:, :, 1:]
+ # pos_emb_ = pos_emb[:, 1:, :]
+ pos_emb_ = pos_emb[:, :-1, :]
+ else:
+ pos_emb_ = pos_emb
+
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
@@ -223,12 +240,26 @@ class ConformerEncoderLayer(nn.Module):
assert output_cache.size(0) == x.size(0)
assert output_cache.size(2) == self.size
assert output_cache.size(1) < x.size(1)
- chunk = x.size(1) - output_cache.size(1)
- x_q = x[:, -chunk:, :]
- residual = residual[:, -chunk:, :]
- mask = mask[:, -chunk:, :]
- x_att = self.self_attn(x_q, x, x, mask, pos_emb)
+ # chunk = x.size(1) - output_cache.size(1)
+ if onnx_mode:
+ chunk = x.size(1) - output_cache.size(1) + 1
+ if isinstance(chunk, int):
+ chunk_1 = torch.tensor(chunk)
+ else:
+ chunk_1 = chunk.clone().detach()
+ # chunk = torch.tensor(chunk)
+ # print(type(chunk))
+ x_q = slice_helper(x, chunk_1)
+ residual = slice_helper(residual, chunk_1)
+ mask = slice_helper(mask, chunk_1)
+ else:
+ chunk = x.size(1) - output_cache.size(1)
+ x_q = x[:, -chunk:, :]
+ residual = residual[:, -chunk:, :]
+ mask = mask[:, -chunk:, :]
+
+ x_att = self.self_attn(x_q, x, x, mask, pos_emb_)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
@@ -16,8 +16,11 @@ class BaseSubsampling(torch.nn.Module):
self.right_context = 0
self.subsampling_rate = 1
- def position_encoding(self, offset: int, size: int) -> torch.Tensor:
- return self.pos_enc.position_encoding(offset, size)
+ def position_encoding(self,
+ offset: torch.Tensor,
+ size: torch.Tensor,
+ onnx_mode: bool = False) -> torch.Tensor:
+ return self.pos_enc.position_encoding(offset, size, onnx_mode)
class LinearNoSubsampling(BaseSubsampling):
@@ -89,16 +92,17 @@ class Conv2dSubsampling4(BaseSubsampling):
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
self.pos_enc = pos_enc_class
# The right context for every conv layer is computed by:
- # (kernel_size - 1) * frame_rate_of_this_layer
+ # (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer
self.subsampling_rate = 4
- # 6 = (3 - 1) * 1 + (3 - 1) * 2
+ # 6 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2
self.right_context = 6
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
- offset: int = 0
+ offset: torch.Tensor = torch.tensor(0),
+ onnx_mode: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
@@ -118,7 +122,7 @@ class Conv2dSubsampling4(BaseSubsampling):
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
- x, pos_emb = self.pos_enc(x, offset)
+ x, pos_emb = self.pos_enc(x, offset, onnx_mode)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
@@ -143,9 +147,9 @@ class Conv2dSubsampling6(BaseSubsampling):
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
odim)
self.pos_enc = pos_enc_class
- # 10 = (3 - 1) * 1 + (5 - 1) * 2
+ # 14 = (3 - 1) / 2 * 2 * 1 + (5 - 1) / 2 * 3 * 2
self.subsampling_rate = 6
- self.right_context = 10
+ self.right_context = 14
def forward(
self,
@@ -198,7 +202,7 @@ class Conv2dSubsampling8(BaseSubsampling):
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
self.pos_enc = pos_enc_class
self.subsampling_rate = 8
- # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
+ # 14 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2 + (3 - 1) / 2 * 2 * 4
self.right_context = 14
def forward(
@@ -5,6 +5,15 @@
import torch
+def tril_onnx(x, diagonal: torch.Tensor = torch.tensor(0)):
+ m,n = x.shape[0], x.shape[1]
+ arange = torch.arange(n, device = x.device)
+ mask = arange.expand(m, n)
+ mask_maker = torch.arange(m, device = x.device).unsqueeze(-1)
+ if diagonal:
+ mask_maker = mask_maker + diagonal
+ mask = mask <= mask_maker
+ return mask * x
def subsequent_mask(
size: int,
@@ -35,13 +44,17 @@ def subsequent_mask(
[1, 1, 0],
[1, 1, 1]]
"""
- ret = torch.ones(size, size, device=device, dtype=torch.bool)
- return torch.tril(ret, out=ret)
+ # ret = torch.ones(size, size, device=device, dtype=torch.bool)
+ # return torch.tril(ret, out=ret)
+ # to export onnx, we change the code as follows
+ ret = torch.ones(size, size, device=device)
+ #return torch.tril(ret, out=ret)
+ return tril_onnx(ret)
def subsequent_chunk_mask(
- size: int,
- chunk_size: int,
+ size: torch.tensor(0),
+ chunk_size: torch.tensor(0),
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
@@ -67,6 +80,18 @@ def subsequent_chunk_mask(
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
+ row_index = torch.arange(size, device = device)
+ index = row_index.expand(size, size)
+ expand_size = torch.ones((size), device = device)*size
+ #expand_size = expand_size.long()
+ if num_left_chunks < 0:
+ start1 = torch.tensor(0)
+ else:
+ start1 = torch.max((torch.floor_divide(row_index, chunk_size)-num_left_chunks).float()*chunk_size, torch.tensor(0.0)).long().view(size,1)
+ ending = torch.min((torch.floor_divide(row_index, chunk_size)+1).float()*chunk_size, expand_size.float()).long().view(size,1)
+ ret[torch.where(index < ending)] = True
+ ret[torch.where(index < start1)] = False
+ '''
for i in range(size):
if num_left_chunks < 0:
start = 0
@@ -74,6 +99,8 @@ def subsequent_chunk_mask(
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
+ print("ret:", ret)
+ '''
return ret
@@ -107,18 +134,18 @@ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
"""
# Whether to use chunk mask or not
if use_dynamic_chunk:
- max_len = xs.size(1)
+ max_len = xs.shape[1]
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
- chunk_size = decoding_chunk_size
+ chunk_size = torch.tensor(decoding_chunk_size)
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
- chunk_size = torch.randint(1, max_len, (1, )).item()
+ chunk_size = torch.randint(1, max_len, (1, ))
num_left_chunks = -1
if chunk_size > max_len // 2:
chunk_size = max_len
@@ -128,14 +155,14 @@ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
- chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
+ chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
- chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
+ chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
@@ -145,7 +172,7 @@ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
return chunk_masks
-def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
+def make_pad_mask(lengths: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
@@ -162,8 +189,11 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
- batch_size = int(lengths.size(0))
- max_len = int(lengths.max().item())
+ # batch_size = int(lengths.size(0))
+ # max_len = int(lengths.max().item())
+ # to export the decoder onnx and avoid the constant fold
+ batch_size = xs.shape[0]
+ max_len = xs.shape[1]
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,