import unittest
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestAtbMLAPO(TestCase):
@unittest.skip("Skipping due to outdated CANN version; please update CANN to the latest version and remove this skip")
@SupportedDevices(['Ascend910B'])
def test_atb_mlapo_20(self):
token_num = 8
head_num = 128
N_7168 = 7168
block_num = 192
block_size = 128
dtype = torch.float16
device = 'npu'
input1 = torch.randn((token_num, N_7168), dtype=dtype, device=device)
gamma0 = torch.randn((N_7168), dtype=dtype, device=device)
beta0 = torch.randn((N_7168), dtype=dtype, device=device)
quant_scale0 = torch.randn((1,), dtype=dtype, device=device)
quant_offset0 = torch.randint(
0, 7, (1,), dtype=torch.int8, device=device)
wdqkv = torch.randint(0, 7, (1, 224, 2112, 32),
dtype=torch.int8, device=device)
wdqkv = torch_npu.npu_format_cast(wdqkv, 29)
de_scale0 = torch.randint(
0, 7, (2112, ), dtype=torch.int64, device=device)
bias0 = torch.randint(0, 7, (2112, ), dtype=torch.int32, device=device)
gamma1 = torch.randn((1536), dtype=dtype, device=device)
beta1 = torch.randn((1536), dtype=dtype, device=device)
quant_scale1 = torch.randn((1,), dtype=dtype, device=device)
quant_offset1 = torch.randint(
0, 7, (1,), dtype=torch.int8, device=device)
wuq = torch.randint(0, 7, (1, 48, head_num * 192, 32),
dtype=torch.int8, device=device)
wuq = torch_npu.npu_format_cast(wuq, 29)
de_scale1 = torch.randint(
0, 7, (head_num * 192, ), dtype=torch.int64, device=device)
bias1 = torch.randint(0, 7, (head_num * 192, ),
dtype=torch.int32, device=device)
gamma2 = torch.randn((512), dtype=dtype, device=device)
cos = torch.randn((token_num, 64), dtype=dtype, device=device)
sin = torch.randn((token_num, 64), dtype=dtype, device=device)
wuk = torch.randn((head_num, 128, 512), dtype=dtype, device=device)
wuk = torch_npu.npu_format_cast(wuk, 29)
kv_cache = torch.randint(
0, 7, (block_num, head_num * 512 // 32, block_size, 32), dtype=torch.int8, device=device)
kv_cache = torch_npu.npu_format_cast(kv_cache, 29)
kv_cache_rope = torch.randn(
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype, device=device)
kv_cache_rope = torch_npu.npu_format_cast(kv_cache_rope, 29)
slotmapping = torch.randint(
0, 7, (token_num,), dtype=torch.int32, device=device)
ctkv_scale = torch.randn((1,), dtype=dtype, device=device)
qnope_scale = torch.randn((head_num), dtype=dtype, device=device)
q_out0 = torch.randint(
0, 7, (token_num, head_num, 512), dtype=torch.int8, device=device)
kv_cache_out0 = torch.randint(
0, 7, (block_num, head_num * 512 // 32, block_size, 32), dtype=torch.int8, device=device)
kv_cache_out0 = torch_npu.npu_format_cast(kv_cache_out0, 29)
q_out1 = torch.randn((token_num, head_num, 64),
dtype=dtype, device=device)
kv_cache_out1 = torch.randn(
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype, device=device)
kv_cache_out1 = torch_npu.npu_format_cast(kv_cache_out1, 29)
torch_npu.atb.npu_mla_preprocess(
input1, gamma0, beta0, wdqkv, de_scale0,
gamma1, beta1, wuq, de_scale1,
gamma2, cos, sin, wuk, kv_cache, kv_cache_rope, slotmapping,
quant_scale0=quant_scale0,
quant_offset0=quant_offset0,
bias0=bias0,
quant_scale1=quant_scale0,
quant_offset1=quant_offset1,
bias1=bias1,
ctkv_scale=ctkv_scale,
q_nope_scale=qnope_scale,
cache_mode="int8_nzcache",
quant_mode="per_tensor_quant_asymm",
q_out0=q_out0,
kv_cache_out0=kv_cache_out0,
q_out1=q_out1,
kv_cache_out1=kv_cache_out1,
)
torch.npu.synchronize()
@unittest.skip("Skipping due to outdated CANN version; please update CANN to the latest version and remove this skip")
@SupportedDevices(['Ascend910B'])
def test_atb_mlapo_20_model_case(self):
token_num = 1
head_num = 128
N_7168 = 7168
block_num = 192
block_size = 128
dtype = torch.bfloat16
device = 'npu:0'
input1 = torch.randn((token_num, N_7168), dtype=dtype, device=device)
gamma0 = torch.randn((N_7168), dtype=dtype, device=device)
beta0 = torch.randn((N_7168), dtype=dtype, device=device)
quant_scale0 = torch.randn((1,), dtype=dtype, device=device)
quant_offset0 = torch.randint(
0, 7, (1,), dtype=torch.int8, device=device)
wdqkv = torch.randint(0, 7, (1, 224, 2112, 32),
dtype=torch.int8, device=device)
wdqkv = torch_npu.npu_format_cast(wdqkv, 29)
de_scale0 = torch.rand((2112, ), dtype=torch.float, device=device)
bias0 = torch.randint(0, 7, (2112, ), dtype=torch.int32, device=device)
gamma1 = torch.randn((1536), dtype=dtype, device=device)
beta1 = torch.randn((1536), dtype=dtype, device=device)
quant_scale1 = torch.randn((1,), dtype=dtype, device=device)
quant_offset1 = torch.randint(
0, 7, (1,), dtype=torch.int8, device=device)
wuq = torch.randint(0, 7, (1, 48, head_num * 192, 32),
dtype=torch.int8, device=device)
wuq = torch_npu.npu_format_cast(wuq, 29)
de_scale1 = torch.rand(
(head_num * 192, ), dtype=torch.float, device=device)
bias1 = torch.randint(0, 7, (head_num * 192, ),
dtype=torch.int32, device=device)
gamma2 = torch.randn((512), dtype=dtype, device=device)
cos = torch.randn((token_num, 64), dtype=dtype, device=device)
sin = torch.randn((token_num, 64), dtype=dtype, device=device)
wuk = torch.randn((head_num, 128, 512), dtype=dtype, device=device)
wuk = torch_npu.npu_format_cast(wuk, 29)
kv_cache = torch.randint(
0, 7, (block_num, head_num * 512 // 32, block_size, 32), dtype=torch.int8, device=device)
kv_cache = torch_npu.npu_format_cast(kv_cache, 29)
kv_cache_rope = torch.randn(
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype, device=device)
kv_cache_rope = torch_npu.npu_format_cast(kv_cache_rope, 29)
slotmapping = torch.randint(
0, 7, (token_num,), dtype=torch.int32, device=device)
ctkv_scale = torch.randn((1,), dtype=dtype, device=device)
qnope_scale = torch.randn((head_num), dtype=dtype, device=device)
q_out0 = torch.randint(
0, 7, (token_num, head_num, 512), dtype=torch.int8, device=device)
kv_cache_out0 = torch.randint(
0, 7, (block_num, head_num * 512 // 32, block_size, 32), dtype=torch.int8, device=device)
kv_cache_out0 = torch_npu.npu_format_cast(kv_cache_out0, 29)
q_out1 = torch.randn((token_num, head_num, 64),
dtype=dtype, device=device)
kv_cache_out1 = torch.randn(
(block_num, head_num*64//16, block_size, 16), dtype=dtype, device=device)
kv_cache_out1 = torch_npu.npu_format_cast(kv_cache_out1, 29)
torch_npu.atb.npu_mla_preprocess(
input1, gamma0, beta0, wdqkv, de_scale0,
gamma1, beta1, wuq, de_scale1,
gamma2, cos, sin, wuk, kv_cache, kv_cache_rope, slotmapping,
quant_scale0=quant_scale0,
quant_offset0=quant_offset0,
bias0=bias0,
quant_scale1=quant_scale0,
quant_offset1=quant_offset1,
bias1=bias1,
ctkv_scale=ctkv_scale,
q_nope_scale=qnope_scale,
cache_mode="int8_nzcache",
quant_mode="per_tensor_quant_asymm",
q_out0=q_out0,
kv_cache_out0=kv_cache_out0,
q_out1=q_out1,
kv_cache_out1=kv_cache_out1,
)
out = torch_npu.atb.npu_mla_preprocess(
input1, gamma0, beta0, wdqkv, de_scale0,
gamma1, beta1, wuq, de_scale1,
gamma2, cos, sin, wuk, kv_cache, kv_cache_rope, slotmapping,
quant_scale0=quant_scale0,
quant_offset0=quant_offset0,
bias0=bias0,
quant_scale1=quant_scale0,
quant_offset1=quant_offset1,
bias1=bias1,
ctkv_scale=ctkv_scale,
q_nope_scale=qnope_scale,
cache_mode="int8_nzcache",
quant_mode="per_tensor_quant_asymm"
)
self.assertEqual(out[0], q_out0)
self.assertEqual(out[1], kv_cache_out0)
self.assertEqual(out[2], q_out1)
self.assertEqual(out[3], kv_cache_out1)
if __name__ == "__main__":
run_tests()