diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py

index 9e54de7..92d5c64 100644

--- a/mmseg/models/backbones/vit.py

+++ b/mmseg/models/backbones/vit.py

@@ -353,7 +353,7 @@ class VisionTransformer(nn.Module):

                 nn.init.constant_(m.bias, 0)

                 nn.init.constant_(m.weight, 1.0)

 

-        if not self.random_init:

+        if not self.random_init and pretrained != True:

             self.default_cfg = default_cfgs[self.model_name]

 

             if self.model_name in ['vit_small_patch16_224', 'vit_base_patch16_224']:

diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py

index 1baa800..4ff2c00 100644

--- a/mmseg/models/segmentors/encoder_decoder.py

+++ b/mmseg/models/segmentors/encoder_decoder.py

@@ -248,7 +248,7 @@ class EncoderDecoder(BaseSegmentor):

             seg_logit = self.slide_inference(img, img_meta, rescale)

         else:

             seg_logit = self.whole_inference(img, img_meta, rescale)

-        output = F.softmax(seg_logit, dim=1)

+        output = seg_logit

         flip = img_meta[0]['flip']

         if flip:

             flip_direction = img_meta[0]['flip_direction']

@@ -263,7 +263,7 @@ class EncoderDecoder(BaseSegmentor):

     def simple_test(self, img, img_meta, rescale=True):

         """Simple test with single image."""

         seg_logit = self.inference(img, img_meta, rescale)

-        seg_pred = seg_logit.argmax(dim=1)

+        seg_pred = seg_logit

         if torch.onnx.is_in_onnx_export():

             # our inference backend only support 4D output

             seg_pred = seg_pred.unsqueeze(0)

diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py

index f22d4d3..5e2c0d0 100644

--- a/tools/pytorch2onnx.py

+++ b/tools/pytorch2onnx.py

@@ -178,7 +178,7 @@ if __name__ == '__main__':

         raise ValueError('invalid input shape')

 

     cfg = mmcv.Config.fromfile(args.config)

-    cfg.model.pretrained = None

+    cfg.model.pretrained = True

 

     # build the model and load checkpoint

     segmentor = build_segmentor(