"""
deepseekv4 Attention Module
This module implements the Attention mechanism for deepseekv4 model, which uses
a paged memory management approach similar to operating systems to efficiently
handle variable-length sequences and dynamic batch sizes in cfa_attention computation.
Main Functions:
- cfa_attention: Main cfa_attention function with Attention support
- ifa_flash: JIT compiled kernel implementing Flash Attention with paged KV cache
- gen_block_table: Generate block mapping table for Attention
- kv_cache_concat_bsnd: Convert paged KV cache to BSND format
"""
from dataclasses import dataclass
import torch
import pypto
import pytest
import numpy as np
import math
import os
from compress_flash_attention_impl import cfa_attention, cfa_graph
np.random.seed(0)
torch.manual_seed(0)
np.set_printoptions(formatter={'float': '{:.6f}'.format})
@dataclass
class AttentionConfig:
b: int
s1: int
s2: int
n1: int
n2: int
q_d: int
kv_d: int
block_size: int = 128
cmp_ratio: int = 128
max_blocks: int = 0
actual_seq: torch.Tensor = None
block_table_batch: int = 0
kv_num_blocks: int = 0
def gen_block_table(actual_seq_len, block_size, block_table_shape, cmp_ratio=128, enable_win=False):
block_num_per_batch = []
block_num = 0
if enable_win:
cmp_ratio = 1
for actual_seq in actual_seq_len:
block_num_per_batch.append(math.ceil(actual_seq.item() // cmp_ratio / block_size))
block_num += math.ceil(actual_seq.item() / block_size)
block_idx_list = torch.arange(0, block_num, dtype=torch.int32)
block_idx_list = block_idx_list[torch.randperm(block_idx_list.size(0))]
cmp_block_table = torch.full(
block_table_shape, -1, dtype=torch.int32, device=actual_seq_len.device
)
block_idx = 0
block_table_batch_idx = 0
for idx in block_num_per_batch:
for j in range(idx):
cmp_block_table[block_table_batch_idx][j] = block_idx_list[block_idx]
block_idx += 1
block_table_batch_idx += 1
return cmp_block_table
def get_decode_case(device="cpu"):
b = 64
s1 = 2
s2 = 8 * 1024
q_d = 512
nq = 64
nkv = 1
block_table_batch = b
block_size = 128
cmp_ratio = 128
kv_num_blocks = b * ((s2 + block_size - 1) // block_size)
actual_seq_values = [s2] * b
actual_seq_tensor = torch.tensor(actual_seq_values, dtype=torch.int32, device=device)
attn_cfg = AttentionConfig(b=b, s1=s1, s2=s2, n1=nq, n2=nkv,
q_d=q_d, kv_d=q_d, block_size=block_size, block_table_batch=block_table_batch,
kv_num_blocks=kv_num_blocks, actual_seq=actual_seq_tensor, cmp_ratio=cmp_ratio)
attn_cfg.max_blocks = (s2 + block_size - 1) // block_size
return attn_cfg
class MM(torch.nn.Module):
def forward(
self,
q: torch.Tensor,
cmp_kv: torch.Tensor,
sinks: torch.Tensor,
cmp_block_table: torch.Tensor,
seqused_kv: torch.Tensor,
ori_kv: torch.Tensor,
ori_block_table: torch.Tensor,
cmp_ratio: int = 1,
):
return cfa_graph(q, cmp_kv, sinks, cmp_block_table, seqused_kv, ori_kv, \
ori_block_table, cmp_ratio)
def softmax(x, sinks, is_fp16=False, is_new_sink=False):
if is_fp16:
original_dtype = x.dtype
x = x.float()
x_max = x.max(dim=-1, keepdim=True).values
x_sub = x - x_max
y = torch.exp(x_sub)
x_sum = y.sum(dim=-1, keepdim=True)
if sinks is not None:
if not is_new_sink:
x_sum += sinks.unsqueeze(-1)
else:
x_sum += torch.exp(sinks.unsqueeze(-1) - x_max)
ans = y / x_sum
if is_fp16:
ans = ans.to(original_dtype)
x_max = x_max.to(original_dtype)
x_sum = x_sum.to(original_dtype)
return ans, x_max, x_sum
def matmul_proxy(left, right):
fp32 = torch.float32
return torch.matmul(left.to(fp32), right.to(fp32)).to(left.dtype)
def get_block_kv(kv_2d, cmp_block_table, b_idx, s2_idx, block_size, cur_seq):
block_idx = cmp_block_table[b_idx][s2_idx]
block_idx_valid = max(block_idx, 0)
actual_s2_tile = min(block_size, cur_seq - s2_idx * block_size)
kj_start = block_idx_valid * block_size
kj_end = kj_start + actual_s2_tile
kvj = kv_2d[kj_start:kj_end, :]
return kvj
def flash_end(out, sinks, li_upd, mi_upd, oi_upd, n2g_ofs, g_tile, bs_ofs, dtype, is_new_sink=False):
li = li_upd.unsqueeze(-1)
if sinks is not None:
if not is_new_sink:
li += sinks.unsqueeze(-1)
else:
li += torch.exp(sinks - mi_upd).unsqueeze(-1)
oi_final = oi_upd / li
oi_upd_3d = oi_final.unsqueeze(0)
attn_out_start = n2g_ofs
attn_out_end = n2g_ofs + g_tile
if attn_out_end > out.shape[1]:
attn_out_end = out.shape[1]
attn_out_start = attn_out_end - g_tile
out[bs_ofs : bs_ofs + 1, attn_out_start:attn_out_end, :] = (
oi_upd_3d.to(dtype)
)
def kv_cache_concat_bsnd(kv_cache_out, cmp_block_table, actual_seqs):
b = actual_seqs.shape[0]
n2 = kv_cache_out.shape[2]
d = kv_cache_out.shape[3]
block_size = kv_cache_out.shape[1]
dtype = kv_cache_out.dtype
kv_max = (torch.max(actual_seqs).item() + block_size - 1) // block_size * block_size
cmp_kv = torch.zeros([b, kv_max, n2, d], dtype=dtype).to(kv_cache_out.device)
for b_idx in range(b):
block_list = cmp_block_table[b_idx]
kv_nope_temp_tensor = torch.zeros([1, kv_max, n2, d], dtype=dtype)
s_idx = 0
for _, block_idx in enumerate(block_list):
if block_idx == -1:
break
start_idx = s_idx * block_size
end_idx = (s_idx + 1) * block_size
kv_nope_temp_tensor[:, start_idx:end_idx, :, :] = kv_cache_out[
block_idx : block_idx + 1, :, :, :
]
s_idx += 1
cmp_kv[b_idx : b_idx + 1, :, :, :] = kv_nope_temp_tensor
return cmp_kv
def ifa_flash_torch(q, cmp_kv, sinks, cmp_block_table, seqused_kv, output_flash, tmp_out, cmp_ratio=128, is_new_sink=False,
ori_kv=None, ori_block_table=None):
"""
Args:
q: Query [batch_size * s1, num_head, head_size]
k: Key cache [num_blocks, block_size, kv_head_num, head_size]
v: Value cache [num_blocks, block_size, kv_head_num, head_size]
cmp_block_table: Block mapping table for compress cmp_kv cache [batch_size, max_num_blocks_per_query]
start_pos: Actual start position [batch_size], satisify start_pos + s1 = original actual seq
out: Output [batch_size * s1, num_head, head_size]
"""
fp32 = torch.float32
q_shape = q.shape
device = q.device
dtype = q.dtype
bs1, n1, d = q_shape[0], q_shape[1], q_shape[2]
b = seqused_kv.shape[0]
s1 = bs1 // b
k_shape = cmp_kv.shape
_, block_size, n2, _ = k_shape
g = n1 // n2
g_tile = g
kv_2d = cmp_kv.reshape(-1, d)
q_2d = q.reshape(-1, d)
scale = d ** -0.5
win = 128
for b_idx in range(b):
for s1_idx in range(s1):
cur_seq = (seqused_kv[b_idx] - (s1 - 1 - s1_idx)) // cmp_ratio
cur_seq = max(cur_seq, 0)
s2_loop = math.ceil(cur_seq / block_size)
for g_idx in range(g // g_tile):
oi_upd = torch.zeros((g_tile, d), device=device, dtype=fp32)
li_upd = torch.zeros(g_tile, device=device, dtype=fp32)
mi_upd = torch.zeros(g_tile, device=device, dtype=fp32)
bs_ofs = b_idx * s1 + s1_idx
n2g_ofs = g_idx * g_tile
qi_start = bs_ofs * n1 + n2g_ofs
qi_end = qi_start + g_tile
qi = q_2d[qi_start:qi_end, :]
if ori_kv is not None and ori_block_table is not None:
kv_win_2d = ori_kv.reshape(-1, d)
valid_len = seqused_kv[b_idx] - (s1 - s1_idx - 1)
valid_win_len = min(valid_len, win)
valid_start_pos = valid_len - valid_win_len
valid_end_pos = valid_len - 1
start_offset = valid_start_pos % block_size
start_block = valid_start_pos // block_size
end_block = valid_end_pos // block_size
kv_list = []
for block_idx in range(start_block, end_block + 1):
block_idx_valid = max(ori_block_table[b_idx, block_idx], 0)
block_offset = block_idx_valid * block_size
kv_block = kv_win_2d[block_offset: block_offset + block_size, :]
kv_list.append(kv_block)
kv_cur = torch.cat(kv_list, axis=0)
kv_cur = kv_cur[start_offset : start_offset + valid_win_len, :]
mm1 = matmul_proxy(qi, kv_cur.t())
muls_res = mm1 * scale
tilda_mij, _ = torch.max(muls_res, dim=-1, keepdim=True)
tsub = muls_res - tilda_mij
tilda_pij = torch.exp(tsub)
tilda_lij = torch.sum(tilda_pij, dim=-1, keepdim=True)
oi_tmp = matmul_proxy(tilda_pij.to(dtype), kv_cur)
oi_upd = oi_tmp
li_upd = tilda_lij.squeeze(-1)
mi_upd = tilda_mij.squeeze(-1)
if s2_loop == 0:
flash_end(output_flash, sinks, li_upd, mi_upd, oi_upd, n2g_ofs, g_tile, bs_ofs, dtype, \
is_new_sink=is_new_sink)
for s2_idx in range(s2_loop):
kvj = get_block_kv(kv_2d, cmp_block_table, b_idx, s2_idx, block_size, cur_seq)
mm1 = matmul_proxy(qi, kvj.t())
muls_res = mm1 * scale
tilda_mij, _ = torch.max(muls_res, dim=-1, keepdim=True)
if s2_idx == 0 and ori_kv is None:
tsub = muls_res - tilda_mij
tilda_pij = torch.exp(tsub)
tilda_lij = torch.sum(tilda_pij, dim=-1, keepdim=True)
oi_tmp = matmul_proxy(tilda_pij.to(dtype), kvj)
oi_upd = oi_tmp
li_upd = tilda_lij.squeeze(-1)
mi_upd = tilda_mij.squeeze(-1)
else:
mi = mi_upd.unsqueeze(-1)
max_new, _ = torch.max(
torch.cat([mi, tilda_mij], dim=-1), dim=-1, keepdim=True
)
tsub = muls_res - max_new
tilda_pij = torch.exp(tsub)
tilda_lij = torch.sum(tilda_pij, dim=-1, keepdim=True)
tsub2 = torch.sub(mi, max_new)
mi_upd = max_new.squeeze(-1)
update_mul = torch.exp(tsub2)
li = li_upd.unsqueeze(-1)
sum_new = li * update_mul + tilda_lij
li_upd = sum_new.squeeze(-1)
q1 = matmul_proxy(tilda_pij.to(dtype), kvj)
oi_upd = oi_upd * update_mul + q1
if s2_idx == s2_loop - 1:
flash_end(output_flash, sinks, li_upd, mi_upd, oi_upd, n2g_ofs, g_tile, bs_ofs, dtype, \
is_new_sink=is_new_sink)
return output_flash
def ifa_golden(q, cmp_kv, sinks, cmp_block_table, seqused_kv, output_flash, tmp_out, enable_flash=True, cmp_ratio=1,
is_new_sink=True, ori_kv=None, ori_block_table=None):
if not enable_flash:
fp64 = torch.float64
b = seqused_kv.shape[0]
bs = q.shape[0]
s1 = bs // b
nkv = cmp_kv.shape[2]
d = cmp_kv.shape[3]
softmax_scale = d**-0.5
compress_actual_seqs = seqused_kv // cmp_ratio
kv_bsnd = kv_cache_concat_bsnd(
cmp_kv, cmp_block_table, compress_actual_seqs
)
if ori_kv is not None and ori_block_table is not None:
k_cfa_bsnd = kv_cache_concat_bsnd(
cmp_kv, cmp_block_table, compress_actual_seqs
)
k_win_bsnd = kv_cache_concat_bsnd(
ori_kv, ori_block_table, seqused_kv
)
kv_bsnd = torch.cat([k_cfa_bsnd], dim=1)
for i in range(b):
for j in range(s1):
for n2_idx in range(nkv):
seq_end = seqused_kv[i] - (s1 - 1 - j)
seq_len = seq_end // cmp_ratio
q_bs = q[i * s1 + j]
kv_win_view = k_win_bsnd[i, max(seq_end-128,0):seq_end,:,:].reshape(-1,d)
kv_bs = kv_bsnd[i, :seq_len, n2_idx : n2_idx + 1].reshape(
seq_len, d
)
kv_bs = torch.cat([kv_win_view,kv_bs], dim=0)
q_bs = q_bs.to(fp64)
kv_bs_64 = kv_bs.to(fp64)
qk_bmm_res = matmul_proxy(q_bs, kv_bs_64.transpose(1, 0))
qk_ele_res = qk_bmm_res * softmax_scale
softmax_res, _, _ = softmax(qk_ele_res, sinks, True, is_new_sink=is_new_sink)
bmm2_res = matmul_proxy(softmax_res.to(output_flash.dtype), kv_bs.to(output_flash.dtype))
output_flash[i * s1 + j] = bmm2_res.to(output_flash.dtype)
return output_flash, kv_bs
else:
output_flash = ifa_flash_torch(
q=q,
cmp_kv=cmp_kv,
sinks=sinks,
cmp_block_table=cmp_block_table,
seqused_kv=seqused_kv,
output_flash=output_flash,
tmp_out=tmp_out,
cmp_ratio=cmp_ratio,
is_new_sink=is_new_sink,
ori_kv=ori_kv,
ori_block_table=ori_block_table,
)
return output_flash
def c128(enable_flash: bool, enable_high_perf: bool, enable_graph: bool, device: str, attn_cfg: AttentionConfig):
torch_dtype = torch.bfloat16
b = attn_cfg.b
s1 = attn_cfg.s1
d = attn_cfg.q_d
nq = attn_cfg.n1
nkv = attn_cfg.n2
cmp_ratio = attn_cfg.cmp_ratio
block_size = attn_cfg.block_size
max_blocks = attn_cfg.max_blocks
seqused_kv = attn_cfg.actual_seq
q_shape = [b * s1, nq, d]
cmp_kv_shape = [attn_cfg.kv_num_blocks, block_size, nkv, d]
cmp_blk_tbl_shape = [attn_cfg.block_table_batch, max_blocks]
max_actual_seq = max(seqused_kv)
win_max_actual_seq = max(max_actual_seq, block_size + s1 - 1)
win_max_blocks = math.ceil(win_max_actual_seq / block_size)
ori_kv_shape = [b * win_max_blocks, block_size, nkv, d]
ori_blk_tbl_shape = cmp_blk_tbl_shape
empty_kwargs = {"dtype": torch_dtype, "device": device}
q = torch.empty(q_shape, **empty_kwargs).uniform_(-1, 1)
cmp_kv = torch.empty(cmp_kv_shape, **empty_kwargs).uniform_(-1, 1)
sinks = torch.empty(nq, dtype=torch.float32, device=device).uniform_(-1, 1)
ori_kv = torch.empty(ori_kv_shape, **empty_kwargs).uniform_(-1, 1)
ori_block_table = gen_block_table(seqused_kv, block_size, ori_blk_tbl_shape, cmp_ratio=cmp_ratio, \
enable_win=True)
tmp_out_golden = torch.zeros((b * s1 * 2 * block_size, q_shape[2]), **empty_kwargs) + 1
output_flash = torch.zeros(q_shape, **empty_kwargs)
cmp_block_table = gen_block_table(seqused_kv, block_size, cmp_blk_tbl_shape, cmp_ratio=cmp_ratio)
attention_out = torch.zeros(q_shape, **empty_kwargs)
ifa_golden(q, cmp_kv, sinks, cmp_block_table, seqused_kv, output_flash, tmp_out_golden, enable_flash=False, \
cmp_ratio=cmp_ratio, is_new_sink=True, ori_kv=ori_kv, ori_block_table=ori_block_table)
if enable_graph:
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
compiler_config = CompilerConfig()
compiler_config.mode = "reduce-overhead"
npu_backend = tng.get_npu_backend(compiler_config=compiler_config)
model = torch.compile(MM(), dynamic=False, fullgraph=True, backend=npu_backend)
for _ in range(10):
attention_out = model(q, cmp_kv, sinks, cmp_block_table, seqused_kv, ori_kv, ori_block_table, cmp_ratio)
pypto.runtime._device_synchronize()
else:
for _ in range(10):
attention_out = cfa_attention(q, cmp_kv, sinks, cmp_block_table, seqused_kv, ori_kv, ori_block_table, cmp_ratio)
import utils.compare as compare
compare.compare(output_flash, attention_out, "golden vs npu", rtol=0.0078125, atol=0.0001)
def test_c128_decode(enable_flash: bool=False, enable_high_perf: bool=False, enable_graph: bool=False, \
device_id: int = 0):
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
device = f'npu:{device_id}'
attn_cfg = get_decode_case(device=device)
c128(enable_flash=enable_flash, enable_high_perf=enable_high_perf, enable_graph=enable_graph, device=device, \
attn_cfg=attn_cfg)
if __name__ == "__main__":
import argparse as ap
import utils.golden.attn_golden as attn_golden
p = ap.ArgumentParser(description="参数配置")
p.add_argument("-f", "--enable-flash", action="store_true", help="开启flash模式")
p.add_argument("-p", "--high-perf", action="store_true", help="启用高性能模式")
p.add_argument("-g", "--enable-graph", action="store_true", help="启用高性能模式")
p.add_argument("-c", "--device-id", type=int, default=0, help="显卡序号,默认0")
p.add_argument("-u", "--upper", type=int, default=6000, help="融合上限法")
args = p.parse_args()
test_c128_decode(enable_flash=args.enable_flash, enable_high_perf=args.high_perf, enable_graph=args.enable_graph,
device_id=args.device_id)