import torch
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests

import mx_driving


def sigmoid_focal_loss(logit, target, gamma=2.0, alpha=0.25, weight=None, reduction='mean'):
    logit_size = logit.shape
    output = torch.zeros_like(logit)
    for i in range(logit_size[0]):
        target_i = target[i].item()
        for j in range(logit_size[1]):
            sigmoid_x = torch.sigmoid(logit[i, j])
            entropy_p = torch.pow(1 - sigmoid_x, gamma) * torch.log(sigmoid_x)
            entropy_n = torch.pow(sigmoid_x, gamma) * torch.log(1 - sigmoid_x)
            if j == target_i:
                output[i, j] += -alpha * entropy_p
            else:
                output[i, j] += (alpha - 1) * entropy_n
            if weight is not None:
                output[i, j] *= weight[target_i]

    if reduction == 'mean':
        output = output.sum() / logit.shape[0]
    elif reduction == 'sum':
        output = output.sum()
    return output


def sigmoid_focal_loss_grad(logit, target, grad_output, gamma=2.0, alpha=0.25, weight=None, reduction='mean'):
    logit_size = logit.shape
    grad_input = torch.zeros_like(logit)
    for i in range(logit_size[0]):
        target_i = target[i].item()
        for j in range(logit_size[1]):
            sigmoid_x = torch.sigmoid(logit[i, j])
            entropy_p = (
                alpha * torch.pow(1 - sigmoid_x, gamma) * (gamma * sigmoid_x * torch.log(sigmoid_x) - (1.0 - sigmoid_x))
            )
            entropy_n = (
                (1 - alpha)
                * torch.pow(sigmoid_x, gamma)
                * (sigmoid_x - gamma * (1 - sigmoid_x) * torch.log(1 - sigmoid_x))
            )
            if j == target_i:
                grad_input[i, j] += entropy_p
            else:
                grad_input[i, j] += entropy_n
            if weight is not None:
                grad_input[i, j] *= weight[target_i]
    grad_input *= grad_output
    if reduction == 'mean':
        grad_input /= logit_size[0]
    return grad_input


@golden_data_cache(__file__)
def gen_data(N, NC):
    logit = torch.rand(N, NC, dtype=torch.float32) * 10 - 5
    logit.requires_grad = True
    target = torch.randint(low=0, high=NC, size=(N,), dtype=torch.int64)
    weight = torch.rand(NC, dtype=torch.float32) * 10 - 5
    return logit, target, weight


class TestSigmoidFocalLoss(TestCase):
    def test_sigmoid_focal_loss(self):
        N_list = [1, 10, 79]
        NC_list = [1, 10, 79]
        for N in N_list:
            for NC in NC_list:
                logit, target, weight = gen_data(N, NC)
                logit_npu, target_npu, weight_npu = logit.npu(), target.npu(), weight.npu()
                output_golden = sigmoid_focal_loss(logit_npu, target_npu, 2.0, 0.25, weight_npu, 'mean')
                grad_golden = sigmoid_focal_loss_grad(
                    logit_npu, target_npu, torch.ones_like(output_golden), 2.0, 0.25, weight_npu, 'mean'
                )
                torch.npu.synchronize()

                output_mxdriving = mx_driving.sigmoid_focal_loss(logit_npu, target_npu, 2.0, 0.25, weight_npu, 'mean')
                output_mxdriving.backward()
                grad_mxdriving = logit.grad
                torch.npu.synchronize()
                self.assertRtolEqual(output_golden.cpu(), output_mxdriving.cpu())
                self.assertRtolEqual(grad_golden.cpu(), grad_mxdriving.cpu())


if __name__ == "__main__":
    run_tests()