import torch
import numpy as np
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests

LOWER = 0
UPPER = 1


class TestBinaryCrossEntropy(TestCase):

    def generate_input(self, lower, upper, shape, dtype):
        np.random.seed(1234)
        temp = np.random.uniform(lower, upper, shape).astype(dtype)
        npu_input = torch.from_numpy(temp)
        return npu_input

    def cpu_op_exec(self, predict, target, weight=None, reduction="mean"):
        res = torch.nn.functional.binary_cross_entropy(predict, target, weight=weight, reduction=reduction)
        return res.numpy()

    def cpu_op_exec_half(self, predict, target, weight=None, reduction="mean"):
        res = torch.nn.functional.binary_cross_entropy(predict, target, weight=weight, reduction=reduction)
        return res.type(torch.float16).numpy()

    def npu_op_exec(self, predict, target, weight=None, reduction="mean"):
        predict = predict.to("npu")
        target = target.to("npu")
        if weight is not None:
            weight = weight.to("npu")
        res = torch.nn.functional.binary_cross_entropy(predict, target, weight=weight, reduction=reduction)
        res = res.to("cpu")
        return res.numpy()

    def test_binary_cross_entropy_float32(self):
        for shape, weight_shape, reduction in [
            ((10, 64), None, "mean"),
            ((10, 64), (10, 1), "mean"),
            ((10, 64), None, "mean"),
            ((10, 64), (10, 64), "mean"),
            ((10, 64), (10, 64), "sum"),
            ((10, 64), (10, 64), "none")
        ]:
            predict = self.generate_input(LOWER, UPPER, shape, np.float32)
            target = torch.empty(shape, dtype=torch.float32).random_(2)
            weight = None
            if weight_shape is not None:
                weight = self.generate_input(LOWER, UPPER, weight_shape, np.float32)
            cpu_output = self.cpu_op_exec(predict, target, weight=weight, reduction=reduction)
            npu_output = self.npu_op_exec(predict, target, weight=weight, reduction=reduction)
            self.assertRtolEqual(cpu_output, npu_output)

    def test_binary_cross_entropy_float16(self):
        for shape, weight_shape, reduction in [
            ((10, 64), (10, 64), "sum"),
            ((10, 64), (10, 64), "mean"),
            ((10, 64), (10, 64), "none")
        ]:
            predict = self.generate_input(LOWER, UPPER, shape, np.float16)
            target = torch.empty(shape, dtype=torch.float16).random_(2)
            predict_32 = predict.type(torch.float32)
            target_32 = target.type(torch.float32)
            weight = None
            weight_32 = None
            if weight_shape is not None:
                weight = self.generate_input(LOWER, UPPER, weight_shape, np.float16)
                weight_32 = weight.type(torch.float32)

            npu_output = self.npu_op_exec(predict, target, weight=weight, reduction=reduction)
            cpu_output = self.cpu_op_exec_half(predict_32, target_32, weight=weight_32, reduction=reduction)
            self.assertRtolEqual(cpu_output, npu_output)


if __name__ == "__main__":
    run_tests()