AttentionToFFN

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:将Attention节点上数据发往FFN节点。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x 输入 本卡发送的token数据,3D Tensor,shape为 (X, Bs, H)。 FLOAT16、BFLOAT16 ND
sessionId 输入 表示当前Attention Worker节点的Id,1D Tensor,shape为 (X, )。 INT32 ND
microBatchId 输入 表示当前microBatch的Id,1D Tensor,shape为 (X, )。 INT32 ND
layerId 输入 表示当前模型层数的Id,1D Tensor,shape为 (X, )。 INT32 ND
expertIds 输入 每个micro batch组中每个token的topK个专家索引,3D Tensor,shape为 (X, Bs, K),expertIds取值区间为[0, moeExpertNum)。 INT32 ND
expertRankTable 输入 每个micro batch组中专家Id到FFN卡专家部署的映射表(外部需保证值正确),3D Tensor,shape为 (L, moeExpertNum + sharedExpertNum, M)。 INT32 ND
scalesOptional 输入 可选参数,表示每个专家的量化平滑参数,3D Tensor,shape为 (L, moeExpertNum + sharedExpertNum, H)。 FLOAT32 ND
activeMaskOptional 输入 可选参数,表示token是否参与通信,可传有效数据或空指针,2D Tensor,shape为 (X, Bs)。 BOOL ND
group 属性 通信域名称(专家并行),字符串长度[1, 128)。 STRING -
worldSize 属性 通信域大小,取值区间[2, 768]。 INT64 -
ffnTokenInfoTableShape 属性 表示FFN节点上token信息表格shape大小的列表,长度为3。 INT32 -
ffnTokenDataShape 属性 表示FFN节点上token数据表格shape大小的列表,长度为5。 INT32 -
attnTokenInfoTableShape 属性 表示Attention节点上token信息表格shape大小的列表,长度为3。 INT32 -
moeExpertNum 属性 MoE专家数量,取值范围(0, 1024]。 INT64 -
quantMode 属性 表示量化模式,支持0(非量化)、2(动态量化)。 INT64 -
syncFlag 属性 表示FFN节点同步模式,支持0(同步)、1(异步)。 INT64 -
ffnStartrankId 属性 表示FFN节点的起始Id。 INT64 -

约束说明

  • 调用算子过程中使用的groupworldSizetokenInfoTableShapetokenDataShape参数及HCCL_BUFFSIZE取值所有卡需保持一致,网络中不同层中也需保持一致。

  • 参数说明里shape格式说明:

    • X:表示micro batch sequence size(token组数),当前版本只支持X = 1。
    • BS:表示batch sequence size(本卡最终输出的token数量)取值范围为0 < BS ≤ 512。
    • K:表示选取topK个专家,,取值范围为0 < K ≤ 16且满足0 < K ≤ moeExpertNum。
    • H:表示hidden size(隐藏层大小),取值范围为1024 ≤ H ≤ 8192。
    • L:表示模型层数,当前版本只支持L = 1。
    • M:表示expertRankTable最后一维的长度,具体体现为部署在FFN节点上数量最多的专家部署信息列表的长度。
    • moeExpertNum:表示MoE专家数量,取值范围为0 < moeExpertNum ≤ 1024。
  • 通信域使用约束:

    • AttentionToFFN算子的通信域中不允许有其他算子。

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_attention_to_ffn.cpp 通过 aclnnAttentionToFFN 接口方式调用AttentionToFFN算子。