* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef METADEF_CXX_INC_GRAPH_BASE_CUSTOM_OP_H
#define METADEF_CXX_INC_GRAPH_BASE_CUSTOM_OP_H
#include "exe_graph/runtime/eager_op_execution_context.h"
#include <functional>
#include <memory>
#include "graph/ge_error_codes.h"
#include "exe_graph/runtime/infer_datatype_context.h"
#include "exe_graph/runtime/infer_shape_context.h"
#include "exe_graph/runtime/op_compile_context.h"
#include "exe_graph/runtime/update_args_context.h"
namespace ge {
* 自定义算子能力接口的公共基类。
* 用户可按需组合继承 CompilableOp、EagerExecuteOp、ShapeInferOp,
* 以声明算子支持的构图编译、Eager 执行和推理能力。
*/
class BaseCustomOp {
public:
virtual ~BaseCustomOp() = default;
};
class PortableOp : virtual public BaseCustomOp {
public:
* 序列化自定义算子的 kernel bin 数据
* @param buffer 输出的二进制数据,由算子自定义格式,GE 不解析只透传
* @return 状态码,默认实现返回 GRAPH_SUCCESS
*/
virtual graphStatus Serialize(std::vector<uint8_t> &buffer) = 0;
* 反序列化自定义算子的 kernel bin 数据
* @param buffer 输入的二进制数据
* @return 状态码,默认实现返回 GRAPH_SUCCESS
*/
virtual graphStatus Deserialize(const std::vector<uint8_t> &buffer) = 0;
~PortableOp() override = default;
};
* 自定义算子的构图编译接口。
* 当算子进入 GE 构图编译流程后,若实现该接口,GE 会回调 Compile
* 完成算子编译相关处理。
*/
class CompilableOp : virtual public BaseCustomOp {
public:
~CompilableOp() override = default;
* 自定义算子及时编译函数
* @param ctx 算子编译上下文
* @return 状态码
*/
virtual graphStatus Compile(gert::OpCompileContext *ctx) = 0;
};
* 自定义算子的 Eager 执行接口。
* 适用于算子基于运行时上下文执行的场景。
*/
class EagerExecuteOp : virtual public BaseCustomOp {
public:
~EagerExecuteOp() override = default;
* 自定义算子的执行函数
* @param ctx 执行时上下文,可通过上下文获取input tensor,分配输出内存,分配workspace等
* @return 状态码
*/
virtual graphStatus Execute(gert::EagerOpExecutionContext *ctx) = 0;
};
* 自定义算子的 Args 刷新能力接口。
* 继承此接口的算子会在 I/O 地址变化时被框架回调 UpdateHostArgs。
*/
class ArgsUpdater : virtual public BaseCustomOp {
public:
~ArgsUpdater() override = default;
* @param ctx UpdateArgsContext,可通过上下文获取更新后的 I/O 地址和 args buffer
* @return 状态码,GRAPH_FAILED 将终止后续刷新
*/
virtual graphStatus UpdateHostArgs(gert::UpdateArgsContext *ctx) = 0;
};
* 自定义算子的 Shape 推理接口。
* 适用于算子基于推理上下文执行形状和数据类型推导的场景。
*/
class ShapeInferOp : virtual public BaseCustomOp {
public:
~ShapeInferOp() override = default;
* 形状推理函数,用于推导算子输出的形状
* @param ctx 形状推理上下文,可通过上下文获取输入张量形状,设置输出张量形状等
* @return 状态码
*/
virtual graphStatus InferShape(gert::InferShapeContext *ctx) = 0;
* 数据类型推理函数,用于推导算子输出的数据类型
* @param ctx 数据类型推理上下文,可通过上下文获取输入张量数据类型,设置输出张量数据类型等
* @return 状态码
*/
virtual graphStatus InferDataType(gert::InferDataTypeContext *ctx) = 0;
};
using BaseOpCreator = std::function<std::unique_ptr<BaseCustomOp>()>;
* 自定义算子创建器注册辅助类。
* 通常配合 REG_AUTO_MAPPING_OP 宏静态注册算子类型和创建函数。
*/
class CustomOpCreatorRegister {
public:
CustomOpCreatorRegister(const AscendString &operator_type, const BaseOpCreator &op_creator);
~CustomOpCreatorRegister() = default;
};
}
#define REG_JOIN(g_register, y) g_register##y
#define REG_AUTO_MAPPING_OP(custom_op_class) REG_AUTO_MAPPING_OP_UNIQ(__COUNTER__, custom_op_class)
#define REG_AUTO_MAPPING_OP_UNIQ(ctr, custom_op_class) \
static const ge::CustomOpCreatorRegister REG_JOIN(custom_op_register, ctr)( \
#custom_op_class, []() -> std::unique_ptr<ge::BaseCustomOp> { return std::make_unique<custom_op_class>(); })
#endif