import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestNPUCrossEntropyLoss(TestCase):
def cross_entropy_loss_golden(self, N, C, predictions, targets, weight, ignore_index, label_smoothing, reduction):
if weight is None:
weight = torch.ones((C,)).npu()
predictions_cast = predictions.to(torch.float32)
predictions_max = torch.max(predictions_cast, dim=1)[0].unsqueeze(-1)
log_softmax_probs = (predictions_cast - predictions_max -
torch.log(torch.sum(torch.exp(predictions_cast - predictions_max), -1, keepdim=True)))
nll_loss = torch.gather(log_softmax_probs, 1, targets.reshape(-1, 1)).reshape(-1)
weight_yn = torch.gather(weight, 0, targets)
loss_out = -nll_loss * weight_yn
if ignore_index >= 0:
ignore_mask = (targets - ignore_index).bool().float().npu()
else:
ignore_mask = torch.ones((N,)).npu()
loss_out = loss_out * ignore_mask
smooth_loss = -torch.sum(log_softmax_probs * weight.unsqueeze(0), -1, keepdim=False)
if ignore_index >= 0:
smooth_loss = smooth_loss * ignore_mask
if reduction == "mean":
weight_after_mask = weight_yn * ignore_mask
mean_out = torch.sum(loss_out, -1, keepdim=False) / torch.sum(weight_after_mask, -1, keepdim=False)
ret = ((1 - label_smoothing) * mean_out + torch.sum(smooth_loss, -1, keepdim=False)
/ torch.sum(weight_after_mask, -1, keepdim=False) * label_smoothing / C).reshape(1)
elif reduction == "sum":
sum_out = torch.sum(loss_out, -1, keepdim=False)
ret = (1 - label_smoothing) * sum_out + torch.sum(smooth_loss, -1, keepdim=False) * label_smoothing / C
ret = ret.reshape(1)
else:
none_out = loss_out
ret = (1 - label_smoothing) * none_out + smooth_loss * label_smoothing / C
ret = ret.to(predictions.dtype)
log_softmax_probs = log_softmax_probs.to(predictions.dtype)
return ret, log_softmax_probs
def cross_entropy_loss_custom(self, inputs, target, weight, reduction, ignore_index=-100, label_smoothing=0.0):
loss, log_probs, _, _ = torch_npu.npu_cross_entropy_loss(inputs, target, weight,
reduction=reduction, ignore_index=ignore_index,
label_smoothing=label_smoothing)
return loss, log_probs
@unittest.skip("Skipping test_npu_cross_entropy_loss for now")
@SupportedDevices(['Ascend910B'])
def test_npu_cross_entropy_loss(self):
shape_list = [[4096, 8080], [4096, 8192]]
dtype_list = [torch.float32, torch.float16, torch.bfloat16]
reduction_list = ["mean", "sum", "none"]
generate_list = []
for shape in shape_list:
for dtype in dtype_list:
for reduction in reduction_list:
generate_list.append([shape, dtype, reduction])
for item in generate_list:
N = item[0][0]
C = item[0][1]
input_npu = torch.randn(N, C).to(item[1]).npu()
target_npu = torch.arange(0, N).npu()
input_golden = input_npu.clone()
target_golden = target_npu.clone()
golden0, golden1 = self.cross_entropy_loss_golden(N, C, input_golden,
target_golden, weight=None, ignore_index=-100,
label_smoothing=0.0, reduction=item[2])
out0, out1 = self.cross_entropy_loss_custom(input_npu, target_npu, None, reduction=item[2])
if item[1] == torch.bfloat16:
self.assertRtolEqual(golden0, out0, prec16=2**(-6))
self.assertRtolEqual(golden1, out1, prec16=2**(-6))
elif item[1] == torch.float16:
self.assertRtolEqual(golden0, out0, 2**(-7))
self.assertRtolEqual(golden1, out1, 2**(-7))
else:
self.assertRtolEqual(golden0, out0, 2**(-10))
self.assertRtolEqual(golden1, out1, 2**(-10))
@unittest.skip("Skipping test_npu_cross_entropy_loss_backward for now")
@SupportedDevices(['Ascend910B'])
def test_npu_cross_entropy_loss_backward(self):
shape_list = [[4096, 8080], [4096, 8192]]
dtype_list = [torch.float32, torch.float16, torch.bfloat16]
reduction_list = ["mean", "none"]
generate_list = []
for shape in shape_list:
for dtype in dtype_list:
for reduction in reduction_list:
generate_list.append([shape, dtype, reduction])
for item in generate_list:
N = item[0][0]
C = item[0][1]
input_npu = torch.randn(N, C).to(item[1]).npu()
target_npu = torch.arange(0, N).npu()
input_golden = input_npu.clone()
target_golden = target_npu.clone()
input_golden.requires_grad_()
input_npu.requires_grad_()
golden0, _ = self.cross_entropy_loss_golden(N, C, input_golden,
target_golden, weight=None, ignore_index=-100,
label_smoothing=0.0, reduction=item[2])
out0, _ = self.cross_entropy_loss_custom(input_npu, target_npu, None, reduction=item[2])
golden0.sum().backward()
out0.sum().backward()
golden_grad = input_golden.grad.detach()
input_grad = input_npu.grad.detach()
if item[1] == torch.bfloat16:
self.assertRtolEqual(golden_grad, input_grad, prec16=2**(-6))
elif item[1] == torch.float16:
self.assertRtolEqual(golden_grad, input_grad, 2**(-7))
else:
self.assertRtolEqual(golden_grad, input_grad, 2**(-10))
if __name__ == "__main__":
run_tests()