@@ -118,10 +118,20 @@ class Attention(nn.Module):
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
- attn_float = attn.softmax(dim=-1, dtype=torch.float32)
+ # attn_float = attn.softmax(dim=-1, dtype=torch.float32)
+ # ************************** #
+ attn = attn.transpose(1, -1)
+ attn_float = attn.softmax(dim=1, dtype=torch.float32)
+ attn = attn.transpose(1, -1)
+ # ************************** #
attn = attn_float.type_as(attn)
else:
- attn = attn.softmax(dim=-1)
+ # attn = attn.softmax(dim=-1)
+ # ***************************** #
+ attn = attn.transpose(1, -1)
+ attn = attn.softmax(dim=1)
+ attn = attn.transpose(1, -1)
+ # ***************************** #
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.head_dim* self.num_heads)
@@ -109,6 +109,7 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
help='Valid label indices txt file for validation of partial label space')
+
def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
@@ -163,7 +164,7 @@ def validate(args):
model = model.cuda()
if args.apex_amp:
- model = amp.initialize(model, opt_level='O1')
+ model = amp.initialize(model, opt_level='O2')
if args.channels_last:
model = model.to(memory_format=torch.channels_last)