import unittest
import numpy as np
import torch_npu
import torch
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import get_npu_device, SupportedDevices


class TestNPUGemmaRmsNorm(TestCase):

    def supported_op_exec(self, x, gamma):
        x_fp32 = np.array(x, dtype=np.float32)
        gamma_fp32 = np.array(gamma, dtype=np.float32)

        variance = np.mean(np.power(x_fp32, 2), axis=-1, keepdims=True)
        epsilon = 1e-6
        std = np.sqrt(variance + epsilon)
        rstd = 1 / std
        result_mid = x_fp32 * rstd
        gamma_mid = gamma_fp32 + 1
        result_fp32 = result_mid * gamma_mid

        result = np.array(result_fp32, dtype=x.dtype)

        return result, rstd

    def custom_op_exec(self, x, gamma):
        y, rstd = torch_npu.npu_gemma_rms_norm(x, gamma)
        return y.cpu().numpy(), rstd.cpu().numpy()

    @unittest.skip("skip test_gemma_rms_norm now")
    @SupportedDevices(['Ascend910B'])
    def test_gemma_rms_norm(self, device="npu"):
        if device is None:
            device = get_npu_device()
        cpu_input0 = np.random.uniform(0, 100, [256, 512]).astype(np.float32)
        cpu_input1 = np.random.uniform(0, 100, [512]).astype(np.float32)
        npu_input0 = torch.from_numpy(cpu_input0).to(device)
        npu_input1 = torch.from_numpy(cpu_input1).to(device)

        supported_output0, supported_output1 = self.supported_op_exec(cpu_input0, cpu_input1)
        custom_output0, custom_output1 = self.custom_op_exec(npu_input0, npu_input1)
        self.assertRtolEqual(supported_output0, custom_output0)
        self.assertRtolEqual(supported_output1, custom_output1)

    @unittest.skip("skip test_gemma_rms_norm_fp16 now")
    @SupportedDevices(['Ascend910B'])
    def test_gemma_rms_norm_fp16(self, device="npu"):
        if device is None:
            device = get_npu_device()
        cpu_input0 = np.random.uniform(0, 100, [256, 512]).astype(np.float16)
        cpu_input1 = np.random.uniform(0, 100, [512]).astype(np.float16)
        npu_input0 = torch.from_numpy(cpu_input0).to(device)
        npu_input1 = torch.from_numpy(cpu_input1).to(device)

        supported_output0, supported_output1 = self.supported_op_exec(cpu_input0, cpu_input1)
        custom_output0, custom_output1 = self.custom_op_exec(npu_input0, npu_input1)
        self.assertRtolEqual(supported_output0, custom_output0)
        self.assertRtolEqual(supported_output1, custom_output1)


    @SupportedDevices(['Ascend910B'])
    def test_gemma_rms_norm_meta(self):
        cpu_input0 = np.random.uniform(0, 100, [256, 512]).astype(np.float32)
        cpu_input1 = np.random.uniform(0, 100, [512]).astype(np.float32)
        npu_input0 = torch.from_numpy(cpu_input0).to("npu")
        npu_input1 = torch.from_numpy(cpu_input1).to("npu")
        meta_input0 = torch.from_numpy(cpu_input0).to("meta")
        meta_input1 = torch.from_numpy(cpu_input1).to("meta")

        npu_output0, npu_output1 = torch_npu.npu_gemma_rms_norm(npu_input0, npu_input1)
        meta_output0, meta_output1 = torch_npu.npu_gemma_rms_norm(meta_input0, meta_input1)
        self.assertEqual(npu_output0.shape, meta_output0.shape)
        self.assertEqual(npu_output1.shape, meta_output1.shape)
        self.assertEqual(npu_output0.dtype, meta_output0.dtype)
        self.assertEqual(npu_output1.dtype, meta_output1.dtype)


    @SupportedDevices(['Ascend910B'])
    def test_gemma_rms_norm_fp16_meta(self):
        cpu_input0 = np.random.uniform(0, 100, [256, 512]).astype(np.float16)
        cpu_input1 = np.random.uniform(0, 100, [512]).astype(np.float16)
        npu_input0 = torch.from_numpy(cpu_input0).to("npu")
        npu_input1 = torch.from_numpy(cpu_input1).to("npu")
        meta_input0 = torch.from_numpy(cpu_input0).to("meta")
        meta_input1 = torch.from_numpy(cpu_input1).to("meta")

        npu_output0, npu_output1 = torch_npu.npu_gemma_rms_norm(npu_input0, npu_input1)
        meta_output0, meta_output1 = torch_npu.npu_gemma_rms_norm(meta_input0, meta_input1)
        self.assertEqual(npu_output0.shape, meta_output0.shape)
        self.assertEqual(npu_output1.shape, meta_output1.shape)
        self.assertEqual(npu_output0.dtype, meta_output0.dtype)
        self.assertEqual(npu_output1.dtype, meta_output1.dtype)


if __name__ == "__main__":
    run_tests()