import torch
import torch_npu
import unittest
import numpy as np
import torch.nn.functional as F
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
def non_quant_golden_with_offset(x, weight, scale, perTokenScale, groupList, bias, offset):
groupNum, k, n = weight.shape
quantGroupNum = scale.shape[1]
index = np.cumsum(groupList)
xSplit = np.split(x, index * 2, axis=0)
perTokenScaleSplit = np.split(perTokenScale, index, axis=0)
weightGroup = weight.reshape(groupNum, quantGroupNum, k // quantGroupNum, n).astype(np.int32)
mmOuts = []
atomic = np.float16
for i in range(groupNum):
xi = xSplit[i].reshape(-1, quantGroupNum, k // quantGroupNum).astype(np.int32)
mmi = np.zeros([xi.shape[0], n], dtype=atomic)
for j in range(quantGroupNum):
mm = np.matmul(xi[:, j, :], weightGroup[i, j, ...])
mm = mm.astype(np.float32) * scale[i, j].reshape(1, -1)
mmi = (mmi.astype(atomic) + mm.astype(atomic)).astype(atomic)
mmi = mmi.reshape(-1, 2, n).astype(np.float32)
mmi = mmi[:, 0, :] * 16 + mmi[:, 1, :] + bias[i].reshape(1, n)
if offset is not None:
mmi_xo = np.zeros(mmi.shape, dtype=np.float32)
xi_o = xSplit[i].astype(np.int32).reshape(-1, 2, k)
xi_o = xi_o[:, 0, :] * 16 + xi_o[:, 1, :] + 8
xi_o = xi_o.astype(np.float16).reshape(-1, quantGroupNum, k // quantGroupNum)
for j in range(quantGroupNum):
mm = xi_o[:, j, :].sum(axis=1, keepdims=True).astype(np.float32)
mm = np.matmul(mm, offset[i, j].reshape(1, -1))
mmi_xo += mm.astype(np.float32)
mmi += mmi_xo
mmi = mmi * perTokenScaleSplit[i]
mmOuts.append(mmi)
golden = np.concatenate(mmOuts, axis=0)
golden_tensor = torch.from_numpy(golden)
return golden_tensor.to(torch.float32)
def combine_func(x, logits, residual, residScale, sourceRow, topK, offset):
out = x * logits.reshape(-1, 1)
index = np.argsort(sourceRow)
out = out[index].reshape(-1, topK, x.shape[-1]).sum(axis=1)
out[offset:offset + residual.shape[0], :] += residScale * residual.to(torch.float32)
return out
class TestGroupedMatmulFinalizeRouting(TestCase):
def supported_op_exec(self,
topK, x, weight, group_list, scale, pertoken_scale,
shared_input=None, logit=None, row_index=None,
shared_input_scale=1, shared_input_offset=0):
x_split = torch.split(x, group_list.tolist(), dim=0)
pertoken_scale_split = torch.split(pertoken_scale, group_list.tolist(), dim=0)
mm_outs = []
for i in range(len(group_list)):
mm = torch.matmul(x_split[i].to(torch.int32), weight[i].to(torch.int32))
mm = mm.to(torch.float32) * scale[i].to(torch.float32) * pertoken_scale_split[i]
mm_outs.append(mm)
mm_out = torch.cat(mm_outs, dim=0)
if shared_input is not None:
out = mm_out * logit.reshape(-1, 1)
index = torch.argsort(row_index, dim=0)
out = out[index].reshape(-1, topK, mm_out.shape[-1]).sum(dim=1)
out[shared_input_offset:shared_input_offset + shared_input.shape[0], :] += \
shared_input_scale * shared_input.to(torch.float32)
else:
out = mm_out * logit.reshape(-1, 1)
index = torch.argsort(row_index, dim=0)
out = out[index].reshape(-1, topK, mm_out.shape[-1]).sum(dim=1)
return out
def supported_a8w4_op_exec(self, topK, x_in, weight_in, groupList_in, scale_in,
bias_in, offset, perTokenScale_in, residual, logits,
sourceRow, residScale, shared_input_offset=0):
weightNz = weight_in.astype(np.int8)
groupNum = groupList_in.shape[0]
m = x_in.shape[0]
k = x_in.shape[1]
n = scale_in.shape[2]
weight = weightNz.reshape(groupNum, k, n)
xC12 = np.concatenate([x_in.reshape(m, 1, k) // 16, (x_in.reshape(m, 1, k) & 0x0F) - 8], axis=1).reshape(m * 2, k)
data = xC12.astype(np.int8)
data[data < 0] += 16
xInt4 = (data[..., 1::2] << 4) | (data[..., ::2] & 0x0F)
xInt4.dtype = np.int8
scaleUint32 = scale_in.astype(np.uint32)
scaleUint32.dtype = np.float32
mm_out = non_quant_golden_with_offset(xC12, weight, scaleUint32, perTokenScale_in, groupList_in, bias_in, offset)
out = combine_func(mm_out, logits, residual, residScale, sourceRow, topK, shared_input_offset)
return out
@SupportedDevices(["Ascend910B"])
def test_npu_grouped_matmul_finalize_routing_1(self, device="npu"):
m, k, n, batch, topK, group_num, shared_input_scale = 576, 2048, 7168, 72, 8, 8, 1
x = torch.randint(-10, 10, (m, k), dtype=torch.int8)
weight = torch.randint(-10, 10, (group_num, k, n), dtype=torch.int8)
scale = torch.normal(0, 0.01, (group_num, n), dtype=torch.float32)
pertoken_scale = torch.normal(0, 0.01, (m, 1), dtype=torch.float32)
group_list = torch.tensor([batch] * group_num, dtype=torch.int64)
logit_ori = torch.normal(0, 0.1, (batch, group_num), dtype=torch.float32)
routing = torch.argsort(logit_ori, 1)[:, -topK:]
shared_input = torch.normal(0, 0.1, (batch // 4, n), dtype=torch.bfloat16)
logit = F.softmax(
logit_ori[torch.arange(batch).reshape(-1, 1).repeat(1, topK), routing],
dim=1,
dtype=torch.float32
).reshape(m)
row_index = (torch.argsort(routing.reshape(-1)) // topK).to(torch.int64)
shared_input_offset = batch // 2
output_bs = batch
supported_output = self.supported_op_exec(topK, x, weight, group_list, scale,
pertoken_scale, shared_input, logit, row_index,
shared_input_scale, shared_input_offset)
weightNz = torch_npu.npu_format_cast(weight.npu(), 29)
pertoken_scale = pertoken_scale.reshape(m)
custom_output = torch_npu.npu_grouped_matmul_finalize_routing(
x.npu(), weightNz, group_list.npu(), scale=scale.npu(),
pertoken_scale=pertoken_scale.npu(), shared_input=shared_input.npu(),
logit=logit.npu(), row_index=row_index.npu(),
shared_input_offset=shared_input_offset, output_bs=output_bs
).to("cpu")
self.assertRtolEqual(supported_output, custom_output, 0.001)
@unittest.skip("Skipping due to outdated CANN version; please update CANN to the latest version and remove this skip.")
@SupportedDevices(["Ascend910B"])
def test_npu_grouped_matmul_finalize_routing_sharedinput_none_grouplist_cumsum(self, device="npu"):
m, k, n, batch, topK, group_num, shared_input_scale = 576, 2048, 7168, 72, 8, 8, 1
x = torch.randint(-10, 10, (m, k), dtype=torch.int8)
weight = torch.randint(-10, 10, (group_num, k, n), dtype=torch.int8)
scale = torch.normal(0, 0.01, (group_num, n), dtype=torch.float32)
pertoken_scale = torch.normal(0, 0.01, (m, 1), dtype=torch.float32)
group_list = torch.tensor([batch] * group_num, dtype=torch.int64)
logit_ori = torch.normal(0, 0.1, (batch, group_num), dtype=torch.float32)
routing = torch.argsort(logit_ori, 1)[:, -topK:]
logit = F.softmax(
logit_ori[torch.arange(batch).reshape(-1, 1).repeat(1, topK), routing],
dim=1,
dtype=torch.float32
).reshape(m)
row_index = (torch.argsort(routing.reshape(-1)) // topK).to(torch.int64)
shared_input_offset = batch // 2
output_bs = batch
supported_output = self.supported_op_exec(topK, x, weight, group_list, scale,
pertoken_scale, logit=logit, row_index=row_index,
shared_input_scale=shared_input_scale,
shared_input_offset=shared_input_offset)
group_list_type = 0
group_list = torch.cumsum(group_list, dim=0)
weightNz = torch_npu.npu_format_cast(weight.npu(), 29)
pertoken_scale = pertoken_scale.reshape(m)
custom_output = torch_npu.npu_grouped_matmul_finalize_routing(
x.npu(), weightNz, group_list.npu(), scale=scale.npu(),
pertoken_scale=pertoken_scale.npu(), shared_input=None,
logit=logit.npu(), row_index=row_index.npu(),
shared_input_offset=shared_input_offset, output_bs=output_bs, group_list_type=group_list_type
).to("cpu")
self.assertRtolEqual(supported_output, custom_output, 0.001)
@unittest.skip("Skip temporary. The kernel is not supported.")
@SupportedDevices(["Ascend910B"])
def test_npu_grouped_matmul_finalize_routing_w8a8_support_none_pertoken_scale(self, device="npu"):
m, k, n, batch, topK, group_num, shared_input_scale = 576, 2048, 7168, 72, 8, 8, 1
x = torch.randint(-10, 10, (m, k), dtype=torch.int8)
weight = torch.randint(-10, 10, (group_num, k, n), dtype=torch.int8)
scale = torch.normal(0, 0.01, (group_num, n), dtype=torch.float32)
pertoken_scale = torch.ones((m, 1), dtype=torch.float32)
group_list = torch.tensor([batch] * group_num, dtype=torch.int64)
logit_ori = torch.normal(0, 0.1, (batch, group_num), dtype=torch.float32)
routing = torch.argsort(logit_ori, 1)[:, -topK:]
shared_input = torch.normal(0, 0.1, (batch // 4, n), dtype=torch.bfloat16)
logit = F.softmax(
logit_ori[torch.arange(batch).reshape(-1, 1).repeat(1, topK), routing],
dim=1,
dtype=torch.float32
).reshape(m)
row_index = (torch.argsort(routing.reshape(-1)) // topK).to(torch.int64)
shared_input_offset = batch // 2
output_bs = batch
supported_output = self.supported_op_exec(topK, x, weight, group_list, scale,
pertoken_scale, shared_input, logit, row_index,
shared_input_scale, shared_input_offset)
weightNz = torch_npu.npu_format_cast(weight.npu(), 29)
pertoken_scale = pertoken_scale.reshape(m)
custom_output = torch_npu.npu_grouped_matmul_finalize_routing(
x.npu(), weightNz, group_list.npu(), scale=scale.npu(),
pertoken_scale=None, shared_input=shared_input.npu(),
logit=logit.npu(), row_index=row_index.npu(),
shared_input_offset=shared_input_offset, output_bs=output_bs
).to("cpu")
self.assertRtolEqual(supported_output, custom_output, 0.001)
@unittest.skip("Skip temporary. The kernel is not supported.")
@SupportedDevices(["Ascend910B"])
def test_npu_grouped_matmul_finalize_routing_a8w4(self, device="npu"):
m, k, n, group_num = 8, 2048, 7168, 8
batch = m // group_num
quantGroupSize = k
topK = 8
x = torch.randint(-5, 5, (m, k), dtype=torch.int8)
weight = torch.randint(-5, 5, (group_num, k, n), dtype=torch.int32)
scale_np = np.random.normal(0, 0.01, (group_num, 1, n)).astype(np.float32)
perGroupScale = np.ones([group_num, k // quantGroupSize, n]).astype(np.float32)
scaleUint32 = (scale_np * perGroupScale).astype(np.float16).astype(np.float32)
scaleUint32.dtype = np.uint32
scaleUint64 = np.zeros((group_num, k // quantGroupSize, n * 2), dtype=np.uint32)
scaleUint64[..., ::2] = scaleUint32
scaleUint64.dtype = np.int64
scale = torch.from_numpy(scaleUint64)
bias = torch.normal(0, 0.01, (group_num, n), dtype=torch.float32)
offset = torch.randint(-5, 5, (group_num, k // quantGroupSize, n), dtype=torch.float32)
pertoken_scale = torch.normal(0, 0.01, (m, 1), dtype=torch.float32)
group_list = torch.tensor([batch] * group_num, dtype=torch.int64)
logit_ori = torch.normal(0, 0.1, (batch, group_num), dtype=torch.float32)
routing = torch.argsort(logit_ori, 1)[:, -topK:]
shared_input = torch.normal(0, 0.1, (max(batch // 4, 1), n), dtype=torch.bfloat16)
logit = F.softmax(
logit_ori[torch.arange(batch).reshape(-1, 1).repeat(1, topK), routing],
dim=1,
dtype=torch.float32
).reshape(m)
row_index = (torch.argsort(routing.reshape(-1)) // topK).to(torch.int64)
shared_input_scale = 1
shared_input_offset = 0
output_bs = batch
weight_quant = torch_npu.npu_quantize(weight.float().npu(), torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False)
custom_output = torch_npu.npu_grouped_matmul_finalize_routing(
x.npu(), weight_quant, group_list.npu(), scale=scale.npu(),
bias=bias.npu(), offset=offset.npu(),
pertoken_scale=pertoken_scale.reshape(m).npu(), shared_input=shared_input.npu(),
logit=logit.npu(), row_index=row_index.npu(), shared_input_weight=shared_input_scale,
shared_input_offset=shared_input_offset, output_bs=output_bs
).to("cpu")
supported_output = self.supported_a8w4_op_exec(
topK, x.numpy(), weight.numpy(), group_list.numpy(), scale.numpy(), bias.numpy(),
offset.numpy(), pertoken_scale.numpy(), shared_input, logit.numpy(), row_index.numpy(),
shared_input_scale, shared_input_offset
)
self.assertRtolEqual(supported_output, custom_output, 0.001)
if __name__ == "__main__":
run_tests()