AlltoAllMatmul

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:完成AlltoAll通信、Permute(保证通信后地址连续)和Matmul计算的融合,先通信后计算,支持非量化、K-C量化、K-C动态量化和mx量化模式

  • 计算公式:假设x1输入shape为(BS, H),mx量化场景下x1Scale输入shape为(BS, ceil(H/64), 2),rankSize为NPU卡数

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:

      • 非量化场景:

        commOut=AlltoAll(x1.view(rankSize,BS/rankSize,H))permutedOut=commOut.permute(1,0,2).view(BS/rankSize,rankSize∗H)output=permutedOut@x2+biascommOut = AlltoAll(x1.view(rankSize, BS/rankSize, H)) \\ permutedOut = commOut.permute(1, 0, 2).view(BS/rankSize, rankSize*H) \\ output = permutedOut @ x2 + bias \\

      • K-C量化场景:

        commOut=AlltoAll(x1.view(rankSize,BS/rankSize,H))permutedOut=commOut.permute(1,0,2).view(BS/rankSize,rankSize∗H)outputquant=x1@x2output=outputquant×x1scale×x2scaleoutput=output+biascommOut = AlltoAll(x1.view(rankSize, BS/rankSize, H)) \\ permutedOut = commOut.permute(1, 0, 2).view(BS/rankSize, rankSize*H) \\ output_{quant} = x1 @ x2 \\ output = output_{quant} \times x1_{scale} \times x2_{scale} \\ output = output + bias

      • K-C动态量化场景:

        commOut=AlltoAll(x1.view(rankSize,BS/rankSize,H))permutedOut=commOut.permute(1,0,2).view(BS/rankSize,rankSize∗H)x1quant,x1scale=Quant(permutedOut)outputquant=x1quant@x2output=outputquant×x1scale×x2scaleoutput=output+biascommOut = AlltoAll(x1.view(rankSize, BS/rankSize, H)) \\ permutedOut = commOut.permute(1, 0, 2).view(BS/rankSize, rankSize*H) \\ x1_{quant}, x1_{scale} = Quant(permutedOut) \\ output_{quant} = x1_{quant} @ x2 \\ output = output_{quant} \times x1_{scale} \times x2_{scale} \\ output = output + bias

    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:

      • 非量化场景:

        commOut=AlltoAll(x1.view(rankSize,BS/rankSize,H))permutedOut=commOut.permute(1,0,2).view(BS/rankSize,rankSize∗H)output=permutedOut@x2+biascommOut = AlltoAll(x1.view(rankSize, BS/rankSize, H)) \\ permutedOut = commOut.permute(1, 0, 2).view(BS/rankSize, rankSize*H) \\ output = permutedOut @ x2 + bias \\

    • Ascend 950PR/Ascend 950DT:

      • 非量化场景:

        commOut=AlltoAll(x1.view(rankSize,BS/rankSize,H))permutedOut=commOut.permute(1,0,2).view(BS/rankSize,rankSize∗H)output=permutedOut@x2+biascommOut = AlltoAll(x1.view(rankSize, BS/rankSize, H)) \\ permutedOut = commOut.permute(1, 0, 2).view(BS/rankSize, rankSize*H) \\ output = permutedOut @ x2 + bias \\

      • K-C动态量化场景:

        commOut=AlltoAll(x1.view(rankSize,BS/rankSize,H))permutedOut=commOut.permute(1,0,2).view(BS/rankSize,rankSize∗H)dynQuantX1,dynQuantX1Scale=dynamicQuant(permutedOut)output=(dynQuantX1@x2+bias)×dynQuantX1Scale×x2ScalecommOut = AlltoAll(x1.view(rankSize, BS/rankSize, H)) \\ permutedOut = commOut.permute(1, 0, 2).view(BS/rankSize, rankSize*H) \\ dynQuantX1, dynQuantX1Scale = dynamicQuant(permutedOut) \\ output = (dynQuantX1@x2 + bias) \times dynQuantX1Scale \times x2Scale

      • mx量化场景:

        commOut=AlltoAll(x1.view(rankSize,BS/rankSize,H))permutedOut=commOut.permute(1,0,2).view(BS/rankSize,rankSize∗H)commScale=AlltoAll(x1Scale.view(rankSize,BS/rankSize,ceil(H/64),2))permutedScale=commScale.permute(1,0,2,3).view(BS/rankSize,ceil(H/64)∗rankSize,2)output=∑0⌊kblockSize=32⌋(permutedOut@x2∗(permutedScale∗x2Scale))+biascommOut = AlltoAll(x1.view(rankSize, BS/rankSize, H)) \\ permutedOut = commOut.permute(1, 0, 2).view(BS/rankSize, rankSize*H) \\ commScale = AlltoAll(x1Scale.view(rankSize, BS/rankSize, ceil(H/64), 2)) \\ permutedScale = commScale.permute(1, 0, 2, 3).view(BS/rankSize, ceil(H/64)*rankSize, 2) \\ output = \sum_{0}^{\left \lfloor \frac{k}{blockSize=32} \right \rfloor} (permutedOut @ x2 * (permutedScale * x2Scale)) + bias

参数说明​

参数名 输入/输出/属性 描述 数据类型 数据格式
x1 输入 融合算子的左矩阵,即公式中的输入x1。 FLOAT16、BFLOAT16、INT4、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1 ND
x2 输入 融合算子的右矩阵,也是Matmul的右矩阵,即公式中的输入x2。 FLOAT16、BFLOAT16、INT8、INT4、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1 ND
bias 可选输入 可选输入,阵乘运算后累加的偏置,对应公式中的bias。 FLOAT16、BFLOAT16、FLOAT32 ND
x1_scale 可选输入 左矩阵的量化系数,对应公式中的x1Scale。 FLOAT32、FLOAT8_E8M0 ND
x2_scale 可选输入 右矩阵的量化系数,对应公式中的x2Scale。 FLOAT32、FLOAT8_E8M0 ND
comm_scale 可选输入 预留参数,低比特通信的量化系数。 - -
x1_offset 可选输入 预留参数,左矩阵的量化偏置。 - -
x2_offset 可选输入 预留参数,右矩阵的量化偏置。 - -
y 输出 计算+通信的结果,即公式中的输出output。 FLOAT16、BFLOAT16、FLOAT32 ND
alltoall_out 输出 接收AlltoAll和Permute后的结果,即公式中的输出permutedOut。 与输入x1保持一致 ND
group 必选属性 Host侧标识列组的字符串,即通信域名称,通过Hccl接口HcclGetCommName获取commName作为该参数,字符串长度要求(0, 128)。 STRING -
world_size 必选属性 使用的npu卡数,公式中的rankSize。 INT -
all2all_axes 可选属性 AlltoAll和Permute数据交换的方向,支持配置空或者[-2, -1],传入空时默认按[-2, -1]处理,表示将输入由(BS, H)转为(BS/rankSize, H*rankSize)。 aclIntArray*(元素类型INT64) ND
x1_quant_mode 可选属性 左矩阵的量化方式,按照实际场景配置。 INT -
x2_quant_mode 可选属性 右矩阵的量化方式,按照实际场景配置。 INT -
comm_quant_mode 可选属性 低比特通信的量化方式,预留参数,当前仅支持配置为0,表示不量化。 INT -
comm_quant_dtype 可选属性 低比特通信的量化类型,预留参数,当前仅支持配置为-1,表示ACL_DT_UNDEFINED。 INT -
x1_quant_dtype 可选属性 量化Matmul左矩阵的量化类型,AlltoAll通信与Permute后的结果,按照该参数配置量化后作为Matmul计算的左矩阵输入,按照实际场景配置。 INT -
transpose_x1 可选属性 标识左矩阵是否转置过,暂不支持配置为True。 bool -
transpose_x2 可选属性 标识右矩阵是否转置过,配置为True时右矩阵Shape为(N,H*rankSize)。 bool -
group_size 可选属性 用于Matmul计算三个方向上的量化分组大小,仅在scale输入都是2维及以上数据时取值有效,其他场景默认传入0即可。 INT -
alltoall_out_flag 可选属性 用于标识是否需要保留AlltoAll和Permute后的结果。 bool -

x1QuantMode、x2QuantMode、commQuantMode的枚举值与量化模式关系如下:

  • 0: 不量化
  • 1: pertensor
  • 2: perchannel
  • 3: pertoken
  • 4: pergroup
  • 5: perblock
  • 6: mx量化
  • 7: pertoken动态量化

约束说明

  • 默认支持确定性计算。
  • NPU卡数(rankSize),根据设备型号有不同限制:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:支持2、4、8卡。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持2、4、8、16卡。
    • Ascend 950PR/Ascend 950DT:支持2、4、8、16卡。
  • 空tensor和非连续tensor的支持度根据不同设备型号有不同的限制:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持任何空tensor;不支持任何非连续tensor。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品、Ascend 950PR/Ascend 950DT:仅支持非量化场景下输入x1的第一维度(BS)为0的空tensor,其它空tensor均不支持;仅支持输入x2的转置非连续tensor,其它非连续tensor均不支持。
  • 输入x1必须是2维,其shape为(BS, H),BS必须整除NPU卡数,BS和N的值不得超过2147483647(INT32_MAX),不支持转置。
  • 输入x2必须是2维,其shape为(H*rankSize, N),H*rankSize范围根据芯片型号和场景不同有不同约束,详见量化aclnn约束说明非量化aclnn约束说明。当处于mx量化场景时,x2必须转置,其shape为(N, H*rankSize),transpose_x2配置为True。
  • bias若非空,其维度必须为1维,shape为(N)。
  • x1_scale若非空,在mx量化场景时,其维度为3维,shape为(BS, ceil(H/64), 2);在K-C量化场景时,其维度为1维,shape为(BS);在K-C动态量化场景时,其维度为1维,shape为(H*rankSize)。
  • x2_scale若非空,在mx量化场景时,其维度为3维,shape为(N, ceil(H*rankSize/64), 2);其它场景中其维度为1维,shape为(N)。
  • all2all_axes为1维数组,shape必须为(2)。
  • 目前支持的量化模式,根据设备型号有不同限制:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:支持K-C量化和K-C动态量化模式,x1QuantMode=3或7,x2QuantMode=2。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:目前不支持量化场景。
    • Ascend 950PR/Ascend 950DT:支持K-C动态量化模式,x1QuantMode=7,x2QuantMode=2;mx量化模式,x1QuantMode=6,x2QuantMode=6。
  • 非量化场景x1、x2计算输入的数据类型要和output、alltoAllOutOptional计算输出的数据类型一致,传入的x1、x2与output均不为空指针。
  • 量化场景x1和alltoAllOutOptional的数据类型一致,传入的x1、x2、x2Scale与output均不为空指针。
  • x1、x2和bias计算输入的数据类型根据不同设备型号有不同的限制:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:
      • 非量化场景下,output计算输出的数据类型为FLOAT16时,bias计算输入的数据类型支持FLOAT16;output计算输出的数据类型为BFLOAT16时,bias计算输入的数据类型支持FLOAT32。
      • 量化场景下,数据类型组合详见量化aclnn约束说明
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:
      • 非量化场景下,output计算输出的数据类型为FLOAT16时,bias计算输入的数据类型支持FLOAT16;output计算输出的数据类型为BFLOAT16时,bias计算输入的数据类型支持FLOAT32。
      • A3目前不支持量化场景。
    • Ascend 950PR/Ascend 950DT:
      • 非量化场景下,x1/x2计算输入的数据类型为FLOAT16时,bias计算输入的数据类型支持FLOAT16和FLOAT32;x1/x2计算输入的数据类型为BFLOAT16时,bias计算输入的数据类型支持BFLOAT16和FLOAT32。
      • 量化场景下,支持K-C动态量化模式和mx量化模式,x1计算输入的数据类型为FLOAT16、BFLOAT16、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1,x2计算输入的数据类型为FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1,bias的数据类型为FLOAT32或者bias为空,具体类型组合详见量化aclnn约束说明
      • mx量化模式下,当x1和x2的数据类型为FLOAT4_E2M1时,两者的数据类型必须一致。
  • 通算融合算子不支持并发调用,不同的通算融合算子也不支持并发调用。
  • 不支持跨超节点通信,只支持超节点内。
  • 通信引擎约束:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:支持MTE通信。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持AICPU通信。
    • Ascend 950PR/Ascend 950DT:支持CCU通信。

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_allto_all_matmul.cpp 通过aclnnAlltoAllMatMul接口方式调用非量化场景的AlltoAllMatMul算子。
aclnn接口 test_aclnn_allto_all_quant_matmul.cpp 通过aclnnAlltoAllQuantMatMul接口方式调用量化场景的AlltoAllMatMul算子。