@@ -24,6 +24,7 @@
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ # 'num_classes': 1000, 'input_size': (3, 384, 384), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
@@ -84,19 +85,22 @@
self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim)
self.attn_drop = nn.Dropout(attn_drop)
-
+
def im2cswin(self, x):
- B, N, C = x.shape
+ _, B, C, N = x.shape
H = W = int(np.sqrt(N))
- x = x.transpose(-2,-1).contiguous().view(B, C, H, W)
- x = img2windows(x, self.H_sp, self.W_sp)
+ H_sp, W_sp = self.H_sp, self.W_sp
+ x = x.contiguous().view(2*B, C, H, W)
+ x_reshape = x.view(2*B, C, H // H_sp, H_sp, W // W_sp, W_sp)
+ x = x_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C)
x = x.reshape(-1, self.H_sp* self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
return x
+
def get_lepe(self, x, func):
- B, N, C = x.shape
+ _, B, C, N = x.shape
H = W = int(np.sqrt(N))
- x = x.transpose(-2,-1).contiguous().view(B, C, H, W)
+ x = x.contiguous().view(B, C, H, W)
H_sp, W_sp = self.H_sp, self.W_sp
x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
@@ -112,15 +116,16 @@
"""
x: B L C
"""
- q,k,v = qkv[0], qkv[1], qkv[2]
+ qk, v = torch.split(qkv, [2,1], dim=0) # qk: 2 B C N; v: 1 B C N
### Img2Window
H = W = self.resolution
- B, L, C = q.shape
+ _, B, C, L = qk.shape
assert L == H * W, "flatten img_tokens has wrong size"
- q = self.im2cswin(q)
- k = self.im2cswin(k)
+ qk = self.im2cswin(qk)
+ tmp = qk.shape[0]//2
+ q, k = torch.split(qk, int(tmp), dim=0)
v, lepe = self.get_lepe(v, self.get_v)
q = q * self.scale
@@ -193,11 +198,13 @@
B, L, C = x.shape
assert L == H * W, "flatten img_tokens has wrong size"
img = self.norm1(x)
- qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3)
+ qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 3, 1)
if self.branch_num == 2:
- x1 = self.attns[0](qkv[:,:,:,:C//2])
- x2 = self.attns[1](qkv[:,:,:,C//2:])
+ x1,x2 = torch.split(qkv, C//2, dim=2)
+ x1 = self.attns[0](x1)
+ x2 = self.attns[1](x2)
+
attened_x = torch.cat([x1,x2], dim=2)
else:
attened_x = self.attns[0](qkv)