BROADCAST_INFER
产品支持情况
头文件
#include <graph/operator_reg.h>
功能说明
提供公共函数宏封装,供算子开发者开发InferShape函数。该函数基于2个输入的shape,设置输出的shape。该宏只是设置shape,未设置dtype。
-
如果2个输入的shape一致,会按输入的shape设置输出shape。
-
如果2个输入的shape不一致,会按照Broadcast的策略,取2个输入shape的并集。
比如输入shape分别为(1,2,3,4)和(3,1,3,4),则该宏会设置算子的输出shape为(3,2,3,4)。
函数原型
BROADCAST_INFER(in1_name, in2_name, out_name)
该函数会自动调用如下函数:
graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape,
const function<vector<int64_t>()> &get_in2_shape,
const function<void(const std::vector<int64_t> &y_shape)> &set_out_shape);
参数说明
返回值说明
执行成功或失败。
约束说明
无
调用示例
IMPLEMT_INFERFUNC(RightShift, RightShiftInfer) {
DataType type = op.GetInputDesc("x").GetDataType();
SET_OUTPUT_TYPE(op, "z", type);
return BROADCAST_INFER("x", "y", "z")(op);
}