"""
Gated Delta Rule Implementation Module
This module implements the core computation functions for the Chunk Gated Delta Rule
attention mechanism used in Qwen3Next model. It provides efficient linear attention
computation with O(n) complexity instead of O(n²) for traditional attention.
Main Functions:
- l2norm: L2 normalization for query and key
- pre_attn: Pre-attention computation including gate cumsum and decay mask
- inverse_pto: Block-wise matrix inversion
- cal_value_and_key_cumdecay: Value and key cumulative decay computation
- recurrent_state_attn_all: Recurrent state attention computation
- chunk_gated_delta_rule: Main fused operator entry point
Example:
See qwen3_next_gated_delta_rule.py for usage examples.
"""
import pypto
def l2norm(
query: pypto.Tensor, key: pypto.Tensor, eps: float = 1e-6
) -> tuple[pypto.Tensor, pypto.Tensor]:
"""
L2 normalization.
Parameters
---------
query: [L, D]
key: [L, D]
eps=1e-6
Return
---------
query_after_l2norm: [L, D]
key_after_l2norm: [L, D]
"""
pypto.set_vec_tile_shapes(128, 128)
query_after_l2norm = query / pypto.sqrt((query * query).sum(-1, keepdim=True) + eps)
key_after_l2norm = key / pypto.sqrt((key * key).sum(-1, keepdim=True) + eps)
return query_after_l2norm, key_after_l2norm
def pre_attn(
gate_view: pypto.Tensor,
key_view_2d: pypto.Tensor,
beta_view: pypto.Tensor,
tril: pypto.Tensor,
mask: pypto.Tensor,
) -> tuple[pypto.Tensor, pypto.Tensor, pypto.Tensor, pypto.Tensor]:
"""
Calculate gate_cumsum, decay_mask, beta_k and kkt.
Parameters
---------
gate: [L, 1]
key: [L, D]
beta: [L, 1]
tril: [L, L]
mask: [L, L]
---------
gate_cum: [L, 1]
decay_mask: [L, L]
A: [L, L]
key_beta: [L, D]
"""
pypto.set_vec_tile_shapes(128, 128)
pypto.set_cube_tile_shapes([128, 128], [128, 128], [128, 128])
gate_cum = pypto.matmul(tril, gate_view, pypto.DT_FP32)
decay_mask = ((gate_cum - gate_cum.transpose(0, 1)) * tril).exp()
key_beta = key_view_2d * beta_view
kkt = pypto.matmul(key_beta, key_view_2d, pypto.DT_FP32, b_trans=True)
a = kkt * decay_mask * mask
return gate_cum, decay_mask, a, key_beta
def inverse_pto(**kwargs) -> pypto.Tensor:
"""
Calculate inverse of big matrix.
Parameters
---------
attn: [L, L]
eye: [L // 8, L]
size: matrix size
zeros_16: 16 * 16 zero matrix
zeros_32: 32 * 32 zero matrix
zeros_64: 64 * 64 zero matrix
Return
---------
attn_inv: [L, L]
"""
attn = kwargs.get("attn")
eye = kwargs.get("eye")
size = kwargs.get("size")
zeros_16 = kwargs.get("zeros_16")
zeros_32 = kwargs.get("zeros_32")
zeros_64 = kwargs.get("zeros_64")
min_length = size // 8
pypto.set_vec_tile_shapes(128, 128)
attn_8_8_list = []
for i in range(8):
attn_8_8_list.append(attn.view([min_length, min_length], [min_length * i, min_length * i]) + 0.0)
attn_tmp_dim1 = pypto.concat(attn_8_8_list, dim=1)
attn_tmp_dim1_inv = inverse_pto_min_length(attn_tmp_dim1, eye, min_length, min_length * 8)
attn_8_8_inv_list = []
for i in range(8):
attn_8_8_inv_list.append(attn_tmp_dim1_inv[:, min_length * i:min_length * (i + 1)] + 0.0)
attn_4_inv_list = []
for i in range(4):
attn_4_inv_list.append(inverse_matmul(attn=attn, attn_1_1_inv=attn_8_8_inv_list[i * 2],
attn_2_2_inv=attn_8_8_inv_list[i * 2 + 1], x_ofs=min_length * i * 2, y_ofs=min_length * i * 2,
m_len=min_length, zero_tensor=zeros_16))
attn_2_inv_list = []
for i in range(2):
attn_2_inv_list.append(inverse_matmul(attn=attn, attn_1_1_inv=attn_4_inv_list[i * 2],
attn_2_2_inv=attn_4_inv_list[i * 2 + 1], x_ofs=min_length * i * 4, y_ofs=min_length * i * 4,
m_len=min_length * 2, zero_tensor=zeros_32))
attn_inv = inverse_matmul(attn=attn, attn_1_1_inv=attn_2_inv_list[0],
attn_2_2_inv=attn_2_inv_list[1], x_ofs=0, y_ofs=0, m_len=min_length * 4, zero_tensor=zeros_64)
return attn_inv
def inverse_pto_min_length(
attn_dim1: pypto.Tensor,
eye: pypto.Tensor,
row_num: int,
col_num: int,
) -> pypto.Tensor:
"""
Calculate inverse of matrix with tail concat optimization.
Parameters
---------
attn_dim0: [L, L // 8]
attn_dim1: [L // 8, L]
eye: [L // 8, L]
row_num: L // 8
col_num: L
Return
---------
res: [L, L]
"""
size = col_num // row_num
attn_inv_list = {}
attn_inv_list[1] = attn_dim1[:2, :]
pypto.set_vec_tile_shapes(128, 128)
for i in range(2, row_num, 1):
attn_inv_cur = attn_inv_list.get(i - 1) + 0.0
row = attn_dim1.view([1, col_num], [i, 0])
row_expand = row.reshape([size, row_num]).view([size, i], [0, 0]).transpose(1, 0).reshape([size * i, 1])
attn_inv_cur_reshape = attn_inv_cur.reshape([size * i, row_num])
prod_mul = (row_expand * attn_inv_cur_reshape).reshape([i, col_num])
prod = prod_mul.sum(0, keepdim=True)
attn_update = row + prod
attn_inv_list[i] = pypto.concat([attn_inv_cur, attn_update], dim=0)
res = attn_inv_list.get(row_num - 1) + eye
return res
def inverse_matmul(**kwargs) -> pypto.Tensor:
"""
Calculate inverse of small matrix.
Parameters
---------
attn: [L, L]
attn_1_1_inv: attn upper left matrix
attn_2_2_inv: attn bottom right matrix
x_ofs: row offset
y_ofs: column offset
m_len: matrix length
zero_tensor: zero matrix
Return
---------
attn_inv: [len * 2, len * 2]
"""
attn = kwargs.get("attn")
attn_1_1_inv = kwargs.get("attn_1_1_inv")
attn_2_2_inv = kwargs.get("attn_2_2_inv")
x_ofs = kwargs.get("x_ofs")
y_ofs = kwargs.get("y_ofs")
m_len = kwargs.get("m_len")
zero_tensor = kwargs.get("zero_tensor")
pypto.set_cube_tile_shapes([128, 128], [128, 128], [128, 128])
attn_2_1 = attn.view([m_len, m_len], [x_ofs + m_len, y_ofs])
attn_2_1_inv = (attn_2_2_inv @ attn_2_1) @ attn_1_1_inv
attn_inv = pypto.tensor([m_len * 2, m_len * 2], dtype=attn_1_1_inv.dtype)
attn_inv[0:m_len, 0:m_len] = attn_1_1_inv
attn_inv[0:m_len, m_len:m_len * 2] = zero_tensor
attn_inv[m_len:m_len * 2, 0:m_len] = attn_2_1_inv
attn_inv[m_len:m_len * 2, m_len:m_len * 2] = attn_2_2_inv
return attn_inv
def cal_value_and_key_cumdecay(
attn: pypto.Tensor,
value_view: pypto.Tensor,
beta_view: pypto.Tensor,
key_beta: pypto.Tensor,
gate_cum: pypto.Tensor,
) -> tuple[pypto.Tensor, pypto.Tensor]:
"""
Calculate value and k cumdecay.
Parameters
---------
attn: [L, L]
value_view: [L, D]
beta_view: [L, D]
key_beta: [L, D]
gate_cum: [L, 1]
Return
---------
value_out: [L, D]
key_cum_out: [L, D]
"""
pypto.set_vec_tile_shapes(128, 128)
pypto.set_cube_tile_shapes([128, 128], [128, 128], [128, 128])
value_beta_view = value_view * beta_view
value_out = pypto.matmul(attn, value_beta_view, pypto.DT_FP32)
g_exp = pypto.exp(gate_cum)
weighted_k_beta_view = key_beta * g_exp
key_cum_out = pypto.matmul(attn, weighted_k_beta_view, pypto.DT_FP32)
return value_out, key_cum_out
def recurrent_state_attn_all(**kwargs) -> tuple[pypto.Tensor, pypto.Tensor]:
"""
Calculate attention.
Parameters
---------
query: [L, D]
key: [L, D]
value:[L, Dv]
k_cumdecay:[L, Dk]
gate: [L, 1]
state: [D, D]
decay_mask: [L, L]
tril: [L, L]
Return
---------
chunk_attn_out: [L, D]
state_new:[Dv, Dk]
"""
query = kwargs.get("query")
key = kwargs.get("key")
value = kwargs.get("value")
k_cumdecay = kwargs.get("k_cumdecay")
gate = kwargs.get("gate")
state = kwargs.get("state")
decay_mask = kwargs.get("decay_mask")
tril = kwargs.get("tril")
dv = value.shape[-1]
l = gate.valid_shape[0]
gate_exp = gate.exp()
pypto.set_cube_tile_shapes([128, 128], [128, 128], [128, 128])
pypto.set_vec_tile_shapes(64, 128)
_last_gate_1 = gate[l - 1:l, :]
kgexp = key * (_last_gate_1 - gate).exp()
qgexp = query * gate_exp
pypto.set_cube_tile_shapes([128, 128], [128, 128], [64, 64])
v_prime = pypto.matmul(k_cumdecay, state, pypto.DT_FP32, b_trans=True)
attn_inter = pypto.matmul(qgexp, state, pypto.DT_FP32, b_trans=True)
pypto.set_cube_tile_shapes([64, 64], [128, 128], [128, 128])
temp_matmul_vprime = pypto.matmul(v_prime, kgexp, pypto.DT_FP32, a_trans=True)
pypto.set_cube_tile_shapes([128, 128], [128, 128], [128, 128])
temp_matmul_value = pypto.matmul(value, kgexp, pypto.DT_FP32, a_trans=True)
attn = pypto.matmul(query, key, pypto.DT_FP32, b_trans=True)
_last_gate_2 = pypto.expand_clone(gate_exp[l - 1:l, :], (dv, 1))
final_state_1 = state * _last_gate_2
state_new = final_state_1 + temp_matmul_value - temp_matmul_vprime
attn_tmp = attn * decay_mask * tril
chunk_attn_value = pypto.matmul(attn_tmp, value, pypto.DT_FP32)
pypto.set_cube_tile_shapes([128, 128], [128, 128], [64, 64])
chunk_attn_vprime = pypto.matmul(attn_tmp, v_prime, pypto.DT_FP32)
chunk_attn_out = attn_inter + chunk_attn_value - chunk_attn_vprime
return chunk_attn_out, state_new
def chunk_gated_delta_rule(b, nqk, nv, d, l):
t = pypto.DYNAMIC
b1 = b + 1
b1 = pypto.DYNAMIC
b = pypto.DYNAMIC
query_shape = [t, nqk, d]
key_shape = [t, nqk, d]
value_shape = [t, nv, d]
beta_shape = [t, nv]
gate_shape = [t, nv]
states_shape = [b, nv, d, d]
mask_shape = [l, l]
tril_mask_shape = [l, l]
eye_shape = [16, l]
act_seq_len_shape = [b1]
core_attn_out_shape = [t, nv, d]
last_state_data_shape = [b, nv, d, d]
@pypto.frontend.jit(
runtime_options={
"stitch_function_max_num": 2,
"device_sched_parallelism": 8
},
pass_options={
"vec_nbuffer_setting": {8: 16, 1: 16, 28: 32, 6: 4, 25: 32, 26: 8,
2: 8, 3: 8, 10: 32, 5: 4, 7: 16, 9: 32, 0: 4, 4: 4, -2: 1}
}
)
def kernel(
query: pypto.Tensor(query_shape, pypto.DT_FP32),
key: pypto.Tensor(key_shape, pypto.DT_FP32),
value: pypto.Tensor(value_shape, pypto.DT_FP32),
beta: pypto.Tensor(beta_shape, pypto.DT_FP32),
gate: pypto.Tensor(gate_shape, pypto.DT_FP32),
states: pypto.Tensor(states_shape, pypto.DT_FP32),
mask: pypto.Tensor(mask_shape, pypto.DT_FP32),
tril_mask: pypto.Tensor(tril_mask_shape, pypto.DT_FP32),
eye: pypto.Tensor(eye_shape, pypto.DT_FP32),
act_seq_len: pypto.Tensor(act_seq_len_shape, pypto.DT_INT32),
core_attn_out: pypto.Tensor(core_attn_out_shape, pypto.DT_FP32),
last_state_data: pypto.Tensor(last_state_data_shape, pypto.DT_FP32)
):
"""
Chunk Gated Delta Rule fused operator.
This is the main entry point for the Gated Delta Rule attention computation
in the scenario where Sequence length is divisible by L.
It processes input sequences in chunks of size L=128, maintaining recurrent
state across chunks for efficient long sequence modeling.
Parameters
----------
query : Input query tensor, shape [T, Nqk, D], dtype float32
key : Input key tensor, shape [T, Nqk, D], dtype float32
value : Input value tensor, shape [T, Nv, D], dtype float32
beta : Beta scaling factor, shape [T, Nv], dtype float32
gate : Gate signal, shape [T, Nv], dtype float32
states : Initial recurrent states, shape [B, Nv, D, D], dtype float32
mask : Attention mask (lower triangular negative), shape [L, L], dtype float32
tril_mask : Lower triangular mask, shape [L, L], dtype float32
eye : Identity matrix (specially processed), shape [16, 128], dtype float32
act_seq_len : Cumulative sequence length indices, shape [B+1], dtype int32
core_attn_out : Output attention tensor, shape [T, Nv, D], dtype float32
last_state_data : Output updated states, shape [B, Nv, D, D], dtype float32
"""
_, nqk, d = query.shape
_, nv, d = value.shape
b = states.shape[0]
l, l = mask.shape
group = nv // nqk
last_state = pypto.tensor([d, d], pypto.DT_FP32)
pypto.experimental.set_operation_options(combine_axis=True)
for b_idx in pypto.loop(b, name="LOOP_B_TND", idx_name="b_idx"):
s = act_seq_len[b_idx + 1] - act_seq_len[b_idx]
b_ofs = act_seq_len[b_idx]
for nv_idx in pypto.loop(nv, name="LOOP_Nv_TND", idx_name="nv_idx", parallel=True):
nqk_idx = nv_idx // group
pypto.set_vec_tile_shapes(16, 16, 128, 128)
last_state = states[b_idx, nv_idx]
for s_idx in pypto.loop(0, s, l, name="LOOP_S_TND", idx_name="s_idx", unroll_list=[16, 1]):
bs_ofs = b_ofs + s_idx
actual_l = (s - s_idx).min(l)
query_view = pypto.view(query, [l, 1, d], [bs_ofs, nqk_idx, 0], valid_shape=[actual_l, 1, d])
key_view = pypto.view(key, [l, 1, d], [bs_ofs, nqk_idx, 0], valid_shape=[actual_l, 1, d])
value_view = pypto.view(value, [l, 1, d], [bs_ofs, nv_idx, 0], valid_shape=[actual_l, 1, d])
beta_view = pypto.view(beta, [l, 1], [bs_ofs, nv_idx], valid_shape=[actual_l, 1])
gate_view = pypto.view(gate, [l, 1], [bs_ofs, nv_idx], valid_shape=[actual_l, 1])
pypto.set_vec_tile_shapes(128, 128, 128)
query_view_2d = pypto.reshape(query_view, [l, d], valid_shape=[actual_l, d])
key_view_2d = pypto.reshape(key_view, [l, d], valid_shape=[actual_l, d])
value_view_2d = pypto.reshape(value_view, [l, d], valid_shape=[actual_l, d])
zeros_16 = pypto.full(size=[16, 16], fill_value=0.0, dtype=pypto.DT_FP32)
zeros_32 = pypto.full(size=[32, 32], fill_value=0.0, dtype=pypto.DT_FP32)
zeros_64 = pypto.full(size=[64, 64], fill_value=0.0, dtype=pypto.DT_FP32)
query_norm, key_norm = l2norm(query_view_2d, key_view_2d)
scale = 1 / d ** 0.5
query_scale = query_norm * scale
gate_cum, decay_mask, a_block, key_beta = pre_attn(gate_view, key_norm, beta_view, tril_mask, mask)
a_block_inverse = inverse_pto(attn=a_block, eye=eye, size=128, zeros_16=zeros_16,
zeros_32=zeros_32, zeros_64=zeros_64)
value_out, key_cum_out = cal_value_and_key_cumdecay(a_block_inverse, value_view_2d,
beta_view, key_beta, gate_cum)
chunk_attn_out, cur_state = recurrent_state_attn_all(query=query_scale, key=key_norm,
value=value_out, k_cumdecay=key_cum_out, gate=gate_cum, state=last_state,
decay_mask=decay_mask, tril=tril_mask)
last_state[:] = cur_state
core_attn_out[bs_ofs:bs_ofs + l, nv_idx] = chunk_attn_out
last_state_data[b_idx, nv_idx] = last_state
return kernel
def chunk_gated_delta_rule_unaligned(b, nqk, nv, d, l):
t_unaligned = pypto.DYNAMIC
b1 = b + 1
b1 = pypto.DYNAMIC
b = pypto.DYNAMIC
query_shape = [t_unaligned, nqk, d]
key_shape = [t_unaligned, nqk, d]
value_shape = [t_unaligned, nv, d]
beta_shape = [t_unaligned, nv]
gate_shape = [t_unaligned, nv]
states_shape = [b, nv, d, d]
mask_shape = [l, l]
tril_mask_shape = [l, l]
eye_shape = [16, l]
act_seq_len_shape = [b1]
core_attn_out_shape = [t_unaligned, nv, d]
last_state_data_shape = [b, nv, d, d]
@pypto.frontend.jit(
runtime_options={
"stitch_function_max_num": 2,
"device_sched_parallelism": 8
},
pass_options={
"vec_nbuffer_setting": {9: 16, 1: 16, 28: 32, 6: 4, 25: 32, 26: 8,
2: 8, 3: 8, 10: 32, 5: 4, 7: 16, 8: 32, 0: 4, 4: 4, 35: 8, 17: 8, 14: 30, 36: 16, -2: 1}
}
)
def kernel(
query: pypto.Tensor(query_shape, pypto.DT_FP32),
key: pypto.Tensor(key_shape, pypto.DT_FP32),
value: pypto.Tensor(value_shape, pypto.DT_FP32),
beta: pypto.Tensor(beta_shape, pypto.DT_FP32),
gate: pypto.Tensor(gate_shape, pypto.DT_FP32),
states: pypto.Tensor(states_shape, pypto.DT_FP32),
mask: pypto.Tensor(mask_shape, pypto.DT_FP32),
tril_mask: pypto.Tensor(tril_mask_shape, pypto.DT_FP32),
eye: pypto.Tensor(eye_shape, pypto.DT_FP32),
act_seq_len: pypto.Tensor(act_seq_len_shape, pypto.DT_INT32),
core_attn_out: pypto.Tensor(core_attn_out_shape, pypto.DT_FP32),
last_state_data: pypto.Tensor(last_state_data_shape, pypto.DT_FP32)
):
"""
Chunk Gated Delta Rule fused operator.
This is the main entry point for the Gated Delta Rule attention computation
in the scenario where Sequence length is not divisible by L.
It processes input sequences in chunks of size L=128, maintaining recurrent
state across chunks for efficient long sequence modeling.
Parameters
----------
query : Input query tensor, shape [T, Nqk, D], dtype float32
key : Input key tensor, shape [T, Nqk, D], dtype float32
value : Input value tensor, shape [T, Nv, D], dtype float32
beta : Beta scaling factor, shape [T, Nv], dtype float32
gate : Gate signal, shape [T, Nv], dtype float32
states : Initial recurrent states, shape [B, Nv, D, D], dtype float32
mask : Attention mask (lower triangular negative), shape [L, L], dtype float32
tril_mask : Lower triangular mask, shape [L, L], dtype float32
eye : Identity matrix (specially processed), shape [16, 16], dtype float32
act_seq_len : Cumulative sequence length indices, shape [B+1], dtype int32
core_attn_out : Output attention tensor, shape [T, Nv, D], dtype float32
last_state_data : Output updated states, shape [B, Nv, D, D], dtype float32
"""
_, nqk, d = query.shape
_, nv, d = value.shape
b = states.shape[0]
l, l = mask.shape
group = nv // nqk
pypto.experimental.set_operation_options(combine_axis=True)
for b_idx in pypto.loop(b, name="LOOP_B_TND", idx_name="b_idx"):
s = act_seq_len[b_idx + 1] - act_seq_len[b_idx]
b_ofs = act_seq_len[b_idx]
for nv_idx in pypto.loop(nv, name="LOOP_Nv_TND", idx_name="nv_idx", parallel=True):
nqk_idx = nv_idx // group
pypto.set_vec_tile_shapes(16, 16, 128, 128)
last_state = states[b_idx, nv_idx]
for s_idx in pypto.loop(0, s, l, name="LOOP_S_TND", idx_name="s_idx", unroll_list=[16, 1]):
bs_ofs = b_ofs + s_idx
actual_l = (s - s_idx).min(l)
zeros_16 = pypto.full(size=[16, 16], fill_value=0.0, dtype=pypto.DT_FP32)
zeros_32 = pypto.full(size=[32, 32], fill_value=0.0, dtype=pypto.DT_FP32)
zeros_64 = pypto.full(size=[64, 64], fill_value=0.0, dtype=pypto.DT_FP32)
query_view = pypto.view(query, [l, 1, d], [bs_ofs, nqk_idx, 0], valid_shape=[actual_l, 1, d])
key_view = pypto.view(key, [l, 1, d], [bs_ofs, nqk_idx, 0], valid_shape=[actual_l, 1, d])
value_view = pypto.view(value, [l, 1, d], [bs_ofs, nv_idx, 0], valid_shape=[actual_l, 1, d])
beta_view = pypto.view(beta, [l, 1], [bs_ofs, nv_idx], valid_shape=[actual_l, 1])
gate_view = pypto.view(gate, [l, 1], [bs_ofs, nv_idx], valid_shape=[actual_l, 1])
pypto.set_vec_tile_shapes(128, 128, 128)
query_view_2d = pypto.reshape(query_view, [l, d], valid_shape=[actual_l, d])
key_view_2d = pypto.reshape(key_view, [l, d], valid_shape=[actual_l, d])
value_view_2d = pypto.reshape(value_view, [l, d], valid_shape=[actual_l, d])
if pypto.is_loop_end(s_idx):
pad_q = pypto.fillpad(query_view_2d, "constant", 0.0)
pad_k = pypto.fillpad(key_view_2d, "constant", 0.0)
pad_v = pypto.fillpad(value_view_2d, "constant", 0.0)
pad_b = pypto.fillpad(beta_view, "constant", 0.0)
pad_g = pypto.fillpad(gate_view, "constant", 0.0)
query_norm, key_norm = l2norm(pad_q, pad_k)
scale = 1 / d ** 0.5
query_scale = query_norm * scale
gate_cum, decay_mask, a_block, key_beta = pre_attn(pad_g, key_norm, pad_b, tril_mask, mask)
a_block_inverse = inverse_pto(attn=a_block, eye=eye, size=128, zeros_16=zeros_16,
zeros_32=zeros_32, zeros_64=zeros_64)
value_out, key_cum_out = cal_value_and_key_cumdecay(
a_block_inverse, pad_v, pad_b, key_beta, gate_cum)
chunk_attn_out, cur_state = recurrent_state_attn_all(query=query_scale, key=key_norm,
value=value_out, k_cumdecay=key_cum_out, gate=gate_cum, state=last_state,
decay_mask=decay_mask, tril=tril_mask)
last_state[:] = cur_state
last_state_data[b_idx, nv_idx] = last_state
pypto.set_vec_tile_shapes(128, 16, 128)
chunk_attn_out_reshaped = chunk_attn_out.reshape([l, 1, d], valid_shape=[actual_l, 1, d])
pypto.assemble(chunk_attn_out_reshaped, [bs_ofs, nv_idx, 0], core_attn_out)
else:
query_norm, key_norm = l2norm(query_view_2d, key_view_2d)
scale = 1 / d ** 0.5
query_scale = query_norm * scale
gate_cum, decay_mask, a_block, key_beta = pre_attn(
gate_view, key_norm, beta_view, tril_mask, mask)
a_block_inverse = inverse_pto(attn=a_block, eye=eye, size=128, zeros_16=zeros_16,
zeros_32=zeros_32, zeros_64=zeros_64)
value_out, key_cum_out = cal_value_and_key_cumdecay(
a_block_inverse, value_view_2d, beta_view, key_beta, gate_cum)
chunk_attn_out, cur_state = recurrent_state_attn_all(query=query_scale, key=key_norm,
value=value_out, k_cumdecay=key_cum_out, gate=gate_cum, state=last_state,
decay_mask=decay_mask, tril=tril_mask)
last_state[:] = cur_state
last_state_data[b_idx, nv_idx] = last_state
core_attn_out[bs_ofs:bs_ofs + l, nv_idx] = chunk_attn_out
return kernel