@@ -107,14 +107,15 @@ class ShiftViTBlock(nn.Module):
def shift_feat(x, n_div):
B, C, H, W = x.shape
g = C // n_div
- out = torch.zeros_like(x)
+
+ tensor_list = [None, None, None, None, None]
+ tensor_list[0] = torch.cat((x[:, g*0:g*1, :, 1:], torch.zeros(B, g, H, 1)), dim=3)
+ tensor_list[1] = torch.cat((torch.zeros(B, g, H, 1), x[:, g*1:g*2, :, :-1]), dim=3)
+ tensor_list[2] = torch.cat((x[:, g*2:g*3, 1:, :], torch.zeros(B, g, 1, W)), dim=2)
+ tensor_list[3] = torch.cat((torch.zeros(B, g, 1, W), x[:, g*3:g*4, :-1, :]), dim=2)
+ tensor_list[4] = x[:, g*4:, :, :]
+ out = torch.cat(tensor_list, dim=1)
- out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:] # shift left
- out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1] # shift right
- out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :] # shift up
- out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :] # shift down
-
- out[:, g * 4:, :, :] = x[:, g * 4:, :, :] # no shift
return out