@@ -220,7 +220,8 @@ class SwinTransformerBlock(nn.Module):
# cyclic shift
if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ shifted_x = torch.cat((x[:, self.shift_size:,:,:], x[:, :self.shift_size,:,:]), dim=1)
+ shifted_x = torch.cat((shifted_x[:, :,self.shift_size:,:], shifted_x[:, :,:self.shift_size,:]), dim=2)
attn_mask = mask_matrix
else:
shifted_x = x
@@ -239,7 +240,8 @@ class SwinTransformerBlock(nn.Module):
# reverse cyclic shift
if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ x = torch.cat((shifted_x[:, -self.shift_size:,:,:], shifted_x[:, :-self.shift_size,:,:]), dim=1)
+ x = torch.cat((x[:, :,-self.shift_size:,:], x[:, :, :-self.shift_size,:]), dim=2)
else:
x = shifted_x