MulNoNan

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:完成安全乘法计算,当 x2 为 0 时返回 0,从而屏蔽 0 * inf = NaN0 * NaN = NaN 两类异常。等价于 TensorFlow 的 tf.math.multiply_no_nans

  • 计算公式:

    y={0,if x2=0x1×x2,if x2≠0y = \begin{cases} 0, & \text{if } x2 = 0 \\ x1 \times x2, & \text{if } x2 \neq 0 \end{cases}

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x1 输入 公式中的乘法输入张量x1。 FLOAT16, FLOAT, INT32, BFLOAT16 ND
x2 输入 公式中的乘法输入张量x2,作为判 0 主体,shape需与x1可广播。 同x1 ND
y 输出 公式中的输出张量y,shape为x1、x2广播后的统一形状。 同x1 ND

约束说明

  • x1、x2、y 必须为同一种数据类型,不支持混合数据类型。
  • 仅判 x2 一侧,x2 = -0 同样进入零臂(IEEE-754 -0 == 0);x2 != 0 时所有 IEEE 行为按普通乘法保留(NaN / Inf 正常传播)。
  • 支持任意 NumPy 广播形态(含标量、单维 broadcast、跨 rank broadcast),支持动态 shape 与动态 rank。

实现方案

文件 说明
计算图原型 op_graph/mul_no_nan_proto.h REG_OP(MulNoNan),二输入一输出
算子定义 op_host/mul_no_nan_def.cpp OpDef::AddConfig("ascend950", ...)
InferShape op_host/mul_no_nan_infershape.cpp 复用 Ops::Base::InferShape4Broadcast
Tiling op_host/arch35/mul_no_nan_tiling_arch35.{h,cpp} 按 dtype 分支调用 Ops::Base::BroadcastBaseTiling<OpDag>
DAG op_kernel/arch35/mul_no_nan_dag.h fp32/int32 通路原生计算;fp16/bf16 通路提升 fp32 中间精度
Struct op_kernel/arch35/mul_no_nan_struct.h BRC_TEMP_SCH_MODE_KEY_DECL/SEL
Kernel 入口 op_kernel/mul_no_nan_apt.cpp KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY) + BroadcastSch<schMode, OpDag>

fp32 / int32 通路

In0/In1 -- CopyInBrc --+
                       +-- Mul(x1,x2) --+
                                        +-- Select(mask, MulRes, Zero) -- CopyOut -- Out0
            Const(0)  -- Duplicate -- Zero --+
                                             |
                       +-- Compare(NE, x2, Zero) -- mask

Vec::Compare<u8, T, NE> 输出位掩码,Vec::Select<u8, T, TENSOR_TENSOR>mask ? MulRes : Zero 逐元素选择。x2 == 0 时即使 MulResNaN0 · inf / 0 · NaN)也被 Select 丢弃,输出 0。

fp16 / bf16 通路

In0/In1 -- CopyInBrc -- Cast(->fp32) --+
                                       +-- Mul --+
                                                 +-- Select -- Cast(->T,RINT) -- CopyOut -- Out0
            Const(0,fp32) -- Duplicate -- Zero --+
                                                 |
                              +-- Compare(NE) -- mask

把 fp16 / bf16 整体提升到 fp32 做 cmp/sel/mul,末端用 CAST_MODE_RINT (round-to-nearest-even)回退。原因:

  1. 与 DSL mul_no_nan.py 在 fp16 vcmpsel 不可用时 fallback 到 fp32 的行为一致;
  2. 避免 fp16 中间溢出(如 3e4 · 3e4 在 fp16 中先 inf 再被 select 处理会 引入额外的 saturation 不确定性),fp32 中间有充足动态范围。

调用说明

调用方式 样例代码 说明
图模式 test_geir_mul_no_nan 通过算子IR构图方式调用MulNoNan算子。