ChunkGatedDeltaRule

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:完成chunk版的Gated Delta Rule计算。

  • 计算公式:

    Gated Delta Rule(门控Delta规则,GDR)是一种应用于循环神经网络的算子,也被应用于一种线性注意力机制中。在每个时间步 tt,GDR根据当前的输入 qtq_tktk_tvtv_t、上一个隐藏状态 St−1S_{t-1}、衰减系数 αt\alpha_t 以及更新强度 βt\beta_t,计算当前的注意力输出 oto_t 和新的隐藏状态 StS_t,其计算公式如下:

    St:=St−1(αt(I−βtktktT))+βtvtktT=αtSt−1+βt(vt−αtSt−1kt)ktTS_t := S_{t-1}(\alpha_t(I - \beta_t k_t k_t^T)) + \beta_t v_t k_t^T = \alpha_t S_{t-1} + \beta_t (v_t - \alpha_t S_{t-1}k_t)k_t^T

    ot:=St(qt⋅scale)o_t := S_t (q_t \cdot scale)

    其中,St−1,St∈RDv×DkS_{t-1},S_t \in \mathbb{R}^{D_v \times D_k}qt,kt∈RDkq_t, k_t \in \mathbb{R}^{D_k}vt∈RDvv_t \in \mathbb{R}^{D_v}αt∈R\alpha_t \in \mathbb{R}βt∈R\beta_t \in \mathbb{R}ot∈RDvo_t \in \mathbb{R}^{D_v}

    Chunked Gated Delta Rule是GDR的chunk版实现(参考论文),它通过将输入序列切块,实现了一定的并行效果,在长上下文场景其计算效率相对Recurrent Gated Delta Rule更高,适用于prefill阶段。输入一个长度为L的序列,该算子可以计算出每一步的输出 ot,t∈{1,2,..,L}o_t, t \in \{1, 2, .., L\} 以及最终的状态矩阵 SLS_L

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
query 输入 公式中的q。 BFLOAT16 ND
key 输入 公式中的输入k。 BFLOAT16 ND
value 输入 公式中的输入v。 BFLOAT16 ND
beta 输入 公式中的β。 BFLOAT16 ND
initial_state 输入 初始状态矩阵,公式中的输入S_0。 BFLOAT16 ND
actual_seq_lengths 输入 每个batch的序列长度。 INT32 ND
g 输入 衰减系数,公式中的α=e^g FLOAT32 ND
out 输出 每一步的attention结果,公式中的o_t。 BFLOAT16 ND
final_state 输出 最终的状态矩阵,公式中的S_L。 BFLOAT16 ND
scale_value 可选属性 query的缩放因子,公式中的scale。默认为1.0 FLOAT -

约束说明

  • 为方便理解后续排布格式(如 BNSD、TND 等),统一说明各缩写维度含义:

    • B:输入样本批量大小(Batch)。
    • T:设 LiL_i 为第 ii 个序列长度,则 T=∑iBLiT=\sum_i^B L_i 表示累积序列长度。
    • Nk:Query 和 Key 头数。
    • Nv:Value 头数。
    • Dk:Query 和 Key 隐藏层维度。
    • Dv:Value 隐藏层维度。
  • 当前仅支持 TND 布局:

    • query、key 形状:(T,Nk,Dk)(T, Nk, Dk)
    • value、out 形状:(T,Nv,Dv)(T, Nv, Dv)
    • beta、g 形状:(T,Nv)(T, Nv)
    • actual_seq_lengths 形状:(B,)(B,)
    • initial_state、final_state 形状:(B,Nv,Dv,Dk)(B, Nv, Dv, Dk)

    维度需满足以下约束:

    • 0<Nv≤64,0<Nk≤640 \lt Nv \le 64,0 \lt Nk \le 64,且 Nv mod Nk=0Nv \bmod Nk = 0
    • 0<Dv≤128,0<Dk≤1280 \lt Dv \le 128,0 \lt Dk \le 128
    • B>0,T>0B \gt 0,T \gt 0
  • 受算法数值特性限制,需满足以下取值约束,否则易出现数值溢出、精度异常:

    • 张量元素:0<query<10 < \text{query} < 1
    • 张量元素:0<key<10 < \text{key} < 1
    • 张量元素:g<0g < 0
    • 张量元素:0<beta<10 < \text{beta} < 1

调用说明

调用方式 样例代码 说明
aclnn test_aclnn_chunk_gated_delta_rul.cpp 通过aclnnChunkGatedDeltaRule调用aclnnChunkGatedDeltaRule算子