"""
"""
import math
import os
import torch
import pytest
from win_attention_impl import deepseekv4_win_atten, get_mask, sliding_win_atten_graph
class SWA(torch.nn.Module):
def forward(self, q, ori_block_table, ori_kv, seqused_kv, sinks, win_size, \
mask, cu_seqlens_q):
return sliding_win_atten_graph(q, ori_block_table, ori_kv, seqused_kv, \
sinks, win_size, mask, cu_seqlens_q)
def gen_uniform_data(data_shape, min_value, max_value, dtypes, device_id):
if isinstance(data_shape, list):
data_shape = tuple(data_shape)
if min_value == 0 and max_value == 0:
return torch.zeros(data_shape, dtype=dtypes, device=f'npu:{device_id}')
if dtypes == torch.bool:
return torch.randint(0, 2, size=data_shape, dtype=torch.bool, device=f'npu:{device_id}')
return torch.rand(data_shape, dtype=dtypes, device=f'npu:{device_id}').uniform_(min_value, max_value)
def gen_win_attn_data_tnd(t, n_q, d_q, n_kv, d_kv, block_size, seqused_kv_list, dtypes, device_id):
torch.manual_seed(42)
b = len(seqused_kv_list)
new_dtype = dtypes
actual_seq_max = max(seqused_kv_list)
s_kv_max = actual_seq_max
shape_q = [t * n_q, d_q]
shape_kv = [b, s_kv_max, n_kv, d_kv]
atten_out_shape = [t, n_q, d_kv]
block_num_per_batch = []
block_num_min = 0
block_num = 0
sinks = gen_uniform_data([n_q], -1, 1, torch.float32, device_id)
q = gen_uniform_data(shape_q, -1, 1, new_dtype, device_id)
q_tnd = q.reshape(t, n_q, d_q)
kv_bsnd = gen_uniform_data(shape_kv, -1, 1, new_dtype, device_id)
for actual_seq in seqused_kv_list:
block_num_per_batch.append(math.ceil(actual_seq / block_size))
block_num_min += math.ceil(actual_seq / block_size)
block_table_shape = [b, math.ceil(s_kv_max / block_size)]
block_num = block_num_min
block_idx_list = torch.randperm(block_num, dtype=torch.int32)
block_idx = 0
block_table = [-1] * block_table_shape[1]
block_table = torch.tile(torch.tensor(block_table, device=f'npu:{device_id}').\
to(torch.int32), (block_table_shape[0], 1))
block_table_batch_idx = 0
for idx in block_num_per_batch:
for j in range(idx):
block_table[block_table_batch_idx][j] = (block_idx_list[block_idx])
block_idx += 1
block_table_batch_idx += 1
kv_cache = torch.zeros([block_num, block_size, n_kv, d_kv], dtype=new_dtype, device=f'npu:{device_id}')
for b_idx in range(b):
for block_i, kv_cache_blk_id in enumerate(block_table[b_idx]):
block_offset = block_i * block_size
if kv_cache_blk_id == -1:
continue
else:
kv_valid = kv_bsnd[b_idx, block_offset:(block_offset + block_size), :, :]
kv_cache[kv_cache_blk_id, 0: kv_valid.shape[0], :, :] = \
kv_bsnd[b_idx, block_offset:(block_offset + block_size), :, :]
atten_out = torch.zeros(atten_out_shape, dtype=new_dtype, device=f'npu:{device_id}')
return q_tnd, block_table, kv_cache, sinks, atten_out
def win_atten_calc_tnd(input_params_win_attn, seqused_kv_list, sinks, q_tnd, \
kv_cache, block_table, cu_seqlens_q, device_id):
t = input_params_win_attn[0]
n_q = input_params_win_attn[2]
d_q = input_params_win_attn[3]
win = input_params_win_attn[4]
scalar = input_params_win_attn[5]
atten_out_shape = [t, n_q, d_q]
atten_out = torch.zeros(atten_out_shape, dtype=torch.bfloat16, device=f'npu:{device_id}')
block_size = kv_cache.shape[1]
b = len(seqused_kv_list)
for b_index in range(b):
cur_s_q = cu_seqlens_q[b_index + 1] - cu_seqlens_q[b_index]
for s1_index in range(cur_s_q):
t_index = cu_seqlens_q[b_index] + s1_index
actual_seq = seqused_kv_list[b_index]
q_tensor_cur = q_tnd[t_index:(t_index + 1), :, :].reshape(n_q, d_q)
cur_loc = actual_seq - cur_s_q + s1_index + 1
valid_len = min(cur_loc, win)
cur_start_pos = cur_loc - valid_len
end_pos = cur_loc
start_block = cur_start_pos // block_size
start_offset = cur_start_pos % block_size
end_block = (end_pos - 1) // block_size
kv_list = []
for block_idx in range(start_block, end_block + 1):
physical_block_id = block_table[b_index, block_idx]
kv_block = kv_cache[physical_block_id, :, 0, :]
kv_list.append(kv_block)
kv_cur = torch.cat(kv_list, axis=0)
kv_cur = kv_cur[start_offset : start_offset + valid_len, :]
sum_exp = torch.zeros([n_q, 1], dtype=torch.float32, device=f'npu:{device_id}')
acc_s = torch.matmul(q_tensor_cur.to(torch.float32), kv_cur.to(torch.float32).transpose(1, 0))
acc_s = acc_s * scalar
scores_max = torch.max(acc_s, dim=-1, keepdims=True)[0]
acc_s = torch.exp(acc_s - scores_max)
sum_exp = torch.sum(acc_s, dim=-1, keepdims=True)
sum_exp += torch.exp(sinks.reshape(n_q, 1) - scores_max)
v1_res = acc_s / sum_exp
v1_res = v1_res.to(torch.bfloat16)
mm2_res = torch.matmul(v1_res, kv_cur)
atten_out[t_index:(t_index + 1), :, :] = mm2_res
return atten_out
def test_win_atten_tnd_mask(allow_in_graph=False) -> None:
for b in [64]:
s_val = 2
cu_seqlens_q = [i * s_val for i in range(b + 1)]
t = cu_seqlens_q[-1]
win_size = 128
n_q = 64
block_size = 128
n_kv = 1
dtypes = torch.bfloat16
head_dim = 512
d_q = head_dim
d_kv = head_dim
scalar = d_q ** -0.5
input_params_win_attn = [t, n_kv, n_q, d_q, win_size, scalar]
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
seqused_kv_list = [8192] * b
seqused_kv_list_tensor = torch.tensor(seqused_kv_list, dtype=torch.int32, device=f'npu:{device_id}')
cu_seqlens_q_tenor = torch.tensor(cu_seqlens_q, dtype=torch.int32, device=f'npu:{device_id}')
q_tnd, block_table, kv_cache, sinks, _ = gen_win_attn_data_tnd(t, n_q, d_q, n_kv, d_kv, \
block_size, seqused_kv_list, dtypes, device_id)
mask2 = get_mask(4, n_q, device_id, block_size)
if allow_in_graph:
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
compiler_config = CompilerConfig()
compiler_config.mode = "reduce-overhead"
npu_backend = tng.get_npu_backend(compiler_config=compiler_config)
model = torch.compile(SWA(), dynamic=False, fullgraph=True, backend=npu_backend)
q_npu = q_tnd.npu()
ori_block_table_npu = block_table.npu()
ori_kv_npu = kv_cache.npu()
seqused_kv_list_tensor_npu = seqused_kv_list_tensor.npu()
attn_sinks_npu = sinks.npu()
mask2_npu = mask2.npu()
cu_seqlens_q_tenor_npu = cu_seqlens_q_tenor.npu()
atten_out = model(q_npu, ori_block_table_npu, ori_kv_npu, seqused_kv_list_tensor_npu, \
attn_sinks_npu, win_size, mask2_npu, cu_seqlens_q_tenor_npu)
else:
atten_out = deepseekv4_win_atten(q_tnd, block_table, kv_cache, seqused_kv_list_tensor, \
sinks, win_size, mask=mask2, cu_seqlens_q=cu_seqlens_q_tenor)
golden = win_atten_calc_tnd(input_params_win_attn, seqused_kv_list, sinks, q_tnd, \
kv_cache, block_table, cu_seqlens_q, device_id)
from utils.compare import compare
compare(golden, atten_out, "SWA tnd mask 版本", rtol=0.0078125, atol=0.0001)
if __name__ == "__main__":
test_win_atten_tnd_mask(False)