import random
import os
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
np.random.seed(1)
random.seed(1)
MAX_SEQ_LEN = 1024
kv_head_num = 32
class PagedInputData:
def __init__(self, query, key_cache, value_cache, block_tables, context_lens, mask):
self.query = query
self.key_cache = key_cache
self.value_cache = value_cache
self.block_tables = block_tables
self.context_lens = context_lens
self.mask = mask
class TestPagedAttentionAlibi(TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
os.environ["ASCEND_LAUNCH_BLOCKING"] = "0"
def group_mm_torch(self, heads, group_num, A, B):
group_head = heads // group_num
score = None
for i in range(group_num):
group_score = torch.matmul(A[i * group_head: (i + 1) * group_head, :, :].to(torch.float32),
B[i:(i + 1), :, :].to(torch.float32))
if score is None:
score = group_score
else:
score = torch.cat((score, group_score), 0)
return score
def ref_masked_attention(
self,
query,
key,
value,
scale: float,
alibi_bias,
mask_data_type=torch.bfloat16
):
query = query
query = torch.permute(query, (1, 0, 2))
key = torch.permute(key, (1, 2, 0))
sim = self.group_mm_torch(query.shape[0], key.shape[0], query, key).to(mask_data_type)
sim = sim.to(torch.float32) * scale
sim = sim + alibi_bias.to(torch.float32)
sim = sim.numpy()
row_max = np.max(sim, axis=-1, keepdims=True)
sim -= row_max
sim = np.exp(sim)
row_sum = np.sum(sim, axis=-1, keepdims=True)
p = sim / row_sum
p = torch.from_numpy(p).to(mask_data_type)
value = torch.permute(value, (1, 0, 2))
out = self.group_mm_torch(query.shape[0], key.shape[0], p, value)
out = torch.permute(out, (1, 0, 2))
return out
def ref_single_query_cached_kv_attention(
self,
output,
query,
key_cache,
value_cache,
block_tables,
context_lens,
mask,
mask_dim=4,
mask_data_type=torch.bfloat16
) -> None:
num_heads = query.shape[1]
kv_heads = value_cache.shape[2]
head_size = key_cache.shape[3]
head_size_v = value_cache.shape[3]
block_size = value_cache.shape[1]
num_input_tokens = query.shape[0]
for i in range(num_input_tokens):
q = query[i].view(1, num_heads, head_size)
block_table = block_tables[i]
context_len = int(context_lens[i])
keys = []
values = []
for j in range(context_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, block_offset, :, :]
k = k.reshape(kv_heads, head_size)
keys.append(k)
v = value_cache[block_number, block_offset, :, :]
v = v.reshape(kv_heads, head_size_v)
values.append(v)
keys = torch.stack(keys, axis=0)
values = torch.stack(values, axis=0)
scale = np.float32(1.0 / (head_size ** 0.5))
if mask_dim == 4:
out = self.ref_masked_attention(q, keys, values, scale, mask[i, :, :, :context_len], mask_data_type)
out = out.reshape(num_heads, head_size_v)
elif mask_dim == 3:
out = self.ref_masked_attention(q, keys, values, scale, mask[i, :, :context_len], mask_data_type)
out = out.reshape(num_heads, head_size_v)
output[i] = out
def calc_data(self, num_tokens, num_heads, kv_heads, head_size, block_size, num_blocks, k_seqlen, dtype, mask_dim=4, mask_data_type=torch.bfloat16):
head_size_v = np.random.randint(1, head_size)
query = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(num_tokens, num_heads, head_size))).to(mask_data_type)
key_cache = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(num_blocks, block_size, kv_heads, head_size))).to(mask_data_type)
value_cache = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(num_blocks, block_size, kv_heads, head_size_v))).to(mask_data_type)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
context_lens = [k_seqlen] * num_tokens
max_context_len = max(context_lens)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_tokens):
block_table = [random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq)]
block_tables.append(block_table)
if mask_dim == 4:
alibi_slopes = np.random.random(num_heads).astype(np.float16)
mask = np.zeros((num_tokens, num_heads, 1, max_context_len), dtype=np.float16)
for i, context_len in enumerate(context_lens):
position_ids = np.arange(context_len).astype(np.int32)
alibi_bias = (position_ids - context_len + 1).astype(np.float16)
alibi_bias = alibi_slopes.reshape(-1, 1, 1) * alibi_bias.reshape(1, 1, -1)
mask[i, :, :, :context_len] = alibi_bias
mask = torch.from_numpy(mask).to(mask_data_type)
elif mask_dim == 3:
mask = np.zeros((num_tokens, 1, max_context_len), dtype=np.float16)
for i in range(num_tokens):
mask[i, :, :i] = -10000
mask = torch.from_numpy(mask).to(mask_data_type)
else:
assert (False)
ref_output = torch.zeros((num_tokens, num_heads, head_size_v)).to(mask_data_type)
self.ref_single_query_cached_kv_attention(
ref_output,
query,
key_cache,
value_cache,
block_tables,
context_lens,
mask,
mask_dim,
mask_data_type)
q = query
key_cache = key_cache
value_cache = value_cache
block_tables = np.array(block_tables).astype(np.int32)
contex_lens = np.array(context_lens).astype(np.int32)
alib_mask = mask
golden_out = ref_output
return q, key_cache, value_cache, block_tables, contex_lens, alib_mask, golden_out
def compare_output_data(self, out, golden, ratios):
error_count = 0
strict_error_count = 0
fp16_min_normal = 1.0 / (1 << 14)
golden = golden.to(torch.float32)
out = out.to(torch.float32)
total_elements = out.shape[0] * out.shape[1] * out.shape[2]
diff = torch.abs(golden - out)
max_diff = diff.max().item()
limit_error = torch.maximum(torch.abs(golden * ratios[0]), torch.tensor(ratios[1]))
strict_limit_error = torch.maximum(torch.abs(golden * ratios[2]), torch.tensor(ratios[3]))
error_count = torch.gt(diff, limit_error).sum().item()
strict_error_count = torch.gt(diff, strict_limit_error).sum().item()
print(f"maxDiff {max_diff}")
print("1/1000 Accuracy is %f", 1 - float(error_count) / total_elements)
print("5/1000 Accuracy is %f", 1 - float(strict_error_count) / total_elements)
if self.data_type == torch.bfloat16 or self.is_int8_flag:
print("accuracy is correct: %r", (float(strict_error_count) / total_elements) <= ratios[2])
else:
print("accuracy is correct: %r", (float(strict_error_count) / total_elements) <= ratios[0])
calc_times = out.shape[2] * self.max_context_len + 4
if self.data_type == torch.bfloat16:
if calc_times < 2048:
error = 2**(-7)
else:
error = 2**(-6)
error_threshold = torch.clamp(torch.abs(golden), min=1) * error
return (diff <= error_threshold).all()
else:
if calc_times < 2048:
error = 2**(-8)
else:
error = 2**(-7)
error_threshold = torch.clamp(torch.abs(golden), min=1) * error
return (diff <= error_threshold).all()
@SupportedDevices(['Ascend910B'])
def test_paged_attention_v2_bf16(self):
self.num_tokens = 1
self.num_heads = 32
self.kv_heads = 32
self.block_size = 128
self.head_size = 288
self.num_blocks = 64
self.k_seqlen = 128
self.tor = 1.0 / (self.head_size ** 0.5)
self.dtype = "float16"
self.mask_dim = 4
self.data_type = torch.bfloat16
self.max_context_len = self.k_seqlen
self.q, self.key_cache, self.value_cache, self.block_tables, self.contex_lens, self.alib_mask, self.golden_out = self.calc_data(
self.num_tokens,
self.num_heads,
self.kv_heads,
self.head_size,
self.block_size,
self.num_blocks,
self.k_seqlen,
self.dtype,
self.mask_dim,
self.data_type)
self.data = self.q, self.key_cache, self.value_cache, torch.from_numpy(self.block_tables), torch.from_numpy(
self.contex_lens), self.alib_mask, self.golden_out
self.in_tensors = [tensor.npu() for tensor in self.data]
self.query = self.in_tensors[0]
self.keyCache = self.in_tensors[1]
self.valueCache = self.in_tensors[2]
self.blockTables = self.in_tensors[3]
self.contextLens = self.in_tensors[4].cpu()
self.mask = self.in_tensors[5]
self.attnOut = torch.empty_like(self.golden_out).npu()
torch_npu.atb._npu_paged_attention_v2(self.query, self.keyCache, self.blockTables, self.contextLens, value_cache=self.valueCache, mask=self.mask, num_kv_heads=self.kv_heads, num_heads=self.num_heads, scale_value=self.tor, mask_type=2, out=self.attnOut)
ratios = [0.001, 0.001, 0.005, 0.005]
self.compare_output_data(self.attnOut.cpu(), self.golden_out, ratios)
@SupportedDevices(['Ascend910B'])
def test_paged_attention_v2(self):
self.num_tokens = 1
self.num_heads = 32
self.kv_heads = 32
self.block_size = 128
self.head_size = 288
self.num_blocks = 64
self.k_seqlen = 128
self.tor = 1.0 / (self.head_size ** 0.5)
self.dtype = "float16"
self.mask_dim = 4
self.data_type = torch.float16
self.is_int8_flag = False
self.max_context_len = self.k_seqlen
self.q, self.key_cache, self.value_cache, self.block_tables, self.contex_lens, self.alib_mask, self.golden_out = self.calc_data(
self.num_tokens, self.num_heads, self.kv_heads, self.head_size, self.block_size, self.num_blocks,
self.k_seqlen, self.dtype, self.mask_dim, self.data_type)
self.data = self.q, self.key_cache, self.value_cache, torch.from_numpy(self.block_tables), torch.from_numpy(
self.contex_lens), self.alib_mask, self.golden_out
self.in_tensors = [tensor.npu() for tensor in self.data]
self.query = self.in_tensors[0]
self.keyCache = self.in_tensors[1]
self.valueCache = self.in_tensors[2]
self.blockTables = self.in_tensors[3]
self.contextLens = self.in_tensors[4].cpu()
self.mask = self.in_tensors[5]
self.attnOut = torch.empty_like(self.golden_out).npu()
torch_npu.atb._npu_paged_attention_v2(self.query, self.keyCache, self.blockTables, self.contextLens, value_cache=self.valueCache, mask=self.mask, num_kv_heads=self.kv_heads, num_heads=self.num_heads, scale_value=self.tor, mask_type=2, out=self.attnOut)
self.assertRtolEqual(self.golden_out, self.attnOut)
ratios = [0.001, 0.001, 0.005, 0.005]
self.compare_output_data(self.attnOut.cpu(), self.golden_out, ratios)
@SupportedDevices(['Ascend910B'])
def test_paged_attention_v2_aclgraph_update(self):
self.num_tokens = 1
self.num_heads = 32
self.kv_heads = 32
self.block_size = 128
self.head_size = 288
self.num_blocks = 64
self.k_seqlen = 128
self.tor = 1.0 / (self.head_size ** 0.5)
self.dtype = "float16"
self.mask_dim = 4
self.data_type = torch.float16
self.is_int8_flag = False
self.max_context_len = self.k_seqlen
self.q, self.key_cache, self.value_cache, self.block_tables, self.contex_lens, self.alib_mask, self.golden_out = self.calc_data(
self.num_tokens, self.num_heads, self.kv_heads, self.head_size, self.block_size, self.num_blocks,
self.k_seqlen, self.dtype, self.mask_dim, self.data_type)
self.data = self.q, self.key_cache, self.value_cache, torch.from_numpy(self.block_tables), torch.from_numpy(
self.contex_lens), self.alib_mask, self.golden_out
self.in_tensors = [tensor.npu() for tensor in self.data]
self.query = self.in_tensors[0]
self.keyCache = self.in_tensors[1]
self.valueCache = self.in_tensors[2]
self.blockTables = self.in_tensors[3]
self.contextLens = torch.tensor([128], dtype=torch.int32)
context_lens_new = torch.tensor([64], dtype=torch.int32)
self.mask = self.in_tensors[5]
self.attnOut = torch.empty_like(self.golden_out).npu()
attnOut_Graph = torch.empty_like(self.golden_out).npu()
torch_npu.atb._npu_paged_attention_v2(self.query, self.keyCache, self.blockTables, context_lens_new, value_cache=self.valueCache,
mask=self.mask, num_kv_heads=self.kv_heads, num_heads=self.num_heads, scale_value=self.tor, mask_type=2, out=self.attnOut)
g = torch.npu.NPUGraph()
event = torch.npu.ExternalEvent()
update_stream = torch.npu.Stream()
handle = None
workspace = torch_npu.atb._npu_paged_attention_v2_get_workspace(self.query, self.keyCache, self.blockTables, context_lens_new,
value_cache=self.valueCache, mask=self.mask, num_kv_heads=self.kv_heads,
num_heads=self.num_heads, scale_value=self.tor, mask_type=2, out=attnOut_Graph)
with torch.npu.graph(g):
stream = torch.npu.current_stream()
event.wait(stream)
event.reset(stream)
torch.npu.graph_task_group_begin(stream)
torch_npu.atb._npu_paged_attention_v2(self.query, self.keyCache, self.blockTables, self.contextLens,
value_cache=self.valueCache, mask=self.mask, num_kv_heads=self.kv_heads,
num_heads=self.num_heads, scale_value=self.tor, mask_type=2,
workspace=workspace, out=attnOut_Graph)
handle = torch.npu.graph_task_group_end(stream)
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.atb._npu_paged_attention_v2(self.query, self.keyCache, self.blockTables, context_lens_new,
value_cache=self.valueCache, mask=self.mask, num_kv_heads=self.kv_heads,
num_heads=self.num_heads, scale_value=self.tor, mask_type=2,
workspace=workspace, out=attnOut_Graph)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
g.replay()
self.assertEqual(attnOut_Graph.cpu(), self.attnOut.cpu())
if __name__ == '__main__':
run_tests()