--- ./CSWin-Transformer/models/cswin.py	2022-09-03 11:59:22.539626278 +0800
+++ cswin_modify.py	2022-09-03 13:07:33.947676406 +0800
@@ -24,6 +24,7 @@
     return {
         'url': url,
         'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+        # 'num_classes': 1000, 'input_size': (3, 384, 384), 'pool_size': None,
         'crop_pct': .9, 'interpolation': 'bicubic',
         'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
         'first_conv': 'patch_embed.proj', 'classifier': 'head',
@@ -84,19 +85,22 @@
         self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim)
 
         self.attn_drop = nn.Dropout(attn_drop)
-
+    
     def im2cswin(self, x):
-        B, N, C = x.shape
+        _, B, C, N = x.shape
         H = W = int(np.sqrt(N))
-        x = x.transpose(-2,-1).contiguous().view(B, C, H, W)
-        x = img2windows(x, self.H_sp, self.W_sp)
+        H_sp, W_sp = self.H_sp, self.W_sp
+        x = x.contiguous().view(2*B, C, H, W)
+        x_reshape = x.view(2*B, C, H // H_sp, H_sp, W // W_sp, W_sp)
+        x = x_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C)
         x = x.reshape(-1, self.H_sp* self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
         return x
 
+
     def get_lepe(self, x, func):
-        B, N, C = x.shape
+        _, B, C, N = x.shape
         H = W = int(np.sqrt(N))
-        x = x.transpose(-2,-1).contiguous().view(B, C, H, W)
+        x = x.contiguous().view(B, C, H, W)
 
         H_sp, W_sp = self.H_sp, self.W_sp
         x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
@@ -112,15 +116,16 @@
         """
         x: B L C
         """
-        q,k,v = qkv[0], qkv[1], qkv[2]
+        qk, v = torch.split(qkv, [2,1], dim=0) # qk: 2 B C N;  v: 1 B C N
 
         ### Img2Window
         H = W = self.resolution
-        B, L, C = q.shape
+        _, B, C, L = qk.shape
         assert L == H * W, "flatten img_tokens has wrong size"
         
-        q = self.im2cswin(q)
-        k = self.im2cswin(k)
+        qk = self.im2cswin(qk)
+        tmp = qk.shape[0]//2
+        q, k = torch.split(qk, int(tmp), dim=0)
         v, lepe = self.get_lepe(v, self.get_v)
 
         q = q * self.scale
@@ -193,11 +198,13 @@
         B, L, C = x.shape
         assert L == H * W, "flatten img_tokens has wrong size"
         img = self.norm1(x)
-        qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3)
+        qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 3, 1)
         
         if self.branch_num == 2:
-            x1 = self.attns[0](qkv[:,:,:,:C//2])
-            x2 = self.attns[1](qkv[:,:,:,C//2:])
+            x1,x2 = torch.split(qkv, C//2, dim=2)
+            x1 = self.attns[0](x1)
+            x2 = self.attns[1](x2)
+
             attened_x = torch.cat([x1,x2], dim=2)
         else:
             attened_x = self.attns[0](qkv)