import unittest
import random
import time
from dataclasses import dataclass
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
@dataclass
class GenerateDataParams:
num_prompts: int = 2
num_heads: int = 32
kv_heads: int = 32
head_size: int = 128
block_size: int = 128
num_blocks: int = 64
@dataclass
class RefFAWithPrefixCacheParams:
output: torch.Tensor
query: torch.Tensor
key_cache: torch.Tensor
value_cache: torch.Tensor
block_tables: torch.Tensor
seq_lens: torch.Tensor
query_lens: torch.Tensor
mask: torch.Tensor
@dataclass
class GenerateDataResult:
query: torch.Tensor
key_cache: torch.Tensor
value_cache: torch.Tensor
block_tables: torch.Tensor
seq_lens: torch.Tensor
query_lens: torch.Tensor
mask: torch.Tensor
ref_output: torch.Tensor
class TestFA(TestCase):
SCALE_TYPE = 0
MAX_SEQ_LEN = 1024
rand_list = np.random.uniform(0.0, 2.0, size=100).astype(np.float32)
rand_list = [x if x > 1.0 else 1.0 for x in rand_list]
def group_matmul(self, head, kv_head, A, B):
group_num = head // kv_head
score = None
for i in range(kv_head):
group_score = torch.matmul(
A[i * group_num: (i + 1) * group_num, :, :],
B[i: (i + 1), :, :],
)
if score is None:
score = group_score
else:
score = torch.cat((score, group_score), dim=0)
return score
def ref_masked_attention(self, query, key, value, scale: float, mask=None):
query_len = query.shape[0]
seq_len = key.shape[0]
query = query * scale
query = query.permute(1, 0, 2)
key = key.permute(1, 2, 0)
score = self.group_matmul(query.shape[0], key.shape[0], query, key)
if mask is not None:
score = score + mask[:, seq_len - query_len:seq_len, :seq_len]
row_max = torch.max(score, dim=-1, keepdim=True).values
score -= row_max
score = score.to(torch.float32)
score = torch.exp(score)
row_sum = torch.sum(score, axis=-1, keepdims=True)
p = score / row_sum
p = p.to(torch.bfloat16)
value = value.permute(1, 0, 2)
out = self.group_matmul(query.shape[0], key.shape[0], p, value)
out = out.permute(1, 0, 2)
return out
def ref_fa_with_prefix_cache(self, params: RefFAWithPrefixCacheParams):
num_prompts = params.query_lens.shape[0]
num_heads = params.query.shape[1]
kv_heads = params.value_cache.shape[2]
head_size = params.key_cache.shape[3]
head_size_v = params.value_cache.shape[3]
block_size = params.value_cache.shape[1]
curr = 0
for i in range(num_prompts):
seq_len = int(params.seq_lens[i])
query_len = int(params.query_lens[i])
querys = params.query[curr:curr + query_len]
block_table = params.block_tables[i]
keys = []
values = []
for j in range(seq_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = params.key_cache[block_number, block_offset, :, :]
k = k.reshape(kv_heads, head_size)
keys.append(k)
v = params.value_cache[block_number, block_offset, :, :]
v = v.reshape(kv_heads, head_size_v)
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
if self.SCALE_TYPE == 1:
scale = 1.0 * self.rand_list[i] / (head_size ** 0.5)
else:
scale = 1.0 / (head_size ** 0.5)
out = self.ref_masked_attention(querys, keys, values, scale, params.mask)
out = out.reshape(query_len, num_heads, head_size_v)
params.output[curr:curr + query_len] = out
curr += query_len
def generate_data(self, params: GenerateDataParams):
seq_lens = [1024] * params.num_prompts
context_lens = random.choices([128 * n for n in range(1, 8)], k=params.num_prompts)
query_lens = [seq_len - context_len for seq_len, context_len in zip(seq_lens, context_lens)]
num_tokens = sum(query_lens)
head_size_v = params.head_size
query = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(num_tokens, params.num_heads, params.head_size))).to(torch.bfloat16)
key_cache = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(params.num_blocks, params.block_size, params.kv_heads, params.head_size))).to(torch.bfloat16)
value_cache = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(params.num_blocks, params.block_size, params.kv_heads, head_size_v))).to(torch.bfloat16)
max_seq_len = max(seq_lens)
max_num_blocks_per_seq = (max_seq_len + params.block_size - 1) // params.block_size
block_tables = []
for _ in range(params.num_prompts):
block_table = [
random.randint(0, params.num_blocks - 1) for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
query_lens = torch.tensor(query_lens, dtype=torch.int32)
block_tables = torch.tensor(block_tables, dtype=torch.int32)
mask = torch.ones(size=(1, max_seq_len, max_seq_len), dtype=torch.bfloat16)
mask = torch.triu(mask, 1)
mask *= -10000.0
ref_output = torch.zeros((num_tokens, params.num_heads, head_size_v), dtype=torch.bfloat16)
cache_params = RefFAWithPrefixCacheParams(
output=ref_output,
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_tables,
seq_lens=seq_lens,
query_lens=query_lens,
mask=mask
)
self.ref_fa_with_prefix_cache(cache_params)
return GenerateDataResult(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_tables,
seq_lens=seq_lens,
query_lens=query_lens,
mask=mask,
ref_output=ref_output
)
@unittest.skip("skip case")
@SupportedDevices(['Ascend910B'])
def test_flash_attention_qlens(self):
num_heads = 32
num_kv_heads = 16
head_dim = 128
num_prompts = 2
block_num = 64
block_size = 128
data_params = GenerateDataParams(
num_prompts=num_prompts,
num_heads=num_heads,
kv_heads=num_kv_heads,
head_size=head_dim,
block_size=block_size,
num_blocks=block_num
)
result = self.generate_data(data_params)
query = result.query.npu()
key_cache = result.key_cache.npu()
value_cache = result.value_cache.npu()
block_tables = result.block_tables.npu()
gt_output = result.ref_output.npu()
output = torch.empty_like(gt_output)
seq_lens = result.seq_lens
query_lens = result.query_lens
scale = head_dim ** -0.5
mask_compress = torch.ones(size=(128, 128), dtype=torch.bfloat16)
mask_compress = torch.triu(mask_compress, 1)
mask_compress = mask_compress.npu()
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_tables,
mask=mask_compress,
seq_len=query_lens,
context_lens=seq_lens,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
out=output)
self.assertRtolEqual(output, gt_output)
if __name__ == '__main__':
run_tests()