import math
import unittest
import numpy as np
import torch

import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


class TestGroupNormSilu(TestCase):

    def supported_op_exec(self, x, gama, beta, group, eps):
        res = torch.ops.aten.native_group_norm(x, gama, beta, x.shape[0], x.shape[1], x.shape[2] * x.shape[3], group, eps)
        res = list(res)
        res[0] = torch.nn.functional.silu(res[0])
        return res

    def custom_op_exec(self, x, gama, beta, group, eps):
        return torch_npu.npu_group_norm_silu(x, gama, beta, group, eps)

    @SupportedDevices(['Ascend910B'])
    def test_npu_(self, device="npu"):
        x = torch.randn(24, 320, 48, 48, dtype=torch.float32).npu()
        gama = torch.randn(320, dtype=torch.float32).npu()
        beta = torch.randn(320, dtype=torch.float32).npu()

        group = 32
        eps = 0.000100

        supported_output = self.supported_op_exec(x, gama, beta, group, eps)
        custom_output = self.custom_op_exec(x, gama, beta, group, eps)
        self.assertRtolEqual(supported_output, custom_output)

    @SupportedDevices(['Ascend950'])
    def test_npu_950_(self, device="npu"):
        x = torch.randn(24, 320, 48, 48, dtype=torch.float16).npu()
        gama = torch.randn(320, dtype=torch.float32).npu()
        beta = torch.randn(320, dtype=torch.float32).npu()

        group = 32
        eps = 0.0001

        supported_output = self.supported_op_exec(x, gama, beta, group, eps)
        custom_output = self.custom_op_exec(x, gama, beta, group, eps)
        self.assertRtolEqual(supported_output, custom_output)


if __name__ == "__main__":
    run_tests()