diff --git a/vnet.py b/vnet.py
index 993403c..de1787e 100644
--- a/vnet.py
+++ b/vnet.py
@@ -19,7 +19,7 @@ class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
         if input.dim() != 5:
             raise ValueError('expected 5D input (got {}D input)'
                              .format(input.dim()))
-        super(ContBatchNorm3d, self)._check_input_dim(input)
+        # super(ContBatchNorm3d, self)._check_input_dim(input)
 
     def forward(self, input):
         self._check_input_dim(input)
@@ -59,7 +59,7 @@ class InputTransition(nn.Module):
         out = self.bn1(self.conv1(x))
         # split input in to 16 channels
         x16 = torch.cat((x, x, x, x, x, x, x, x,
-                         x, x, x, x, x, x, x, x), 0)
+                         x, x, x, x, x, x, x, x), 1)
         out = self.relu1(torch.add(out, x16))
         return out
 
@@ -127,10 +127,7 @@ class OutputTransition(nn.Module):
 
         # make channels the last axis
         out = out.permute(0, 2, 3, 4, 1).contiguous()
-        # flatten
-        out = out.view(out.numel() // 2, 2)
-        out = self.softmax(out)
-        # treat channel 0 as the predicted output
+        out = self.softmax(out, dim=-1)
         return out