import unittest
import numpy as np
import torch_npu
import torch
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import get_npu_device, SupportedDevices


class TestNPURmsNorm(TestCase):

    def supported_op_exec(self, x, gamma):
        x_fp32 = np.array(x, dtype=np.float32)
        gamma_fp32 = np.array(gamma, dtype=np.float32)

        variance = np.mean(np.power(x_fp32, 2), axis=-1, keepdims=True)
        epsilon = 1e-6
        std = np.sqrt(variance + epsilon)
        rstd = 1 / std
        result_mid = x_fp32 * rstd
        result_fp32 = result_mid * gamma_fp32

        result = np.array(result_fp32, dtype=x.dtype)

        return result, rstd

    def custom_op_exec(self, x, gamma):
        y, rstd = torch_npu.npu_rms_norm(x, gamma)
        return y.cpu().numpy(), rstd.cpu().numpy()

    @SupportedDevices(['Ascend910B'])
    def test_rms_norm(self, device="npu"):
        if device is None:
            device = get_npu_device()
        cpu_input0 = np.random.uniform(0, 100, [1024, 12288]).astype(np.float32)
        cpu_input1 = np.random.uniform(0, 100, [12288]).astype(np.float32)
        npu_input0 = torch.from_numpy(cpu_input0).to(device)
        npu_input1 = torch.from_numpy(cpu_input1).to(device)

        supported_output0, supported_output1 = self.supported_op_exec(cpu_input0, cpu_input1)
        custom_output0, custom_output1 = self.custom_op_exec(npu_input0, npu_input1)
        self.assertRtolEqual(supported_output0, custom_output0)
        self.assertRtolEqual(supported_output1, custom_output1)

    @SupportedDevices(['Ascend910B'])
    def test_rms_norm_mix_dtype(self, device="npu"):
        if device is None:
            device = get_npu_device()
        cpu_input0 = np.random.uniform(0, 100, [1024, 12288]).astype(np.float16)
        cpu_input1 = np.random.uniform(0, 100, [12288]).astype(np.float32)
        npu_input0 = torch.from_numpy(cpu_input0).to(device)
        npu_input1 = torch.from_numpy(cpu_input1).to(device)

        supported_output0, supported_output1 = self.supported_op_exec(cpu_input0, cpu_input1)
        custom_output0, custom_output1 = self.custom_op_exec(npu_input0, npu_input1)
        self.assertRtolEqual(supported_output0, custom_output0)
        self.assertRtolEqual(supported_output1, custom_output1)

if __name__ == "__main__":
    run_tests()