diff --git a/src/models/data_loader.py b/src/models/data_loader.py
index 0511cb1..9e9a55b 100644
--- a/src/models/data_loader.py
+++ b/src/models/data_loader.py
@@ -28,10 +28,12 @@ class Batch(object):
labels = torch.tensor(self._pad(pre_labels, 0))
segs = torch.tensor(self._pad(pre_segs, 0))
- mask = 1 - (src == 0)
+ # mask = 1 - (src == 0)
+ mask = (src == 0)
clss = torch.tensor(self._pad(pre_clss, -1))
- mask_cls = 1 - (clss == -1)
+ # mask_cls = 1 - (clss == -1)
+ mask_cls = (clss == -1)
clss[clss == -1] = 0
setattr(self, 'clss', clss.to(device))