diff --git a/unilm/trocr/generator.py b/unilm/trocr/generator.py
index d3266ce..cca3b34 100644
--- a/unilm/trocr/generator.py
+++ b/unilm/trocr/generator.py
@@ -7,7 +7,11 @@ from fairseq.sequence_generator import SequenceGenerator
 from torch import Tensor
 
 class TextRecognitionGenerator(SequenceGenerator):
-
+    def forward(self, imgs):
+        sample = {"net_input": {"imgs": imgs}}
+        y = self._generate(sample)
+        return y
+    
     def _generate(
         self,
         sample: Dict[str, Dict[str, Tensor]],
@@ -34,12 +38,14 @@ class TextRecognitionGenerator(SequenceGenerator):
             # "src_tokens": [],
             # "src_lengths": [],        
         encoder_outs = self.model.forward_encoder(net_input)  # T x B x C
-        src_lengths = encoder_outs[0]['encoder_padding_mask'][0].eq(0).long().sum(dim=1) # B
-        src_tokens = encoder_outs[0]['encoder_padding_mask'][0]  # B x T
+        src_lengths = encoder_outs[0]['encoder_padding_mask'][0].eq(0).long().sum(dim=1).to(torch.int32) # B
+        src_tokens = encoder_outs[0]['encoder_padding_mask'][0].to(torch.int32)  # B x T
 
         # bsz: total number of sentences in beam
         # Note that src_tokens may have more than 2 dimensions (i.e. audio features)
         bsz, src_len = src_tokens.size()[:2]
+        bsz = bsz.to(torch.int32)
+        src_len = src_len.to(torch.int32)
         beam_size = self.beam_size
 
         if constraints is not None and not self.search.supports_constraints:
@@ -63,7 +69,6 @@ class TextRecognitionGenerator(SequenceGenerator):
             self.min_len <= max_len
         ), "min_len cannot be larger than max_len, please adjust these!"
 
-
         # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
         new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
         new_order = new_order.to(src_tokens.device).long()
@@ -75,13 +80,8 @@ class TextRecognitionGenerator(SequenceGenerator):
         scores = (
             torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
         )  # +1 for eos; pad is never chosen for scoring
-        tokens = (
-            torch.zeros(bsz * beam_size, max_len + 2)
-            .to(src_tokens)
-            .long()
-            .fill_(self.pad)
-        )  # +2 for eos and pad
-        tokens[:, 0] = self.eos if bos_token is None else bos_token
+        tokens = torch.ones(bsz * beam_size, max_len + 1).to(torch.int32)
+        tokens = torch.nn.functional.pad(tokens, (1, 0), value=self.eos)
         attn: Optional[Tensor] = None
 
         # A list that indicates candidates that should be ignored.
@@ -123,8 +123,15 @@ class TextRecognitionGenerator(SequenceGenerator):
             original_batch_idxs = sample["id"]
         else:
             original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
-
-        for step in range(max_len + 1):  # one extra step for EOS marker
+        loop_size = 30
+        cand_bbsz_idx_out = torch.empty(1, 1, cand_size).type(torch.int32)
+        cand_scores_out = torch.empty(1, 1, cand_size).type(torch.float)
+        scores_out = torch.empty(1, beam_size, 201).type(torch.float)
+        attn_out = torch.empty(1, beam_size, 578, 202).type(torch.float)
+        eos_mask_out = torch.empty(1, 1, cand_size).type(torch.bool)
+        tokens_out = torch.empty(1, beam_size, 202).type(torch.int32)
+
+        for step in range(loop_size):  # one extra step for EOS marker
             # reorder decoder internal states based on the prev choice of beams
             if reorder_state is not None:
                 if batch_idxs is not None:
@@ -158,8 +165,8 @@ class TextRecognitionGenerator(SequenceGenerator):
 
             lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
 
-            lprobs[:, self.pad] = -math.inf  # never select pad
-            lprobs[:, self.unk] -= self.unk_penalty  # apply unk penalty
+            y = torch.zeros(10, 1) - math.inf
+            lprobs = torch.cat((lprobs[:, 0].unsqueeze(1), y, lprobs[:, 2:]), dim=1)
 
             # handle max length constraint
             if step >= max_len:
@@ -177,7 +184,8 @@ class TextRecognitionGenerator(SequenceGenerator):
                 )
             elif step < self.min_len:
                 # minimum length constraint (does not apply if using prefix_tokens)
-                lprobs[:, self.eos] = -math.inf
+                y = torch.zeros(10, 1) - math.inf
+                lprobs = torch.cat((lprobs[:, :self.eos], y, lprobs[:, self.eos + 1:]), dim=1)
 
             # Record attention scores, only support avg_attn_scores is a Tensor
             if avg_attn_scores is not None:
@@ -185,7 +193,7 @@ class TextRecognitionGenerator(SequenceGenerator):
                     attn = torch.empty(
                         bsz * beam_size, avg_attn_scores.size(1), max_len + 2
                     ).to(scores)
-                attn[:, :, step + 1].copy_(avg_attn_scores)
+                attn = torch.cat((attn[:, :, :step + 1], avg_attn_scores.unsqueeze(2), attn[:, :, (step + 2):]), dim=2)
 
             scores = scores.type_as(lprobs)
             eos_bbsz_idx = torch.empty(0).to(
@@ -218,7 +226,6 @@ class TextRecognitionGenerator(SequenceGenerator):
             # finalize hypotheses that end in eos
             # Shape of eos_mask: (batch size, beam size)
             eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
-            eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
 
             # only consider eos when it's among the top beam_size indices
             # Now we know what beam item(s) to finish
@@ -226,82 +233,22 @@ class TextRecognitionGenerator(SequenceGenerator):
             eos_bbsz_idx = torch.masked_select(
                 cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
             )
-
-            finalized_sents: List[int] = []
-            if eos_bbsz_idx.numel() > 0:
-                eos_scores = torch.masked_select(
-                    cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
-                )
-
-                finalized_sents = self.finalize_hypos(
-                    step,
-                    eos_bbsz_idx,
-                    eos_scores,
-                    tokens,
-                    scores,
-                    finalized,
-                    finished,
-                    beam_size,
-                    attn,
-                    src_lengths,
-                    max_len,
-                )
-                num_remaining_sent -= len(finalized_sents)
-
-            assert num_remaining_sent >= 0
-            if num_remaining_sent == 0:
-                break
-            if self.search.stop_on_max_len and step >= max_len:
-                break
-            assert step < max_len, f"{step} < {max_len}"
-
-            # Remove finalized sentences (ones for which {beam_size}
-            # finished hypotheses have been generated) from the batch.
-            if len(finalized_sents) > 0:
-                new_bsz = bsz - len(finalized_sents)
-
-                # construct batch_idxs which holds indices of batches to keep for the next pass
-                batch_mask = torch.ones(
-                    bsz, dtype=torch.bool, device=cand_indices.device
-                )
-                batch_mask[finalized_sents] = False
-                # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
-                batch_idxs = torch.arange(
-                    bsz, device=cand_indices.device
-                ).masked_select(batch_mask)
-
-                # Choose the subset of the hypothesized constraints that will continue
-                self.search.prune_sentences(batch_idxs)
-
-                eos_mask = eos_mask[batch_idxs]
-                cand_beams = cand_beams[batch_idxs]
-                bbsz_offsets.resize_(new_bsz, 1)
-                cand_bbsz_idx = cand_beams.add(bbsz_offsets)
-                cand_scores = cand_scores[batch_idxs]
-                cand_indices = cand_indices[batch_idxs]
-
-                if prefix_tokens is not None:
-                    prefix_tokens = prefix_tokens[batch_idxs]
-                src_lengths = src_lengths[batch_idxs]
-                cands_to_ignore = cands_to_ignore[batch_idxs]
-
-                scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
-                tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
-                if attn is not None:
-                    attn = attn.view(bsz, -1)[batch_idxs].view(
-                        new_bsz * beam_size, attn.size(1), -1
-                    )
-                bsz = new_bsz
-            else:
-                batch_idxs = None
+            cand_bbsz_idx_out = torch.cat((cand_bbsz_idx_out, cand_bbsz_idx.unsqueeze(0)), dim = 0)
+            cand_scores_out = torch.cat((cand_scores_out, cand_scores.unsqueeze(0)), dim = 0)
+            scores_out = torch.cat((scores_out, scores.unsqueeze(0)), dim = 0)
+            attn_out = torch.cat((attn_out, attn.unsqueeze(0)), dim = 0)
+            eos_mask_out = torch.cat((eos_mask_out, eos_mask.unsqueeze(0)), dim = 0)
+            tokens_out = torch.cat((tokens_out, tokens.unsqueeze(0)), dim = 0)
+            batch_idxs = None
 
             # Set active_mask so that values > cand_size indicate eos hypos
             # and values < cand_size indicate candidate active hypos.
             # After, the min values per row are the top candidate active hypos
 
             # Rewrite the operator since the element wise or is not supported in torchscript.
-
-            eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
+            y = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
+            eos_mask = torch.cat((y, eos_mask[:, beam_size:]), dim = 1)
+            # eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
             active_mask = torch.add(
                 eos_mask.type_as(cand_offsets) * cand_size,
                 cand_offsets[: eos_mask.size(1)],
@@ -328,48 +275,49 @@ class TextRecognitionGenerator(SequenceGenerator):
             active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
             active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
 
-            active_bbsz_idx = active_bbsz_idx.view(-1)
-            active_scores = active_scores.view(-1)
+            active_bbsz_idx = active_bbsz_idx.view(-1).to(torch.int32)
+            active_scores = active_scores.view(-1).to(torch.int32)
 
             # copy tokens and scores for active hypotheses
 
             # Set the tokens for each beam (can select the same row more than once)
-            tokens[:, : step + 1] = torch.index_select(
+            y = torch.index_select(
                 tokens[:, : step + 1], dim=0, index=active_bbsz_idx
             )
+            tokens = torch.cat((y, tokens[:, step + 1 :]), dim = 1)
             # Select the next token for each of them
-            tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
+            tokens = tokens.view(bsz, beam_size, -1)
+            y = torch.gather(
                 cand_indices, dim=1, index=active_hypos
             )
+            tokens = torch.cat((tokens[:, :, :step + 1], y.unsqueeze(2), tokens[:, :, (step + 2):]), dim = 2)
+            tokens = tokens.view(beam_size, -1)
             if step > 0:
-                scores[:, :step] = torch.index_select(
+                y = torch.index_select(
                     scores[:, :step], dim=0, index=active_bbsz_idx
                 )
-            scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
+                scores = torch.cat((y, scores[:, step:]), dim = 1)
+            scores = scores.view(bsz, beam_size, -1)
+            y = torch.gather(
                 cand_scores, dim=1, index=active_hypos
             )
+            if step == 0:
+                scores = torch.cat((y.unsqueeze(2), scores[:, :, 1:]), dim = 2)
+            elif step == 1 :
+                scores = torch.cat((scores[:, :, 0].unsqueeze(2), y.unsqueeze(2), scores[:, :, 2:]), dim = 2)
+            else :
+                scores = torch.cat((scores[:, :, :step], y.unsqueeze(2), scores[:, :, (step + 1):]), dim = 2)
+            scores = scores.view(beam_size, -1)            
 
             # Update constraints based on which candidates were selected for the next beam
             self.search.update_constraints(active_hypos)
 
             # copy attention for active hypotheses
             if attn is not None:
-                attn[:, :, : step + 2] = torch.index_select(
+                y = torch.index_select(
                     attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
                 )
-
+                attn = torch.cat((y, attn[:, :, (step + 2):]), dim = 2)
             # reorder incremental state in decoder
             reorder_state = active_bbsz_idx
-
-        # sort by score descending
-        for sent in range(len(finalized)):
-            scores = torch.tensor(
-                [float(elem["score"].item()) for elem in finalized[sent]]
-            )
-            _, sorted_scores_indices = torch.sort(scores, descending=True)
-            finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
-            finalized[sent] = torch.jit.annotate(
-                List[Dict[str, Tensor]], finalized[sent]
-            )
-        return finalized
-
+        return cand_bbsz_idx_out, eos_mask_out, cand_scores_out, tokens_out, scores_out, attn_out
diff --git a/unilm/trocr/task.py b/unilm/trocr/task.py
index 8925276..a071d70 100644
--- a/unilm/trocr/task.py
+++ b/unilm/trocr/task.py
@@ -9,8 +9,8 @@ try:
     from .data import SROIETextRecognitionDataset, SyntheticTextRecognitionDataset
     from .data_aug import build_data_aug
 except:
-    from data import SROIETextRecognitionDataset, SyntheticTextRecognitionDataset
-    from data_aug import build_data_aug
+    from .data import SROIETextRecognitionDataset, SyntheticTextRecognitionDataset
+    from .data_aug import build_data_aug
 
 import logging
 
@@ -260,3 +260,74 @@ class SROIETextRecognitionTask(LegacyFairseqTask):
             search_strategy=search_strategy,
             **extra_gen_cls_kwargs,
         )
+    
+    def inference_step(
+        self, generator, models, sample, prefix_tokens=None, constraints=None
+    ):
+        import torch
+        from torch import Tensor
+        from typing import Dict, List
+        import numpy as np
+        max_len = 200
+        src_lengths = torch.tensor([578])
+        num_remaining_sent = 1
+        finalized = torch.jit.annotate(
+                List[List[Dict[str, Tensor]]],
+                [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(1)],
+        )
+        finished = [False]
+        path = "../../out/"
+        result_path = [d.path for d in os.scandir(path) if d.is_dir()]
+        file_path_base = os.path.join(result_path[0], "tfm_img_" + str(sample['id'].item()))
+
+        cand_bbsz_idx_out = np.fromfile(file_path_base + "_0.bin", dtype=np.int32).reshape([31,1,20])
+        eos_mask_out = np.fromfile(file_path_base + "_1.bin", dtype=bool).reshape([31,1,20])
+        cand_scores_out = np.fromfile(file_path_base + "_2.bin", dtype=np.float32).reshape([31,1,20])
+        tokens_out = np.fromfile(file_path_base + "_3.bin", dtype=np.int32).reshape([31,10,202])
+        scores_out = np.fromfile(file_path_base + "_4.bin", dtype=np.float32).reshape([31,10,201])
+        attn_out = np.fromfile(file_path_base + "_5.bin", dtype=np.float32).reshape([31,10,578,202])
+
+        for i in range(0, 30):
+            k = i+1
+            cand_bbsz_idx = torch.from_numpy(cand_bbsz_idx_out[k])
+            eos_mask = torch.from_numpy(eos_mask_out[k])
+            cand_scores = torch.from_numpy(cand_scores_out[k])
+            tokens = torch.from_numpy(tokens_out[k])
+            scores = torch.from_numpy(scores_out[k])
+            attn = torch.from_numpy(attn_out[k])
+            eos_bbsz_idx = torch.masked_select(
+                        cand_bbsz_idx[:, :10], mask=eos_mask[:, :10]
+                    )
+            finalized_sents: List[int] = []
+            if eos_bbsz_idx.numel() > 0:
+                    eos_scores = torch.masked_select(
+                        cand_scores[:, :10], mask=eos_mask[:, :10]
+                    )
+                    finalized_sents = generator.finalize_hypos(
+                        i,
+                        eos_bbsz_idx.to(torch.int32),
+                        eos_scores,
+                        tokens,
+                        scores,
+                        finalized,
+                        finished,
+                        10,
+                        attn,
+                        src_lengths,
+                        max_len,
+                    )
+                    num_remaining_sent -= len(finalized_sents)
+                    assert num_remaining_sent >= 0
+                    if num_remaining_sent == 0:
+                        break
+
+        for sent in range(len(finalized)):
+                    scores = torch.tensor(
+                        [float(elem["score"].item()) for elem in finalized[sent]]
+                    )
+                    _, sorted_scores_indices = torch.sort(scores, descending=True)
+                    finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
+                    finalized[sent] = torch.jit.annotate(
+                        List[Dict[str, Tensor]], finalized[sent]
+                    )
+        return finalized
diff --git a/fairseq/fairseq/search.py b/fairseq/fairseq/search.py
index a71e7801..8ac6344c 100644
--- a/fairseq/fairseq/search.py
+++ b/fairseq/fairseq/search.py
@@ -115,16 +115,19 @@ class BeamSearch(Search):
         original_batch_idxs: Optional[Tensor] = None,
     ):
         bsz, beam_size, vocab_size = lprobs.size()
+        if torch.is_tensor(beam_size):
+            beam_size = beam_size.item()
+            vocab_size = vocab_size.item()
 
         if step == 0:
             # at the first step all hypotheses are equally likely, so use
             # only the first beam
-            lprobs = lprobs[:, ::beam_size, :].contiguous()
+            lprobs = lprobs[0][0].contiguous()
         else:
             # make probs contain cumulative scores for each hypothesis
             assert scores is not None
-            lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)
+            lprobs = lprobs + scores[:, :, step - 1].reshape(1, 10, 1)

         top_prediction = torch.topk(
             lprobs.view(bsz, -1),
             k=min(
@@ -137,7 +140,9 @@ class BeamSearch(Search):
         scores_buf = top_prediction[0]
         indices_buf = top_prediction[1]
         # Project back into relative indices and beams
-        beams_buf = torch.div(indices_buf, vocab_size, rounding_mode="trunc")
+        beams_buf = torch.div(indices_buf, vocab_size).to(torch.int32)
+        vocab_size = torch.tensor([vocab_size]).to(torch.int32)
+        indices_buf = indices_buf.to(torch.int32)
         indices_buf = indices_buf.fmod(vocab_size)
 
         # At this point, beams_buf and indices_buf are single-dim and contain relative indices
diff --git a/fairseq/fairseq/sequence_generator.py b/fairseq/fairseq/sequence_generator.py
index 976adbca..0f019951 100644
--- a/fairseq/fairseq/sequence_generator.py
+++ b/fairseq/fairseq/sequence_generator.py
@@ -821,7 +821,7 @@ class EnsembleModel(nn.Module):
                     elif attn_holder is not None:
                         attn = attn_holder[0]
                 if attn is not None:
-                    attn = attn[:, -1, :]
+                    attn = attn.reshape(10, 578)
 
             decoder_out_tuple = (
                 decoder_out[0][:, -1:, :].div_(temperature),
@@ -830,7 +830,7 @@ class EnsembleModel(nn.Module):
             probs = model.get_normalized_probs(
                 decoder_out_tuple, log_probs=True, sample=None
             )
-            probs = probs[:, -1, :]
+            probs = probs.reshape(10, 64044)
             if self.models_size == 1:
                 return probs, attn
 
diff --git a/fairseq/setup.py b/fairseq/setup.py
index c5591915..e8a64b4a 100644
--- a/fairseq/setup.py
+++ b/fairseq/setup.py
@@ -218,10 +218,8 @@ def do_setup(package_data):
             'numpy; python_version>="3.7"',
             "regex",
             "sacrebleu>=1.4.12",
-            "torch",
             "tqdm",
             "bitarray",
-            "torchaudio>=0.8.0",
         ],
         dependency_links=dependency_links,
         packages=find_packages(

diff --git a/fairseq/fairseq/file_utils.py b/fairseq/fairseq/file_utils.py
index b99da2e8..15b60fae 100644
--- a/fairseq/fairseq/file_utils.py
+++ b/fairseq/fairseq/file_utils.py
@@ -265,7 +265,7 @@ def http_get(url, temp_file):
     import requests
     from tqdm import tqdm

-    req = request_wrap_timeout(partial(requests.get, url, stream=True), url)
+    req = request_wrap_timeout(partial(requests.get, url, stream=True, verify=False), url)
     content_length = req.headers.get("Content-Length")
     total = int(content_length) if content_length is not None else None
     progress = tqdm(unit="B", total=total)
@@ -297,7 +297,7 @@ def get_from_cache(url, cache_dir=None):
             import requests

             response = request_wrap_timeout(
-                partial(requests.head, url, allow_redirects=True), url
+                partial(requests.head, url, allow_redirects=True, verify=False), url
             )
             if response.status_code != 200:
                 etag = None