from dataclasses import dataclass
import random
import numpy as np
import torch
import torch_npu
from torch_npu.testing.common_utils import SupportedDevices
from torch_npu.testing.testcase import TestCase, run_tests
@dataclass
class PAAttentionParamsNumpy:
query: np.ndarray
key_cache: np.ndarray
value_cache: np.ndarray
block_table: np.ndarray
context_lens: np.ndarray
@dataclass
class PAAttentionParamsTensor:
query: torch.Tensor
key_cache: torch.Tensor
value_cache: torch.Tensor
block_table: torch.Tensor
class TestPAAclgraphUpdate(TestCase):
num_blocks = 64
num_tokens = 2
block_size = 128
kv_heads = 16
head_size = 288
num_heads = 32
head_size_v = 96
scale = 0.38888
def group_matmul(self, head, kv_head, A, B):
group_num = head // kv_head
score = []
for i in range(kv_head):
group_A = A[i * group_num: (i + 1) * group_num]
group_B = B[i: i + 1]
score.append(np.matmul(group_A, group_B))
return np.concatenate(score, axis=0)
def ref_masked_attention(self, query, key, value):
"""参考注意力计算"""
query = query * self.scale
query = query.transpose(1, 0, 2)
key = key.transpose(1, 2, 0)
sim = self.group_matmul(query.shape[0], key.shape[0], query, key)
sim = sim - np.max(sim, axis=-1, keepdims=True)
exp_sim = np.exp(sim.astype(np.float32))
p = exp_sim / np.sum(exp_sim, axis=-1, keepdims=True)
p = p.astype(np.float16)
value = value.transpose(1, 0, 2)
out = self.group_matmul(p.shape[0], key.shape[0], p, value)
return out.transpose(1, 0, 2)
def golden_attention_impl(self, params_np):
output = np.zeros((self.num_tokens, self.num_heads, self.head_size_v), dtype=np.float16)
for i in range(self.num_tokens):
seq_blocks = params_np.block_table[i]
context_len = params_np.context_lens[i]
keys = []
values = []
for pos in range(context_len):
block_id = seq_blocks[pos // self.block_size]
offset = pos % self.block_size
keys.append(params_np.key_cache[block_id, offset].reshape(self.kv_heads, -1))
values.append(params_np.value_cache[block_id, offset].reshape(self.kv_heads, -1))
out = self.ref_masked_attention(
params_np.query[i:i + 1],
np.stack(keys),
np.stack(values)
)
output[i] = out.reshape(self.num_heads, -1)
return output
def preprocess(self):
"""生成测试输入数据"""
query_np = np.random.uniform(-1, 1, (self.num_tokens, self.num_heads, self.head_size)).astype(np.float16)
key_cache_np = np.random.uniform(-1, 1, (self.num_blocks, self.block_size, self.kv_heads, self.head_size)).astype(np.float16)
value_cache_np = np.random.uniform(-1, 1, (self.num_blocks, self.block_size, self.kv_heads, self.head_size_v)).astype(np.float16)
max_blocks_per_seq = (1024 + self.block_size - 1) // self.block_size
block_table_np = np.array([
[random.randint(0, self.num_blocks - 1) for _ in range(max_blocks_per_seq)]
for _ in range(self.num_tokens)
], dtype=np.int32)
context_lens_np = np.full(self.num_tokens, random.randint(1, 1024), dtype=np.int32)
params_np = PAAttentionParamsNumpy(query_np, key_cache_np, value_cache_np, block_table_np, context_lens_np)
golden_output = self.golden_attention_impl(params_np)
golden_output = torch.from_numpy(golden_output)
query = torch.from_numpy(query_np).npu()
key_cache = torch.from_numpy(key_cache_np).npu()
value_cache = torch.from_numpy(value_cache_np).npu()
block_table = torch.from_numpy(block_table_np).npu()
self.context_lens = torch.from_numpy(context_lens_np)
self.output = torch.zeros_like(query[:, :, :self.head_size_v]).npu()
params_tensor = PAAttentionParamsTensor(query, key_cache, value_cache, block_table)
return params_tensor, golden_output
def atb_paged_attention(self, query, key_cache, value_cache, block_table):
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=self.kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_table,
context_lens=self.context_lens,
out=self.output,
)
@SupportedDevices(['Ascend910B'])
def test_paged_attention_compile_with_aclgraph(self):
params, golden_output = self.preprocess()
compiled_fn = torch.compile(self.atb_paged_attention, backend="inductor", dynamic=False, options={"triton.cudagraphs": True})
compiled_fn(params.query, params.key_cache, params.value_cache, params.block_table)
self.assertRtolEqual(self.output, golden_output, prec16=0.01)
params, golden_output = self.preprocess()
compiled_fn(params.query, params.key_cache, params.value_cache, params.block_table)
self.assertRtolEqual(self.output, golden_output, prec16=0.01)
if __name__ == "__main__":
run_tests()