05360171创建于 2022年3月18日历史提交
diff --git a/tlt/models/layers.py b/tlt/models/layers.py
index 599db83..3032fa4 100644
--- a/tlt/models/layers.py
+++ b/tlt/models/layers.py
@@ -118,10 +118,20 @@ class Attention(nn.Module):
                 padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
                 float("-inf"),
             )
-            attn_float = attn.softmax(dim=-1, dtype=torch.float32)
+            # attn_float = attn.softmax(dim=-1, dtype=torch.float32)
+            # ************************** #
+            attn = attn.transpose(1, -1)
+            attn_float = attn.softmax(dim=1, dtype=torch.float32)
+            attn = attn.transpose(1, -1)
+            # ************************** #
             attn = attn_float.type_as(attn)
         else:
-            attn = attn.softmax(dim=-1)
+            # attn = attn.softmax(dim=-1)
+            # ***************************** #
+            attn = attn.transpose(1, -1)
+            attn = attn.softmax(dim=1)
+            attn = attn.transpose(1, -1)
+            # ***************************** #
         attn = self.attn_drop(attn)
 
         x = (attn @ v).transpose(1, 2).reshape(B, N, self.head_dim* self.num_heads)
diff --git a/validate.py b/validate.py
index fb33924..24f5c8b 100755
--- a/validate.py
+++ b/validate.py
@@ -109,6 +109,7 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
                     help='Valid label indices txt file for validation of partial label space')
 
 
+
 def validate(args):
     # might as well try to validate something
     args.pretrained = args.pretrained or not args.checkpoint
@@ -163,7 +164,7 @@ def validate(args):
 
     model = model.cuda()
     if args.apex_amp:
-        model = amp.initialize(model, opt_level='O1')
+        model = amp.initialize(model, opt_level='O2')
 
     if args.channels_last:
         model = model.to(memory_format=torch.channels_last)