diff --git a/mmseg/models/decode_heads/psp_head.py b/mmseg/models/decode_heads/psp_head.py
index 6990676..9103864 100644
--- a/mmseg/models/decode_heads/psp_head.py
+++ b/mmseg/models/decode_heads/psp_head.py
@@ -49,7 +49,19 @@ class PPM(nn.ModuleList):
         """Forward function."""
         ppm_outs = []
         for ppm in self:
-            ppm_out = ppm(x)
+            if ppm[0].output_size == 2:
+                y = torch.cat([x, x[:, :, x.size(2)//2: x.size(2)//2 + 1, :]], dim=2)
+                y = torch.cat([y, y[:, :, :, y.size(3)//2: y.size(3)//2 + 1]], dim=3)
+                ppm_out = nn.AvgPool2d(kernel_size=(8, 8), stride=(8, 8))(y)
+                ppm_out = nn.AvgPool2d(kernel_size=(4, 4), stride=(4, 4))(ppm_out)
+                ppm_out = ppm[1:](ppm_out)
+            elif ppm[0].output_size == 3:
+                ppm_out = nn.AvgPool2d(kernel_size=(3, 3), stride=(3, 3))(x)
+                ppm_out = nn.AvgPool2d(kernel_size=(7, 7), stride=(7, 7))(ppm_out)
+                ppm_out = ppm[1:](ppm_out)
+            else:
+                ppm_out = ppm(x)
+
             upsampled_ppm_out = resize(
                 ppm_out,
                 size=x.size()[2:],
diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py
index e0ce8df..88f9d91 100644
--- a/mmseg/models/segmentors/encoder_decoder.py
+++ b/mmseg/models/segmentors/encoder_decoder.py
@@ -271,7 +271,6 @@ class EncoderDecoder(BaseSegmentor):
             seg_pred = seg_logit.argmax(dim=1)
         if torch.onnx.is_in_onnx_export():
             # our inference backend only support 4D output
-            seg_pred = seg_pred.unsqueeze(0)
             return seg_pred
         seg_pred = seg_pred.cpu().numpy()
         # unravel batch dim