AllGatherMatmulV2
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | x |
| Atlas 推理系列产品 | x |
| Atlas 训练系列产品 | x |
功能说明
-
算子功能: 完成AllGather通信与MatMul计算融合。在支持x1和x2输入类型为FLOAT16/BFLOAT16的基础上,同时也支持低精度数据类型,此时算子在Matmul计算后会做对应的反量化计算,支持的低精度数据类型与量化方式如下:
-
计算公式:
- 情形1:如果x1和x2数据类型为FLOAT16/BFLOAT16时,入参x1进行AllGather后,对x1、x2进行MatMul计算。
output=AllGather(x1)@x2+biasoutput=AllGather(x1)@x2 + bias
gatherOut=AllGather(x1)gatherOut=AllGather(x1)
- 情形2:如果x1和x2数据类型为FLOAT8_E4M3FN/FLOAT8_E5M2/HIFLOAT8的pertensor场景,或者x1和x2数据类型为INT8/INT4的perchannel、pertoken场景,且不输出amaxOut,入参x1进行AllGather后,对x1、x2进行MatMul计算,然后进行dequant操作。
output=(x1Scale∗x2Scale)∗(AllGather(x1)@x2+bias)output=(x1Scale*x2Scale)*(AllGather(x1)@x2 + bias)
gatherOut=AllGather(x1)gatherOut=AllGather(x1)
-
情形3:如果x1和x2数据类型为FLOAT8_E4M3FN/FLOAT8_E5M2/HIFLOAT8的perblock场景,且不输出amaxOut, 当x1为(m, k)、x2为(k, n)时, x1Scale为(ceildiv(m, 128), ceildiv(k, 128))、x2Scale为(ceildiv(k, 128), ceildiv(n, 128))时,入参x1和x1Scale进行AllGather后,对x1、x2进行perblock量化MatMul计算,然后进行dequant操作。
output=∑0⌊kblockSize=128⌋(AllGather(x1)pr@x2rq∗(AllGather(x1Scale)pr∗x2Scalerq))output=\sum_{0}^{\left \lfloor \frac{k}{blockSize=128} \right \rfloor} (AllGather(x1)_{pr}@x2_{rq}*(AllGather(x1Scale)_{pr}*x2Scale_{rq}))
gatherOut=AllGather(x1)gatherOut=AllGather(x1)
-
情形4:如果x1和x2数据类型为FLOAT8_E4M3FN/FLOAT8_E5M2的mx量化场景,x1为(m, k)、x2 为(n, k),且x1Scale为(m, ceilDiv(k, 64), 2)、x2Scale为(n, ceilDiv(k, 64), 2),入参x1和x1Scale进行AllGather后,对x1、x2进行MatMul计算,然后进行dequant操作;
output=∑0⌊kblockSize=32⌋(AllGather(x1)pr@x2rq∗(AllGather(x1Scale)pr∗x2Scalerq))output=\sum_{0}^{\left \lfloor \frac{k}{blockSize=32} \right \rfloor} (AllGather(x1)_{pr}@x2_{rq}*(AllGather(x1Scale)_{pr}*x2Scale_{rq}))
gatherOut=AllGather(x1)gatherOut=AllGather(x1)
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x1 | 输入 | MM左矩阵,即计算公式中的x1。 | FLOAT16、BFLOAT16、FLOAT8_E4M3FN、FLOAT8_E5M2、HIFLOAT8、INT8、INT4 | ND |
| x2 | 输入 | MM右矩阵,即计算公式中的x2。 | FLOAT16、BFLOAT16、FLOAT8_E4M3FN、FLOAT8_E5M2、HIFLOAT8、INT8、INT4 | ND |
| bias | 可选输入 | 即计算公式中的bias。 | FLOAT16、BFLOAT16、FLOAT | ND |
| x1Scale | 可选输入 | mm左矩阵反量化参数。 | FLOAT16、BFLOAT16、FLOAT | ND |
| x2Scale | 可选输入 | mm右矩阵反量化参数。 | FLOAT16、BFLOAT16、FLOAT | ND |
| quantScale | 可选输入 | 即计算公式中的bias。 | FLOAT | ND |
| output | 输出 | AllGather通信与MatMul计算的结果,即计算公式中的output。 | FLOAT16、BFLOAT16、FLOAT | ND |
| gatherOut | 输出 | 仅输出all_gather通信后的结果。即公式中的gatherOut。 | FLOAT16、BFLOAT16、FLOAT8_E4M3FN、FLOAT8_E5M2、HIFLOAT8、INT8、INT4 | ND |
| amaxOut | 可选输出 | MM计算的最大值结果,即公式中的amaxOut。 | FLOAT | ND |
| blockSize | 属性 | 用于表示mm输出矩阵在M轴方向上和N轴方向上可以用于对应方向上的多少个数的量化。 | INT64 | - |
| group | 属性 | 通信域名称。 | STRING | - |
| gatherIndex | 属性 | 标识gather目标。 | INT64 | - |
| commTurn | 属性 | 通信数据切分数,即总数据量/单次通信量。 | INT64 | - |
| streamMode | 属性 | 流模式的枚举。 | INT64 | - |
| groupSize | 可选属性 | 用于表示反量化中x1Scale/x2Scale输入的一个数在其所在的对应维度方向上可以用于该方向x1/x2输入的多少个数的反量化。 | INT64 | - |
| commMode | 属性 | 通信模式。 | STRING | - |
-
确定性计算:
- 该算子默认确定性实现。
-
Ascend 950PR/Ascend 950DT:
- 输入x1为2维,其维度为(m, k)。x2必须是2维,其维度为(k, n),轴满足mm算子入参要求,k轴相等,且k轴取值范围为[256, 65535)。
- bias为1维,shape为(n,)。
- 输出output为2维,其维度为(m*rank_size, n),rank_size为卡数。
- 输出gatherout为2维,其维度为(m*rank_size, k),rank_size为卡数。
- 当x1、x2的数据类型为FLOAT16/BFLOAT16时,output计算输出数据类型和x1、x2保持一致。
- 当x1、x2的数据类型为FLOAT8_E4M3FN/FLOAT8_E5M2/HIFLOAT8时,output输出数据类型支持FLOAT16、BFLOAT16、FLOAT。
- 当x1、x2的数据类型为FLOAT16/BFLOAT16/HIFLOAT8时,x1和x2数据类型需要保持一致。
- 当x1、x2数据类型为FLOAT8_E4M3FN/FLOAT8_E5M2时,x1和x2数据类型可以为其中一种。
- 当x1、x2数据类型为FLOAT16/BFLOAT16/HIFLOAT8/FLOAT8_E4M3FN/FLOAT8_E5M2时,x2矩阵支持转置/不转置场景,x1矩阵只支持不转置场景。
- 当groupSize取值为549764202624,bias必须为空。
- 支持2、4、8、16、32、64卡。
- allgather(x1)集合通信数据总量不能超过16*256MB,集合通信数据总量计算方式为:m * k * sizeof(x1_dtype) * 卡数。由于shape不同,算子内部实现可能存在差异,实际支持的总通信量可能略小于该值。
-
Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:
- 只支持x2矩阵转置/不转置,x1矩阵仅支持不转置场景。
- 输入x1必须是2维,其shape为(m, k)。
- 输入x2必须是2维,其shape为(k, n),轴满足mm算子入参要求,k轴相等,且k轴取值范围为[256, 65535)。
- bias仅支持输入nullptr。
- 输出为2维,其shape为(m*rank_size, n), rank_size为卡数。
- 不支持空tensor。
- x1和x2的数据类型需要保持一致。
- x1和x2数据类型为INT4时,k与n必须为偶数。
- 支持2、4、8卡。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_all_gather_matmul_v2 | 通过aclnnAllGatherMatmulV2接口方式调用AllGatherMatmulV2算子。 |