import math
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestAtbAttention(TestCase):
def compare_output_data(self, out, golden, ratios):
error_count = 0
strict_error_count = 0
fp16_min_normal = 1.0 / (1 << 14)
golden = golden.flatten().to(torch.float32)
out = out.flatten().to(torch.float32)
out_len = out.shape[0]
diff = torch.abs(golden - out)
max_diff = diff.max().item()
limit_error = torch.maximum(torch.abs(golden * ratios[0]), torch.tensor(ratios[1]))
strict_limit_error = torch.maximum(torch.abs(golden * ratios[2]), torch.tensor(ratios[3]))
error_count = torch.gt(diff, limit_error).sum().item()
strict_error_count = torch.gt(diff, strict_limit_error).sum().item()
print(f"maxDiff {max_diff}")
print("1/1000 Accuracy is %f", 1 - float(error_count) / out_len)
print("5/1000 Accuracy is %f", 1 - float(strict_error_count) / out_len)
if self.data_type == torch.bfloat16 or self.is_int8_flag:
print("accuracy is correct in old standard: %r", (float(strict_error_count) / out_len) <= ratios[2])
else:
print("accuracy is correct in old standard: %r", (float(strict_error_count) / out_len) <= ratios[0])
calc_times = self.head_size_qk * self.max_context_len + 4
if self.data_type == torch.bfloat16:
if calc_times < 2048:
error = 2 ** (-7)
else:
error = 2 ** (-6)
error_threshold = torch.clamp(torch.abs(golden), min=1) * error
res = (diff <= error_threshold).all().item()
print("accuracy is correct in new standard: %r", res)
return res
else:
if calc_times < 2048:
error = 2 ** (-8)
else:
error = 2 ** (-7)
error_threshold = torch.clamp(torch.abs(golden), min=1) * error
res = (diff <= error_threshold).all().item()
print("accuracy is correct in new standard: %r", res)
return res
def get_alibi_slopes(self, n_heads):
n = 2 ** math.floor(math.log2(n_heads))
m0 = 2.0 ** (-8.0 / n)
slopes = torch.pow(m0, torch.arange(1, n + 1))
if n < n_heads:
m1 = 2.0 ** (-4.0 / n)
mm = torch.pow(m1, torch.arange(1, 1 + 2 * (n_heads - n), 2))
slopes = torch.cat([slopes, mm])
return slopes
def group_mm_torch(self, heads, group_num, A, B, is_k):
group_head = heads // group_num
score_high = None
for i in range(group_num):
if self.is_int8_flag:
int8_B = B[i: (i + 1), :, :, ]
head_dim = int8_B.shape[2]
int32_B = torch.matmul(torch.eye(int8_B.shape[1]).to(torch.float32), int8_B.to(torch.float32)).to(
torch.int32)
if is_k:
if self.has_bias:
int32_B = int32_B + self.offset1[i * head_dim:(i + 1) * head_dim]
fp32_B = int32_B.to(torch.float32) * self.de_scale1_fp32[i * head_dim:(i + 1) * head_dim]
fp32_B = torch.permute(fp32_B, (0, 2, 1))
else:
if self.has_bias:
int32_B = int32_B + self.offset2[i * head_dim:(i + 1) * head_dim]
fp32_B = int32_B.to(torch.float32) * self.de_scale2_fp32[i * head_dim:(i + 1) * head_dim]
group_score_high = torch.matmul(A[i * group_head: (i + 1) * group_head, :, :].to(torch.float32),
fp32_B)
else:
group_score_high = torch.matmul(A[i * group_head: (i + 1) * group_head, :, :].to(torch.float32),
B[i:(i + 1), :, :].to(torch.float32))
if score_high is None:
score_high = group_score_high
else:
score_high = torch.cat((score_high, group_score_high), 0)
return score_high
def process_deq_scale(self, deq_scale) -> np.ndarray:
new_deq_scale = np.frombuffer(deq_scale.tobytes(), dtype=np.uint32)
return new_deq_scale.astype(np.uint64)
def softmax(self, sim):
row_max = torch.max(sim, axis=-1, keepdims=True)[0]
sim_sub = sim - row_max
sim_sub = torch.exp(sim_sub)
row_sum = torch.sum(sim_sub, axis=-1, keepdims=True)
soft_res = sim_sub / row_sum
return soft_res
def softmax_numpy(self, sim):
sim = sim.cpu().numpy()
row_max = np.max(sim, axis=-1, keepdims=True)
sim_sub = sim - row_max
sim_sub = np.exp(sim_sub)
row_sum = np.sum(sim_sub, axis=-1, keepdims=True)
soft_res = sim_sub / row_sum
return soft_res, row_max + np.log(row_sum)
def shape_nd_to_nz(self, shape, dtype='float16'):
assert len(shape) >= 2
batch = shape[:-2]
a, b = shape[-2], shape[-1]
a0, b0 = 16, 16
return list(batch) + [math.ceil(b / b0), math.ceil(a / a0), a0, b0]
def gen_axes_for_transpose(self, offset, base):
return [x for x in range(offset)] + [x + offset for x in base]
def convert_nd_to_nz(self, x):
array_trans = self.gen_axes_for_transpose(len(x.shape) - 2, [2, 0, 1, 3])
x_shape = self.shape_nd_to_nz(x.shape, dtype=x.dtype)
*_, n1, m1, m0, n0 = x_shape
return x.reshape(x_shape[:-4] + [m1, m0, n1, n0]).permute(*array_trans)
def ref_masked_attention(self,
query,
key,
value,
scale: float,
alibi_bias,
mask_data_type=torch.bfloat16
):
query = query
query = torch.permute(query, (1, 0, 2))
if not self.is_int8_flag:
key = torch.permute(key, (1, 2, 0))
else:
key = torch.permute(key, (1, 0, 2))
sim_high = self.group_mm_torch(query.shape[0], key.shape[0], query, key, 1)
sim_out = sim_high.to(torch.float32)
sim_high = sim_high.to(torch.float32) * scale
if alibi_bias is not None:
sim_high = sim_high + alibi_bias.to(torch.float32)
p_high, lse = self.softmax_numpy(sim_high)
p = torch.from_numpy(p_high).to(mask_data_type)
p_high = torch.from_numpy(p_high)
lse = torch.permute(torch.from_numpy(lse).to(mask_data_type), (1, 0, 2))
value = torch.permute(value, (1, 0, 2))
out = self.group_mm_torch(query.shape[0], key.shape[0], p, value, 0)
out_high = self.group_mm_torch(query.shape[0], key.shape[0], p_high, value, 0)
out = torch.permute(out, (1, 0, 2))
out_high = torch.permute(out_high, (1, 0, 2))
sim_out = torch.permute(sim_out, (1, 0, 2))
return out, out_high, sim_out, lse
def ref_single_query_cached_kv_attention(self,
sim,
output,
true_out,
lse,
query,
key_cache,
value_cache,
block_tables,
context_lens,
mask,
mask_dim=4,
mask_data_type=torch.bfloat16
) -> None:
mask_index_coff = 1
if self.compressHead:
query = query.view(self.num_tokens * self.kv_heads, self.num_heads // self.kv_heads, self.head_size_qk)
output = output.view(self.num_tokens * self.kv_heads, self.num_heads // self.kv_heads, self.head_size_vo)
true_out = true_out.view(self.num_tokens * self.kv_heads, self.num_heads // self.kv_heads,
self.head_size_vo)
if mask_dim == 4:
mask_shape = mask.shape
mask = mask.view(mask_shape[0] * self.kv_heads, self.num_heads // self.kv_heads, 1,
self.max_context_len)
else:
mask_index_coff = self.kv_heads
num_heads = query.shape[1]
kv_heads = value_cache.shape[2]
head_size_qk = key_cache.shape[3]
head_size_vo = value_cache.shape[3]
block_size = value_cache.shape[1]
num_input_tokens = query.shape[0]
index = 0
for i, context_len in enumerate(context_lens):
if context_len == 0:
continue
block_table = block_tables[i]
context_len = int(context_len)
q = query[index].view(1, num_heads, head_size_qk)
keys = []
values = []
for j in range(context_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, block_offset, :, :]
k = k.reshape(kv_heads, head_size_qk)
keys.append(k)
v = value_cache[block_number, block_offset, :, :]
v = v.reshape(kv_heads, head_size_vo)
values.append(v)
keys = torch.stack(keys, axis=0)
values = torch.stack(values, axis=0)
scale = np.float32(1.0 / (head_size_qk ** 0.5))
if mask_dim == 4:
out, out_high, sim_out, _ = self.ref_masked_attention(q, keys, values, scale,
mask[i, :, :, :context_len], mask_data_type)
out = out.reshape(num_heads, head_size_vo)
elif mask_dim == 3:
out, out_high, sim_out, _ = self.ref_masked_attention(q, keys, values, scale,
mask[i // mask_index_coff, :, :context_len],
mask_data_type)
out = out.reshape(num_heads, head_size_vo)
else:
out, out_high, sim_out, lse_i = self.ref_masked_attention(q, keys, values, scale, mask,
mask_data_type)
out = out.reshape(num_heads, head_size_vo)
lse_i = lse_i.reshape(num_heads, 1)
lse[index] = lse_i.to(mask_data_type)
out_high = out_high.reshape(num_heads, head_size_vo)
sim_out = sim_out.reshape(1, num_heads * context_len)
output[index] = out.to(mask_data_type)
true_out[index] = out_high
sim[index] = sim_out
index = index + 1
def calc_data(self, num_tokens, num_heads, kv_heads, head_size_qk, head_size_vo, block_size, num_blocks, k_seqlen,
dtype, mask_dim=0, mask_data_type=torch.bfloat16,
dynamic_batch=False, dynamic_seqlen=None, is_int8_flag=False, has_bias=False,
compressHead=False, is_kv_combined=True, is_nz_in=False):
self.num_heads = num_heads
self.kv_heads = kv_heads
self.num_tokens = num_tokens
self.compressHead = compressHead
self.head_size_qk = head_size_qk
self.head_size_vo = head_size_vo
print(
f'input info: {num_tokens}, {num_heads}, {kv_heads}, {head_size_qk}, {head_size_vo}, {block_size}, {num_blocks}, {k_seqlen}, {dtype}')
query = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(num_tokens, num_heads, head_size_qk))).to(dtype)
kv_range = 1.0
kv_type = dtype
if is_int8_flag:
kv_range = 4.0
kv_type = torch.int8
if not compressHead:
key_cache = torch.from_numpy(
np.random.uniform(-kv_range, kv_range, size=(num_blocks, block_size, kv_heads, head_size_qk))).to(
kv_type)
if not is_kv_combined:
value_cache = torch.from_numpy(
np.random.uniform(-kv_range, kv_range, size=(num_blocks, block_size, kv_heads, head_size_vo))).to(
kv_type)
else:
value_cache = key_cache[:, :, :, :head_size_vo]
else:
key_cache = torch.from_numpy(
np.random.uniform(-kv_range, kv_range, size=(num_blocks * kv_heads, block_size, 1, head_size_qk))).to(
kv_type)
if not is_kv_combined:
value_cache = torch.from_numpy(np.random.uniform(-kv_range, kv_range, size=(
num_blocks * kv_heads, block_size, 1, head_size_vo))).to(kv_type)
else:
value_cache = key_cache[:, :, :, :head_size_vo]
self.data_type = dtype
if dynamic_batch:
context_lens = dynamic_seqlen
else:
context_lens = [k_seqlen] * num_tokens
max_context_len = max(context_lens)
self.max_context_len = max_context_len
batch = len(context_lens)
if mask_dim == 4:
mask = np.zeros((batch, num_heads, 1, self.max_context_len), dtype=np.float32)
alibi_slopes = self.get_alibi_slopes(num_heads)
for i, context_len in enumerate(context_lens):
if context_len == 0:
continue
position_ids = np.arange(context_len).astype(np.int32)
alibi_bias = (position_ids - context_len + 1).astype(np.float32)
alibi_bias = alibi_slopes.reshape(-1, 1, 1) * alibi_bias.reshape(1, 1, -1)
mask[i, :, :, :context_len] = alibi_bias
mask = torch.from_numpy(mask).to(mask_data_type)
elif mask_dim == 3:
mask = np.zeros((batch, 1, max_context_len), dtype=np.float16)
for i in range(batch):
mask[i, :, :i] = -10000
mask = torch.from_numpy(mask).to(mask_data_type)
else:
mask = None
if compressHead:
context_lens = [
val
for val in context_lens
for _ in range(kv_heads)
]
batch = len(context_lens)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for i in range(batch):
block_table = [
i * max_num_blocks_per_seq + _
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
self.is_int8_flag = is_int8_flag
if is_int8_flag:
de_scale1_fp32 = np.random.randint(-1, 2, size=(kv_heads * head_size)).astype(np.float32)
de_scale1_int64 = self.process_deq_scale(de_scale1_fp32)
de_scale2_fp32 = np.random.randint(-1, 2, size=(kv_heads * head_size)).astype(np.float32)
de_scale2_int64 = self.process_deq_scale(de_scale2_fp32)
offset1 = np.random.randint(-20, 20, size=(kv_heads * head_size)).astype(np.int32)
offset2 = np.random.randint(-20, 20, size=(kv_heads * head_size)).astype(np.int32)
self.de_scale1_int64 = torch.tensor(list(de_scale1_int64), dtype=torch.int64)
self.de_scale2_int64 = torch.tensor(list(de_scale2_int64), dtype=torch.int64)
self.de_scale1_fp32 = torch.from_numpy(de_scale1_fp32)
self.de_scale2_fp32 = torch.from_numpy(de_scale2_fp32)
self.offset1 = torch.from_numpy(offset1)
self.offset2 = torch.from_numpy(offset2)
self.has_bias = has_bias
shape_out = (num_tokens, num_heads, head_size_vo)
ref_output = torch.zeros(shape_out, dtype=dtype)
true_out = torch.zeros(shape_out, dtype=torch.float32)
sim = torch.zeros((num_tokens, num_heads * k_seqlen), dtype=torch.float32)
lse = torch.zeros((num_tokens, num_heads, 1), dtype=dtype)
self.ref_single_query_cached_kv_attention(
sim,
ref_output,
true_out,
lse,
query,
key_cache,
value_cache,
block_tables,
context_lens,
mask,
mask_dim,
mask_data_type
)
self.q_split1, self.q_split2 = torch.split(query, [512, 64], dim=2)
self.key_cache_split1, self.key_cache_split2 = torch.split(key_cache, [512, 64], dim=3)
if (is_nz_in):
key_cache_split1, key_cache_split2 = torch.split(key_cache, [512, 64], dim=3)
key_cache_split1 = key_cache_split1.reshape(num_blocks, block_size, -1)
key_cache_split2 = key_cache_split2.reshape(num_blocks, block_size, -1)
key_cache_split1_nz = self.convert_nd_to_nz(key_cache_split1)
key_cache_split2_nz = self.convert_nd_to_nz(key_cache_split2)
self.key_cache_split1 = key_cache_split1_nz.to(mask_data_type).reshape(num_blocks, -1, block_size, 16)
self.key_cache_split2 = key_cache_split2_nz.to(mask_data_type).reshape(num_blocks, -1, block_size, 16)
self.block_tables = np.array(block_tables).astype(np.int32)
self.contex_lens = np.array(context_lens).astype(np.int32)
self.alib_mask = mask
self.golden_out = ref_output
self.true_out = true_out
self.lse = lse
def golden_calc(self, in_tensors):
golden_out = torch.tensor(self.golden_out)
return [golden_out, self.lse]
def golden_compare(self, out_tensors, golden_tensors):
go_double = compare_cv(self.true_out, golden_tensors[0], out_tensors[0])
go_old = self.compare_output_data(out_tensors[0], golden_tensors[0], [0.001, 0.001, 0.005, 0.005])
lse_double = True
lse_old = True
if self.is_ring:
lse_double = compare_cv(golden_tensors[1], golden_tensors[1], out_tensors[1])
lse_old = self.compare_output_data(out_tensors[1], golden_tensors[1], [0.001, 0.001, 0.005, 0.005])
return (go_double or go_old) and (lse_double or lse_old)
@SupportedDevices(['Ascend910B'])
def test_atb_attention_with_lse(self):
num_tokens = 32
num_heads = 32
kv_heads = 1
block_size = 128
head_size_qk = 576
head_size_vo = 512
num_blocks = 64
k_seqlen = 256
tor = 1.0 / (head_size_qk ** 0.5)
mask_dim = 0
dtype = torch.float16
is_kv_combined = True
self.is_ring = 0
is_nz_in = False
self.calc_data(num_tokens, num_heads, kv_heads, head_size_qk, head_size_vo, block_size, num_blocks, k_seqlen,
dtype, mask_dim, dtype,
is_kv_combined=is_kv_combined, is_nz_in=is_nz_in)
go_output, lse_output = torch_npu.atb.npu_multi_head_latent_attention(
self.q_split1.npu(),
self.q_split2.npu(),
self.key_cache_split1.npu(),
self.key_cache_split2.npu(),
torch.tensor(self.block_tables).int().npu(),
torch.tensor(self.contex_lens).int(),
num_heads,
tor,
kv_heads,
True,
calc_type="calc_type_ring"
)
self.assertRtolEqual(go_output, self.golden_out)
self.assertRtolEqual(lse_output, self.lse)
@SupportedDevices(['Ascend910B'])
def test_atb_attention_with_lse_out(self):
num_tokens = 32
num_heads = 32
kv_heads = 1
block_size = 128
head_size_qk = 576
head_size_vo = 512
num_blocks = 64
k_seqlen = 256
tor = 1.0 / (head_size_qk ** 0.5)
mask_dim = 0
dtype = torch.float16
is_kv_combined = True
self.is_ring = 0
is_nz_in = False
self.calc_data(num_tokens, num_heads, kv_heads, head_size_qk, head_size_vo, block_size, num_blocks, k_seqlen,
dtype, mask_dim, dtype,
is_kv_combined=is_kv_combined, is_nz_in=is_nz_in)
go_output = torch.empty(
[num_tokens, num_heads, 512],
dtype=torch.float16,
).npu()
lse_output = torch.empty(
[num_tokens, num_heads, 1],
dtype=torch.float16,
).npu()
torch_npu.atb.npu_multi_head_latent_attention(
self.q_split1.npu(),
self.q_split2.npu(),
self.key_cache_split1.npu(),
self.key_cache_split2.npu(),
torch.tensor(self.block_tables).int().npu(),
torch.tensor(self.contex_lens).int(),
num_heads,
tor,
kv_heads,
True,
calc_type="calc_type_ring",
output=go_output,
lse=lse_output
)
self.assertRtolEqual(go_output, self.golden_out)
self.assertRtolEqual(lse_output, self.lse)
@SupportedDevices(['Ascend910B'])
def test_atb_attention_with_lse_out_aclgraph(self):
num_tokens = 32
num_heads = 32
kv_heads = 1
block_size = 128
head_size_qk = 576
head_size_vo = 512
num_blocks = 64
k_seqlen = 256
tor = 1.0 / (head_size_qk ** 0.5)
mask_dim = 0
dtype = torch.float16
is_kv_combined = True
self.is_ring = 0
is_nz_in = False
self.calc_data(num_tokens, num_heads, kv_heads, head_size_qk, head_size_vo, block_size, num_blocks, k_seqlen,
dtype, mask_dim, dtype,
is_kv_combined=is_kv_combined, is_nz_in=is_nz_in)
go_output = torch.empty(
[num_tokens, num_heads, 512],
dtype=torch.float16,
).npu()
lse_output = torch.empty(
[num_tokens, num_heads, 1],
dtype=torch.float16,
).npu()
go_output_graph = torch.empty([num_tokens, num_heads, 512], dtype=torch.float16).npu()
lse_output_graph = torch.empty([num_tokens, num_heads, 1], dtype=torch.float16).npu()
context_lens_new = [1024] * len(self.contex_lens)
context_lens_new_tensor = torch.tensor(context_lens_new).int()
q_nope, q_rope, ctkv, k_rope, block_tables, context_lens = self.q_split1.npu(), self.q_split2.npu(), self.key_cache_split1.npu(), self.key_cache_split2.npu(), torch.tensor(self.block_tables).int().npu(), torch.tensor(self.contex_lens).int()
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_rope,
ctkv,
k_rope,
block_tables,
context_lens_new_tensor,
num_heads,
tor,
kv_heads,
True,
calc_type="calc_type_ring",
output=go_output,
lse=lse_output
)
g = torch.npu.NPUGraph()
event = torch.npu.ExternalEvent()
update_stream = torch.npu.Stream()
handle = None
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
q_nope,
q_rope,
ctkv,
k_rope,
block_tables,
context_lens_new_tensor,
num_heads,
tor,
kv_heads,
True,
calc_type="calc_type_ring")
with torch.npu.graph(g):
stream = torch.npu.current_stream()
event.wait(stream)
event.reset(stream)
torch.npu.graph_task_group_begin(stream)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_rope,
ctkv,
k_rope,
block_tables,
context_lens,
num_heads,
tor,
kv_heads,
True,
calc_type="calc_type_ring",
workspace=workspace,
output=go_output_graph,
lse=lse_output_graph
)
handle = torch.npu.graph_task_group_end(stream)
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_rope,
ctkv,
k_rope,
block_tables,
context_lens_new_tensor,
num_heads,
tor,
kv_heads,
True,
calc_type="calc_type_ring",
workspace=workspace,
output=go_output_graph,
lse=lse_output_graph
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
g.replay()
self.assertEqual(go_output.cpu(), go_output_graph.cpu())
self.assertEqual(lse_output.cpu(), lse_output_graph.cpu())
if __name__ == "__main__":
run_tests()