@@ -49,7 +49,19 @@ class PPM(nn.ModuleList):
"""Forward function."""
ppm_outs = []
for ppm in self:
- ppm_out = ppm(x)
+ if ppm[0].output_size == 2:
+ y = torch.cat([x, x[:, :, x.size(2)//2: x.size(2)//2 + 1, :]], dim=2)
+ y = torch.cat([y, y[:, :, :, y.size(3)//2: y.size(3)//2 + 1]], dim=3)
+ ppm_out = nn.AvgPool2d(kernel_size=(8, 8), stride=(8, 8))(y)
+ ppm_out = nn.AvgPool2d(kernel_size=(4, 4), stride=(4, 4))(ppm_out)
+ ppm_out = ppm[1:](ppm_out)
+ elif ppm[0].output_size == 3:
+ ppm_out = nn.AvgPool2d(kernel_size=(3, 3), stride=(3, 3))(x)
+ ppm_out = nn.AvgPool2d(kernel_size=(7, 7), stride=(7, 7))(ppm_out)
+ ppm_out = ppm[1:](ppm_out)
+ else:
+ ppm_out = ppm(x)
+
upsampled_ppm_out = resize(
ppm_out,
size=x.size()[2:],
@@ -271,7 +271,6 @@ class EncoderDecoder(BaseSegmentor):
seg_pred = seg_logit.argmax(dim=1)
if torch.onnx.is_in_onnx_export():
# our inference backend only support 4D output
- seg_pred = seg_pred.unsqueeze(0)
return seg_pred
seg_pred = seg_pred.cpu().numpy()
# unravel batch dim