import unittest
import numpy as np
import torch_npu
import torch
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import get_npu_device, SupportedDevices
class TestNPUGemmaRmsNorm(TestCase):
def supported_op_exec(self, x, gamma):
x_fp32 = np.array(x, dtype=np.float32)
gamma_fp32 = np.array(gamma, dtype=np.float32)
variance = np.mean(np.power(x_fp32, 2), axis=-1, keepdims=True)
epsilon = 1e-6
std = np.sqrt(variance + epsilon)
rstd = 1 / std
result_mid = x_fp32 * rstd
gamma_mid = gamma_fp32 + 1
result_fp32 = result_mid * gamma_mid
result = np.array(result_fp32, dtype=x.dtype)
return result, rstd
def custom_op_exec(self, x, gamma):
y, rstd = torch_npu.npu_gemma_rms_norm(x, gamma)
return y.cpu().numpy(), rstd.cpu().numpy()
@unittest.skip("skip test_gemma_rms_norm now")
@SupportedDevices(['Ascend910B'])
def test_gemma_rms_norm(self, device="npu"):
if device is None:
device = get_npu_device()
cpu_input0 = np.random.uniform(0, 100, [256, 512]).astype(np.float32)
cpu_input1 = np.random.uniform(0, 100, [512]).astype(np.float32)
npu_input0 = torch.from_numpy(cpu_input0).to(device)
npu_input1 = torch.from_numpy(cpu_input1).to(device)
supported_output0, supported_output1 = self.supported_op_exec(cpu_input0, cpu_input1)
custom_output0, custom_output1 = self.custom_op_exec(npu_input0, npu_input1)
self.assertRtolEqual(supported_output0, custom_output0)
self.assertRtolEqual(supported_output1, custom_output1)
@unittest.skip("skip test_gemma_rms_norm_fp16 now")
@SupportedDevices(['Ascend910B'])
def test_gemma_rms_norm_fp16(self, device="npu"):
if device is None:
device = get_npu_device()
cpu_input0 = np.random.uniform(0, 100, [256, 512]).astype(np.float16)
cpu_input1 = np.random.uniform(0, 100, [512]).astype(np.float16)
npu_input0 = torch.from_numpy(cpu_input0).to(device)
npu_input1 = torch.from_numpy(cpu_input1).to(device)
supported_output0, supported_output1 = self.supported_op_exec(cpu_input0, cpu_input1)
custom_output0, custom_output1 = self.custom_op_exec(npu_input0, npu_input1)
self.assertRtolEqual(supported_output0, custom_output0)
self.assertRtolEqual(supported_output1, custom_output1)
@SupportedDevices(['Ascend910B'])
def test_gemma_rms_norm_meta(self):
cpu_input0 = np.random.uniform(0, 100, [256, 512]).astype(np.float32)
cpu_input1 = np.random.uniform(0, 100, [512]).astype(np.float32)
npu_input0 = torch.from_numpy(cpu_input0).to("npu")
npu_input1 = torch.from_numpy(cpu_input1).to("npu")
meta_input0 = torch.from_numpy(cpu_input0).to("meta")
meta_input1 = torch.from_numpy(cpu_input1).to("meta")
npu_output0, npu_output1 = torch_npu.npu_gemma_rms_norm(npu_input0, npu_input1)
meta_output0, meta_output1 = torch_npu.npu_gemma_rms_norm(meta_input0, meta_input1)
self.assertEqual(npu_output0.shape, meta_output0.shape)
self.assertEqual(npu_output1.shape, meta_output1.shape)
self.assertEqual(npu_output0.dtype, meta_output0.dtype)
self.assertEqual(npu_output1.dtype, meta_output1.dtype)
@SupportedDevices(['Ascend910B'])
def test_gemma_rms_norm_fp16_meta(self):
cpu_input0 = np.random.uniform(0, 100, [256, 512]).astype(np.float16)
cpu_input1 = np.random.uniform(0, 100, [512]).astype(np.float16)
npu_input0 = torch.from_numpy(cpu_input0).to("npu")
npu_input1 = torch.from_numpy(cpu_input1).to("npu")
meta_input0 = torch.from_numpy(cpu_input0).to("meta")
meta_input1 = torch.from_numpy(cpu_input1).to("meta")
npu_output0, npu_output1 = torch_npu.npu_gemma_rms_norm(npu_input0, npu_input1)
meta_output0, meta_output1 = torch_npu.npu_gemma_rms_norm(meta_input0, meta_input1)
self.assertEqual(npu_output0.shape, meta_output0.shape)
self.assertEqual(npu_output1.shape, meta_output1.shape)
self.assertEqual(npu_output0.dtype, meta_output0.dtype)
self.assertEqual(npu_output1.dtype, meta_output1.dtype)
if __name__ == "__main__":
run_tests()