@@ -19,7 +19,7 @@ class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
- super(ContBatchNorm3d, self)._check_input_dim(input)
+ # super(ContBatchNorm3d, self)._check_input_dim(input)
def forward(self, input):
self._check_input_dim(input)
@@ -59,7 +59,7 @@ class InputTransition(nn.Module):
out = self.bn1(self.conv1(x))
# split input in to 16 channels
x16 = torch.cat((x, x, x, x, x, x, x, x,
- x, x, x, x, x, x, x, x), 0)
+ x, x, x, x, x, x, x, x), 1)
out = self.relu1(torch.add(out, x16))
return out
@@ -127,10 +127,7 @@ class OutputTransition(nn.Module):
# make channels the last axis
out = out.permute(0, 2, 3, 4, 1).contiguous()
- # flatten
- out = out.view(out.numel() // 2, 2)
- out = self.softmax(out)
- # treat channel 0 as the predicted output
+ out = self.softmax(out, dim=-1)
return out