"""
"""
import pypto
import torch
import math
from st.pypto_test import TestBuilder
def op_page_attention(params, q_nope, k_nope_cache, v_nope_cache, q_rope, k_rope_cache,
block_table, act_seqs, attention_out):
block_size = params["block_size"]
tile_config = params["tile_config"]
max_unroll_times = params["max_unroll_times"]
is_nz_format = params["is_nz_format"]
dtype = q_nope.dtype
d_n = q_nope.shape[1]
d_r = q_rope.shape[1]
softmax_scale = (d_n+d_r) ** -0.5
n_tile = tile_config.head_num_q_tile
c1_tile = tile_config.c1_tile_shape
v1_tile = tile_config.v1_tile_shape
c2_tile = tile_config.c2_tile_shape
v2_tile = tile_config.v2_tile_shape
def inside_main_function():
batch_size = block_table.shape[0]
n_q = q_nope.shape[0] // batch_size
n_loop = n_q // n_tile
for b_idx in pypto.loop(0, batch_size, 1, name="LOOP_L0_bIdx", idx_name="b_idx"):
def inside_b_idx_loop(b_idx):
cur_seq = act_seqs[b_idx]
bn_per_batch = (cur_seq + block_size - 1) // block_size
bn_per_batch.as_variable()
for n_idx in pypto.loop(0, n_loop, 1, name="LOOP_L1_nIdx", idx_name="n_idx"):
def inside_n_idx_loop(b_idx, n_idx, bn_per_batch):
nonlocal n_tile
cur_n_tile = n_tile
oi_update = pypto.tensor([n_tile, d_n], pypto.DT_FP32, "oi_update")
li_update = pypto.tensor([n_tile, 1], pypto.DT_FP32, "li_update")
mi_update = pypto.tensor([n_tile, 1], pypto.DT_FP32, "mi_update")
cur_offset = b_idx * n_q + n_idx * n_tile
oi_offset = [cur_offset, 0]
for bn in pypto.loop(0, bn_per_batch, 1, name="LOOP_L2_bn",
idx_name="bn", unroll_List={max_unroll_times}):
def inside_bn_loop(**kwargs):
b_idx = kwargs.get("b_idx")
block_table = kwargs.get("block_table")
cur_seq = kwargs.get("cur_seq")
bn = kwargs.get("bn")
block_size = kwargs.get("block_size")
nonlocal oi_update, li_update, mi_update
cur_s2_tile = block_size
qn = pypto.view(q_nope, [cur_n_tile, d_n], [cur_offset, 0])
qr = pypto.view(q_rope, [cur_n_tile, d_r], [cur_offset, 0])
qi = pypto.tensor([cur_n_tile, d_n + d_r], dtype, "qi")
pypto.assemble(qn, [0, 0], qi)
pypto.assemble(qr, [0, d_n], qi)
cur_block_idx = block_table[b_idx, bn]
cur_block_idx.as_variable()
kn = pypto.view(k_nope_cache, [cur_s2_tile, d_n],
[cur_block_idx * block_size, 0],
valid_shape=[(cur_seq - bn * block_size).min(block_size), d_n])
kr = pypto.view(k_rope_cache, [cur_s2_tile, d_r],
[cur_block_idx * block_size, 0],
valid_shape=[(cur_seq - bn * block_size).min(block_size), d_r])
kj_format = pypto.TileOpFormat.TILEOP_NZ if is_nz_format else (
pypto.TileOpFormat.TILEOP_ND
)
kj = pypto.tensor([cur_s2_tile, d_n + d_r], dtype, "kj", kj_format)
pypto.assemble(kn, [0, 0], kj)
pypto.assemble(kr, [0, d_n], kj)
kj = pypto.view(kj, [cur_s2_tile, d_n + d_r], [0, 0],
valid_shape=[(cur_seq - bn * block_size).min(block_size), d_r + d_n])
vj = pypto.view(v_nope_cache, [cur_s2_tile, d_n], [cur_block_idx * block_size, 0],
valid_shape=[(cur_seq - bn * block_size).min(block_size), d_n])
pypto.set_semantic_label("MatMul")
pypto.set_cube_tile_shapes(
[c1_tile[0], c1_tile[1]], [c1_tile[2], c1_tile[3]],
[c1_tile[4], c1_tile[5]])
pypto.set_matrix_size([qi.shape[0], 0, kj.shape[0]])
sij = pypto.matmul(qi, kj, pypto.DT_FP32, b_trans=True)
pypto.set_vec_tile_shapes(v1_tile[0], v1_tile[1])
pypto.set_semantic_label("SoftMax")
sij_scale = pypto.mul(sij, float(softmax_scale))
pypto.set_semantic_label("SoftMax")
tilda_mij = pypto.amax(sij_scale, dim=-1, keepdim=True)
tsub = pypto.sub(sij_scale, tilda_mij)
tilda_pij = pypto.exp(tsub)
tilda_pij_f16 = pypto.cast(tilda_pij, dtype)
tilda_lij = pypto.sum(tilda_pij, dim=-1, keepdim=True)
if pypto.cond(pypto.is_loop_begin(bn)):
def inside_if_loop_begin():
nonlocal oi_update, li_update, mi_update
pypto.set_cube_tile_shapes(
[c2_tile[0], c2_tile[1]], [c2_tile[2], c2_tile[3]],
[c2_tile[4], c2_tile[5]])
pypto.set_semantic_label("b1-matmul2")
pypto.set_matrix_size(
[tilda_pij_f16.shape[0], tilda_pij_f16.shape[1],
vj.shape[1]])
oi_tmp = pypto.matmul(tilda_pij_f16,
vj, pypto.DT_FP32)
pypto.set_vec_tile_shapes(v2_tile[0], v2_tile[1])
pypto.set_semantic_label("b1-after-matmul2")
if pypto.cond(pypto.is_loop_end(bn)):
pypto.set_semantic_label("b1-after-matmul2")
oi_update[:] = (pypto.div(oi_tmp, tilda_lij))
pypto.assemble(oi_update, oi_offset, attention_out)
else:
oi_update[:] = (oi_tmp)
li_update[:] = (tilda_lij)
mi_update[:] = (tilda_mij)
inside_if_loop_begin()
else:
def inside_else_loop_begin():
nonlocal oi_update, li_update, mi_update
oi = oi_update
li = li_update
mi = mi_update
pypto.set_semantic_label("Softmax-acc")
mi_new = pypto.maximum(mi, tilda_mij)
t1 = pypto.sub(mi, mi_new)
t2 = pypto.exp(t1)
t3 = pypto.sub(tilda_mij, mi_new)
t4 = pypto.exp(t3)
t5 = pypto.mul(t4, tilda_lij)
t6 = pypto.mul(t2, li)
li_new = pypto.add(t6, t5)
q3 = pypto.mul(oi, t2)
pypto.set_semantic_label("bn-matmul2")
pypto.set_cube_tile_shapes(
[c2_tile[0], c2_tile[1]], [c2_tile[2], c2_tile[3]], [c2_tile[4],
c2_tile[5]])
pypto.set_matrix_size(
[tilda_pij_f16.shape[0],
tilda_pij_f16.shape[1], vj.shape[1]])
q1 = pypto.matmul(tilda_pij_f16, vj, pypto.DT_FP32)
pypto.set_vec_tile_shapes(v2_tile[0], v2_tile[1])
pypto.set_semantic_label("bn-after-matmul2")
q2 = pypto.mul(q1, t4)
oi_tmp = pypto.add(q3, q2)
if pypto.cond(pypto.is_loop_end(bn)):
oi_update[:] = (pypto.div(oi_tmp, li_new))
pypto.assemble(oi_update, oi_offset, attention_out)
else:
oi_update[:] = (oi_tmp)
li_update[:] = (li_new)
mi_update[:] = (mi_new)
inside_else_loop_begin()
inside_bn_loop(
b_idx=b_idx,
block_table=block_table,
cur_seq=cur_seq,
bn=bn,
block_size=block_size,
bn_per_batch=bn_per_batch)
inside_n_idx_loop(b_idx, n_idx, bn_per_batch)
inside_b_idx_loop(b_idx)
inside_main_function()
def op_page_attention_golden(params, q_nope, k_nope_cache, v_cache, q_rope, k_rope_cache,
block_table, act_seqs, attention_out):
b = params["b"]
n_q = params["n_q"]
s_q = params["s_q"]
n_kv = params["n_kv"]
kv_lora_rank = params["kv_lora_rank"]
qk_rope_dim = params["qk_rope_dim"]
d_q = kv_lora_rank + qk_rope_dim
d_k = kv_lora_rank + qk_rope_dim
d_v = kv_lora_rank
n_tile = params["n_tile"]
block_size = params["block_size"]
block_num = params["block_num"]
q_nope = q_nope.reshape(b, n_q, s_q, kv_lora_rank)
q_rope = q_rope.reshape(b, n_q, s_q, qk_rope_dim)
k_nope_cache = k_nope_cache.reshape(block_num, block_size, n_kv * kv_lora_rank)
k_rope_cache = k_rope_cache.reshape(block_num, block_size, n_kv * qk_rope_dim)
v_cache = v_cache.reshape(block_num, block_size, n_kv * d_v)
q_bnsd = torch.cat([q_nope, q_rope], dim=-1)
k_cache = torch.cat([k_nope_cache, k_rope_cache], dim=-1)
scalar = d_q ** -0.5
tiled_out = []
block_num_per_batch = []
for actual_seq in act_seqs:
block_num_per_batch.append(math.ceil(actual_seq / block_size))
n_loop = math.ceil(n_q / n_tile)
for b_index in range(b):
matmul_dtype = torch.float32
cur_seq = act_seqs[b_index]
bn_per_batch = math.ceil(cur_seq / block_size)
for n_idx in range(n_loop):
oi_update = []
li_update = []
mi_update = []
qi = q_bnsd[b_index, n_idx * n_tile: (n_idx + 1) * n_tile, :, :]
qi = qi.reshape(-1, qi.shape[-1])
for bn in range(block_num_per_batch[b_index]):
cur_block_idx = block_table[b_index][bn]
s2_tile_cur = min(block_size, cur_seq - bn * block_size)
kj = k_cache[cur_block_idx, 0:s2_tile_cur, :]
vj = v_cache[cur_block_idx, 0:s2_tile_cur, :]
kj = kj.reshape(s2_tile_cur, d_k)
vj = vj.reshape(s2_tile_cur, d_v)
sij = torch.matmul(
qi.to(matmul_dtype),
kj.to(matmul_dtype).mT
)
sij_scale = sij * scalar
tilda_mij = sij_scale.max(dim=-1, keepdim=True).values
t_sub = sij_scale - tilda_mij
tilda_pij = torch.exp(t_sub)
tilda_lij = tilda_pij.sum(dim=-1, keepdim=True)
if bn == 0:
oi_tmp = torch.matmul(
tilda_pij.to(matmul_dtype),
vj.to(matmul_dtype)
)
if bn_per_batch == 1:
oi_update = oi_tmp / tilda_lij
else:
oi_update = oi_tmp
li_update = tilda_lij
mi_update = tilda_mij
continue
oi = oi_update
li = li_update
mi = mi_update
mi_new = torch.maximum(mi, tilda_mij)
t1 = mi - mi_new
t2 = torch.exp(t1)
t3 = tilda_mij - mi_new
t4 = torch.exp(t3)
t5 = t4 * tilda_lij
t6 = t2 * li
li_new = t6 + t5
q3 = oi * t2
q1 = torch.matmul(tilda_pij.to(matmul_dtype), vj.to(matmul_dtype))
q2 = q1 * t4
oi_tmp = q3 + q2
if bn == block_num_per_batch[b_index] - 1:
oi_update = oi_tmp / li_new
else:
oi_update = oi_tmp
li_update = li_new
mi_update = mi_new
tiled_out.append(oi_update)
attent_out = torch.cat(tiled_out, dim=0)
return (attent_out, )
class PATest(TestBuilder):
def __init__(self, params: tuple, kernel, kernel_golden, tiling: int):
super().__init__(params, kernel, kernel_golden, tiling)
def get_input_from_param(self):
def gen_uniform_data(data_shape, min_value, max_value, dtype):
if min_value == 0 and max_value == 0:
return torch.zeros(data_shape, dtype=dtype)
if dtype == torch.bool:
return torch.rand(data_shape) < 0.5
return (torch.rand(data_shape) * (max_value - min_value) + min_value).to(dtype)
def convert_tensors_contiguous(tensor_list):
for idx, t in enumerate(tensor_list):
if isinstance(t, torch.Tensor):
tensor_list[idx] = t if t.is_contiguous() else t.contiguous()
return tensor_list
b = self.params["b"]
n_q = self.params["n_q"]
skv = self.params["skv"]
block_size = self.params["block_size"]
dtype = self.params["dtype"]
s_q = self.params["s_q"]
n_kv = self.params["n_kv"]
kv_lora_rank = self.params["kv_lora_rank"]
qk_rope_dim = self.params["qk_rope_dim"]
d_q = kv_lora_rank + qk_rope_dim
d_k = kv_lora_rank + qk_rope_dim
d_v = kv_lora_rank
actual_seq_len = torch.full((b,), skv, dtype=torch.int32)
s_max = max(actual_seq_len)
shape_q = [b * n_q * s_q, d_q]
shape_k = [b, s_max, n_kv * d_k]
block_num_per_batch = []
block_num_min = 0
q_bnsd = gen_uniform_data(shape_q, -1, 1, dtype)
k_tensor_bsh_raw = gen_uniform_data(shape_k, -1, 1, dtype)
v_tensor_bsh_raw = k_tensor_bsh_raw[:, :, :n_kv * d_v]
for actual_seq in actual_seq_len:
block_num_per_batch.append(math.ceil(actual_seq / block_size))
block_num_min += math.ceil(actual_seq / block_size)
block_table_shape = [b, math.ceil(s_max / block_size)]
block_num = block_num_min
block_idx_list = torch.arange(0, block_num, 1)
block_idx_list = torch.randperm(len(block_idx_list), dtype=torch.int32)
block_idx = 0
block_table = torch.full((block_table_shape[0], block_table_shape[1]), -1, dtype=torch.int32)
block_table_batch_idx = 0
for idx in block_num_per_batch:
for j in range(idx):
block_table[block_table_batch_idx][j] = (block_idx_list[block_idx])
block_idx += 1
block_table_batch_idx += 1
k_cache = torch.zeros(block_num, block_size, n_kv * d_k, dtype=dtype)
v_cache = torch.zeros(block_num, block_size, n_kv * d_v, dtype=dtype)
k_tensor_bsh = torch.zeros(b, block_table_shape[1] * block_size, n_kv * d_k, dtype=dtype)
v_tensor_bsh = torch.zeros(b, block_table_shape[1] * block_size, n_kv * d_v, dtype=dtype)
k_tensor_bsh[:, :k_tensor_bsh_raw.shape[1], :] = k_tensor_bsh_raw[:, :, :]
v_tensor_bsh[:, :v_tensor_bsh_raw.shape[1], :] = v_tensor_bsh_raw[:, :, :]
for b_idx in range(b):
for block_i, kv_cache_blk_id in enumerate(block_table[b_idx]):
block_offset = block_i * block_size
if kv_cache_blk_id == -1:
continue
else:
k_cache[kv_cache_blk_id, 0:block_size, :] = k_tensor_bsh[
b_idx, block_offset:(block_offset + block_size), :]
v_cache[kv_cache_blk_id, 0:block_size, :] = v_tensor_bsh[
b_idx, block_offset:(block_offset + block_size), :]
q_nope = q_bnsd[:, :kv_lora_rank]
q_rope = q_bnsd[:, kv_lora_rank:]
k_cache_nope_h = kv_lora_rank * n_kv
k_cache_nope = k_cache[:, :, : k_cache_nope_h]
k_cache_rope = k_cache[:, :, k_cache_nope_h:]
k_cache_nope = k_cache_nope.reshape(k_cache_nope.shape[0] * k_cache_nope.shape[1], k_cache_nope.shape[-1])
k_cache_rope = k_cache_rope.reshape(k_cache_rope.shape[0] * k_cache_rope.shape[1], k_cache_rope.shape[-1])
v_cache = v_cache.reshape(v_cache.shape[0] * v_cache.shape[1], v_cache.shape[-1])
kernel_inputs = [q_nope, k_cache_nope, v_cache, q_rope, k_cache_rope,
block_table, actual_seq_len]
kernel_inputs = convert_tensors_contiguous(kernel_inputs)
self.setup_inputs(*kernel_inputs)
self.set_tol(rtol=5e-4, atol=5e-4)
self.params["block_num"] = block_num
golden_inputs = [q_nope, k_cache_nope, v_cache, q_rope, k_cache_rope,
block_table, actual_seq_len, self.params]
return golden_inputs
class TileConfig:
def __init__(self, head_num_q_tile, c1_tile_shape, v1_tile_shape,
c2_tile_shape, v2_tile_shape):
self.head_num_q_tile = head_num_q_tile
self.c1_tile_shape = c1_tile_shape
self.v1_tile_shape = v1_tile_shape
self.c2_tile_shape = c2_tile_shape
self.v2_tile_shape = v2_tile_shape
def test():
model_golden_params = {
"block_size": 128,
"tile_config": TileConfig(head_num_q_tile=32,
c1_tile_shape=(32, 32, 64, 64, 128, 128),
v1_tile_shape=(32, 64),
c2_tile_shape=(32, 32, 64, 64, 128, 128),
v2_tile_shape=(32, 64)),
"max_unroll_times": 1,
"is_nz_format": False,
"b": 4,
"n_q": 32,
"skv": 256,
"block_num": 64,
"dtype": torch.float32,
"s_q": 1,
"n_kv": 1,
"kv_lora_rank": 512,
"qk_rope_dim": 64,
"n_tile": 32
}
st = PATest(model_golden_params, op_page_attention, op_page_attention_golden, tiling=32)
st()