import unittest

import numpy as np
import torch
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import get_npu_device, SupportedDevices


class TestNPUAddRmsNorm(TestCase):

    def supported_op_exec(self, x1, x2, gamma, epsilon=1e-6):
        ori_dtype = x1.dtype
        x = x1 + x2
        if ori_dtype == np.float16:
            x = x.astype(np.float32)
        variance = np.mean(np.power(x, 2), axis=-1, keepdims=True)
        std = np.sqrt(variance + epsilon)
        rstd = np.divide(1, std)
        result_mid = x * rstd
        result = result_mid * gamma
        if ori_dtype == np.float16:
            x = x.astype(np.float16)
            result = result.astype(np.float16)
        return result, rstd, x

    def custom_op_exec(self, x1, x2, gamma, epsilon=1e-6):
        y, rstd, x = torch_npu.npu_add_rms_norm(x1, x2, gamma, epsilon)
        return y.cpu().numpy(), rstd.cpu().numpy(), x.cpu().numpy()

    @SupportedDevices(['Ascend910B'])
    def test_add_rms_norm_fp32(self, device="npu"):
        if device is None:
            device = get_npu_device()
        cpu_input0 = np.random.uniform(0, 100, [1024, 12288]).astype(np.float32)
        cpu_input1 = np.random.uniform(0, 100, [1024, 12288]).astype(np.float32)
        cpu_input2 = np.random.uniform(0, 100, [12288]).astype(np.float32)
        npu_input0 = torch.from_numpy(cpu_input0).npu()
        npu_input1 = torch.from_numpy(cpu_input1).npu()
        npu_input2 = torch.from_numpy(cpu_input2).npu()

        supported_output0, supported_output1, supported_output2 = self.supported_op_exec(cpu_input0, cpu_input1, cpu_input2)
        custom_output0, custom_output1, custom_output2 = self.custom_op_exec(npu_input0, npu_input1, npu_input2)
        self.assertRtolEqual(supported_output0, custom_output0, 0.0001)
        self.assertRtolEqual(supported_output1, custom_output1, 0.0001)
        self.assertRtolEqual(supported_output2, custom_output2, 0.0001)

    @SupportedDevices(['Ascend910B'])
    def test_add_rms_norm_fp16(self):
        cpu_input0 = np.random.uniform(0, 100, [1024, 12288]).astype(np.float16)
        cpu_input1 = np.random.uniform(0, 100, [1024, 12288]).astype(np.float16)
        cpu_input2 = np.random.uniform(0, 100, [12288]).astype(np.float16)
        npu_input0 = torch.from_numpy(cpu_input0).npu()
        npu_input1 = torch.from_numpy(cpu_input1).npu()
        npu_input2 = torch.from_numpy(cpu_input2).npu()

        supported_output0, supported_output1, supported_output2 = self.supported_op_exec(cpu_input0, cpu_input1, cpu_input2)
        custom_output0, custom_output1, custom_output2 = self.custom_op_exec(npu_input0, npu_input1, npu_input2)
        self.assertRtolEqual(supported_output0, custom_output0, 0.001)
        self.assertRtolEqual(supported_output1, custom_output1, 0.001)
        self.assertRtolEqual(supported_output2, custom_output2, 0.001)


if __name__ == "__main__":
    run_tests()