diff --git a/mmseg/models/backbones/swin_transformer.py b/mmseg/models/backbones/swin_transformer.py
index 94aef1a..77f3086 100644
--- a/mmseg/models/backbones/swin_transformer.py
+++ b/mmseg/models/backbones/swin_transformer.py
@@ -220,7 +220,8 @@ class SwinTransformerBlock(nn.Module):
 
         # cyclic shift
         if self.shift_size > 0:
-            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+            shifted_x = torch.cat((x[:, self.shift_size:,:,:], x[:, :self.shift_size,:,:]), dim=1)
+            shifted_x = torch.cat((shifted_x[:, :,self.shift_size:,:], shifted_x[:, :,:self.shift_size,:]), dim=2)
             attn_mask = mask_matrix
         else:
             shifted_x = x
@@ -239,7 +240,8 @@ class SwinTransformerBlock(nn.Module):
 
         # reverse cyclic shift
         if self.shift_size > 0:
-            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+            x = torch.cat((shifted_x[:, -self.shift_size:,:,:], shifted_x[:, :-self.shift_size,:,:]), dim=1)
+            x = torch.cat((x[:, :,-self.shift_size:,:], x[:, :, :-self.shift_size,:]), dim=2)
         else:
             x = shifted_x