MatmulAllReduce
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | x |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | x |
| Atlas 推理系列产品 | x |
| Atlas 训练系列产品 | x |
功能说明
-
算子功能:完成MatMul计算与AllReduce通信融合。
-
计算公式:
非量化场景:
-
情形1:
yi=xi×weighti+biasiy_i=x_i\times weight_i + bias_i
-
情形2:
output=AllReduce(x1@x2+bias+x3)output = AllReduce(x1 @ x2 + bias + x3)
-
情形3:对量化后的入参x1、x2进行MatMul计算后,接着进行Dequant计算,接着与x3进行Add操作,最后做AllReduce计算。
output=AllReduce(dequantScale∗(x1int8@x2int8+biasint32)+x3)output= AllReduce(dequantScale*(x1_{int8}@x2_{int8} + bias_{int32}) + x3)
-
情形4:对量化后的入参x1、x2进行MatMul计算后,接着进行Dequant和pertoken计算,接着与x3进行Add操作,最后做AllReduce计算。
output=AllReduce(dequantScale∗pertokenScaleOptional∗(x1int8@x2int8+biasOptionalint32)+x3Optional)output= AllReduce(dequantScale * pertokenScaleOptional * (x1_{int8}@x2_{int8} + biasOptional_{int32}) + x3Optional)
-
情形5:对量化后的入参x1、x2进行MatMul、Dequant和pertoken计算,接着与x3进行Add操作,再对输出进行perchannel量化,然后进行AllToAll通信,对第一次通讯结果进行reduceSum计算,接着进行AllGather通信,最后对第二次通信结果进行Dequant,得到最终输出。
matmulAddOutput=(dequantScale∗pertokenScaleOptional∗(x1int8@x2int8+biasOptionalint32)+x3Optional);matmulAddOutput = (dequantScale * pertokenScaleOptional * (x1_{int8}@x2_{int8} + biasOptional_{int32}) + x3Optional);
alltoallOutPutint8=AllToAll(matmulAddOutput/commQuantScale1Optional);alltoallOutPut_{int8} = AllToAll(matmulAddOutput / commQuantScale1Optional);
reduceSumOutPutint8=(add(alltoallOutPutint8)∗(commQuantScale1Optional/commQuantScale2Optional));reduceSumOutPut_{int8} = (add(alltoallOutPut_{int8}) * (commQuantScale1Optional / commQuantScale2Optional));
outPut=(AllGather(reduceSumOutPutint8)∗commQuantScale2Optional);outPut = (AllGather(reduceSumOutPut_{int8}) * commQuantScale2Optional);
-
情形6:
-
commQuantScale1Optional, commQuantScale2Optional不为空时:
matmulAddOutput=(x2Scale∗x1ScaleOptional∗(x1int8@x2int8+biasOptionalint32)+x3Optional);matmulAddOutput = (x2Scale * x1ScaleOptional * (x1_{int8}@x2_{int8} + biasOptional_{int32}) + x3Optional);
alltoallOutputint8=AllToAll(matmulAddOutput/commQuantScale1Optional);alltoallOutput_{int8} = AllToAll(matmulAddOutput / commQuantScale1Optional);
reduceSumOutputint8=(add(alltoallOutputint8)∗(commQuantScale1Optional/commQuantScale2Optional));reduceSumOutput_{int8} = (add(alltoallOutput_{int8}) * (commQuantScale1Optional / commQuantScale2Optional));
output=(AllGather(reduceSumOutputint8)∗commQuantScale2Optional);output = (AllGather(reduceSumOutput_{int8}) * commQuantScale2Optional);
-
x1,x2为INT8,无x1ScaleOptional,x2Scale为INT64/UINT64,可选biasOptional为INT32,out为BFLOAT16/FLOAT16:
output=AllReduce((x1@x2+biasOptional)∗x2Scale+x3Optional)output = AllReduce((x1@x2 + biasOptional) * x2Scale + x3Optional)
-
x1,x2为INT8,x1ScaleOptional为FLOAT32,x2Scale为FLOAT32/BFLOAT16,可选biasOptional为INT32, out为FLOAT16/BFLOAT16:
output=AllReduce((x1@x2+biasOptional)∗x2Scale∗x1ScaleOptional+x3Optional)output = AllReduce((x1@x2 + biasOptional) * x2Scale * x1ScaleOptional + x3Optional)
-
x1,x2为FLOAT4_E2M1/FLOAT8_E4M3FN/FLOAT8_E5M2,x1ScaleOptional为FLOAT8_E8M0,x2Scale为FLOAT8_E8M0,可选biasOptional为FLOAT32, out为FLOAT16/BFLOAT16/FLOAT32:
output=AllReduce((x1∗x1ScaleOptional)@(x2∗x2Scale)+biasOptional+x3Optional)output = AllReduce((x1* x1ScaleOptional)@(x2* x2Scale) + biasOptional + x3Optional)
-
x1,x2为FLOAT8_E4M3FN/FLOAT8_E5M2/HIFLOAT8,x1ScaleOptional为FLOAT32,x2Scale为FLOAT32,可选bias为FLOAT32, out为FLOAT16/BFLOAT16/FLOAT32:
output=AllReduce((x1@x2+biasOptional)∗x2Scale∗x1ScaleOptional+x3Optional)output = AllReduce((x1@x2 + biasOptional) * x2Scale * x1ScaleOptional + x3Optional)
-
x1,x2为FLOAT8_E4M3FN/FLOAT8_E5M2/HIFLOAT8,x1ScaleOptional为FLOAT32,x2Scale为FLOAT32,无biasOptional。当x1为(a0, a1),x2为(b0, b1)时x1ScaleOptional为(ceildiv(a0,128), ceildiv(a1,128)),x2Scale为(ceildiv(b0,128), ceildiv(b1,128)),out为FLOAT16/BFLOAT16/FLOAT32:
outputpq=AllReduce(∑0⌊k128⌋(x1pr@x2rq∗(x1ScaleOptionalpr∗x2Scalerq))+x3)output_{pq} = AllReduce(\sum_{0}^{\left \lfloor \frac{k}{128} \right \rfloor} (x1_{pr}@x2_{rq}*(x1ScaleOptional_{pr}*x2Scale_{rq})) + x3)
-
-
情形7:
output = Allreduce(x1 @ ((x2 + antiquantOffset) ∗antiquantScale) + bias+ x3) output = Allreduce(x1 @ ((x2 + antiquantOffset) *antiquantScale) + bias+ x3)
-
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x1 | 输入 | MatMul左矩阵,即公式中的输入x1。 | FLOAT16、BFLOAT16、INT8、HIFLOAT8、FLOAT8_E5M2、FLOAT8_E4M3FN、FLOAT4_E2M1 | ND |
| x2 | 输入 | MatMul右矩阵,即公式中的输入x2。 | FLOAT16、BFLOAT16、INT8、INT4、HIFLOAT8、FLOAT8_E5M2、FLOAT8_E4M3FN、FLOAT4_E2M1 | ND、FRACTAL_NZ |
| bias | 可选输入 | 公式中的输入bias。 | FLOAT16、BFLOAT16、INT32、FLOAT | ND |
| x3 | 可选输入 | MatMul计算后的Add计算,即公式中的输入x3。 | FLOAT16、BFLOAT16、FLOAT | ND |
| antiquantScale | 可选输入 | 公式中的输入antiquantScale。 | FLOAT16、BFLOAT16 | ND |
| antiquantOffset | 可选输入 | 公式中的输入antiquantOffset。 | FLOAT16、BFLOAT16 | ND |
| dequantScale | 可选输入 | MatMul计算后的去量化系数,即公式中的输入dequantScale。 | FLOAT16、BFLOAT16、FLOAT、UINT64、FLOAT8_E8M0 | ND |
| pertoken_scale | 可选输入 | MatMul计算后的pertoken去量化系数,即公式中的输入pertokenScaleOptional。 | FLOAT、BFLOAT16、FLOAT8_E8M0 | ND |
| comm_quant_scale_1 | 可选输入 | matmulAdd计算后的perchannel量化系数,即公式中的输入comm_quant_scale_1。 | FLOAT、BFLOAT16、FLOAT16 | ND |
| comm_quant_scale_2 | 可选输入 | allGather计算后的perchannel量化系数,即公式中的输入comm_quant_scale_2。 | FLOAT、BFLOAT16、FLOAT16 | ND |
| output | 输出 | 计算+通信的结果,即公式中的输出output。 | FLOAT、BFLOAT16、FLOAT16 | ND |
| group | 属性 |
|
CHAR*、STRING | - |
| reduceOp | 可选属性 |
|
CHAR*、STRING | - |
| is_trans_a | 可选属性 |
|
BOOL | - |
| is_trans_b | 可选属性 |
|
BOOL | - |
| commTurn | 可选属性 |
|
INT64 | - |
| antiquant_group_size | 可选属性 | 伪量化pergroup模式下,对x2进行反量化计算的groupSize输入。 | INT64 | - |
| group_size | 可选属性 |
|
INT64 | - |
| comm_quant_mode | 可选属性 |
|
INT64 | - |
约束说明
- 增量场景不使能MC2,全量场景使能MC2。
- 输入x1可为二维或者三维,其shape为(b, s, k)或者(m, k),不支持非连续输入。
- 输入x2必须是二维。其shape为(k, n),k轴满足mm算子入参要求,k轴相等,m的范围为[1, 2147483647],k、n的范围为[1, 65535]。
- 输入x2的数据格式:
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:支持ND(当前版本仅支持二维输入)和FRACTAL_NZ格式(当前版本仅支持四维输入)。
- Ascend 950PR/Ascend 950DT:支持ND格式。
- 当x2的数据格式为FRACTAL_NZ时,配合aclnnCalculateMatmulWeightSizeV2和aclnnTransMatmulWeight到数据格式NZ的转换,非连续Tensor仅支持transpose场景。当x2的数据格式为ND时,当前版本仅支持二维输入。
- 传入的x1、x2、antiquantScale或者output不为空指针。
- 当输入x1的shape为(b, s, k)时,x3(非空场景)与输出output的shape为(b, s, n),pertoken_scale的shape为(b*s);当输入x1的shape为(m, k)时,x3(非空场景)与输出output的shape为(m, n),pertoken_scale的shape为(m)。
- 输入comm_quant_scale_1和comm_quant_scale_2可选,可为空,当x2为(k, n)时, shape可为(n)或者(1,n)。
- 输入dequantScale可选,可为空,shape在pertensor场景为(1),perchannel场景为(n)/(1, n)。输出为BFLOAT16时,直接将BFLOAT16类型的dequantScale传入本接口。输出为FLOAT16时,如果pertokenScaleOptional不为空,可直接将FLOAT32类型的dequantScale传入本接口,如果pertokenScaleOptional为空,则需提前调用TransQuantParamV2算子的aclnn接口来将dequantScale转成INT64/UINT64数据类型。
- bias若非空,当前版本仅支持一维,shape大小与output最后一维大小相等。antiquantScale在pertensor场景下shape为(1),在perchannel场景下shape为(1,n)/(n),在pergroup场景shape为(ceil(k,antiquantGroupSize), n)。antiquantOffset若非空,其shape与antiquantScale一致。
- x1和x2,x3(非空场景)、antiquantScale、antiquantOffset(非空场景)、output、bias(非空场景)的数据类型和数据格式需要在支持的范围之内。
- x1、antiquantScale、antiquantOffset(非空场景)、x3(非空场景)、bias(非空场景)和output的数据类型相同。antiquantGroupSize在不支持pergroup场景时,传入0,在支持pergroup场景时,传入值的范围为[32, min(k-1,INT_MAX)],且为32的倍数。k取值范围与mm接口保持一致。
- group_size在perblock场景下,只支持549764202624。其他场景,只支持0。
- 只支持x2矩阵转置/不转置,x1矩阵不支持转置场景。
- 属性reduceOp当前版本仅支持输入"sum"。
- 属性commTurn当前版本仅支持输入0。
- 在长序列场景,随着b/s或者m的增大,可能出现OOM或者计算超时。
- 仅支持hccs链路all mesh组网。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:支持1、2、4、8卡。
- Ascend 950PR/Ascend 950DT:支持1、2、4、8、16、32、64卡。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:一个模型中的通算融合MC2算子,仅支持相同通信域。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_matmul_all_reduce.cpp | 通过aclnnMatMulAllReduce接口方式调用MatMulAllReduce算子。 |