05360171创建于 2022年3月18日历史提交
diff --git a/tnt_pytorch/tnt.py b/tnt_pytorch/tnt.py
index 9bff5fd..d29e120 100644
--- a/tnt_pytorch/tnt.py
+++ b/tnt_pytorch/tnt.py
@@ -1,1 +1,15 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# 2021.12.12-Changed for ACL Pytorch
 # 2021.06.15-Changed for implementation of TNT model
@@ -98,7 +112,7 @@ class SE(nn.Module):
         )
 
     def forward(self, x):
-        a = x.mean(dim=1, keepdim=True) # B, 1, C
+        a = x.mean(dim=1, keepdim=True)  # B, 1, C
         a = self.fc(a)
         x = a * x
         return x
@@ -122,8 +136,10 @@ class Attention(nn.Module):
 
     def forward(self, x):
         B, N, C = x.shape
-        qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
-        q, k = qk[0], qk[1]   # make torchscript happy (cannot use tensor as tuple)
+        qk = self.qk(x).reshape(B, N, 2, self.num_heads,
+                                self.head_dim).permute(2, 0, 3, 1, 4)
+        # make torchscript happy (cannot use tensor as tuple)
+        q, k = qk[0], qk[1]
         v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
 
         attn = (q @ k.transpose(-2, -1)) * self.scale
@@ -139,6 +155,7 @@ class Attention(nn.Module):
 class Block(nn.Module):
     """ TNT Block
     """
+
     def __init__(self, outer_dim, inner_dim, outer_num_heads, inner_num_heads, num_words, mlp_ratio=4.,
                  qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
                  norm_layer=nn.LayerNorm, se=0):
@@ -162,7 +179,8 @@ class Block(nn.Module):
         self.outer_attn = Attention(
             outer_dim, outer_dim, num_heads=outer_num_heads, qkv_bias=qkv_bias,
             qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
-        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.drop_path = DropPath(
+            drop_path) if drop_path > 0. else nn.Identity()
         self.outer_norm2 = norm_layer(outer_dim)
         self.outer_mlp = Mlp(in_features=outer_dim, hidden_features=int(outer_dim * mlp_ratio),
                              out_features=outer_dim, act_layer=act_layer, drop=drop)
@@ -174,76 +192,96 @@ class Block(nn.Module):
 
     def forward(self, inner_tokens, outer_tokens):
         if self.has_inner:
-            inner_tokens = inner_tokens + self.drop_path(self.inner_attn(self.inner_norm1(inner_tokens))) # B*N, k*k, c
-            inner_tokens = inner_tokens + self.drop_path(self.inner_mlp(self.inner_norm2(inner_tokens))) # B*N, k*k, c
+            inner_tokens = inner_tokens + \
+                self.drop_path(self.inner_attn(
+                    self.inner_norm1(inner_tokens)))  # B*N, k*k, c
+            inner_tokens = inner_tokens + \
+                self.drop_path(self.inner_mlp(
+                    self.inner_norm2(inner_tokens)))  # B*N, k*k, c
             B, N, C = outer_tokens.size()
-            outer_tokens[:,1:] = outer_tokens[:,1:] + self.proj_norm2(self.proj(self.proj_norm1(inner_tokens.reshape(B, N-1, -1)))) # B, N, C
+            outer_tokens[:, 1:] = outer_tokens[:, 1:] + self.proj_norm2(
+                self.proj(self.proj_norm1(inner_tokens.reshape(B, N-1, -1))))  # B, N, C
         if self.se > 0:
-            outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens)))
+            outer_tokens = outer_tokens + \
+                self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens)))
             tmp_ = self.outer_mlp(self.outer_norm2(outer_tokens))
-            outer_tokens = outer_tokens + self.drop_path(tmp_ + self.se_layer(tmp_))
+            outer_tokens = outer_tokens + \
+                self.drop_path(tmp_ + self.se_layer(tmp_))
         else:
-            outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens)))
-            outer_tokens = outer_tokens + self.drop_path(self.outer_mlp(self.outer_norm2(outer_tokens)))
+            outer_tokens = outer_tokens + \
+                self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens)))
+            outer_tokens = outer_tokens + \
+                self.drop_path(self.outer_mlp(self.outer_norm2(outer_tokens)))
         return inner_tokens, outer_tokens
 
 
 class PatchEmbed(nn.Module):
     """ Image to Visual Word Embedding
     """
+
     def __init__(self, img_size=224, patch_size=16, in_chans=3, outer_dim=768, inner_dim=24, inner_stride=4):
         super().__init__()
         img_size = to_2tuple(img_size)
         patch_size = to_2tuple(patch_size)
-        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        num_patches = (img_size[1] // patch_size[1]) * \
+            (img_size[0] // patch_size[0])
         self.img_size = img_size
         self.patch_size = patch_size
         self.num_patches = num_patches
         self.inner_dim = inner_dim
-        self.num_words = math.ceil(patch_size[0] / inner_stride) * math.ceil(patch_size[1] / inner_stride)
-        
+        self.num_words = math.ceil(
+            patch_size[0] / inner_stride) * math.ceil(patch_size[1] / inner_stride)
+
         self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
-        self.proj = nn.Conv2d(in_chans, inner_dim, kernel_size=7, padding=3, stride=inner_stride)
+        self.proj = nn.Conv2d(in_chans, inner_dim,
+                              kernel_size=7, padding=3, stride=inner_stride)
 
     def forward(self, x):
         B, C, H, W = x.shape
         # FIXME look at relaxing size constraints
         assert H == self.img_size[0] and W == self.img_size[1], \
             f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
-        x = self.unfold(x) # B, Ck2, N
-        x = x.transpose(1, 2).reshape(B * self.num_patches, C, *self.patch_size) # B*N, C, 16, 16
-        x = self.proj(x) # B*N, C, 8, 8
-        x = x.reshape(B * self.num_patches, self.inner_dim, -1).transpose(1, 2) # B*N, 8*8, C
+        x = self.unfold(x)  # B, Ck2, N
+        x = x.transpose(1, 2).reshape(B * self.num_patches,
+                                      C, *self.patch_size)  # B*N, C, 16, 16
+        x = self.proj(x)  # B*N, C, 8, 8
+        x = x.reshape(B * self.num_patches, self.inner_dim, -
+                      1).transpose(1, 2)  # B*N, 8*8, C
         return x
 
 
 class TNT(nn.Module):
     """ TNT (Transformer in Transformer) for computer vision
     """
+
     def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, outer_dim=768, inner_dim=48,
                  depth=12, outer_num_heads=12, inner_num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None,
                  drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, inner_stride=4, se=0):
         super().__init__()
         self.num_classes = num_classes
-        self.num_features = self.outer_dim = outer_dim  # num_features for consistency with other models
+        # num_features for consistency with other models
+        self.num_features = self.outer_dim = outer_dim
 
         self.patch_embed = PatchEmbed(
             img_size=img_size, patch_size=patch_size, in_chans=in_chans, outer_dim=outer_dim,
             inner_dim=inner_dim, inner_stride=inner_stride)
         self.num_patches = num_patches = self.patch_embed.num_patches
         num_words = self.patch_embed.num_words
-        
+
         self.proj_norm1 = norm_layer(num_words * inner_dim)
         self.proj = nn.Linear(num_words * inner_dim, outer_dim)
         self.proj_norm2 = norm_layer(outer_dim)
 
         self.cls_token = nn.Parameter(torch.zeros(1, 1, outer_dim))
-        self.outer_tokens = nn.Parameter(torch.zeros(1, num_patches, outer_dim), requires_grad=False)
-        self.outer_pos = nn.Parameter(torch.zeros(1, num_patches + 1, outer_dim))
+        self.outer_tokens = nn.Parameter(torch.zeros(
+            1, num_patches, outer_dim), requires_grad=False)
+        self.outer_pos = nn.Parameter(
+            torch.zeros(1, num_patches + 1, outer_dim))
         self.inner_pos = nn.Parameter(torch.zeros(1, num_words, inner_dim))
         self.pos_drop = nn.Dropout(p=drop_rate)
 
-        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+        # stochastic depth decay rule
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
         vanilla_idxs = []
         blocks = []
         for i in range(depth):
@@ -265,7 +303,8 @@ class TNT(nn.Module):
         #self.repr_act = nn.Tanh()
 
         # Classifier head
-        self.head = nn.Linear(outer_dim, num_classes) if num_classes > 0 else nn.Identity()
+        self.head = nn.Linear(
+            outer_dim, num_classes) if num_classes > 0 else nn.Identity()
 
         trunc_normal_(self.cls_token, std=.02)
         trunc_normal_(self.outer_pos, std=.02)
@@ -290,15 +329,19 @@ class TNT(nn.Module):
 
     def reset_classifier(self, num_classes, global_pool=''):
         self.num_classes = num_classes
-        self.head = nn.Linear(self.outer_dim, num_classes) if num_classes > 0 else nn.Identity()
+        self.head = nn.Linear(
+            self.outer_dim, num_classes) if num_classes > 0 else nn.Identity()
 
     def forward_features(self, x):
         B = x.shape[0]
-        inner_tokens = self.patch_embed(x) + self.inner_pos # B*N, 8*8, C
-        
-        outer_tokens = self.proj_norm2(self.proj(self.proj_norm1(inner_tokens.reshape(B, self.num_patches, -1))))        
-        outer_tokens = torch.cat((self.cls_token.expand(B, -1, -1), outer_tokens), dim=1)
-        
+        # inner_tokens = self.patch_embed(x) + self.inner_pos # B*N, 8*8, C
+        inner_tokens = x.reshape(-1, 16, 24)
+
+        outer_tokens = self.proj_norm2(self.proj(self.proj_norm1(
+            inner_tokens.reshape(B, self.num_patches, -1))))
+        outer_tokens = torch.cat(
+            (self.cls_token.expand(B, -1, -1), outer_tokens), dim=1)
+
         outer_tokens = outer_tokens + self.outer_pos
         outer_tokens = self.pos_drop(outer_tokens)