MatmulAllReduceAddRmsNorm
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | x |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
功能说明
-
算子功能:完成mm + all_reduce + add + rms_norm计算。
-
计算公式:
-
情景一:
mm_out=allReduce(x1@x2+bias)mm\_out = allReduce(x1 @ x2 + bias)
y=mm_out+residualy = mm\_out + residual
normOut=yRMS(y)∗gamma,RMS(y)=1d∑i=1dyi2+epsilonnormOut = \frac{y}{RMS(y)} * gamma, RMS(y) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} y_{i}^{2} + epsilon}
-
情景二:
mmout=allReduce(dequantscale∗(x1int8@x2int8+biasint32))mm_out = allReduce(dequant_scale * (x1_{int8}@x2_{int8} + bias_{int32}))
y=mmout+residualy = mm_out + residual
normOut=yRMS(y)∗gamma,RMS(y)=1d∑i=1dyi2+epsilonnormOut = \frac{y}{RMS(y)} * gamma, RMS(y) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} y_{i}^{2} + epsilon}
-
情景三:
mm_out=allReduce(x1@(x2∗antiquantscale+antiquantoffset)+bias)mm\_out = allReduce(x1 @ (x2*antiquant_scale + antiquant_offset) + bias)
y=mm_out+residualy = mm\_out + residual
normOut=yRMS(y)∗gamma,RMS(y)=1d∑i=1dyi2+epsilonnormOut = \frac{y}{RMS(y)} * gamma, RMS(y) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} y_{i}^{2} + epsilon}
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x1 | 输入 | MatMul左矩阵,即公式中的输入x1。 | FLOAT16、BFLOAT16、INT8 | ND |
| x2 | 输入 | MatMul右矩阵,即公式中的输入x2。 | FLOAT16、BFLOAT16、INT8、INT4 | ND |
| bias | 可选输入 | Matmul计算之后的Add项,即公式中的输入bias。 | FLOAT16、BFLOAT16、INT32 | ND |
| residual | 输入 | AddRmsNorm融合算子的残差输入,即公式中的输入residual。 | FLOAT16、BFLOAT16 | ND |
| gamma | 输入 | AddRmsNorm融合算子的RmsNorm计算输入,即公式中的输入gamma。 | FLOAT16、BFLOAT16 | ND |
| antiquant_scale | 可选输入 | 公式中的输入antiquant_scale。 | FLOAT16、BFLOAT16 | ND |
| antiquant_offset | 可选输入 | 对x2进行伪量化计算的offset参数,公式中的输入antiquant_offset。 | FLOAT16、BFLOAT16 | ND |
| dequant_scale | 可选输入 | mm计算后的全量化系数,公式中的输入dequant_scale。 | FLOAT16、BFLOAT16、UINT64、INT64 | ND |
| y | 输出 |
|
FLOAT16、BFLOAT16 | ND |
| norm_out | 输出 |
|
FLOAT16、BFLOAT16 | ND |
| group | 属性 |
|
CHAR*、STRING | - |
| reduceOp | 可选属性 |
|
CHAR*、STRING | - |
| is_trans_a | 可选属性 |
|
BOOL | - |
| is_trans_b | 可选属性 | BOOL | - | |
| commTurn | 可选属性 |
|
INT64 | - |
| antiquant_group_size | 可选属性 |
|
INT64 | - |
| epsilon | 可选属性 |
|
DOUBLE | - |
约束说明
- 增量场景不使能MC2,全量场景使能MC2
- 输入x1可为二维或者三维,其shape为(b, s, k)或者(s, k)。x2必须是二维,其shape为(k, n),轴满足mm算子入参要求,k轴相等,m的范围为[1, 2147483647],k的范围为[1, 65535],n的范围为[0, 65535]。bias若非空,bias为一维,其shape为(n)。bias可选,可为空,非空时当前版本仅支持一维输入。
- 输入residual必须是三维,其shape为(b, s, n),当x1为二维时,residual的(b*s)等于x1的s,不支持非连续的tensor。输入gamma必须是一维,其shape为(n),不支持非连续的tensor。
- antiquant_scale满足pertensor场景shape为(1),perchannel场景shape为(1,n)/(n),pergroup场景shape为(ceil(k,antiquant_group_size),n)。antiquant_offset可选,可为空,非空时shape与antiquant_scale一致。
- dequant_scale的shape在pertensor场景为(1),perchannel场景为(n)/(1, n)。
- 输出y和normOut的维度和数据类型同residual。bias若非空,shape大小与normOut最后一维相等。
- bias、residual、gamma、y、normOut计算输入的数据类型要一致。
- antiquant_group_size在不支持pergroup场景时,传入0,在支持pergroup场景时,传入值的范围为[32, min(k-1,INT_MAX)],且为32的倍数。k取值范围与mm接口保持一致。
- 支持(b*s)、n为0的空tensor,不支持k为0的空tensor。
- 只支持x2矩阵转置/不转置,x1矩阵支持不转置场景。
- 属性reduceOp当前版本仅支持输入"sum"。
- 属性commTurn当前版本仅支持输入0。
- 支持1、2、4、8卡,并且仅支持hccs链路all mesh组网。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:一个模型中的通算融合MC2算子,仅支持相同通信域。类型要一致。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_matmul_all_reduce_add_rms_norm.cpp | 通过aclnnMatmulAllReduceAddRmsNorm接口方式调用MatmulAllReduceAddRmsNorm算子。 |