05360171创建于 2022年3月18日历史提交
diff --git a/experiments/cityscapes/seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml b/experiments/cityscapes/seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml
index a51ed0b..72f6008 100644
--- a/experiments/cityscapes/seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml
+++ b/experiments/cityscapes/seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml
@@ -10,7 +10,7 @@ PRINT_FREQ: 10
 
 DATASET:
   DATASET: cityscapes
-  ROOT: data/
+  ROOT: HRNet-Semantic-Segmentation/data/
   TEST_SET: 'list/cityscapes/val.lst'
   TRAIN_SET: 'list/cityscapes/train.lst'
   NUM_CLASSES: 19
diff --git a/lib/datasets/cityscapes.py b/lib/datasets/cityscapes.py
index dec23d2..7b3854c 100644
--- a/lib/datasets/cityscapes.py
+++ b/lib/datasets/cityscapes.py
@@ -41,7 +41,7 @@ class Cityscapes(BaseDataset):
         self.multi_scale = multi_scale
         self.flip = flip
         
-        self.img_list = [line.strip().split() for line in open(root+list_path)]
+        self.img_list = [line.strip().split() for line in open('HRNet-Semantic-Segmentation/data/'+list_path)]
 
         self.files = self.read_files()
         if num_samples:
@@ -63,7 +63,7 @@ class Cityscapes(BaseDataset):
                                         1.0166, 0.9969, 0.9754, 1.0489,
                                         0.8786, 1.0023, 0.9539, 0.9843, 
                                         1.1116, 0.9037, 1.0865, 1.0955, 
-                                        1.0865, 1.1529, 1.0507]).cuda()
+                                        1.0865, 1.1529, 1.0507])
     
     def read_files(self):
         files = []
@@ -100,10 +100,10 @@ class Cityscapes(BaseDataset):
     def __getitem__(self, index):
         item = self.files[index]
         name = item["name"]
-        # image = cv2.imread(os.path.join(self.root,'cityscapes',item["img"]),
-        #                    cv2.IMREAD_COLOR)
-        image = cv2.imread(os.path.join(self.root, item["img"]),
+        image = cv2.imread(os.path.join(self.root,'cityscapes',item["img"]),
                            cv2.IMREAD_COLOR)
+        # image = cv2.imread(os.path.join(self.root, item["img"]),
+        #                    cv2.IMREAD_COLOR)
         size = image.shape
 
         if 'test' in self.list_path:
@@ -112,10 +112,10 @@ class Cityscapes(BaseDataset):
 
             return image.copy(), np.array(size), name
 
-        # label = cv2.imread(os.path.join(self.root,'cityscapes',item["label"]),
-        #                    cv2.IMREAD_GRAYSCALE)
-        label = cv2.imread(os.path.join(self.root, item["label"]),
+        label = cv2.imread(os.path.join(self.root,'cityscapes',item["label"]),
                            cv2.IMREAD_GRAYSCALE)
+        # label = cv2.imread(os.path.join(self.root, item["label"]),
+        #                    cv2.IMREAD_GRAYSCALE)
         label = self.convert_label(label)
 
         image, label = self.gen_sample(image, label, 
diff --git a/lib/models/bn_helper.py b/lib/models/bn_helper.py
index 4c8538e..9b9e445 100644
--- a/lib/models/bn_helper.py
+++ b/lib/models/bn_helper.py
@@ -7,5 +7,5 @@ if torch.__version__.startswith('0'):
     BatchNorm2d_class = InPlaceABNSync
     relu_inplace = False
 else:
-    BatchNorm2d_class = BatchNorm2d = torch.nn.SyncBatchNorm
+    BatchNorm2d_class = BatchNorm2d = torch.nn.BatchNorm2d
     relu_inplace = True
\ No newline at end of file
diff --git a/lib/models/seg_hrnet_ocr.py b/lib/models/seg_hrnet_ocr.py
index da19d23..f6f8e60 100644
--- a/lib/models/seg_hrnet_ocr.py
+++ b/lib/models/seg_hrnet_ocr.py
@@ -663,8 +663,8 @@ class HighResolutionNet(nn.Module):
             logger.info('=> loading pretrained model {}'.format(pretrained))
             model_dict = self.state_dict()
             pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in pretrained_dict.items()}  
-            print(set(model_dict) - set(pretrained_dict))            
-            print(set(pretrained_dict) - set(model_dict))            
+            # print(set(model_dict) - set(pretrained_dict))            
+            # print(set(pretrained_dict) - set(model_dict))            
             pretrained_dict = {k: v for k, v in pretrained_dict.items()
                                if k in model_dict.keys()}
             # for k, _ in pretrained_dict.items():
@@ -678,6 +678,6 @@ class HighResolutionNet(nn.Module):
 
 def get_seg_model(cfg, **kwargs):
     model = HighResolutionNet(cfg, **kwargs)
-    model.init_weights(cfg.MODEL.PRETRAINED)
+    # model.init_weights(cfg.MODEL.PRETRAINED)
 
     return model