diff -Naru a/configs/_base_/recog_datasets/academic_test.py b/configs/_base_/recog_datasets/academic_test.py
@@ -54,4 +54,4 @@
test6['img_prefix'] = test_img_prefix6
test6['ann_file'] = test_ann_file6
-test_list = [test1, test2, test3, test4, test5, test6]
+test_list = [test1]
diff -Naru a/mmocr/models/textrecog/decoders/nrtr_decoder.py b/mmocr/models/textrecog/decoders/nrtr_decoder.py
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
+import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -87,8 +88,7 @@
def get_subsequent_mask(seq):
"""For masking out the subsequent info."""
len_s = seq.size(1)
- subsequent_mask = 1 - torch.triu(
- torch.ones((len_s, len_s), device=seq.device), diagonal=1)
+ subsequent_mask = 1 - torch.from_numpy(np.triu(torch.ones((len_s, len_s)), 1))
subsequent_mask = subsequent_mask.unsqueeze(0).bool()
return subsequent_mask
@@ -156,7 +156,7 @@
init_target_seq = torch.full((N, self.max_seq_len + 1),
self.padding_idx,
device=out_enc.device,
- dtype=torch.long)
+ dtype=torch.int32)
# bsz * seq_len
init_target_seq[:, 0] = self.start_idx