@@ -168,13 +168,13 @@
num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
else:
- encoder_out, encoder_mask = self.encoder(
+ encoder_out, encoder_mask, encoder_t = self.encoder(
speech,
speech_lengths,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
- return encoder_out, encoder_mask
+ return encoder_out, encoder_mask, encoder_t
def recognize(
self,
@@ -361,7 +361,7 @@
assert batch_size == 1
# Let's assume B = batch_size and N = beam_size
# 1. Encoder forward and get CTC score
- encoder_out, encoder_mask = self._forward_encoder(
+ encoder_out, encoder_mask, encoder_t = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
@@ -409,7 +409,7 @@
reverse=True)
cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
- return hyps, encoder_out
+ return hyps, encoder_out, encoder_t
def ctc_prefix_beam_search(
self,
@@ -485,7 +485,7 @@
# For attention rescoring we only support batch_size=1
assert batch_size == 1
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
- hyps, encoder_out = self._ctc_prefix_beam_search(
+ hyps, encoder_out, encoder_t = self._ctc_prefix_beam_search(
speech, speech_lengths, beam_size, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
@@ -510,7 +510,7 @@
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id)
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos,
self.ignore_id)
- decoder_out, r_decoder_out, _ = self.decoder(
+ decoder_out, r_decoder_out, _, decoder_t = self.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
@@ -539,7 +539,7 @@
if score > best_score:
best_score = score
best_index = i
- return hyps[best_index][0]
+ return hyps[best_index][0], encoder_t+decoder_t
@torch.jit.export
def subsampling_rate(self) -> int:
@@ -13,6 +13,7 @@
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.mask import (subsequent_mask, make_pad_mask)
+import time
class TransformerDecoder(torch.nn.Module):
"""Base class of Transfomer decoder module.
@@ -252,13 +253,14 @@
if use_output_layer is True,
olens: (batch, )
"""
+ st = time.time()
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
ys_in_lens)
r_x = torch.tensor(0.0)
if reverse_weight > 0.0:
r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad,
ys_in_lens)
- return l_x, r_x, olens
+ return l_x, r_x, olens, time.time()-st
def forward_one_step(
self,
@@ -26,6 +26,7 @@
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
+import time
class BaseEncoder(torch.nn.Module):
def __init__(
@@ -146,6 +147,7 @@
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""
+ st = time.time()
masks = ~make_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
@@ -164,7 +166,7 @@
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
- return xs, masks
+ return xs, masks, time.time()-st
def forward_chunk(
self,
@@ -139,8 +139,11 @@
model = model.to(device)
model.eval()
+ total_t = 0
+ total_batch = 0
with torch.no_grad(), open(args.result_file, 'w') as fout:
for batch_idx, batch in enumerate(test_data_loader):
+ total_batch += 1
keys, feats, target, feats_lengths, target_lengths = batch
feats = feats.to(device)
target = target.to(device)
@@ -177,7 +180,7 @@
hyps = [hyp]
elif args.mode == 'attention_rescoring':
assert (feats.size(0) == 1)
- hyp = model.attention_rescoring(
+ hyp, exe_t = model.attention_rescoring(
feats,
feats_lengths,
args.beam_size,
@@ -187,6 +190,8 @@
simulate_streaming=args.simulate_streaming,
reverse_weight=args.reverse_weight)
hyps = [hyp]
+ total_t += exe_t
+ print(exe_t)
for i, key in enumerate(keys):
content = ''
for w in hyps[i]:
@@ -195,3 +200,7 @@
content += char_dict[w]
logging.info('{} {}'.format(key, content))
fout.write('{} {}\n'.format(key, content))
+ print("mean_fps: ", 1/(total_t/total_batch))
+ print("mean_time: ", total_t/total_batch)
+ fout.write("mean_time: "+str(total_t/total_batch))
+ fout.write("mean_fps: "+str(1/(total_t/total_batch)))