import math
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestFA(TestCase):
def gen_seq_len(self, batch, max_seq, variate_seq=False):
if variate_seq:
num = max_seq // 16
seqlen_aligned_arange = np.arange(1, num) * 16
if batch > num:
seqlen_aligned_remain = np.random.randint(1, max_seq, size=(batch - num))
seqlen_aligned_remain[:] = ((seqlen_aligned_remain[:] + 15) // 16) * 16
seqlen_aligned = np.concatenate((seqlen_aligned_arange, seqlen_aligned_remain), 0)
else:
seqlen_aligned = seqlen_aligned_arange
sp_list = np.random.randint(0, 15, size=(num - 1))
seqlen = seqlen_aligned - sp_list
seqlen = seqlen[-batch:]
seqlen_aligned = seqlen_aligned[-batch:]
else:
max_seq_aligned = (max_seq + 15) // 16 * 16
sp_list = np.ones((batch,)) * (max_seq_aligned - max_seq)
sp_list = sp_list.astype(np.int32)
seqlen = np.ones((batch,)) * max_seq
seqlen = seqlen.astype(np.int32)
seqlen_aligned = np.ones((batch,)) * max_seq_aligned
seqlen_aligned = seqlen_aligned.astype(np.int32)
ntokens = seqlen.sum()
return seqlen, seqlen_aligned, ntokens
def group_matmul(self, heads, group_num, A, B):
group_head = heads // group_num
score = None
for i in range(group_num):
group_score = np.matmul(A[i * group_head: (i + 1) * group_head, :, :].astype(np.float32),
B[i:(i + 1), :, :].astype(np.float32)).astype(np.float16)
if score is None:
score = group_score
else:
score = np.concatenate((score, group_score), 0)
return score
def calc_expect_func(self, batch, seqlen, heads, embed, group_num=32):
is_mask = True
variate_seq = False
is_decoder = False
max_seq = 2048
src_type = 'float16'
fp32 = True
if is_decoder:
q_seqlen, q_seqlen_aligned, q_ntokens = self.gen_seq_len(batch, 1, variate_seq)
kv_seqlen, kv_seqlen_aligned, kv_ntokens = self.gen_seq_len(batch, seqlen, variate_seq)
else:
q_seqlen, q_seqlen_aligned, q_ntokens = self.gen_seq_len(batch, seqlen, variate_seq)
kv_seqlen, kv_seqlen_aligned, kv_ntokens = q_seqlen, q_seqlen_aligned, q_ntokens
max_s = np.max(q_seqlen)
ntokens2 = (q_seqlen * kv_seqlen).sum()
embed_v = np.random.randint(1, embed)
q = np.random.uniform(-1.0, 1.0, size=(q_ntokens, heads * embed)).astype(np.float16)
k = np.random.uniform(-1.0, 1.0, size=(kv_ntokens, group_num * embed)).astype(np.float16)
v = np.random.uniform(-1.0, 1.0, size=(kv_ntokens, group_num * embed_v)).astype(np.float16)
mask = np.ones(shape=(1, max_s, max_s)).astype(np.float16)
mask = np.triu(mask, 1)
mask *= -10000.0
q_offset = 0
k_offset = 0
v_offset = 0
s = None
_p = None
out = None
for idx in range(batch):
q_s = q_seqlen[idx]
kv_s = kv_seqlen[idx]
q_slice = q[q_offset:q_offset + q_s][:]
q_slice = q_slice.reshape(q_s, heads, embed)
q_slice = np.transpose(q_slice, (1, 0, 2))
k_slice = k[k_offset:k_offset + kv_s][:]
k_slice = k_slice.reshape(kv_s, group_num, embed)
k_slice = np.transpose(k_slice, (1, 0, 2))
k_slice_t = np.transpose(k_slice, (0, 2, 1))
v_slice = v[v_offset:v_offset + kv_s][:]
v_slice = v_slice.reshape(kv_s, group_num, embed_v)
v_slice = np.transpose(v_slice, (1, 0, 2))
score = self.group_matmul(heads, group_num, q_slice, k_slice_t)
if s is None:
s = score.reshape([-1, ])
else:
s = np.concatenate((s, score.reshape([-1, ])), 0)
tor = np.float16(1.0 / math.sqrt(1.0 * embed))
score = score * tor
if is_mask:
score = score + mask[:, :q_s, :kv_s]
score_max = np.max(score, axis=-1)
score = score - score_max.reshape((heads, q_s, 1))
score_exp = np.exp(score.astype(np.float32))
if not fp32:
score_sum = np.sum(score_exp.astype(np.float16), axis=-1)
if _p is None:
_p = score_exp.astype(np.float16).reshape([-1, ])
else:
_p = np.concatenate((_p, score_exp.astype(np.float16).reshape([-1, ])), 0)
p = score_exp.astype(np.float16) / score_sum.reshape((heads, q_s, 1)).astype(np.float16)
out_sub = self.group_matmul(heads, group_num, p, v_slice)
else:
score_sum = np.sum(score_exp, axis=-1)
if _p is None:
_p = score_exp.astype(np.float16).reshape([-1, ])
else:
_p = np.concatenate((_p, score_exp.astype(np.float16).reshape([-1, ])), 0)
p = score_exp.astype(np.float16)
out_sub = self.group_matmul(heads, group_num, p, v_slice)
out_sub = out_sub / score_sum.reshape((heads, q_s, 1)).astype(np.float16)
out_sub = out_sub.reshape(heads, q_s, embed_v)
out_sub = np.transpose(out_sub, (1, 0, 2))
out_sub = np.ascontiguousarray(out_sub)
if out is None:
out = out_sub
else:
out = np.concatenate((out, out_sub), 0)
q_offset += q_s
k_offset += kv_s
v_offset += kv_s
q = q.astype(src_type).reshape(-1, heads, embed)
k = k.astype(src_type).reshape(-1, group_num, embed)
v = v.astype(src_type).reshape(-1, group_num, embed_v)
mask = mask.astype(src_type).reshape(max_s, max_s)
q_len = q_seqlen.astype(np.int32)
out_expect = out.astype(src_type).reshape(-1, heads, embed_v)
ret_data = q, k, v, mask, q_len, tor, heads, out_expect
return ret_data
@SupportedDevices(['Ascend910B'])
def test_flash_attention(self):
kv_head = 32
data = self.calc_expect_func(16, 128, 32, 128, group_num=kv_head)
param_seqlen = data[4].tolist()
in_tensors = []
for tensor in data:
if isinstance(tensor, np.ndarray):
in_tensors.append(torch.from_numpy(tensor))
else:
in_tensors.append(torch.tensor(tensor))
in_tensors = [tensor.npu() for tensor in in_tensors]
query = in_tensors[0]
key = in_tensors[1]
value = in_tensors[2]
mask = in_tensors[3]
seq_len = in_tensors[4].cpu()
tor = data[5]
heads = data[6]
group_num = kv_head
cal_out = in_tensors[7]
out = torch.empty_like(in_tensors[7]).npu()
torch_npu._npu_flash_attention(query, key, value, mask, seq_len, tor, heads, group_num, out)
self.assertRtolEqual(cal_out, out)
if __name__ == '__main__':
run_tests()