@@ -353,7 +353,7 @@ class VisionTransformer(nn.Module):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
- if not self.random_init:
+ if not self.random_init and pretrained != True:
self.default_cfg = default_cfgs[self.model_name]
if self.model_name in ['vit_small_patch16_224', 'vit_base_patch16_224']:
@@ -248,7 +248,7 @@ class EncoderDecoder(BaseSegmentor):
seg_logit = self.slide_inference(img, img_meta, rescale)
else:
seg_logit = self.whole_inference(img, img_meta, rescale)
- output = F.softmax(seg_logit, dim=1)
+ output = seg_logit
flip = img_meta[0]['flip']
if flip:
flip_direction = img_meta[0]['flip_direction']
@@ -263,7 +263,7 @@ class EncoderDecoder(BaseSegmentor):
def simple_test(self, img, img_meta, rescale=True):
"""Simple test with single image."""
seg_logit = self.inference(img, img_meta, rescale)
- seg_pred = seg_logit.argmax(dim=1)
+ seg_pred = seg_logit
if torch.onnx.is_in_onnx_export():
# our inference backend only support 4D output
seg_pred = seg_pred.unsqueeze(0)
@@ -178,7 +178,7 @@ if __name__ == '__main__':
raise ValueError('invalid input shape')
cfg = mmcv.Config.fromfile(args.config)
- cfg.model.pretrained = None
+ cfg.model.pretrained = True
# build the model and load checkpoint
segmentor = build_segmentor(