import unittest
import math
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
QUANT_TYPE_MIN_MAP = {1: -128, 290: -32769, 291: -57345, 292: -449}
QUANT_TYPE_MAX_MAP = {1: 127, 290: 32768, 291: 57344, 292: 448}
class TestNPURmsNormQuant(TestCase):
def numpy_float8_e4m3fn(self):
try:
from ml_dtypes import float8_e4m3fn
return float8_e4m3fn
except ModuleNotFoundError:
raise RuntimeError("ml_dtypes is needed to support float8_e4m3fn dtype!!! "
"Please install with `pip3 install ml-dtypes`")
def numpy_hifloat8(self):
try:
from en_dtypes import hifloat8
return hifloat8
except ModuleNotFoundError:
raise RuntimeError("en_dtypes is needed to support hifloat8 dtype!!! "
"Please install with `pip3 install en-dtypes`")
except ImportError:
raise RuntimeError("Please upgrade en_dtypes to v0.0.3 at least to support hifloat8 dtype!!! "
"Command is `pip3 install --upgrade en-dtypes`")
def numpy_float8_e5m2(self):
try:
from ml_dtypes import float8_e5m2
return float8_e5m2
except ModuleNotFoundError:
raise RuntimeError("ml_dtypes is needed to support float8_e5m2 dtype!!! "
"Please install with `pip3 install ml-dtypes`")
def compare(self, a, b, benchmark):
diff_abs = torch.abs(a - b)
max_diff_abs, _ = torch.max(diff_abs, dim=0)
if max_diff_abs.item() == 0:
return True
else:
rel_error = 0
abs_error = 0
for i in range(a.shape[0]):
yes_no = (a[i] == 0 and b[i].item() != 0)
no_yes = (a[i] != 0 and b[i].item() == 0)
if a[i].item() == 0 and b[i].item() == 0:
diff_rel_item = 0
elif yes_no or no_yes:
diff_rel_item = 1
elif a[i] != 0 and b[i].item() != 0:
diff_rel_item = diff_abs[i].item() / abs(a[i].item())
if abs(a[i].item()) < 1 and diff_abs[i].item() > benchmark:
abs_error += 1
elif abs(a[i].item()) >= 1 and diff_rel_item > benchmark:
rel_error += 1
if (rel_error + abs_error) > 10:
break
if (rel_error + abs_error) > 0:
return False
else:
return True
def npu_rms_norm_quant_golden(self, x, gamma, beta, scale,
offset, epsilon=1e-06, dst_dtype=1):
x_fp32 = x.float()
input_gamma_fp32 = gamma.float()
input_beta_fp32 = beta.float()
tensor_scales = scale.float()
offset = offset.float()
ori_shape = x.shape
len_shape_x = len(x_fp32.shape)
len_shape_gamma = len(gamma.shape)
axis = len_shape_x - len_shape_gamma
variance = torch.mean(torch.pow(x_fp32, 2), axis=-1, keepdims=True)
std = torch.sqrt(variance + epsilon)
rstd = 1 / std
result_mid = x_fp32 * rstd
y_array = result_mid * input_gamma_fp32 + input_beta_fp32
y = y_array.type(torch.float32)
if dst_dtype == 1:
y1 = torch.quantize_per_tensor(y, tensor_scales, offset, torch.qint8)
y1_np = y1.int_repr().detach().clone().cpu().numpy()
return torch.tensor(y1_np).type(torch.float16).type(torch.int8).reshape(ori_shape)
else:
dst_dtype_map = {290: self.numpy_hifloat8(), 291: self.numpy_float8_e5m2(),
292: self.numpy_float8_e4m3fn()}
y = y / tensor_scales + offset
y1_np = y.detach().clone().cpu().numpy()
y1_np = np.clip(y1_np, QUANT_TYPE_MIN_MAP[dst_dtype], QUANT_TYPE_MAX_MAP[dst_dtype]).astype(
dst_dtype_map[dst_dtype], copy=False)
return torch.tensor(y1_np.reshape(x.shape).astype(np.float32))
@unittest.skip("skip until CANN is updated to support aclnnRmsNormQuant")
@SupportedDevices(['Ascend910B'])
def test_npu_rms_norm_quant(self):
torch.manual_seed(42)
np.random.seed(42)
shape_list = [
[[16, ], [16, ]],
[[16, ], [1, 16]],
[[1, 16], [16, ]],
[[1, 16], [1, 16]],
[[1, 1, 16], [16, ]],
[[1, 1, 16], [1, 16]],
[[1, 1, 1, 16], [16, ]],
[[1, 1, 1, 16], [1, 16]],
[[1, 1, 1, 1, 16], [16, ]],
[[1, 1, 1, 1, 16], [1, 16]],
[[1, 1, 1, 1, 1, 16], [16, ]],
[[1, 1, 1, 1, 1, 16], [1, 16]],
[[1, 1, 1, 1, 1, 1, 16], [16, ]],
[[1, 1, 1, 1, 1, 1, 16], [1, 16]],
[[1, 1, 1, 1, 1, 1, 1, 16], [16, ]],
[[1, 1, 1, 1, 1, 1, 1, 16], [1, 16]],
]
benchmark_int8 = 1
for x_shape, quant_shape in shape_list:
D = x_shape[-1]
x = torch.randn(x_shape, dtype=torch.float16)
if quant_shape == [D, ]:
gamma = torch.randn(D, dtype=torch.float16)
beta = torch.randn(D, dtype=torch.float16)
elif quant_shape == [1, D]:
gamma = torch.randn(1, D, dtype=torch.float16)
beta = torch.randn(1, D, dtype=torch.float16)
scale = (torch.rand(1, dtype=torch.float16) * 0.8 + 0.2)
offset = torch.randint(-5, 6, (1, ), dtype=torch.int8)
x_npu = x.npu()
gamma_npu = gamma.npu()
beta_npu = beta.npu()
scale_npu = scale.npu()
offset_npu = offset.npu()
y_ref = self.npu_rms_norm_quant_golden(x, gamma, beta, scale, offset, epsilon=1e-6)
y_npu = torch_npu.npu_rms_norm_quant(x_npu, gamma_npu, beta_npu, scale_npu, offset_npu, epsilon=1e-6)
y_ref_flat = y_ref.reshape(1, y_ref.numel())[0].cpu()
y_npu_flat = y_npu.reshape(1, y_npu.numel())[0].cpu()
self.assertTrue(self.compare(y_ref_flat, y_npu_flat, benchmark_int8))
@unittest.skip("skip until CANN is updated to support aclnnRmsNormQuant")
@SupportedDevices(['Ascend910B'])
def test_npu_rms_norm_quant_bf16(self):
shape_list = [
[[16, ], [16, ]],
[[16, ], [1, 16]],
[[1, 16], [16, ]],
[[1, 16], [1, 16]],
[[1, 1, 16], [16, ]],
[[1, 1, 16], [1, 16]],
[[1, 1, 1, 16], [16, ]],
[[1, 1, 1, 16], [1, 16]],
[[1, 1, 1, 1, 16], [16, ]],
[[1, 1, 1, 1, 16], [1, 16]],
[[1, 1, 1, 1, 1, 16], [16, ]],
[[1, 1, 1, 1, 1, 16], [1, 16]],
[[1, 1, 1, 1, 1, 1, 16], [16, ]],
[[1, 1, 1, 1, 1, 1, 16], [1, 16]],
[[1, 1, 1, 1, 1, 1, 1, 16], [16, ]],
[[1, 1, 1, 1, 1, 1, 1, 16], [1, 16]],
]
benchmark_int8 = 1
for x_shape, quant_shape in shape_list:
D = x_shape[-1]
x = torch.randn(x_shape, dtype=torch.bfloat16)
if quant_shape == [D, ]:
gamma = torch.randn(D, dtype=torch.bfloat16)
beta = torch.randn(D, dtype=torch.bfloat16)
elif quant_shape == [1, D]:
gamma = torch.randn(1, D, dtype=torch.bfloat16)
beta = torch.randn(1, D, dtype=torch.bfloat16)
scale = (torch.rand(1, dtype=torch.bfloat16) * 0.8 + 0.2)
offset = torch.randint(-5, 6, (1, ), dtype=torch.int8)
x_npu = x.npu()
gamma_npu = gamma.npu()
beta_npu = beta.npu()
scale_npu = scale.npu()
offset_npu = offset.npu()
y_ref = self.npu_rms_norm_quant_golden(x, gamma, beta, scale, offset, epsilon=1e-6)
y_npu = torch_npu.npu_rms_norm_quant(x_npu, gamma_npu, beta_npu, scale_npu, offset_npu, epsilon=1e-6)
y_ref_flat = y_ref.reshape(1, y_ref.numel())[0].cpu()
y_npu_flat = y_npu.reshape(1, y_npu.numel())[0].cpu()
self.assertTrue(self.compare(y_ref_flat, y_npu_flat, benchmark_int8))
@SupportedDevices(['Ascend910_95', 'Ascend950'])
def test_npu_rms_norm_quant_float8_e5m2(self):
x_shape = [4, 32]
quant_shape = [32]
x = torch.randn(x_shape, dtype=torch.float16)
gamma = torch.randn(quant_shape, dtype=torch.float16)
beta = torch.randn(quant_shape, dtype=torch.float16)
scale = (torch.rand(1, dtype=torch.float16) * 0.8 + 0.2)
offset = torch.randint(-5, 6, (1,), dtype=torch.int8)
eps = 1e-6
out_dtype = 290
x_npu = x.npu()
gamma_npu = gamma.npu()
beta_npu = beta.npu()
scale_npu = scale.npu()
offset_npu = offset.npu()
y_ref = self.npu_rms_norm_quant_golden(x, gamma, beta, scale, offset, epsilon=eps, dst_dtype=out_dtype)
y_npu = torch_npu.npu_rms_norm_quant(x_npu, gamma_npu, beta_npu, scale_npu, offset_npu, epsilon=eps,
dst_dtype=out_dtype)
benchmark_float32 = 1e-8
y_ref_flat = y_ref.reshape(1, y_ref.numel())[0].cpu()
y_npu_flat = y_npu.to(torch.float32).reshape(1, y_npu.numel())[0].cpu()
print(y_ref_flat)
print(y_npu_flat)
self.assertTrue(self.compare(y_ref_flat, y_npu_flat, benchmark_float32))
if __name__ == "__main__":
run_tests()