FusedMulAdd
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:将
Mul、Add子图融合为单个算子,对三个输入按 NumPy 广播规则对齐后逐元素计算乘加。 -
计算公式:
y=x1×x2+x3y = x_1 \times x_2 + x_3
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x1 | 输入 | 公式中的乘法输入张量x1。 | FLOAT16, FLOAT, INT32 | ND |
| x2 | 输入 | 公式中的乘法输入张量x2,shape需与x1可广播。 | 同x1 | ND |
| x3 | 输入 | 公式中的加法输入张量x3,shape需与x1*x2的结果可广播。 | 同x1 | ND |
| y | 输出 | 公式中的输出张量y,shape为x1、x2、x3广播后的统一形状。 | 同x1 | ND |
约束说明
- x1、x2、x3、y 必须为同一种数据类型,不支持混合数据类型。
- 支持任意 NumPy 广播形态(含标量、单维 broadcast、跨 rank broadcast),支持动态 shape 与动态 rank。
实现方案
| 层 | 文件 | 说明 |
|---|---|---|
| 计算图原型 | op_graph/fused_mul_add_proto.h |
REG_OP(FusedMulAdd),三输入一输出 |
| 算子定义 | op_host/fused_mul_add_def.cpp |
OpDef::AddConfig("ascend950", ...) |
| InferShape | op_host/fused_mul_add_infershape.cpp |
复用 Ops::Base::InferShape4Broadcast(ctx, 3) |
| Tiling | op_host/arch35/fused_mul_add_tiling_arch35.{h,cpp} |
按 dtype 分支调用 Ops::Base::BroadcastBaseTiling<OpDag> |
| DAG | op_kernel/arch35/fused_mul_add_dag.h |
fp32/fp16 通路在 fp32 中间精度下用 Vec::Mul + Vec::Add;int32 通路用 Vec::Mul + Vec::Add |
| Struct | op_kernel/arch35/fused_mul_add_struct.h |
BRC_TEMP_SCH_MODE_KEY_DECL/SEL |
| Kernel 入口 | op_kernel/fused_mul_add_apt.cpp |
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY) + BroadcastSch<schMode, OpDag> |
fp16 / fp32 通路
In0/In1/In2 -- CopyInBrc -- Cast(->fp32) -- Vec::Mul(x1,x2) -- Vec::Add(+x3) -- Cast(->T,RINT) -- CopyOut -- Out0
全部输入先 Cast 到 fp32,再用 Vec::Mul + Vec::Add 两段式按公式
y = x1 * x2 + x3 顺序计算,最后 Cast 回 T 写出。
此处不使用
Vec::FusedMulAdd。该 API 底层实现会 in-place 写回 src2 buffer, 在 broadcast 大 tensor 跨 tile 复用输入 UB buffer 时会污染下一 tile 的输入, 导致 fp32/fp16 huge broadcast 场景精度错误。使用普通Vec::Mul + Vec::Add保证所有 placeholder 派生的 cast 结果只读。
int32 通路
In0/In1/In2 -- CopyInBrc -- Vec::Mul(x1,x2) -- Vec::Add(+x3) -- CopyOut -- Out0
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| 图模式 | test_geir_fused_mul_add | 通过算子IR构图方式调用FusedMulAdd算子。 |