import unittest
import numpy as np
import torch_npu
import torch
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import get_npu_device, SupportedDevices
torch.npu.set_compile_mode(jit_compile=False)
class TestNPUDeepNorm(TestCase):
def supported_op_exec(self, x, gx, beta, gamma):
alpha = 0.3
epsilon = 1e-6
len_shape_x = len(x.shape)
len_shape_gamma = len(gamma.shape)
reduce_axis = tuple(range(len_shape_x - len_shape_gamma, len_shape_x, 1))
new_x = alpha * x + gx
mean = np.mean(new_x, axis=reduce_axis, keepdims=True)
diff = new_x - mean
variance = np.mean(np.power(diff, 2), axis=reduce_axis, keepdims=True)
std = np.sqrt(variance + epsilon)
rstd = 1 / std
result_mid = diff * rstd
y = result_mid * gamma + beta
return mean, rstd, y
def custom_op_exec(self, x, gx, beta, gamma):
mean, rstd, y = torch_npu.npu_deep_norm(x, gx, beta, gamma, float(0.3), 1e-6)
return mean.cpu().numpy(), rstd.cpu().numpy(), y.cpu().numpy()
@SupportedDevices(['Ascend910B'])
def test_deep_norm(self, device="npu"):
if device is None:
device = get_npu_device()
cpu_input_x = np.random.uniform(0, 100, [1024, 2, 12288]).astype(np.float32)
cpu_input_gx = np.random.uniform(0, 100, [1024, 2, 12288]).astype(np.float32)
cpu_input_beta = np.random.uniform(0, 100, [2, 12288]).astype(np.float32)
cpu_input_gamma = np.random.uniform(0, 100, [2, 12288]).astype(np.float32)
npu_input_x = torch.from_numpy(cpu_input_x).to(device)
npu_input_gx = torch.from_numpy(cpu_input_gx).to(device)
npu_input_beta = torch.from_numpy(cpu_input_beta).to(device)
npu_input_gamma = torch.from_numpy(cpu_input_gamma).to(device)
supported_mean, supported_rstd, supported_y = self.supported_op_exec(cpu_input_x, cpu_input_gx, cpu_input_beta, cpu_input_gamma)
custom_mean, custom_rstd, custom_y = self.custom_op_exec(npu_input_x, npu_input_gx, npu_input_beta, npu_input_gamma)
self.assertRtolEqual(supported_mean, custom_mean)
self.assertRtolEqual(supported_rstd, custom_rstd)
self.assertRtolEqual(supported_y, custom_y)
if __name__ == "__main__":
run_tests()