3d5b6ea5创建于 2023年1月12日历史提交
diff --git a/src/models/data_loader.py b/src/models/data_loader.py
index 0511cb1..9e9a55b 100644
--- a/src/models/data_loader.py
+++ b/src/models/data_loader.py
@@ -28,10 +28,12 @@ class Batch(object):
 
             labels = torch.tensor(self._pad(pre_labels, 0))
             segs = torch.tensor(self._pad(pre_segs, 0))
-            mask = 1 - (src == 0)
+            # mask = 1 - (src == 0)
+            mask = (src == 0)
 
             clss = torch.tensor(self._pad(pre_clss, -1))
-            mask_cls = 1 - (clss == -1)
+            # mask_cls = 1 - (clss == -1)
+            mask_cls = (clss == -1)
             clss[clss == -1] = 0
 
             setattr(self, 'clss', clss.to(device))