import os
import shutil
import unittest
import numpy
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
MAX_VALUE_WITH_INT8 = 127
MIN_VALUE_WITH_INT8 = -128
MAX_VALUE_WITH_FLOAT8E5M2 = 57344
MAX_VALUE_WITH_FLOAT8E4M3 = 448
MAX_VALUE_WITH_FLOAT4E2M1 = 6
MAX_VALUE_WITH_FLOAT4E1M2 = 1.75
class TestNPUDequantSwigluQuant(TestCase):
def numpy_float8_e5m2(self):
try:
from ml_dtypes import float8_e5m2
return float8_e5m2
except ModuleNotFoundError:
raise RuntimeError("ml_dtypes is needed to support float8_e5m2 dtype!!! "
"Please install with `pip3 install ml-dtypes`")
def numpy_float8_e4m3fn(self):
try:
from ml_dtypes import float8_e4m3fn
return float8_e4m3fn
except ModuleNotFoundError:
raise RuntimeError("ml_dtypes is needed to support float8_e4m3fn dtype!!! "
"Please install with `pip3 install ml-dtypes`")
def numpy_float4_e2m1(self):
try:
from en_dtypes import float4_e2m1
return float4_e2m1
except ModuleNotFoundError:
raise RuntimeError("en_dtypes is needed to support float4_e2m1 dtype!!! "
"Please install with `pip3 install en-dtypes`")
def numpy_float4_e1m2(self):
try:
from en_dtypes import float4_e1m2
return float4_e1m2
except ModuleNotFoundError:
raise RuntimeError("en_dtypes is needed to support float4_e1m2 dtype!!! "
"Please install with `pip3 install en-dtypes`")
def get_max_num(self, dst_type):
if dst_type == 1:
return MAX_VALUE_WITH_INT8
if dst_type == 291:
return MAX_VALUE_WITH_FLOAT8E5M2
if dst_type == 292:
return MAX_VALUE_WITH_FLOAT8E4M3
if dst_type == 296:
return MAX_VALUE_WITH_FLOAT4E2M1
if dst_type == 297:
return MAX_VALUE_WITH_FLOAT4E1M2
def transform_output(self, dst_type, round_mode, input):
dst_type_map = {1: numpy.int8, 291: self.numpy_float8_e5m2(), 292: self.numpy_float8_e4m3fn(),
296: self.numpy_float4_e2m1(), 297: self.numpy_float4_e1m2()}
round_mode_map = {0: numpy.rint, 1: numpy.round, 2: numpy.floor, 3: numpy.ceil, 4: numpy.trunc}
if dst_type == 1:
input = round_mode_map[round_mode](input)
tmp = input.to(torch.int8).numpy()
return tmp
elif dst_type == 291 or dst_type == 292:
tmp = input.numpy().astype(dst_type_map[dst_type])
return tmp
def golden_dequant_swiglu_quant_torch(
self,
x,
weight_scale,
activate_scale,
bias,
quant_scale,
quant_offset,
group_index,
activate_left,
quant_mode,
dst_type=1,
round_mode=0,
activate_dim=-1,
swiglu_mode=0,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0,
):
x_dtype = x.dtype
if len(x.shape) > 2:
x = x.reshape(-1, x.shape[-1])
if weight_scale is not None and len(weight_scale.shape) == 1:
weight_scale = weight_scale.reshape(1, -1)
if activate_scale is not None and len(activate_scale.shape) >= 1:
activate_scale = activate_scale.reshape(-1, 1)
if quant_offset is None:
quant_offset = torch.tensor([0], dtype=torch.float32)
if group_index is None:
group_index = numpy.array([x.shape[0]])
if quant_mode == 1:
if quant_scale is not None and len(quant_scale.shape) == 1:
quant_scale = quant_scale.reshape(1, -1)
if quant_mode == 1:
quant_mode = "dynamic"
elif quant_mode == 0:
quant_mode = "static"
res_y = torch.zeros([x.shape[0], x.shape[1] // 2], dtype=torch.float32)
res_scale = torch.zeros([x.shape[0]], dtype=torch.float32)
offset = 0
for g_idx in range(group_index.shape[0]):
groupIdx = group_index[g_idx]
x_tensor = x[offset: (offset + groupIdx)]
if "int32" in str(x_dtype):
if bias is not None and bias.dtype is torch.int32:
x_tensor = torch.add(x_tensor, bias[g_idx])
x_tensor = x_tensor.to(torch.float32)
res = torch.mul(x_tensor, weight_scale[g_idx].to(torch.float32))
if activate_scale is not None:
res = torch.mul(res, activate_scale[offset: (offset + groupIdx)].to(torch.float32))
if bias is not None and bias.dtype is not torch.int32:
res = torch.add(res, bias.to(torch.float32))
else:
res = x_tensor
if swiglu_mode == 1:
self_tensor = res[..., ::2]
other = res[..., 1::2]
else:
out = torch.chunk(res, 2, dim=-1)
if activate_left:
self_tensor = out[0]
other = out[1]
else:
self_tensor = out[1]
other = out[0]
if swiglu_mode == 1:
self_tensor = self_tensor.clamp(min=None, max=clamp_limit)
other = other.clamp(min=-clamp_limit, max=clamp_limit)
self_tensor = self_tensor * torch.sigmoid(glu_alpha * self_tensor)
output = self_tensor * (other + glu_bias)
else:
output = torch.nn.functional.silu(self_tensor) * other
if quant_scale is not None:
if quant_mode == "static":
if(len(quant_scale.shape) == 1):
quant_scale = quant_scale.unsqueeze(1).expand(-1, output.shape[1])
if(len(quant_offset.shape) == 1):
quant_offset = quant_offset.unsqueeze(1).expand(-1, output.shape[1])
output = torch.div(output, quant_scale[g_idx].to(torch.float32))
output = torch.add(output, quant_offset[g_idx].to(torch.float32))
scale_out = torch.tensor(0.0)
else:
output = torch.mul(output, quant_scale[g_idx].to(torch.float32))
absd = torch.abs(output)
max_values = torch.amax(absd, dim=-1)
scale_out = max_values / 127
max_values = 127 / max_values
output = output * max_values.unsqueeze(1)
else:
if quant_mode == "static":
output = output + quant_offset
scale_out = torch.tensor(0.0)
else:
absd = torch.abs(output)
max_values = torch.amax(absd, dim=-1)
max_value = self.get_max_num(dst_type)
scale_out = max_values / max_value
max_values = max_value / max_values
output = output * max_values.unsqueeze(1)
if dst_type == 1:
output = torch.clamp(output, min=MIN_VALUE_WITH_INT8, max=MAX_VALUE_WITH_INT8)
else:
output = torch.clamp(output, min=self.get_max_num(dst_type) * -1, max=self.get_max_num(dst_type))
res_y[offset: (offset + groupIdx)] = output
res_scale[offset: (offset + groupIdx)] = scale_out
offset = offset + groupIdx
return self.transform_output(dst_type, round_mode, res_y), res_scale
@unittest.skip("Skip until CANN is updated to 8.3.RC1 to support aclnnDequantSwigluQuantV2")
@SupportedDevices(["Ascend910B"])
def test_npu_dequant_swiglu_quant_1(self, device="npu"):
swiglu_mode = 0
bias = None
quant_offset = None
x_shape = [4608, 2048]
x = torch.randint(-10, 10, x_shape, dtype=torch.int32)
weight_scale = torch.randn(x_shape[1], dtype=torch.float32)
activate_scale = None
quant_scale = torch.randn((1, x_shape[1] // 2), dtype=torch.float32)
group_index = torch.tensor([x.shape[0]])
quant_mode = 1
if quant_mode == 0:
quant_offset = torch.randn((1, x_shape[1] // 2), dtype=torch.float32)
y_cpu, scale_cpu = self.golden_dequant_swiglu_quant_torch(
x,
weight_scale,
activate_scale,
bias,
quant_scale,
quant_offset,
group_index,
activate_left=True,
quant_mode=quant_mode,
swiglu_mode=swiglu_mode,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0,
)
group_index_npu = group_index.npu() if group_index is not None else None
bias_npu = bias.npu() if bias is not None else None
if quant_offset is not None:
quant_offset = quant_offset.npu()
y_npu, scale_npu = torch_npu.npu_dequant_swiglu_quant(
x.npu(),
weight_scale=weight_scale.npu(),
activation_scale=activate_scale,
bias=bias_npu,
quant_scale=quant_scale.npu(),
quant_offset=quant_offset,
group_index=group_index_npu,
activate_left=True,
quant_mode=quant_mode,
swiglu_mode=swiglu_mode,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0,
)
self.assertRtolEqual(y_cpu.numpy(), y_npu.cpu().numpy())
self.assertRtolEqual(scale_cpu.numpy(), scale_npu.cpu().numpy())
@unittest.skip("Skip until CANN is updated to 8.3.RC1 to support aclnnDequantSwigluQuantV2")
@SupportedDevices(["Ascend910B"])
def test_npu_dequant_swiglu_quant_2(self, device="npu"):
swiglu_mode = 1
bias = None
quant_offset = None
x_shape = [4608, 2048]
x = torch.randint(-10, 10, x_shape, dtype=torch.int32)
weight_scale = torch.randn(x_shape[1], dtype=torch.float32)
activate_scale = torch.randn((x_shape[0], 1), dtype=torch.float32)
quant_scale = torch.randn((1, x_shape[1] // 2), dtype=torch.float32)
group_index = torch.tensor([x.shape[0]])
quant_mode = 1
if quant_mode == 0:
quant_offset = torch.randn((1, x_shape[1] // 2), dtype=torch.float32)
y_cpu, scale_cpu = self.golden_dequant_swiglu_quant_torch(
x,
weight_scale,
activate_scale,
bias,
quant_scale,
quant_offset,
group_index,
activate_left=True,
quant_mode=quant_mode,
swiglu_mode=swiglu_mode,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0,
)
group_index_npu = group_index.npu() if group_index is not None else None
bias_npu = bias.npu() if bias is not None else None
if quant_offset is not None:
quant_offset = quant_offset.npu()
y_npu, scale_npu = torch_npu.npu_dequant_swiglu_quant(
x.npu(),
weight_scale=weight_scale.npu(),
activation_scale=activate_scale.npu(),
bias=bias_npu,
quant_scale=quant_scale.npu(),
quant_offset=quant_offset,
group_index=group_index_npu,
activate_left=True,
quant_mode=quant_mode,
swiglu_mode=swiglu_mode,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0,
)
self.assertRtolEqual(y_cpu.numpy(), y_npu.cpu().numpy())
self.assertRtolEqual(scale_cpu.numpy(), scale_npu.cpu().numpy())
@unittest.skip("Skip until CANN is updated to 8.3.RC1 to support aclnnDequantSwigluQuantV2")
@SupportedDevices(["Ascend910B"])
def test_npu_dequant_swiglu_quant_swiglu_mode2(self, device="npu"):
swiglu_mode = 2
bias = None
quant_offset = None
x_shape = [4608, 2048]
x = torch.randint(-10, 10, x_shape, dtype=torch.int32)
weight_scale = torch.randn(x_shape[1], dtype=torch.float32)
activate_scale = torch.randn((x_shape[0], 1), dtype=torch.float32)
quant_scale = torch.randn((1, x_shape[1] // 2), dtype=torch.float32)
group_index = torch.tensor([x.shape[0]])
quant_mode = 1
if quant_mode == 0:
quant_offset = torch.randn((1, x_shape[1] // 2), dtype=torch.float32)
y_cpu, scale_cpu = self.golden_dequant_swiglu_quant_torch(
x,
weight_scale,
activate_scale,
bias,
quant_scale,
quant_offset,
group_index,
activate_left=True,
quant_mode=quant_mode,
swiglu_mode=swiglu_mode,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0,
)
group_index_npu = group_index.npu() if group_index is not None else None
bias_npu = bias.npu() if bias is not None else None
if quant_offset is not None:
quant_offset = quant_offset.npu()
y_npu, scale_npu = torch_npu.npu_dequant_swiglu_quant(
x.npu(),
weight_scale=weight_scale.npu(),
activation_scale=activate_scale.npu(),
bias=bias_npu,
quant_scale=quant_scale.npu(),
quant_offset=quant_offset,
group_index=group_index_npu,
activate_left=True,
quant_mode=quant_mode,
swiglu_mode=swiglu_mode,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0,
)
self.assertRtolEqual(y_cpu.numpy(), y_npu.cpu().numpy())
self.assertRtolEqual(scale_cpu.numpy(), scale_npu.cpu().numpy())
@unittest.skip("Skip until CANN is updated to 8.3.RC1 to support aclnnDequantSwigluQuantV2")
@SupportedDevices(["Ascend910B"])
def test_npu_dequant_swiglu_quant_3(self, device="npu"):
swiglu_mode = 0
bias = None
quant_offset = None
x_shape = [4608, 2048]
x = torch.randint(-10, 10, x_shape, dtype=torch.int32)
weight_scale = torch.randn(x_shape[1], dtype=torch.float32)
activate_scale = torch.randn((x_shape[0], 1), dtype=torch.float32)
quant_scale = torch.randn((1, x_shape[1] // 2), dtype=torch.float32)
group_index = torch.tensor([x.shape[0]])
quant_mode = 0
if quant_mode == 0:
quant_offset = torch.randn((1, x_shape[1] // 2), dtype=torch.float32)
y_cpu, _ = self.golden_dequant_swiglu_quant_torch(
x,
weight_scale,
activate_scale,
bias,
quant_scale,
quant_offset,
group_index,
activate_left=True,
quant_mode=quant_mode,
swiglu_mode=swiglu_mode,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0,
)
group_index_npu = group_index.npu() if group_index is not None else None
bias_npu = bias.npu() if bias is not None else None
if quant_offset is not None:
quant_offset = quant_offset.npu()
y_npu, scale_npu = torch_npu.npu_dequant_swiglu_quant(
x.npu(),
weight_scale=weight_scale.npu(),
activation_scale=activate_scale.npu(),
bias=bias_npu,
quant_scale=quant_scale.npu(),
quant_offset=quant_offset,
group_index=group_index_npu,
activate_left=True,
quant_mode=quant_mode,
swiglu_mode=swiglu_mode,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0,
)
self.assertRtolEqual(y_cpu.numpy(), y_npu.cpu().numpy())
@SupportedDevices(["Ascend950"])
def test_npu_dequant_swiglu_quant_4(self, device="npu"):
x_shape = [8, 4]
x = torch.randint(-10, 10, x_shape, dtype=torch.int32)
weight_scale = torch.randn([4], dtype=torch.float32)
activate_scale = None
bias = torch.randn([4], dtype=torch.float32)
quant_scale = None
quant_offset = None
group_index = None
activate_left = False
quant_mode = 1
dst_type = 291
round_mode = 0
activate_dim = -1
y_cpu, _ = self.golden_dequant_swiglu_quant_torch(
x,
weight_scale,
activate_scale,
bias,
quant_scale,
quant_offset,
group_index,
activate_left,
quant_mode,
dst_type,
round_mode,
activate_dim,
swiglu_mode=0,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0)
y_npu, _ = torch_npu.npu_dequant_swiglu_quant(
x.npu(),
weight_scale=weight_scale.npu(),
activation_scale=activate_scale,
bias=bias.npu(),
quant_scale=quant_scale,
quant_offset=quant_offset,
group_index=group_index,
activate_left=activate_left,
quant_mode=quant_mode,
dst_type=dst_type,
round_mode=round_mode,
activate_dim=activate_dim,
swiglu_mode=0,
clamp_limit=7.0,
glu_alpha=1.702,
glu_bias=1.0)
self.assertRtolEqual(y_cpu.astype(numpy.float32), y_npu.to(torch.float32).cpu().numpy())
if __name__ == "__main__":
run_tests()