BlockSparseAttentionGrad

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品
Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品
Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • ​算子功能​:aclnnBlockSparseAttention稀疏注意力反向计算,支持灵活的块级稀疏模式,通过BlockSparseMask指定每个Q块选择的KV块,实现高效的稀疏注意力计算。

  • ​计算公式​:

    稀疏块大小:blockShapeX×blockShapeYblockShapeX×blockShapeY,BlockSparseMask指定稀疏模式。

    已知正向计算公式为:

    attentionOut=Softmax(Mask(scale⋅query⋅keysparseT,atten_mask))⋅valuesparseattentionOut=Softmax(Mask(scale⋅query⋅key_{sparse}^{T}, atten\_mask))⋅value_{sparse}

    为方便表达,以变量SSPP表示计算公式:

    S=Mask(scale⋅query⋅keysparseT,atten_mask)S = Mask(scale⋅query⋅key_{sparse}^{T},atten\_mask)

    P=SoftMax(S)P = SoftMax(S)

    V=valuesparseV = value_{sparse}

    Out=PVOut = PV

    则反向计算公式为:

    softmax_grad=softmaxGrad(dOut,attentionOut)softmax\_grad = softmaxGrad(dOut, attentionOut)

    dP=dOut∗VTdP=dOut * V^T

    dS=P∗(dP−softmax_grad)dS = P * (dP-softmax\_grad)

    dV=PT∗dOutdV=P^T * dOut

    dQ=(dS∗K)∗scaledQ=(dS*K)*scale

    dK=(dST∗Q)∗scaledK=(dS^T*Q)*scale

BlockSparseAttentionGrad输入dout、query、key、value, attentionOut的数据排布格式支持从多种维度排布解读,可通过qInputLayout和kvInputLayout传入。为了方便理解后续支持的具体排布格式(如 BNSD、TND 等),此处先对排布格式中各缩写字母所代表的维度含义进行统一说明:

  • B:表示输入样本批量大小(Batch)
  • T:B和S合轴紧密排列的长度(Total tokens)
  • S:表示输入样本序列长度(Seq-Length)
  • H:表示隐藏层的大小(Head-Size)
  • N:表示多头数(Head-Num)
  • D:表示隐藏层最小的单元尺寸,需满足D=H/N(Head-Dim)

当前支持的布局:

  • qInputLayout: "TND" "BNSD"
  • kvInputLayout: "TND" "BNSD"

参数说明

参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor
dout(aclTensor*) 输入 反向输出梯度,代表最终输出对当前算子的梯度信息。 不支持空Tensor。
支持的shape为:
  • TND: [totalQTokens, headNum, headDim]。
  • BNSD: [batch, headNum, maxQSeqLength, headDim]。
FLOAT16、BFLOAT16 ND 3-4 ×
query(aclTensor*) 输入 注意力计算中的查询向量,即公式中的query。 不支持空Tensor。
支持的shape为:
  • TND: [totalQTokens, headNum, headDim]。
  • BNSD: [batch, headNum, maxQSeqLength, headDim]。
FLOAT16、BFLOAT16 ND 3-4 ×
key(aclTensor*) 输入 注意力计算中的键向量,即公式中的key。 不支持空Tensor。
支持的shape为:
  • TND: [totalKTokens, numKeyValueHeads, headDim]。
  • BNSD: [batch, numKeyValueHeads, maxKvSeqLength, headDim]。
FLOAT16、BFLOAT16 ND 3-4 ×
value(aclTensor*) 输入 注意力计算中的值向量,即公式中的value。 不支持空Tensor。
支持的shape为:
  • TND: [totalVTokens, numKeyValueHeads, headDim]。
  • BNSD: [batch, numKeyValueHeads, maxKvSeqLength, headDim]。
FLOAT16、BFLOAT16 ND 3-4 ×
attentionOut(aclTensor*) 输入 正向 BlockSparseAttention 计算的输出结果,即公式中的attentionOut。 不支持空Tensor。
支持的shape为:
  • TND: [totalQTokens, headNum, headDim]。
  • BNSD: [batch, headNum, maxQSeqLength, headDim]。
FLOAT16、BFLOAT16 ND 3-4 ×
softmaxLse(aclTensor*) 输入 Softmax计算的log-sum-exp中间结果。用于反向计算梯度的对数和指数逆推。 不支持空Tensor。
支持的shape为:
  • TND: [totalQTokens, headNum, 1]。
  • BNSD: [batch, headNum, maxQSeqLength, 1]。
FLOAT ND 3-4 ×
blockSparseMaskOptional(aclTensor*) 输入 块状稀疏掩码,表示实际的稀疏pattern。决定哪些block实际参与注意力计算。 不支持空Tensor。
可选输入(当前版本为必选):
  • shape为[batch, headNum, ceilDiv(maxQSeqLength, blockShapeX), ceilDiv(maxKvSeqLength, blockShapeY)]。
  • 表示按block划分后哪些block需要参与计算(为1),哪些block不参与计算(为0)。
  • 如传入nullptr,则视为不开启块稀疏计算,即所有token之间的注意力分数都会被计算。
BOOL ND 4 ×
attenMaskOptional(aclTensor*) 输入 注意力掩码,即公式中的atten_mask。用于屏蔽不应参与计算的特定token。 支持空Tensor。
当前不支持,应传入nullptr。
BOOL ND 2 ×
blockShapeOptional(aclIntArray*) 输入 稀疏块形状数组。指定每个稀疏块的二维尺寸(行数和列数)。
  • 当配置了blockSparseMaskOptional时:如配置此输入,算子会从中获取稀疏块尺寸;如不配置此输入,算子将默认稀疏块尺寸为[128,128]。
INT64 - 1 -
  • 当未配置blockSparseMaskOptional时:无论此项如何配置,算子均将忽略。
当配置此输入时的元素要求:
  • 必须包含至少两个元素 [blockShapeX, blockShapeY]。
  • blockShapeX: Q方向块大小,值必须大于0。
  • blockShapeY: KV方向块大小,值必须大于0。
actualSeqLengthsOptional(aclIntArray*) 输入 query的实际序列长度数组。
用于描述变长序列场景下(即含有 Padding 填充数据的场景),每个 Batch 中实际有效的 query token 数量。
变长序列场景(当 qInputLayout 为 "TND" 时):该项输入必须配置。因为 TND 格式为一维连续排布,算子需要依赖该数组来准确切分界定各个序列的真实边界。 INT64 - 1 -
定长/变长场景(当 qInputLayout 为 "BNSD" 时):
  • 如配置该项,算子会按指定的有效长度处理,忽略 Padding 部分的数据,提升性能;
  • 如不配置(传 nullptr),算子将默认把 query shape 中的 S 维度作为有效长度进行全量处理。
actualSeqLengthsKvOptional(aclIntArray*) 输入 key/value的实际序列长度数组。
用于描述变长序列场景下(即含有 Padding 填充数据的场景),每个 Batch 中实际有效的 key/value token 数量。
变长序列场景(当 kvInputLayout 为 "TND" 时):该项输入必须配置。因为 TND 格式为一维连续排布,算子需要依赖该数组来准确切分界定各个序列的真实边界。 INT64 - 1 -
定长/变长场景(当 kvInputLayout 为 "BNSD" 时):
  • 如配置该项,算子会按指定的有效长度处理,忽略 Padding 部分的数据,提升性能;
  • 如不配置(传 nullptr),算子将默认把 key/value shape 中的 S 维度作为有效长度进行全量处理。
qInputLayout(char*) 输入 query的数据排布格式。指示输入张量在内存中的具体排布(如连续或合轴排列)。 当前仅支持"TND"、"BNSD",qInputLayout与kvInputLayout需要保持一致。 - - - -
kvInputLayout(char*) 输入 key和value的数据排布格式。指示输入张量在内存中的具体排布。 当前仅支持"TND"、"BNSD",qInputLayout与kvInputLayout需要保持一致。 - - - -
numKeyValueHeads(int64_t) 输入 key/value的注意力头数。用于支持GQA(分组查询注意力)机制下的头数比例映射。 - - - - -
maskType(int64_t) 输入 注意力计算中的掩码类型。指定采用何种预设规则的掩码逻辑。 当前只支持传 0:代表不加mask场景。 - - - -
scaleValue(double) 输入 缩放系数,即公式中的scale。用于注意力分数的归一化处理。 一般设置为D^-0.5。 - - - -
preTokens(int64_t) 输入 滑窗向前包含的token数量。限制当前token只能与前方的多少个历史token计算注意力。 用于滑窗attention场景,当前不支持滑窗attention,只支持传入2147483647。 - - - -
nextTokens(int64_t) 输入 滑窗向后包含的token数量。限制当前token只能与后方的多少个未来token计算注意力。 用于滑窗attention场景,当前不支持滑窗attention,只支持传入2147483647。 - - - -
dq(aclTensor*) 输出 query的梯度输出结果,即公式中的dq。 不支持空Tensor。
数据类型和shape与输入query保持一致。
FLOAT16、BFLOAT16 ND 3-4
dk(aclTensor*) 输出 key的梯度输出结果,即公式中的dk。 不支持空Tensor。
数据类型和shape与输入key保持一致。
FLOAT16、BFLOAT16 ND 3-4
dv(aclTensor*) 输出 value的梯度输出结果,即公式中的dv。 不支持空Tensor。
数据类型和shape与输入value保持一致。
FLOAT16、BFLOAT16 ND 3-4
workspaceSize(uint64_t*) 输出 返回需要在Device侧申请的workspace大小。 - - - - -
executor(aclOpExecutor**) 输出 返回op执行器,包含了算子计算流程。 - - - - -
    - Atlas A2 训练产品、Atlas A3 训练产品: 不支持FLOAT8_E5M2、FLOAT8_E4M3FN。

约束说明

  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
  • actualSeqLengthsOptional在qInputLayout为“TND”时必选;actualSeqLengthsKvOptional在kvInputLayout为“TND”时必选。
  • 根据算子支持的输入 Layout,query 张量 Shape 中对应的 head 维度大小记为 N1,key 和 value 张量 Shape 中对应的 head 维度大小记为 N2。必须满足 N1 >= N2 且 N1 % N2 == 0。(例如:在 BNSD 布局下,N1 对应 query 的第 2 维,N2 对应 key/value 的第 2 维)
  • headdim=128。
  • 当前只支持 BNSD 和 MHA(N1==N2)。

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_block_sparse_attention_grad 非TND(TND代表Total sequence length, Num heads, Head dimension,通常用于表示变长序列场景下的连续内存排布格式)场景,通过aclnnBlockSparseAttentionGrad接口方式调用BlockSparseAttentionGrad算子。