"""Decoders and output normalization for CTC.
Authors
* Mirco Ravanelli 2020
* Aku Rouhe 2020
* Sung-Lin Yeh 2020
"""
import torch
from itertools import groupby
from speechbrain.dataio.dataio import length_to_mask
class CTCPrefixScorer:
"""This class implements the CTC prefix scorer of Algorithm 2 in
reference: https://www.merl.com/publications/docs/TR2017-190.pdf.
Official implementation: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
Arguments
---------
x : torch.Tensor
The encoder states.
enc_lens : torch.Tensor
The actual length of each enc_states sequence.
batch_size : int
The size of the batch.
beam_size : int
The width of beam.
blank_index : int
The index of the blank token.
eos_index : int
The index of the end-of-sequence (eos) token.
ctc_window_size: int
Compute the ctc scores over the time frames using windowing based on attention peaks.
If 0, no windowing applied.
"""
def __init__(
self,
x,
enc_lens,
batch_size,
beam_size,
blank_index,
eos_index,
ctc_window_size=0,
):
self.blank_index = blank_index
self.eos_index = eos_index
self.max_enc_len = x.size(1)
self.batch_size = batch_size
self.beam_size = beam_size
self.vocab_size = x.size(-1)
self.device = x.device
self.minus_inf = -1e20
self.last_frame_index = enc_lens - 1
self.ctc_window_size = ctc_window_size
mask = 1 - length_to_mask(enc_lens)
mask = mask.unsqueeze(-1).expand(-1, -1, x.size(-1)).eq(1)
x.masked_fill_(mask, self.minus_inf)
x[:, :, 0] = x[:, :, 0].masked_fill_(mask[:, :, 0], 0)
xnb = x.transpose(0, 1)
xb = (
xnb[:, :, self.blank_index]
.unsqueeze(2)
.expand(-1, -1, self.vocab_size)
)
self.x = torch.stack([xnb, xb])
self.beam_offset = (
torch.arange(batch_size, device=self.device) * self.beam_size
)
self.cand_offset = (
torch.arange(batch_size, device=self.device) * self.vocab_size
)
def forward_step(self, g, state, candidates=None, attn=None):
"""This method if one step of forwarding operation
for the prefix ctc scorer.
Arguments
---------
g : torch.Tensor
The tensor of prefix label sequences, h = g + c.
state : tuple
Previous ctc states.
candidates : torch.Tensor
(batch_size * beam_size, ctc_beam_size), The topk candidates for rescoring.
The ctc_beam_size is set as 2 * beam_size. If given, performing partial ctc scoring.
"""
prefix_length = g.size(1)
last_char = [gi[-1] for gi in g] if prefix_length > 0 else [0] * len(g)
self.num_candidates = (
self.vocab_size if candidates is None else candidates.size(-1)
)
if state is None:
r_prev = torch.full(
(self.max_enc_len, 2, self.batch_size, self.beam_size),
self.minus_inf,
device=self.device,
)
r_prev[:, 1] = torch.cumsum(
self.x[0, :, :, self.blank_index], 0
).unsqueeze(2)
r_prev = r_prev.view(-1, 2, self.batch_size * self.beam_size)
psi_prev = 0.0
else:
r_prev, psi_prev = state
if candidates is not None:
scoring_table = torch.full(
(self.batch_size * self.beam_size, self.vocab_size),
-1,
dtype=torch.long,
device=self.device,
)
col_index = torch.arange(
self.batch_size * self.beam_size, device=self.device
).unsqueeze(1)
scoring_table[col_index, candidates] = torch.arange(
self.num_candidates, device=self.device
)
scoring_index = (
candidates
+ self.cand_offset.unsqueeze(1)
.repeat(1, self.beam_size)
.view(-1, 1)
).view(-1)
x_inflate = torch.index_select(
self.x.view(2, -1, self.batch_size * self.vocab_size),
2,
scoring_index,
).view(2, -1, self.batch_size * self.beam_size, self.num_candidates)
else:
scoring_table = None
x_inflate = (
self.x.unsqueeze(3)
.repeat(1, 1, 1, self.beam_size, 1)
.view(
2, -1, self.batch_size * self.beam_size, self.num_candidates
)
)
r = torch.full(
(
self.max_enc_len,
2,
self.batch_size * self.beam_size,
self.num_candidates,
),
self.minus_inf,
device=self.device,
)
r.fill_(self.minus_inf)
if prefix_length == 0:
r[0, 0] = x_inflate[0, 0]
r_sum = torch.logsumexp(r_prev, 1)
phi = r_sum.unsqueeze(2).repeat(1, 1, self.num_candidates)
if candidates is not None:
for i in range(self.batch_size * self.beam_size):
pos = scoring_table[i, last_char[i]]
if pos != -1:
phi[:, i, pos] = r_prev[:, 1, i]
else:
for i in range(self.batch_size * self.beam_size):
phi[:, i, last_char[i]] = r_prev[:, 1, i]
if self.ctc_window_size == 0 or attn is None:
start = max(1, prefix_length)
end = self.max_enc_len
else:
_, attn_peak = torch.max(attn, dim=1)
max_frame = torch.max(attn_peak).item() + self.ctc_window_size
min_frame = torch.min(attn_peak).item() - self.ctc_window_size
start = max(max(1, prefix_length), int(min_frame))
end = min(self.max_enc_len, int(max_frame))
for t in range(start, end):
rnb_prev = r[t - 1, 0]
rb_prev = r[t - 1, 1]
r_ = torch.stack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view(
2, 2, self.batch_size * self.beam_size, self.num_candidates
)
r[t] = torch.logsumexp(r_, 1) + x_inflate[:, t]
psi_init = r[start - 1, 0].unsqueeze(0)
phix = torch.cat((phi[0].unsqueeze(0), phi[:-1]), dim=0) + x_inflate[0]
if candidates is not None:
psi = torch.full(
(self.batch_size * self.beam_size, self.vocab_size),
self.minus_inf,
device=self.device,
)
psi_ = torch.logsumexp(
torch.cat((phix[start:end], psi_init), dim=0), dim=0
)
for i in range(self.batch_size * self.beam_size):
psi[i, candidates[i]] = psi_[i]
else:
psi = torch.logsumexp(
torch.cat((phix[start:end], psi_init), dim=0), dim=0
)
for i in range(self.batch_size * self.beam_size):
psi[i, self.eos_index] = r_sum[
self.last_frame_index[i // self.beam_size], i
]
psi[:, self.blank_index] = self.minus_inf
return psi - psi_prev, (r, psi, scoring_table)
def permute_mem(self, memory, index):
"""This method permutes the CTC model memory
to synchronize the memory index with the current output.
Arguments
---------
memory : No limit
The memory variable to be permuted.
index : torch.Tensor
The index of the previous path.
Return
------
The variable of the memory being permuted.
"""
r, psi, scoring_table = memory
best_index = (
index
+ (self.beam_offset.unsqueeze(1).expand_as(index) * self.vocab_size)
).view(-1)
psi = torch.index_select(psi.view(-1), dim=0, index=best_index)
psi = (
psi.view(-1, 1)
.repeat(1, self.vocab_size)
.view(self.batch_size * self.beam_size, self.vocab_size)
)
if scoring_table is not None:
effective_index = (
index // self.vocab_size + self.beam_offset.view(-1, 1)
).view(-1)
selected_vocab = (index % self.vocab_size).view(-1)
score_index = scoring_table[effective_index, selected_vocab]
score_index[score_index == -1] = 0
best_index = score_index + effective_index * self.num_candidates
r = torch.index_select(
r.view(
-1, 2, self.batch_size * self.beam_size * self.num_candidates
),
dim=-1,
index=best_index,
)
r = r.view(-1, 2, self.batch_size * self.beam_size)
return r, psi
def filter_ctc_output(string_pred, blank_id=-1):
"""Apply CTC output merge and filter rules.
Removes the blank symbol and output repetitions.
Arguments
---------
string_pred : list
A list containing the output strings/ints predicted by the CTC system.
blank_id : int, string
The id of the blank.
Returns
-------
list
The output predicted by CTC without the blank symbol and
the repetitions.
Example
-------
>>> string_pred = ['a','a','blank','b','b','blank','c']
>>> string_out = filter_ctc_output(string_pred, blank_id='blank')
>>> print(string_out)
['a', 'b', 'c']
"""
if isinstance(string_pred, list):
string_out = [i[0] for i in groupby(string_pred)]
string_out = list(filter(lambda elem: elem != blank_id, string_out))
else:
raise ValueError("filter_ctc_out can only filter python lists")
return string_out
def ctc_greedy_decode(probabilities, seq_lens, blank_id=-1):
"""Greedy decode a batch of probabilities and apply CTC rules.
Arguments
---------
probabilities : torch.tensor
Output probabilities (or log-probabilities) from the network with shape
[batch, probabilities, time]
seq_lens : torch.tensor
Relative true sequence lengths (to deal with padded inputs),
the longest sequence has length 1.0, others a value between zero and one
shape [batch, lengths].
blank_id : int, string
The blank symbol/index. Default: -1. If a negative number is given,
it is assumed to mean counting down from the maximum possible index,
so that -1 refers to the maximum possible index.
Returns
-------
list
Outputs as Python list of lists, with "ragged" dimensions; padding
has been removed.
Example
-------
>>> import torch
>>> probs = torch.tensor([[[0.3, 0.7], [0.0, 0.0]],
... [[0.2, 0.8], [0.9, 0.1]]])
>>> lens = torch.tensor([0.51, 1.0])
>>> blank_id = 0
>>> ctc_greedy_decode(probs, lens, blank_id)
[[1], [1]]
"""
if isinstance(blank_id, int) and blank_id < 0:
blank_id = probabilities.shape[-1] + blank_id
batch_max_len = probabilities.shape[1]
batch_outputs = []
for seq, seq_len in zip(probabilities, seq_lens):
actual_size = int(torch.round(seq_len * batch_max_len))
scores, predictions = torch.max(seq.narrow(0, 0, actual_size), dim=1)
out = filter_ctc_output(predictions.tolist(), blank_id=blank_id)
batch_outputs.append(out)
return batch_outputs