GroupedMatmulFinalizeRouting
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
功能说明
- 算子功能:GroupedMatmul和MoeFinalizeRouting的融合算子,GroupedMatmul计算后的输出按照索引做combine动作
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x1 | 输入 | 输入x(左矩阵)。 | INT8 | ND |
| x2 | 输入 | 输入weight(右矩阵) | INT4、INT8 | ND、NZ |
| scaleOptional | 输入 | 量化参数中的缩放因子,perchannel量化参数 | INT64、BF16、FLOAT32 | ND |
| biasOptional | 输入 | 矩阵的偏移 | BF16、FLOAT32 | ND |
| offsetOptional | 输入 | 非对称量化的偏移量 | FLOAT32 | ND |
| antiquantScaleOptional | 输入 | 伪量化的缩放因子 | FLOAT32 | ND |
| antiquantOffsetOptional | 输入 | 伪量化的偏移量 | FLOAT32 | ND |
| pertokenScaleOptional | 输入 | 矩阵计算的反量化参数 | FLOAT32 | ND |
| groupListOptional | 输入 | 输入和输出分组轴方向的matmul大小分布 | INT64 | ND |
| sharedInputOptional | 输入 | moe计算中共享专家的输出,需要与moe专家的输出进行combine操作 | BF16 | ND |
| logitOptional | 输入 | moe专家对各个token的logit大小 | FLOAT32 | ND |
| rowIndexOptional | 输入 | moe专家输出按照该rowIndex进行combine,其中的值即为combine做scatter add的索引 | INT32、INT64 | ND |
| dtype | 属性 | 计算的输出类型:0:FLOAT32;1:FLOAT16;2:BFLOAT16。目前仅支持0。 | INT64 | |
| sharedInputWeight | 属性 | 共享专家与moe专家进行combine的系数,sharedInput先与该参数乘,然后在和moe专家结果累加。 | FLOAT | |
| sharedInputOffset | 属性 | 共享专家输出的在总输出中的偏移。 | INT64 | |
| transposeX | 属性 | 左矩阵是否转置,仅支持false。 | BOOL | |
| transposeW | 属性 | 右矩阵是否转置,仅支持false。 | BOOL | |
| groupListType | 属性 | 分组模式:配置为0:cumsum模式,即为前缀和;配置为1:count模式。 | INT64 | |
| tuningConfigOptional | 属性 | 数组中的第一个元素表示各个专家处理的token数的预期值,算子tiling时会按照数组的第一个元素合理进行tiling切分,性能更优。从第二个元素开始预留,用户无须填写。未来会进行扩展。兼容历史版本,用户如不使用该参数,不传入(即为nullptr)即可。 | INT64 | |
| out | 输出 | 输出结果。 | FLOAT32 | ND |
| workspaceSize | 输出 | 返回需要在Device侧申请的workspace大小。 | - | - |
| executor | 输出 | 返回op执行器,包含了算子计算流程。 | - | - |
约束说明
输入和输出支持以下数据类型组合:
| x1 | x2 | scaleOptional | biasOptional | offsetOptional | antiquantScaleOptional | antiquantOffsetOptional | pertokenScaleOptional | groupListOptional | sharedInputOptional | logitOptional | rowIndexOptional | out |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| INT8 | INT4 | INT64 | FLOAT32 | FLOAT32 | null | null | FLOAT32 | INT64 | BFLOAT16 | FLOAT32 | INT64 | FLOAT32 |
| INT8 | INT4 | INT64 | FLOAT32 | null | null | null | FLOAT32 | INT64 | BFLOAT16 | FLOAT32 | INT64 | FLOAT32 |
| INT8 | INT8(NZ) | FLOAT32 | null | null | null | null | FLOAT32 | INT64 | BFLOAT16 | FLOAT32 | INT64 | FLOAT |
| INT8 | INT8(NZ) | FLOAT32 | null | null | null | null | FLOAT32 | INT64 | BFLOAT16 | FLOAT32 | INT64 | FLOAT |
| INT8 | INT4(NZ) | INT64 | FLOAT32 | FLOAT32 | null | null | FLOAT32 | INT64 | BFLOAT16 | FLOAT32 | INT64 | FLOAT |
| INT8 | INT4(NZ) | INT64 | FLOAT32 | null | null | null | FLOAT32 | INT64 | BFLOAT16 | FLOAT32 | INT64 | FLOAT |
调用说明
| 调用方式 | 调用样例 | 说明 |
|---|---|---|
| aclnn调用 | test_aclnn_grouped_matmul_finalize_routing | 通过aclnnGroupedMatmulFinalizeRoutingV3接口方式调用GroupedMatmulFinalizeRouting算子。 |
| aclnn调用 | test_aclnn_grouped_matmul_finalize_routing_weight_nz | 通过aclnnGroupedMatmulFinalizeRoutingWeightNzV2接口方式调用GroupedMatmulFinalizeRoutingWeightNz算子。 |