LightningIndexerGrad

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT x
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:lightning_indexer_gradlightning_indexer的反向算子,基于正向算子的输出sparseIndices计算querykeyweights的梯度。

  • 计算公式: LightningIndexer反向计算公式如下:

    S=Relu(Matmul(Query,Gather(Key,Indices)))S = Relu(Matmul(Query, Gather(Key, Indices)))

    Y=Dy∗WeightsY = Dy*Weights

    dW=Reduce(S∗dy)dW = Reduce(S * dy)

    dQ=Matmul(ReluGrad(Y,S),Gather(Key,Indices))dQ = Matmul(ReluGrad(Y, S), Gather(Key, Indices))

    dK=ScatterAdd(Matmul(ReluGrad(Y,S),Q),Indices)dK = ScatterAdd(Matmul(ReluGrad(Y, S), Q), Indices)

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
query 输入 公式中的输入Q,不支持空tensor和非连续。layout为BSND时,shape为(B,S1,N1,D);layout为TND时,shape为(T1,N1,D)。 FLOAT16、BFLOAT16 ND
key 输入 公式中的输入K,不支持空tensor和非连续。layout为BSND时,shape为(B,S2,N2,D);layout为TND时,shape为(T2,N2,D)。 FLOAT16、BFLOAT16 ND
dy 输入 公式中的输入dY,表示输出梯度,不支持空tensor和非连续。layout为BSND时,shape为(B,S1,N1,D);layout为TND时,shape为(T1,N1,D)。 FLOAT16、BFLOAT16 ND
sparseIndices 输入 公式中的输入Indices,为LightningIndexer正向算子的sparseIndicesOut输出,不支持空tensor和非连续。layout为BSND时,shape为(B,S1,N2,K);layout为TND时,shape为(T1,N2,K),其中K为topK保留的block数量。 INT32 ND
weights 输入 公式中的输入W,不支持空tensor和非连续。layout为BSND时,shape为(B,S1,N1);layout为TND时,shape为(T1,N1)。 FLOAT16、BFLOAT16 ND
actualSeqLengthsQuery 输入 每个Batch中Query的有效token数,不支持空tensor和非连续。可传入None表示与query的S长度相同;支持长度为B的一维tensor,且每个Batch的有效token数不超过query中的维度S大小且不小于0。layout为TND时该入参必须传入,并以元素数量作为B值;每个元素表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须大于等于前一个元素的值。 INT32 ND
actualSeqLengthsKey 输入 每个Batch中Key的有效token数,不支持空tensor和非连续。可传入None表示与key的S长度相同;支持长度为B的一维tensor,且每个Batch的有效token数不超过key中的维度S大小且不小于0。layout为TND时该入参必须传入,每个元素表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须大于等于前一个元素的值。 INT32 ND
layout 输入 用于标识输入Query/Key的数据排布格式,默认值为"BSND",当前支持BSND、TND。 STRING -
headNum 输入 代表head个数。 INT64 -
sparseMode 输入 表示sparse的模式。sparse_mode为0时代表defaultMask模式;sparse_mode为3时代表rightDownCausal模式的mask,对应以右顶点为划分的下三角场景。 INT32 -
preTokens 输入 用于稀疏计算,表示attention需要和前几个Token计算关联,仅支持默认值2^63-1。 INT64 -
nextTokens 输入 用于稀疏计算,表示attention需要和后几个Token计算关联,仅支持默认值2^63-1。 INT64 -
dQuery 输出 公式中的dQ输出,表示query的梯度,不支持空tensor和非连续。数据类型与query一致,shape与query一致。 FLOAT16、BFLOAT16 ND
dKey 输出 公式中的dK输出,表示key的梯度,不支持空tensor和非连续。数据类型与key一致,shape与key一致。 FLOAT16、BFLOAT16 ND
dWeights 输出 公式中的dW输出,表示weights的梯度,不支持空tensor和非连续。数据类型与weights一致,shape与weights一致。 FLOAT16、BFLOAT16、FLOAT ND
deterministic 输入 表示当前是否支持确定性计算,默认值为False。 BOOL -

约束说明

  • inputLayout支持TND/BSND。

  • 关于数据shape的约束,以Layout的BSND举例。其中:

    • B(Batchsize):取值范围为1~1024。
    • N(Head-Num):取值为1~64。
    • G(Group):取值为N。
    • S1(Seq-LengthQ):取值范围为1~128K。
    • S2(Seq-LengthK):取值范围为topK~128K。
    • D(Head-Dim):取值为128。
    • TopK:取值为2048。

调用示例

调用方式 样例代码 说明
aclnn接口 test_aclnn_lightning_indexer_grad 通过 aclnnLightningIndexerGrad 接口方式调用算子