ChunkGatedDeltaRule
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | × |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:完成chunk版的Gated Delta Rule计算。
-
计算公式:
Gated Delta Rule(门控Delta规则,GDR)是一种应用于循环神经网络的算子,也被应用于一种线性注意力机制中。在每个时间步 tt,GDR根据当前的输入 qtq_t、ktk_t、vtv_t、上一个隐藏状态 St−1S_{t-1}、衰减系数 αt\alpha_t 以及更新强度 βt\beta_t,计算当前的注意力输出 oto_t 和新的隐藏状态 StS_t,其计算公式如下:
St:=St−1(αt(I−βtktktT))+βtvtktT=αtSt−1+βt(vt−αtSt−1kt)ktTS_t := S_{t-1}(\alpha_t(I - \beta_t k_t k_t^T)) + \beta_t v_t k_t^T = \alpha_t S_{t-1} + \beta_t (v_t - \alpha_t S_{t-1}k_t)k_t^T
ot:=St(qt⋅scale)o_t := S_t (q_t \cdot scale)
其中,St−1,St∈RDv×DkS_{t-1},S_t \in \mathbb{R}^{D_v \times D_k},qt,kt∈RDkq_t, k_t \in \mathbb{R}^{D_k},vt∈RDvv_t \in \mathbb{R}^{D_v},αt∈R\alpha_t \in \mathbb{R},βt∈R\beta_t \in \mathbb{R},ot∈RDvo_t \in \mathbb{R}^{D_v}。
Chunked Gated Delta Rule是GDR的chunk版实现(参考论文),它通过将输入序列切块,实现了一定的并行效果,在长上下文场景其计算效率相对Recurrent Gated Delta Rule更高,适用于prefill阶段。输入一个长度为L的序列,该算子可以计算出每一步的输出 ot,t∈{1,2,..,L}o_t, t \in \{1, 2, .., L\} 以及最终的状态矩阵 SLS_L。
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| query | 输入 | 公式中的q。 | BFLOAT16 | ND |
| key | 输入 | 公式中的输入k。 | BFLOAT16 | ND |
| value | 输入 | 公式中的输入v。 | BFLOAT16 | ND |
| beta | 输入 | 公式中的β。 | BFLOAT16 | ND |
| initial_state | 输入 | 初始状态矩阵,公式中的输入S_0。 | BFLOAT16 | ND |
| actual_seq_lengths | 输入 | 每个batch的序列长度。 | INT32 | ND |
| g | 输入 | 衰减系数,公式中的α=e^g | FLOAT32 | ND |
| out | 输出 | 每一步的attention结果,公式中的o_t。 | BFLOAT16 | ND |
| final_state | 输出 | 最终的状态矩阵,公式中的S_L。 | BFLOAT16 | ND |
| scale_value | 可选属性 | query的缩放因子,公式中的scale。默认为1.0 | FLOAT | - |
约束说明
-
为方便理解后续排布格式(如 BNSD、TND 等),统一说明各缩写维度含义:
- B:输入样本批量大小(Batch)。
- T:设 LiL_i 为第 ii 个序列长度,则 T=∑iBLiT=\sum_i^B L_i 表示累积序列长度。
- Nk:Query 和 Key 头数。
- Nv:Value 头数。
- Dk:Query 和 Key 隐藏层维度。
- Dv:Value 隐藏层维度。
-
当前仅支持 TND 布局:
- query、key 形状:(T,Nk,Dk)(T, Nk, Dk)
- value、out 形状:(T,Nv,Dv)(T, Nv, Dv)
- beta、g 形状:(T,Nv)(T, Nv)
- actual_seq_lengths 形状:(B,)(B,)
- initial_state、final_state 形状:(B,Nv,Dv,Dk)(B, Nv, Dv, Dk)
维度需满足以下约束:
- 0<Nv≤64,0<Nk≤640 \lt Nv \le 64,0 \lt Nk \le 64,且 Nv mod Nk=0Nv \bmod Nk = 0
- 0<Dv≤128,0<Dk≤1280 \lt Dv \le 128,0 \lt Dk \le 128
- B>0,T>0B \gt 0,T \gt 0
-
受算法数值特性限制,需满足以下取值约束,否则易出现数值溢出、精度异常:
- 张量元素:0<query<10 < \text{query} < 1
- 张量元素:0<key<10 < \text{key} < 1
- 张量元素:g<0g < 0
- 张量元素:0<beta<10 < \text{beta} < 1
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn | test_aclnn_chunk_gated_delta_rul.cpp | 通过aclnnChunkGatedDeltaRule调用aclnnChunkGatedDeltaRule算子 |