MulNoNan
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:完成安全乘法计算,当 x2 为 0 时返回 0,从而屏蔽
0 * inf = NaN、0 * NaN = NaN两类异常。等价于 TensorFlow 的tf.math.multiply_no_nans。 -
计算公式:
y={0,if x2=0x1×x2,if x2≠0y = \begin{cases} 0, & \text{if } x2 = 0 \\ x1 \times x2, & \text{if } x2 \neq 0 \end{cases}
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x1 | 输入 | 公式中的乘法输入张量x1。 | FLOAT16, FLOAT, INT32, BFLOAT16 | ND |
| x2 | 输入 | 公式中的乘法输入张量x2,作为判 0 主体,shape需与x1可广播。 | 同x1 | ND |
| y | 输出 | 公式中的输出张量y,shape为x1、x2广播后的统一形状。 | 同x1 | ND |
约束说明
- x1、x2、y 必须为同一种数据类型,不支持混合数据类型。
- 仅判 x2 一侧,
x2 = -0同样进入零臂(IEEE-754-0 == 0);x2 != 0时所有 IEEE 行为按普通乘法保留(NaN / Inf 正常传播)。 - 支持任意 NumPy 广播形态(含标量、单维 broadcast、跨 rank broadcast),支持动态 shape 与动态 rank。
实现方案
| 层 | 文件 | 说明 |
|---|---|---|
| 计算图原型 | op_graph/mul_no_nan_proto.h |
REG_OP(MulNoNan),二输入一输出 |
| 算子定义 | op_host/mul_no_nan_def.cpp |
OpDef::AddConfig("ascend950", ...) |
| InferShape | op_host/mul_no_nan_infershape.cpp |
复用 Ops::Base::InferShape4Broadcast |
| Tiling | op_host/arch35/mul_no_nan_tiling_arch35.{h,cpp} |
按 dtype 分支调用 Ops::Base::BroadcastBaseTiling<OpDag> |
| DAG | op_kernel/arch35/mul_no_nan_dag.h |
fp32/int32 通路原生计算;fp16/bf16 通路提升 fp32 中间精度 |
| Struct | op_kernel/arch35/mul_no_nan_struct.h |
BRC_TEMP_SCH_MODE_KEY_DECL/SEL |
| Kernel 入口 | op_kernel/mul_no_nan_apt.cpp |
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY) + BroadcastSch<schMode, OpDag> |
fp32 / int32 通路
In0/In1 -- CopyInBrc --+
+-- Mul(x1,x2) --+
+-- Select(mask, MulRes, Zero) -- CopyOut -- Out0
Const(0) -- Duplicate -- Zero --+
|
+-- Compare(NE, x2, Zero) -- mask
Vec::Compare<u8, T, NE> 输出位掩码,Vec::Select<u8, T, TENSOR_TENSOR> 按
mask ? MulRes : Zero 逐元素选择。x2 == 0 时即使 MulRes 是 NaN(0 · inf
/ 0 · NaN)也被 Select 丢弃,输出 0。
fp16 / bf16 通路
In0/In1 -- CopyInBrc -- Cast(->fp32) --+
+-- Mul --+
+-- Select -- Cast(->T,RINT) -- CopyOut -- Out0
Const(0,fp32) -- Duplicate -- Zero --+
|
+-- Compare(NE) -- mask
把 fp16 / bf16 整体提升到 fp32 做 cmp/sel/mul,末端用 CAST_MODE_RINT
(round-to-nearest-even)回退。原因:
- 与 DSL
mul_no_nan.py在 fp16vcmpsel不可用时 fallback 到 fp32 的行为一致; - 避免 fp16 中间溢出(如
3e4 · 3e4在 fp16 中先 inf 再被 select 处理会 引入额外的 saturation 不确定性),fp32 中间有充足动态范围。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| 图模式 | test_geir_mul_no_nan | 通过算子IR构图方式调用MulNoNan算子。 |