@@ -107,14 +107,15 @@ class ShiftViTBlock(nn.Module):
     def shift_feat(x, n_div):
         B, C, H, W = x.shape
         g = C // n_div
-        out = torch.zeros_like(x)
+        
+        tensor_list = [None, None, None, None, None]
+        tensor_list[0] = torch.cat((x[:, g*0:g*1, :, 1:], torch.zeros(B, g, H, 1)), dim=3)
+        tensor_list[1] = torch.cat((torch.zeros(B, g, H, 1), x[:, g*1:g*2, :, :-1]), dim=3)
+        tensor_list[2] = torch.cat((x[:, g*2:g*3, 1:, :], torch.zeros(B, g, 1, W)), dim=2)
+        tensor_list[3] = torch.cat((torch.zeros(B, g, 1, W), x[:, g*3:g*4, :-1, :]), dim=2)
+        tensor_list[4] = x[:, g*4:, :, :]
+        out = torch.cat(tensor_list, dim=1)
 
-        out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:]  # shift left
-        out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1]  # shift right
-        out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :]  # shift up
-        out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :]  # shift down
-
-        out[:, g * 4:, :, :] = x[:, g * 4:, :, :]  # no shift
         return out