import math
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
PER_BLOCK_Q_SIZE = 128
PER_BLOCK_KV_SIZE = 256
FP8_MAX = 448.0
class TestNPUQuantFlashAttentionV2(TestCase):
def npu_block_quant(self, tensor, scale, PER_BLOCK_SIZE=128):
dim1 = tensor.shape[0]
dim2 = tensor.shape[1]
dim3 = tensor.shape[2]
dim4 = tensor.shape[3]
quanted_tensor = torch.zeros([dim1, dim2, dim3, dim4]).to(torch.float32)
for b in range(dim1):
for n in range(dim2):
for s in range(0, dim3, PER_BLOCK_SIZE):
s_start = s // PER_BLOCK_SIZE * PER_BLOCK_SIZE
s_end = min(s // PER_BLOCK_SIZE * PER_BLOCK_SIZE + PER_BLOCK_SIZE, dim3)
quanted_tensor[b, n, s_start:s_end, :] = tensor[b, n, s_start:s_end, :] * scale[
b, n, s // PER_BLOCK_SIZE, 0]
return quanted_tensor
def calc_quant_scale(self, tensor, PER_BLOCK_SIZE=128):
dim1 = tensor.shape[0]
dim2 = tensor.shape[1]
dim3 = tensor.shape[2]
dim4 = tensor.shape[3]
scale_for_tensor = torch.ones([dim1, dim2, math.ceil(dim3 / PER_BLOCK_SIZE), 1]).to(torch.float32)
for b in range(dim1):
for n in range(dim2):
for s in range(0, dim3, PER_BLOCK_SIZE):
s_start = s // PER_BLOCK_SIZE * PER_BLOCK_SIZE
s_end = min(s // PER_BLOCK_SIZE * PER_BLOCK_SIZE + PER_BLOCK_SIZE, dim3)
chunk = tensor[b, n, s_start:s_end, :]
max_val = torch.max(torch.abs(chunk))
epsilon = 1e-8
scale_for_tensor[b, n, s // PER_BLOCK_SIZE, 0] = FP8_MAX / (max_val + epsilon)
return scale_for_tensor
def supported_op_exec_quant(self, query, key, value, d_scale_q, d_scale_k, d_scale_v):
scale = 0.08838
query = query.to(torch.float32)
key = key.to(torch.float32)
value = value.to(torch.float32)
d_scale_qf = torch.zeros(1, 8, 256, 1, dtype=torch.float32)
d_scale_kf = torch.zeros(1, 8, 256, 1, dtype=torch.float32)
d_scale_vf = torch.zeros(1, 8, 256, 1, dtype=torch.float32)
for i in range(256):
d_scale_qf[:, :, i, :] = d_scale_q[:, :, i // PER_BLOCK_Q_SIZE, :]
d_scale_kf[:, :, i, :] = d_scale_k[:, :, i // PER_BLOCK_KV_SIZE, :]
d_scale_vf[:, :, i, :] = d_scale_v[:, :, i // PER_BLOCK_KV_SIZE, :]
qk = torch.matmul(query * d_scale_qf, (key * d_scale_kf).transpose(2, 3)).mul(scale)
max = torch.max(qk, dim=-1, keepdim=True)[0]
softmax_res = torch.exp(qk - max).to(torch.float8_e4m3fn).to(torch.float32)
sum = torch.sum(softmax_res, dim=-1, keepdim=True)
attention_out = torch.matmul(softmax_res, value * d_scale_vf)
attention_out_res = attention_out / sum
return attention_out_res
def custom_op_exec_quant(self, query, key, value, d_scale_q, d_scale_k, d_scale_v, p_scale):
scale = 0.08838
return torch_npu.npu_quant_fusion_attention(
query, key, value, head_num=8, input_layout="BNSD", scale=scale,
d_scale_q=d_scale_q, d_scale_k=d_scale_k, d_scale_v=d_scale_v, p_scale=p_scale)
@SupportedDevices(['Ascend950'])
def test_npu_quant_flash_attention_with_fp8(self, device="npu"):
query = torch.randn(1, 8, 256, 128, dtype=torch.float16)
key = torch.randn(1, 8, 256, 128, dtype=torch.float16)
value = torch.randn(1, 8, 256, 128, dtype=torch.float16)
scale_q = self.calc_quant_scale(query, 128).to(torch.float32)
scale_k = self.calc_quant_scale(key, 256).to(torch.float32)
scale_v = self.calc_quant_scale(value, 256).to(torch.float32)
p_scale = torch.randn(1, dtype=torch.float32)
d_scale_q = 1 / scale_q
d_scale_k = 1 / scale_k
d_scale_v = 1 / scale_v
query_fp8 = self.npu_block_quant(query, scale_q, 128).to(torch.float8_e4m3fn)
key_fp8 = self.npu_block_quant(key, scale_k, 256).to(torch.float8_e4m3fn)
value_fp8 = self.npu_block_quant(value, scale_v, 256).to(torch.float8_e4m3fn)
output = self.supported_op_exec_quant(query_fp8.to(torch.float32), key_fp8.to(torch.float32), value_fp8.to(torch.float32),
d_scale_q, d_scale_k, d_scale_v)
fa_result = self.custom_op_exec_quant(query_fp8.npu(), key_fp8.npu(), value_fp8.npu(), d_scale_q.npu(), d_scale_k.npu(),
d_scale_v.npu(), p_scale.npu())
self.assertRtolEqual(output.half(), fa_result[0], prec=0.01, prec16=0.01)
@SupportedDevices(['Ascend950'])
def test_npu_quant_flash_attention_with_hifp8(self, device="npu"):
scale = 0.08838
query = torch.ones(1, 57600, 5, 128, dtype=torch.uint8).npu()
key = torch.ones(1, 57600, 5, 128, dtype=torch.uint8).npu()
value = torch.ones(1, 57600, 5, 128, dtype=torch.uint8).npu()
d_scale_q = torch.ones(1, 5, 450, 1, dtype=torch.float32).npu()
d_scale_k = torch.ones(1, 5, 225, 1, dtype=torch.float32).npu()
d_scale_v = torch.ones(1, 5, 113, 1, dtype=torch.float32).npu()
query_dtype = torch_npu.hifloat8
p_scale = torch.ones(1, dtype=torch.float32).npu()
custom_output = torch_npu.npu_quant_fusion_attention(
query, key, value, head_num=5, input_layout="BSND", scale=scale,
d_scale_q=d_scale_q, d_scale_k=d_scale_k, d_scale_v=d_scale_v, p_scale=p_scale, query_dtype=query_dtype)
golden_output = torch.ones(2, 1, 1, 128, dtype=torch.float16).npu()
res = custom_output[0].equal(golden_output)
self.assertRtolEqual(res, True)
if __name__ == "__main__":
run_tests()