@@ -7,7 +7,11 @@ from fairseq.sequence_generator import SequenceGenerator
from torch import Tensor
class TextRecognitionGenerator(SequenceGenerator):
-
+ def forward(self, imgs):
+ sample = {"net_input": {"imgs": imgs}}
+ y = self._generate(sample)
+ return y
+
def _generate(
self,
sample: Dict[str, Dict[str, Tensor]],
@@ -34,12 +38,14 @@ class TextRecognitionGenerator(SequenceGenerator):
# "src_tokens": [],
# "src_lengths": [],
encoder_outs = self.model.forward_encoder(net_input) # T x B x C
- src_lengths = encoder_outs[0]['encoder_padding_mask'][0].eq(0).long().sum(dim=1) # B
- src_tokens = encoder_outs[0]['encoder_padding_mask'][0] # B x T
+ src_lengths = encoder_outs[0]['encoder_padding_mask'][0].eq(0).long().sum(dim=1).to(torch.int32) # B
+ src_tokens = encoder_outs[0]['encoder_padding_mask'][0].to(torch.int32) # B x T
# bsz: total number of sentences in beam
# Note that src_tokens may have more than 2 dimensions (i.e. audio features)
bsz, src_len = src_tokens.size()[:2]
+ bsz = bsz.to(torch.int32)
+ src_len = src_len.to(torch.int32)
beam_size = self.beam_size
if constraints is not None and not self.search.supports_constraints:
@@ -63,7 +69,6 @@ class TextRecognitionGenerator(SequenceGenerator):
self.min_len <= max_len
), "min_len cannot be larger than max_len, please adjust these!"
-
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long()
@@ -75,13 +80,8 @@ class TextRecognitionGenerator(SequenceGenerator):
scores = (
torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
) # +1 for eos; pad is never chosen for scoring
- tokens = (
- torch.zeros(bsz * beam_size, max_len + 2)
- .to(src_tokens)
- .long()
- .fill_(self.pad)
- ) # +2 for eos and pad
- tokens[:, 0] = self.eos if bos_token is None else bos_token
+ tokens = torch.ones(bsz * beam_size, max_len + 1).to(torch.int32)
+ tokens = torch.nn.functional.pad(tokens, (1, 0), value=self.eos)
attn: Optional[Tensor] = None
# A list that indicates candidates that should be ignored.
@@ -123,8 +123,15 @@ class TextRecognitionGenerator(SequenceGenerator):
original_batch_idxs = sample["id"]
else:
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
-
- for step in range(max_len + 1): # one extra step for EOS marker
+ loop_size = 30
+ cand_bbsz_idx_out = torch.empty(1, 1, cand_size).type(torch.int32)
+ cand_scores_out = torch.empty(1, 1, cand_size).type(torch.float)
+ scores_out = torch.empty(1, beam_size, 201).type(torch.float)
+ attn_out = torch.empty(1, beam_size, 578, 202).type(torch.float)
+ eos_mask_out = torch.empty(1, 1, cand_size).type(torch.bool)
+ tokens_out = torch.empty(1, beam_size, 202).type(torch.int32)
+
+ for step in range(loop_size): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
if batch_idxs is not None:
@@ -158,8 +165,8 @@ class TextRecognitionGenerator(SequenceGenerator):
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
- lprobs[:, self.pad] = -math.inf # never select pad
- lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
+ y = torch.zeros(10, 1) - math.inf
+ lprobs = torch.cat((lprobs[:, 0].unsqueeze(1), y, lprobs[:, 2:]), dim=1)
# handle max length constraint
if step >= max_len:
@@ -177,7 +184,8 @@ class TextRecognitionGenerator(SequenceGenerator):
)
elif step < self.min_len:
# minimum length constraint (does not apply if using prefix_tokens)
- lprobs[:, self.eos] = -math.inf
+ y = torch.zeros(10, 1) - math.inf
+ lprobs = torch.cat((lprobs[:, :self.eos], y, lprobs[:, self.eos + 1:]), dim=1)
# Record attention scores, only support avg_attn_scores is a Tensor
if avg_attn_scores is not None:
@@ -185,7 +193,7 @@ class TextRecognitionGenerator(SequenceGenerator):
attn = torch.empty(
bsz * beam_size, avg_attn_scores.size(1), max_len + 2
).to(scores)
- attn[:, :, step + 1].copy_(avg_attn_scores)
+ attn = torch.cat((attn[:, :, :step + 1], avg_attn_scores.unsqueeze(2), attn[:, :, (step + 2):]), dim=2)
scores = scores.type_as(lprobs)
eos_bbsz_idx = torch.empty(0).to(
@@ -218,7 +226,6 @@ class TextRecognitionGenerator(SequenceGenerator):
# finalize hypotheses that end in eos
# Shape of eos_mask: (batch size, beam size)
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
- eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
# only consider eos when it's among the top beam_size indices
# Now we know what beam item(s) to finish
@@ -226,82 +233,22 @@ class TextRecognitionGenerator(SequenceGenerator):
eos_bbsz_idx = torch.masked_select(
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
)
-
- finalized_sents: List[int] = []
- if eos_bbsz_idx.numel() > 0:
- eos_scores = torch.masked_select(
- cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
- )
-
- finalized_sents = self.finalize_hypos(
- step,
- eos_bbsz_idx,
- eos_scores,
- tokens,
- scores,
- finalized,
- finished,
- beam_size,
- attn,
- src_lengths,
- max_len,
- )
- num_remaining_sent -= len(finalized_sents)
-
- assert num_remaining_sent >= 0
- if num_remaining_sent == 0:
- break
- if self.search.stop_on_max_len and step >= max_len:
- break
- assert step < max_len, f"{step} < {max_len}"
-
- # Remove finalized sentences (ones for which {beam_size}
- # finished hypotheses have been generated) from the batch.
- if len(finalized_sents) > 0:
- new_bsz = bsz - len(finalized_sents)
-
- # construct batch_idxs which holds indices of batches to keep for the next pass
- batch_mask = torch.ones(
- bsz, dtype=torch.bool, device=cand_indices.device
- )
- batch_mask[finalized_sents] = False
- # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
- batch_idxs = torch.arange(
- bsz, device=cand_indices.device
- ).masked_select(batch_mask)
-
- # Choose the subset of the hypothesized constraints that will continue
- self.search.prune_sentences(batch_idxs)
-
- eos_mask = eos_mask[batch_idxs]
- cand_beams = cand_beams[batch_idxs]
- bbsz_offsets.resize_(new_bsz, 1)
- cand_bbsz_idx = cand_beams.add(bbsz_offsets)
- cand_scores = cand_scores[batch_idxs]
- cand_indices = cand_indices[batch_idxs]
-
- if prefix_tokens is not None:
- prefix_tokens = prefix_tokens[batch_idxs]
- src_lengths = src_lengths[batch_idxs]
- cands_to_ignore = cands_to_ignore[batch_idxs]
-
- scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
- tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
- if attn is not None:
- attn = attn.view(bsz, -1)[batch_idxs].view(
- new_bsz * beam_size, attn.size(1), -1
- )
- bsz = new_bsz
- else:
- batch_idxs = None
+ cand_bbsz_idx_out = torch.cat((cand_bbsz_idx_out, cand_bbsz_idx.unsqueeze(0)), dim = 0)
+ cand_scores_out = torch.cat((cand_scores_out, cand_scores.unsqueeze(0)), dim = 0)
+ scores_out = torch.cat((scores_out, scores.unsqueeze(0)), dim = 0)
+ attn_out = torch.cat((attn_out, attn.unsqueeze(0)), dim = 0)
+ eos_mask_out = torch.cat((eos_mask_out, eos_mask.unsqueeze(0)), dim = 0)
+ tokens_out = torch.cat((tokens_out, tokens.unsqueeze(0)), dim = 0)
+ batch_idxs = None
# Set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
# Rewrite the operator since the element wise or is not supported in torchscript.
-
- eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
+ y = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
+ eos_mask = torch.cat((y, eos_mask[:, beam_size:]), dim = 1)
+ # eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
active_mask = torch.add(
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[: eos_mask.size(1)],
@@ -328,48 +275,49 @@ class TextRecognitionGenerator(SequenceGenerator):
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
- active_bbsz_idx = active_bbsz_idx.view(-1)
- active_scores = active_scores.view(-1)
+ active_bbsz_idx = active_bbsz_idx.view(-1).to(torch.int32)
+ active_scores = active_scores.view(-1).to(torch.int32)
# copy tokens and scores for active hypotheses
# Set the tokens for each beam (can select the same row more than once)
- tokens[:, : step + 1] = torch.index_select(
+ y = torch.index_select(
tokens[:, : step + 1], dim=0, index=active_bbsz_idx
)
+ tokens = torch.cat((y, tokens[:, step + 1 :]), dim = 1)
# Select the next token for each of them
- tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
+ tokens = tokens.view(bsz, beam_size, -1)
+ y = torch.gather(
cand_indices, dim=1, index=active_hypos
)
+ tokens = torch.cat((tokens[:, :, :step + 1], y.unsqueeze(2), tokens[:, :, (step + 2):]), dim = 2)
+ tokens = tokens.view(beam_size, -1)
if step > 0:
- scores[:, :step] = torch.index_select(
+ y = torch.index_select(
scores[:, :step], dim=0, index=active_bbsz_idx
)
- scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
+ scores = torch.cat((y, scores[:, step:]), dim = 1)
+ scores = scores.view(bsz, beam_size, -1)
+ y = torch.gather(
cand_scores, dim=1, index=active_hypos
)
+ if step == 0:
+ scores = torch.cat((y.unsqueeze(2), scores[:, :, 1:]), dim = 2)
+ elif step == 1 :
+ scores = torch.cat((scores[:, :, 0].unsqueeze(2), y.unsqueeze(2), scores[:, :, 2:]), dim = 2)
+ else :
+ scores = torch.cat((scores[:, :, :step], y.unsqueeze(2), scores[:, :, (step + 1):]), dim = 2)
+ scores = scores.view(beam_size, -1)
# Update constraints based on which candidates were selected for the next beam
self.search.update_constraints(active_hypos)
# copy attention for active hypotheses
if attn is not None:
- attn[:, :, : step + 2] = torch.index_select(
+ y = torch.index_select(
attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
)
-
+ attn = torch.cat((y, attn[:, :, (step + 2):]), dim = 2)
# reorder incremental state in decoder
reorder_state = active_bbsz_idx
-
- # sort by score descending
- for sent in range(len(finalized)):
- scores = torch.tensor(
- [float(elem["score"].item()) for elem in finalized[sent]]
- )
- _, sorted_scores_indices = torch.sort(scores, descending=True)
- finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
- finalized[sent] = torch.jit.annotate(
- List[Dict[str, Tensor]], finalized[sent]
- )
- return finalized
-
+ return cand_bbsz_idx_out, eos_mask_out, cand_scores_out, tokens_out, scores_out, attn_out
@@ -9,8 +9,8 @@ try:
from .data import SROIETextRecognitionDataset, SyntheticTextRecognitionDataset
from .data_aug import build_data_aug
except:
- from data import SROIETextRecognitionDataset, SyntheticTextRecognitionDataset
- from data_aug import build_data_aug
+ from .data import SROIETextRecognitionDataset, SyntheticTextRecognitionDataset
+ from .data_aug import build_data_aug
import logging
@@ -260,3 +260,74 @@ class SROIETextRecognitionTask(LegacyFairseqTask):
search_strategy=search_strategy,
**extra_gen_cls_kwargs,
)
+
+ def inference_step(
+ self, generator, models, sample, prefix_tokens=None, constraints=None
+ ):
+ import torch
+ from torch import Tensor
+ from typing import Dict, List
+ import numpy as np
+ max_len = 200
+ src_lengths = torch.tensor([578])
+ num_remaining_sent = 1
+ finalized = torch.jit.annotate(
+ List[List[Dict[str, Tensor]]],
+ [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(1)],
+ )
+ finished = [False]
+ path = "../../out/"
+ result_path = [d.path for d in os.scandir(path) if d.is_dir()]
+ file_path_base = os.path.join(result_path[0], "tfm_img_" + str(sample['id'].item()))
+
+ cand_bbsz_idx_out = np.fromfile(file_path_base + "_0.bin", dtype=np.int32).reshape([31,1,20])
+ eos_mask_out = np.fromfile(file_path_base + "_1.bin", dtype=bool).reshape([31,1,20])
+ cand_scores_out = np.fromfile(file_path_base + "_2.bin", dtype=np.float32).reshape([31,1,20])
+ tokens_out = np.fromfile(file_path_base + "_3.bin", dtype=np.int32).reshape([31,10,202])
+ scores_out = np.fromfile(file_path_base + "_4.bin", dtype=np.float32).reshape([31,10,201])
+ attn_out = np.fromfile(file_path_base + "_5.bin", dtype=np.float32).reshape([31,10,578,202])
+
+ for i in range(0, 30):
+ k = i+1
+ cand_bbsz_idx = torch.from_numpy(cand_bbsz_idx_out[k])
+ eos_mask = torch.from_numpy(eos_mask_out[k])
+ cand_scores = torch.from_numpy(cand_scores_out[k])
+ tokens = torch.from_numpy(tokens_out[k])
+ scores = torch.from_numpy(scores_out[k])
+ attn = torch.from_numpy(attn_out[k])
+ eos_bbsz_idx = torch.masked_select(
+ cand_bbsz_idx[:, :10], mask=eos_mask[:, :10]
+ )
+ finalized_sents: List[int] = []
+ if eos_bbsz_idx.numel() > 0:
+ eos_scores = torch.masked_select(
+ cand_scores[:, :10], mask=eos_mask[:, :10]
+ )
+ finalized_sents = generator.finalize_hypos(
+ i,
+ eos_bbsz_idx.to(torch.int32),
+ eos_scores,
+ tokens,
+ scores,
+ finalized,
+ finished,
+ 10,
+ attn,
+ src_lengths,
+ max_len,
+ )
+ num_remaining_sent -= len(finalized_sents)
+ assert num_remaining_sent >= 0
+ if num_remaining_sent == 0:
+ break
+
+ for sent in range(len(finalized)):
+ scores = torch.tensor(
+ [float(elem["score"].item()) for elem in finalized[sent]]
+ )
+ _, sorted_scores_indices = torch.sort(scores, descending=True)
+ finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
+ finalized[sent] = torch.jit.annotate(
+ List[Dict[str, Tensor]], finalized[sent]
+ )
+ return finalized
@@ -115,16 +115,19 @@ class BeamSearch(Search):
original_batch_idxs: Optional[Tensor] = None,
):
bsz, beam_size, vocab_size = lprobs.size()
+ if torch.is_tensor(beam_size):
+ beam_size = beam_size.item()
+ vocab_size = vocab_size.item()
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
- lprobs = lprobs[:, ::beam_size, :].contiguous()
+ lprobs = lprobs[0][0].contiguous()
else:
# make probs contain cumulative scores for each hypothesis
assert scores is not None
- lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)
+ lprobs = lprobs + scores[:, :, step - 1].reshape(1, 10, 1)
top_prediction = torch.topk(
lprobs.view(bsz, -1),
k=min(
@@ -137,7 +140,9 @@ class BeamSearch(Search):
scores_buf = top_prediction[0]
indices_buf = top_prediction[1]
# Project back into relative indices and beams
- beams_buf = torch.div(indices_buf, vocab_size, rounding_mode="trunc")
+ beams_buf = torch.div(indices_buf, vocab_size).to(torch.int32)
+ vocab_size = torch.tensor([vocab_size]).to(torch.int32)
+ indices_buf = indices_buf.to(torch.int32)
indices_buf = indices_buf.fmod(vocab_size)
# At this point, beams_buf and indices_buf are single-dim and contain relative indices
@@ -821,7 +821,7 @@ class EnsembleModel(nn.Module):
elif attn_holder is not None:
attn = attn_holder[0]
if attn is not None:
- attn = attn[:, -1, :]
+ attn = attn.reshape(10, 578)
decoder_out_tuple = (
decoder_out[0][:, -1:, :].div_(temperature),
@@ -830,7 +830,7 @@ class EnsembleModel(nn.Module):
probs = model.get_normalized_probs(
decoder_out_tuple, log_probs=True, sample=None
)
- probs = probs[:, -1, :]
+ probs = probs.reshape(10, 64044)
if self.models_size == 1:
return probs, attn
@@ -218,10 +218,8 @@ def do_setup(package_data):
'numpy; python_version>="3.7"',
"regex",
"sacrebleu>=1.4.12",
- "torch",
"tqdm",
"bitarray",
- "torchaudio>=0.8.0",
],
dependency_links=dependency_links,
packages=find_packages(
@@ -265,7 +265,7 @@ def http_get(url, temp_file):
import requests
from tqdm import tqdm
- req = request_wrap_timeout(partial(requests.get, url, stream=True), url)
+ req = request_wrap_timeout(partial(requests.get, url, stream=True, verify=False), url)
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
@@ -297,7 +297,7 @@ def get_from_cache(url, cache_dir=None):
import requests
response = request_wrap_timeout(
- partial(requests.head, url, allow_redirects=True), url
+ partial(requests.head, url, allow_redirects=True, verify=False), url
)
if response.status_code != 200:
etag = None