FusedMulAddAdd
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:将
Mul、Add、Add子图融合为单个算子,对四个输入按 NumPy 广播规则对齐后逐元素计算乘加加,常用于BatchMatmul + bias + residual等模式。 -
计算公式:
y=x1×x2+x3+x4y = x_1 \times x_2 + x_3 + x_4
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x1 | 输入 | 公式中的乘法输入张量x1。 | FLOAT16, FLOAT, INT32 | ND |
| x2 | 输入 | 公式中的乘法输入张量x2,shape需可广播到x1。 | 同x1 | ND |
| x3 | 输入 | 公式中第一次加法的输入张量x3,shape需可广播到x1。 | 同x1 | ND |
| x4 | 输入 | 公式中第二次加法的输入张量x4,shape需可广播到x1。 | 同x1 | ND |
| y | 输出 | 公式中的输出张量y,shape与x1相同。 | 同x1 | ND |
约束说明
- x1、x2、x3、x4、y 必须为同一种数据类型,不支持混合数据类型。
- 计算顺序为
((x1 * x2) + x3) + x4,不可交换。 - x1 必须为完整的输出 shape:x2、x3、x4 的 shape 可按 NumPy 广播规则向上广播到 x1(支持标量、单维 broadcast、跨 rank broadcast),输出 y 的 shape 与 x1 一致。当前 runtime 不支持 x1 自身向上广播(即 x1 比输出 shape 小的场景,例如 x1=
[1]、x2=[3,4]),该类用例会在 RunGraph 阶段失败。
实现方案
| 层 | 文件 | 说明 |
|---|---|---|
| 计算图原型 | op_graph/fused_mul_add_add_proto.h |
REG_OP(FusedMulAddAdd),四输入一输出 |
| 算子定义 | op_host/fused_mul_add_add_def.cpp |
OpDef::AddConfig("ascend950", ...) |
| InferShape | op_host/fused_mul_add_add_infershape.cpp |
复用 Ops::Base::InferShape4Broadcast(ctx, 4) |
| Tiling | op_host/arch35/fused_mul_add_add_tiling_arch35.{h,cpp} |
按 dtype 分支调用 Ops::Base::BroadcastBaseTiling<OpDag> |
| DAG | op_kernel/arch35/fused_mul_add_add_dag.h |
fp32/fp16 通路在 fp32 中间精度下用 Vec::Mul + Vec::Add + Vec::Add;int32 通路用 Vec::Mul + Vec::Add + Vec::Add |
| Struct | op_kernel/arch35/fused_mul_add_add_struct.h |
BRC_TEMP_SCH_MODE_KEY_DECL/SEL |
| Kernel 入口 | op_kernel/fused_mul_add_add_apt.cpp |
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY) + BroadcastSch<schMode, OpDag> |
fp16 / fp32 通路
In0/In1/In2/In3 -- CopyInBrc -- Cast(->fp32) -- Vec::Mul(x1,x2) -- Vec::Add(+x3) -- Vec::Add(+x4) -- Cast(->T,RINT) -- CopyOut -- Out0
全部输入先 Cast 到 fp32,再用 Vec::Mul + Vec::Add + Vec::Add 三段式按公式
y = x1 * x2 + x3 + x4 顺序计算,最后 Cast 回 T 写出。
int32 通路
In0/In1/In2/In3 -- CopyInBrc -- Vec::Mul(x1,x2) -- Vec::Add(+x3) -- Vec::Add(+x4) -- CopyOut -- Out0
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| 图模式 | test_geir_fused_mul_add_add | 通过算子IR构图方式调用FusedMulAddAdd算子。 |