@@ -41,7 +41,7 @@ class SequenceClassificationModel(NNModule):
self.apply(self.init_weights)
self.deberta.apply_state()
- def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, position_ids=None, **kwargs):
+ def forward(self, input_ids, input_mask, type_ids=None, position_ids=None, labels=None,**kwargs):
outputs = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,
position_ids=position_ids, output_all_encoded_layers=True)
encoder_layers = outputs['hidden_states']
@@ -69,7 +69,8 @@ class SequenceClassificationModel(NNModule):
log_softmax = torch.nn.LogSoftmax(-1)
label_confidence = 1
loss = -((log_softmax(logits)*labels).sum(-1)*label_confidence).mean()
-
+
+ return logits
return {
'logits' : logits,
'loss' : loss