AllGatherMatmulV2

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 x
Atlas 推理系列产品 x
Atlas 训练系列产品 x

功能说明

  • 算子功能: 完成AllGather通信与MatMul计算融合。在支持x1和x2输入类型为FLOAT16/BFLOAT16的基础上,同时也支持低精度数据类型,此时算子在Matmul计算后会做对应的反量化计算,支持的低精度数据类型与量化方式如下:

    • Ascend 950PR/Ascend 950DT:

      新增了对低精度数据类型FLOAT8_E4M3FN/FLOAT8_E5M2/HIFLOAT8的支持。支持pertensor、perblock、mx量化方式

    • Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:

      新增了对低精度数据类型INT8/INT4的支持。支持pertoken/perchannel量化方式

  • 计算公式

    • 情形1:如果x1和x2数据类型为FLOAT16/BFLOAT16时,入参x1进行AllGather后,对x1、x2进行MatMul计算。

    output=AllGather(x1)@x2+biasoutput=AllGather(x1)@x2 + bias

    gatherOut=AllGather(x1)gatherOut=AllGather(x1)

    • 情形2:如果x1和x2数据类型为FLOAT8_E4M3FN/FLOAT8_E5M2/HIFLOAT8的pertensor场景,或者x1和x2数据类型为INT8/INT4的perchannel、pertoken场景,且不输出amaxOut,入参x1进行AllGather后,对x1、x2进行MatMul计算,然后进行dequant操作。

    output=(x1Scale∗x2Scale)∗(AllGather(x1)@x2+bias)output=(x1Scale*x2Scale)*(AllGather(x1)@x2 + bias)

    gatherOut=AllGather(x1)gatherOut=AllGather(x1)

    • 情形3:如果x1和x2数据类型为FLOAT8_E4M3FN/FLOAT8_E5M2/HIFLOAT8的perblock场景,且不输出amaxOut, 当x1为(m, k)、x2为(k, n)时, x1Scale为(ceildiv(m, 128), ceildiv(k, 128))、x2Scale为(ceildiv(k, 128), ceildiv(n, 128))时,入参x1和x1Scale进行AllGather后,对x1、x2进行perblock量化MatMul计算,然后进行dequant操作。

      output=∑0⌊kblockSize=128⌋(AllGather(x1)pr@x2rq∗(AllGather(x1Scale)pr∗x2Scalerq))output=\sum_{0}^{\left \lfloor \frac{k}{blockSize=128} \right \rfloor} (AllGather(x1)_{pr}@x2_{rq}*(AllGather(x1Scale)_{pr}*x2Scale_{rq}))

      gatherOut=AllGather(x1)gatherOut=AllGather(x1)

    • 情形4:如果x1和x2数据类型为FLOAT8_E4M3FN/FLOAT8_E5M2的mx量化场景,x1为(m, k)、x2 为(n, k),且x1Scale为(m, ceilDiv(k, 64), 2)、x2Scale为(n, ceilDiv(k, 64), 2),入参x1和x1Scale进行AllGather后,对x1、x2进行MatMul计算,然后进行dequant操作;

      output=∑0⌊kblockSize=32⌋(AllGather(x1)pr@x2rq∗(AllGather(x1Scale)pr∗x2Scalerq))output=\sum_{0}^{\left \lfloor \frac{k}{blockSize=32} \right \rfloor} (AllGather(x1)_{pr}@x2_{rq}*(AllGather(x1Scale)_{pr}*x2Scale_{rq}))

      gatherOut=AllGather(x1)gatherOut=AllGather(x1)

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x1 输入 MM左矩阵,即计算公式中的x1。 FLOAT16、BFLOAT16、FLOAT8_E4M3FN、FLOAT8_E5M2、HIFLOAT8、INT8、INT4 ND
x2 输入 MM右矩阵,即计算公式中的x2。 FLOAT16、BFLOAT16、FLOAT8_E4M3FN、FLOAT8_E5M2、HIFLOAT8、INT8、INT4 ND
bias 可选输入 即计算公式中的bias。 FLOAT16、BFLOAT16、FLOAT ND
x1Scale 可选输入 mm左矩阵反量化参数。 FLOAT16、BFLOAT16、FLOAT ND
x2Scale 可选输入 mm右矩阵反量化参数。 FLOAT16、BFLOAT16、FLOAT ND
quantScale 可选输入 即计算公式中的bias。 FLOAT ND
output 输出 AllGather通信与MatMul计算的结果,即计算公式中的output。 FLOAT16、BFLOAT16、FLOAT ND
gatherOut 输出 仅输出all_gather通信后的结果。即公式中的gatherOut。 FLOAT16、BFLOAT16、FLOAT8_E4M3FN、FLOAT8_E5M2、HIFLOAT8、INT8、INT4 ND
amaxOut 可选输出 MM计算的最大值结果,即公式中的amaxOut。 FLOAT ND
blockSize 属性 用于表示mm输出矩阵在M轴方向上和N轴方向上可以用于对应方向上的多少个数的量化。 INT64 -
group 属性 通信域名称。 STRING -
gatherIndex 属性 标识gather目标。 INT64 -
commTurn 属性 通信数据切分数,即总数据量/单次通信量。 INT64 -
streamMode 属性 流模式的枚举。 INT64 -
groupSize 可选属性 用于表示反量化中x1Scale/x2Scale输入的一个数在其所在的对应维度方向上可以用于该方向x1/x2输入的多少个数的反量化。 INT64 -
commMode 属性 通信模式。 STRING -
## 约束说明
  • 确定性计算:

    • 该算子默认确定性实现。
  • Ascend 950PR/Ascend 950DT:

    • 输入x1为2维,其维度为(m, k)。x2必须是2维,其维度为(k, n),轴满足mm算子入参要求,k轴相等,且k轴取值范围为[256, 65535)。
    • bias为1维,shape为(n,)。
    • 输出output为2维,其维度为(m*rank_size, n),rank_size为卡数。
    • 输出gatherout为2维,其维度为(m*rank_size, k),rank_size为卡数。
    • 当x1、x2的数据类型为FLOAT16/BFLOAT16时,output计算输出数据类型和x1、x2保持一致。
    • 当x1、x2的数据类型为FLOAT8_E4M3FN/FLOAT8_E5M2/HIFLOAT8时,output输出数据类型支持FLOAT16、BFLOAT16、FLOAT。
    • 当x1、x2的数据类型为FLOAT16/BFLOAT16/HIFLOAT8时,x1和x2数据类型需要保持一致。
    • 当x1、x2数据类型为FLOAT8_E4M3FN/FLOAT8_E5M2时,x1和x2数据类型可以为其中一种。
    • 当x1、x2数据类型为FLOAT16/BFLOAT16/HIFLOAT8/FLOAT8_E4M3FN/FLOAT8_E5M2时,x2矩阵支持转置/不转置场景,x1矩阵只支持不转置场景。
    • 当groupSize取值为549764202624,bias必须为空。
    • 支持2、4、8、16、32、64卡。
    • allgather(x1)集合通信数据总量不能超过16*256MB,集合通信数据总量计算方式为:m * k * sizeof(x1_dtype) * 卡数。由于shape不同,算子内部实现可能存在差异,实际支持的总通信量可能略小于该值。
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:

    • 只支持x2矩阵转置/不转置,x1矩阵仅支持不转置场景。
    • 输入x1必须是2维,其shape为(m, k)。
    • 输入x2必须是2维,其shape为(k, n),轴满足mm算子入参要求,k轴相等,且k轴取值范围为[256, 65535)。
    • bias仅支持输入nullptr。
    • 输出为2维,其shape为(m*rank_size, n), rank_size为卡数。
    • 不支持空tensor。
    • x1和x2的数据类型需要保持一致。
    • x1和x2数据类型为INT4时,k与n必须为偶数。
    • 支持2、4、8卡。

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_all_gather_matmul_v2 通过aclnnAllGatherMatmulV2接口方式调用AllGatherMatmulV2算子。