@@ -1,1 +1,15 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# 2021.12.12-Changed for ACL Pytorch
# 2021.06.15-Changed for implementation of TNT model
@@ -98,7 +112,7 @@ class SE(nn.Module):
)
def forward(self, x):
- a = x.mean(dim=1, keepdim=True) # B, 1, C
+ a = x.mean(dim=1, keepdim=True) # B, 1, C
a = self.fc(a)
x = a * x
return x
@@ -122,8 +136,10 @@ class Attention(nn.Module):
def forward(self, x):
B, N, C = x.shape
- qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
- q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple)
+ qk = self.qk(x).reshape(B, N, 2, self.num_heads,
+ self.head_dim).permute(2, 0, 3, 1, 4)
+ # make torchscript happy (cannot use tensor as tuple)
+ q, k = qk[0], qk[1]
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale
@@ -139,6 +155,7 @@ class Attention(nn.Module):
class Block(nn.Module):
""" TNT Block
"""
+
def __init__(self, outer_dim, inner_dim, outer_num_heads, inner_num_heads, num_words, mlp_ratio=4.,
qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
norm_layer=nn.LayerNorm, se=0):
@@ -162,7 +179,8 @@ class Block(nn.Module):
self.outer_attn = Attention(
outer_dim, outer_dim, num_heads=outer_num_heads, qkv_bias=qkv_bias,
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
self.outer_norm2 = norm_layer(outer_dim)
self.outer_mlp = Mlp(in_features=outer_dim, hidden_features=int(outer_dim * mlp_ratio),
out_features=outer_dim, act_layer=act_layer, drop=drop)
@@ -174,76 +192,96 @@ class Block(nn.Module):
def forward(self, inner_tokens, outer_tokens):
if self.has_inner:
- inner_tokens = inner_tokens + self.drop_path(self.inner_attn(self.inner_norm1(inner_tokens))) # B*N, k*k, c
- inner_tokens = inner_tokens + self.drop_path(self.inner_mlp(self.inner_norm2(inner_tokens))) # B*N, k*k, c
+ inner_tokens = inner_tokens + \
+ self.drop_path(self.inner_attn(
+ self.inner_norm1(inner_tokens))) # B*N, k*k, c
+ inner_tokens = inner_tokens + \
+ self.drop_path(self.inner_mlp(
+ self.inner_norm2(inner_tokens))) # B*N, k*k, c
B, N, C = outer_tokens.size()
- outer_tokens[:,1:] = outer_tokens[:,1:] + self.proj_norm2(self.proj(self.proj_norm1(inner_tokens.reshape(B, N-1, -1)))) # B, N, C
+ outer_tokens[:, 1:] = outer_tokens[:, 1:] + self.proj_norm2(
+ self.proj(self.proj_norm1(inner_tokens.reshape(B, N-1, -1)))) # B, N, C
if self.se > 0:
- outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens)))
+ outer_tokens = outer_tokens + \
+ self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens)))
tmp_ = self.outer_mlp(self.outer_norm2(outer_tokens))
- outer_tokens = outer_tokens + self.drop_path(tmp_ + self.se_layer(tmp_))
+ outer_tokens = outer_tokens + \
+ self.drop_path(tmp_ + self.se_layer(tmp_))
else:
- outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens)))
- outer_tokens = outer_tokens + self.drop_path(self.outer_mlp(self.outer_norm2(outer_tokens)))
+ outer_tokens = outer_tokens + \
+ self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens)))
+ outer_tokens = outer_tokens + \
+ self.drop_path(self.outer_mlp(self.outer_norm2(outer_tokens)))
return inner_tokens, outer_tokens
class PatchEmbed(nn.Module):
""" Image to Visual Word Embedding
"""
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, outer_dim=768, inner_dim=24, inner_stride=4):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
- num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ num_patches = (img_size[1] // patch_size[1]) * \
+ (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.inner_dim = inner_dim
- self.num_words = math.ceil(patch_size[0] / inner_stride) * math.ceil(patch_size[1] / inner_stride)
-
+ self.num_words = math.ceil(
+ patch_size[0] / inner_stride) * math.ceil(patch_size[1] / inner_stride)
+
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
- self.proj = nn.Conv2d(in_chans, inner_dim, kernel_size=7, padding=3, stride=inner_stride)
+ self.proj = nn.Conv2d(in_chans, inner_dim,
+ kernel_size=7, padding=3, stride=inner_stride)
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.unfold(x) # B, Ck2, N
- x = x.transpose(1, 2).reshape(B * self.num_patches, C, *self.patch_size) # B*N, C, 16, 16
- x = self.proj(x) # B*N, C, 8, 8
- x = x.reshape(B * self.num_patches, self.inner_dim, -1).transpose(1, 2) # B*N, 8*8, C
+ x = self.unfold(x) # B, Ck2, N
+ x = x.transpose(1, 2).reshape(B * self.num_patches,
+ C, *self.patch_size) # B*N, C, 16, 16
+ x = self.proj(x) # B*N, C, 8, 8
+ x = x.reshape(B * self.num_patches, self.inner_dim, -
+ 1).transpose(1, 2) # B*N, 8*8, C
return x
class TNT(nn.Module):
""" TNT (Transformer in Transformer) for computer vision
"""
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, outer_dim=768, inner_dim=48,
depth=12, outer_num_heads=12, inner_num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, inner_stride=4, se=0):
super().__init__()
self.num_classes = num_classes
- self.num_features = self.outer_dim = outer_dim # num_features for consistency with other models
+ # num_features for consistency with other models
+ self.num_features = self.outer_dim = outer_dim
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, outer_dim=outer_dim,
inner_dim=inner_dim, inner_stride=inner_stride)
self.num_patches = num_patches = self.patch_embed.num_patches
num_words = self.patch_embed.num_words
-
+
self.proj_norm1 = norm_layer(num_words * inner_dim)
self.proj = nn.Linear(num_words * inner_dim, outer_dim)
self.proj_norm2 = norm_layer(outer_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, outer_dim))
- self.outer_tokens = nn.Parameter(torch.zeros(1, num_patches, outer_dim), requires_grad=False)
- self.outer_pos = nn.Parameter(torch.zeros(1, num_patches + 1, outer_dim))
+ self.outer_tokens = nn.Parameter(torch.zeros(
+ 1, num_patches, outer_dim), requires_grad=False)
+ self.outer_pos = nn.Parameter(
+ torch.zeros(1, num_patches + 1, outer_dim))
self.inner_pos = nn.Parameter(torch.zeros(1, num_words, inner_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
vanilla_idxs = []
blocks = []
for i in range(depth):
@@ -265,7 +303,8 @@ class TNT(nn.Module):
#self.repr_act = nn.Tanh()
# Classifier head
- self.head = nn.Linear(outer_dim, num_classes) if num_classes > 0 else nn.Identity()
+ self.head = nn.Linear(
+ outer_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.cls_token, std=.02)
trunc_normal_(self.outer_pos, std=.02)
@@ -290,15 +329,19 @@ class TNT(nn.Module):
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
- self.head = nn.Linear(self.outer_dim, num_classes) if num_classes > 0 else nn.Identity()
+ self.head = nn.Linear(
+ self.outer_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
- inner_tokens = self.patch_embed(x) + self.inner_pos # B*N, 8*8, C
-
- outer_tokens = self.proj_norm2(self.proj(self.proj_norm1(inner_tokens.reshape(B, self.num_patches, -1))))
- outer_tokens = torch.cat((self.cls_token.expand(B, -1, -1), outer_tokens), dim=1)
-
+ # inner_tokens = self.patch_embed(x) + self.inner_pos # B*N, 8*8, C
+ inner_tokens = x.reshape(-1, 16, 24)
+
+ outer_tokens = self.proj_norm2(self.proj(self.proj_norm1(
+ inner_tokens.reshape(B, self.num_patches, -1))))
+ outer_tokens = torch.cat(
+ (self.cls_token.expand(B, -1, -1), outer_tokens), dim=1)
+
outer_tokens = outer_tokens + self.outer_pos
outer_tokens = self.pos_drop(outer_tokens)