import math
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestNPUFlashAttention(TestCase):
def supported_op_exec(self, query, key, value, atten_mask):
scale = 0.08838
qk = torch.matmul(query, key.transpose(2, 3)).mul(scale)
qk = qk + atten_mask * (-10000.0)
softmax_res = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32).to(torch.float16)
attention_out = torch.matmul(softmax_res, value)
return attention_out
def custom_op_exec(self, query, key, value, sparse_params):
scale = 0.08838
atten_mask = None
if sparse_params[0] == 0:
shape = [1, 8, 256, 256]
atten_mask_u = np.triu(np.ones(shape), k=sparse_params[1] + 1)
atten_mask_l = np.tril(np.ones(shape), k=-sparse_params[2] - 1)
atten_masks = atten_mask_u + atten_mask_l
atten_mask = torch.tensor(atten_masks).to(torch.float16).bool().npu()
if sparse_params[0] == 2 or sparse_params[0] == 3 or sparse_params[0] == 4:
atten_masks = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1))
atten_mask = torch.tensor(atten_masks).to(torch.float16).bool().npu()
return torch_npu.npu_fusion_attention(
query, key, value, head_num=8, input_layout="BNSD", scale=scale, sparse_mode=sparse_params[0],
atten_mask=atten_mask, pre_tockens=sparse_params[1], next_tockens=sparse_params[2])
def get_atten_mask(self, sparse_mode=0, pre_tokens=65536, next_tokens=65536):
atten_masks = []
shape = [1, 8, 256, 256]
if sparse_mode == 0:
atten_mask_u = np.triu(np.ones(shape), k=next_tokens + 1)
atten_mask_l = np.tril(np.ones(shape), k=-pre_tokens - 1)
atten_masks = atten_mask_u + atten_mask_l
elif sparse_mode == 1:
atten_masks = np.zeros(shape)
pre_tokens = 65536
next_tokens = 65536
elif sparse_mode == 2:
atten_masks = np.triu(np.ones(shape), k=1)
elif sparse_mode == 3:
atten_masks = np.triu(np.ones(shape), k=1)
elif sparse_mode == 4:
atten_mask_u = np.triu(np.ones(shape), k=next_tokens + 1)
atten_mask_l = np.tril(np.ones(shape), k=-pre_tokens - 1)
atten_masks = atten_mask_u + atten_mask_l
atten_mask = torch.tensor(atten_masks).to(torch.float16)
return atten_mask
def check_result(self, query, key, value, sparse_params):
atten_mask = self.get_atten_mask(sparse_params[0], sparse_params[1], sparse_params[2])
output = self.supported_op_exec(query, key, value, atten_mask)
fa_result = self.custom_op_exec(query.npu(), key.npu(), value.npu(), sparse_params)
self.assertRtolEqual(output, fa_result[0], prec=0.01, prec16=0.01)
@SupportedDevices(['Ascend910B'])
def test_npu_flash_attention(self, device="npu"):
query = torch.randn(1, 8, 256, 256, dtype=torch.float16)
key = torch.randn(1, 8, 256, 256, dtype=torch.float16)
value = torch.randn(1, 8, 256, 256, dtype=torch.float16)
sparse_params_list = [
[0, 128, 128],
[1, 65536, 65536],
[2, 65536, 0],
[3, 65536, 0],
[4, 128, 128]
]
for sparse_params in sparse_params_list:
self.check_result(query, key, value, sparse_params)
if __name__ == "__main__":
run_tests()