* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "custom_op_kernel.h"
#include "register/kernel_registry.h"
#include "common/checker.h"
#include "graph/custom_op_factory.h"
#include "graph/custom_op.h"
#include "kernel/memory/multi_stream_mem_block.h"
#include "graph/def_types.h"
#include "graph/utils/type_utils.h"
#include "exe_graph/runtime/eager_op_execution_context.h"
#include "runtime/kernel.h"
namespace gert {
namespace kernel {
namespace {
enum class CustomOpInput {
kAllocator = 0,
kStream,
kFunc,
kEnd
};
std::string PrintNodeType(const KernelContext *context) {
std::stringstream ss;
auto extend_context = reinterpret_cast<const ExtendedKernelContext *>(context);
auto compute_node_info = extend_context->GetComputeNodeInfo();
if (compute_node_info == nullptr) {
return "compute_node_info is nullptr";
}
ss << "node_type[" << compute_node_info->GetNodeType() << "]";
return ss.str();
}
void ShapeToStringStream(std::stringstream &ss, const Shape &shape) {
ss << "[";
for (size_t j = 0U; j < shape.GetDimNum(); ++j) {
ss << shape.GetDim(j);
if (j + 1U < shape.GetDimNum()) {
ss << ", ";
}
}
ss << "]";
}
void PrintTensor(std::stringstream &ss, const gert::Tensor *tensor) {
ss << "format: ";
ss << ge::TypeUtils::FormatToSerialString(tensor->GetStorageFormat()) << ", ";
ss << "datatype: ";
ss << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << ", ";
ss << "shape: ";
ShapeToStringStream(ss, tensor->GetStorageShape());
ss << ", addr: " << ge::PtrToValue(tensor->GetAddr());
}
}
std::string PrintStreamIdAndTaskId() {
std::stringstream ss;
uint32_t stream_id = 0U;
uint32_t flip_task_id = 0U;
if (rtGetTaskIdAndStreamID(&flip_task_id, &stream_id) == RT_ERROR_NONE) {
const uint32_t task_id = flip_task_id & 0xFFFFU;
const uint32_t flip_num = flip_task_id >> 16U;
ss << "stream_id=" << stream_id << ", task_id=" << task_id << ", flip_num=" << flip_num
<< ", flip_task_id=" << flip_task_id;
}
return ss.str();
}
ge::graphStatus FindCustomOpFunc(KernelContext *context) {
const char *node_type = context->GetInputValue<char *>(0);
GE_ASSERT_NOTNULL(node_type, "Failed to find custom op func, node type is nullptr");
auto custom_op_ptr = ge::CustomOpFactory::CreateOrGetCustomOp(node_type);
GE_ASSERT_NOTNULL(custom_op_ptr);
auto chain = context->GetOutput(0);
GE_ASSERT_NOTNULL(chain);
chain->Set(custom_op_ptr, nullptr);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CreateOutputTensors(const ExtendedKernelContext *extended_kernel_context,
KernelContext *context) {
const size_t node_output_num = extended_kernel_context->GetComputeNodeOutputNum();
for (size_t index = 0; index < node_output_num; index++) {
auto chain = context->GetOutput(index);
GE_ASSERT_NOTNULL(chain);
auto output_desc = extended_kernel_context->GetOutputDesc(index);
GE_ASSERT_NOTNULL(output_desc);
chain->SetWithDefaultDeleter(new (std::nothrow) Tensor(StorageShape(),
output_desc->GetFormat(), output_desc->GetDataType()));
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CreateWorkspaceHolder(KernelContext *context, size_t node_output_num) {
auto workspace_memory_av = context->GetOutput(node_output_num);
GE_ASSERT_NOTNULL(workspace_memory_av);
auto workspace_memory_holder = new (std::nothrow) std::vector<memory::MultiStreamMemBlock *>();
GE_ASSERT_NOTNULL(workspace_memory_holder);
workspace_memory_holder->reserve(1U);
workspace_memory_av->SetWithDefaultDeleter(workspace_memory_holder);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CreateWorkspacesMemory(const ge::FastNode *node, KernelContext *context) {
(void)node;
auto *extended_kernel_context = reinterpret_cast<ExtendedKernelContext *>(context);
GE_ASSERT_NOTNULL(extended_kernel_context);
const size_t node_output_num = extended_kernel_context->GetComputeNodeOutputNum();
GE_ASSERT_SUCCESS(CreateOutputTensors(extended_kernel_context, context));
GE_ASSERT_SUCCESS(CreateWorkspaceHolder(context, node_output_num));
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CopyShapeFromTemplateTensors(KernelContext *context, size_t node_input_num, size_t node_output_num) {
const size_t template_tensor_start = node_input_num + static_cast<size_t>(CustomOpInput::kEnd);
for (size_t index = 0; index < node_output_num; ++index) {
auto template_tensor = context->GetInputPointer<Tensor>(template_tensor_start + index);
auto output_tensor = context->GetOutputPointer<Tensor>(index);
GE_ASSERT_NOTNULL(template_tensor);
GE_ASSERT_NOTNULL(output_tensor);
output_tensor->MutableStorageShape() = template_tensor->GetStorageShape();
output_tensor->MutableOriginShape() = template_tensor->GetOriginShape();
}
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus ExecuteCustomOpImpl(KernelContext *context) {
auto *eager_context = reinterpret_cast<EagerOpExecutionContext *>(context);
GE_ASSERT_NOTNULL(eager_context);
const size_t node_input_num = eager_context->GetComputeNodeInputNum();
auto custom_op_ptr =
context->GetInputValue<ge::BaseCustomOp *>(node_input_num + static_cast<size_t>(CustomOpInput::kFunc));
GE_ASSERT_NOTNULL(custom_op_ptr);
auto *eager_execute_op_ptr = dynamic_cast<ge::EagerExecuteOp *>(custom_op_ptr);
if (eager_execute_op_ptr == nullptr) {
GELOGE(ge::FAILED, "%s is custom op but did not implement EagerExecuteOp", eager_context->GetNodeType());
return ge::GRAPH_FAILED;
}
GE_ASSERT_SUCCESS(eager_execute_op_ptr->Execute(reinterpret_cast<EagerOpExecutionContext *>(context)));
return ge::GRAPH_SUCCESS;
}
ge::graphStatus ExecuteCustomOpFunc(KernelContext *context) {
return ExecuteCustomOpImpl(context);
}
ge::graphStatus ExecuteCustomOpWithInferShapeFunc(KernelContext *context) {
auto *eager_context = reinterpret_cast<EagerOpExecutionContext *>(context);
GE_ASSERT_NOTNULL(eager_context);
GE_ASSERT_SUCCESS(CopyShapeFromTemplateTensors(context, eager_context->GetComputeNodeInputNum(),
eager_context->GetComputeNodeOutputNum()));
return ExecuteCustomOpImpl(context);
}
ge::graphStatus FreeCustomOpWorkspacesFunc(KernelContext *context) {
auto memory_vec =
context->MutableInputPointer<std::vector<GertMemBlock *>>(0);
GE_ASSERT_NOTNULL(memory_vec);
auto gert_allocator = context->GetInputValue<GertAllocator *>(1);
GE_ASSERT_NOTNULL(gert_allocator);
auto stream_id = gert_allocator->GetStreamId();
for (size_t i = 0UL; i < memory_vec->size(); i++) {
if (memory_vec->at(i) != nullptr) {
memory_vec->at(i)->Free(stream_id);
}
}
memory_vec->clear();
return ge::GRAPH_SUCCESS;
}
static std::vector<std::string> CustomOpExecuteKernelTrace(const KernelContext *context) {
auto extend_context = reinterpret_cast<const ExtendedKernelContext *>(context);
auto compute_node_info = extend_context->GetComputeNodeInfo();
if (compute_node_info == nullptr) {
return {PrintNodeType(context), "compute_node_info is nullptr"};
}
auto *eager_op_context = reinterpret_cast<const EagerOpExecutionContext *>(context);
std::stringstream input_tensor_ss;
input_tensor_ss << "input tensor: ";
for (size_t i = 0U; i < compute_node_info->GetInputsNum(); ++i) {
auto tensor = eager_op_context->GetInputTensor(i);
if (tensor == nullptr) {
return {PrintNodeType(context), "The " + std::to_string(i) + "th's input tensor is nullptr"};
}
input_tensor_ss << "[" << i << "]: [";
PrintTensor(input_tensor_ss, tensor);
input_tensor_ss << "]";
if ((i + 1U) < compute_node_info->GetInputsNum()) {
input_tensor_ss << ", ";
}
}
std::stringstream output_tensor_ss;
output_tensor_ss << "output tensor: ";
for (size_t i = 0U; i < compute_node_info->GetOutputsNum(); ++i) {
auto tensor = eager_op_context->GetOutputTensor(i);
if (tensor == nullptr) {
return {PrintNodeType(context), "The " + std::to_string(i) + "th's output tensor is nullptr"};
}
output_tensor_ss << "[" << i << "]: [";
PrintTensor(output_tensor_ss, tensor);
output_tensor_ss << "]";
if ((i + 1U) < compute_node_info->GetOutputsNum()) {
output_tensor_ss << ", ";
}
}
return {PrintNodeType(context), input_tensor_ss.str(), output_tensor_ss.str(), PrintStreamIdAndTaskId()};
}
ge::graphStatus CustomOpProfilingDataFill(const KernelContext *context, ProfilingInfoWrapper &prof_info) {
prof_info.SetBlockDim(std::numeric_limits<uint32_t>::max());
auto extend_context = reinterpret_cast<const ExtendedKernelContext *>(context);
auto compute_node_info = extend_context->GetComputeNodeInfo();
GE_ASSERT_NOTNULL(compute_node_info);
auto node_input_num = compute_node_info->GetInputsNum();
const auto eager_context = reinterpret_cast<const EagerOpExecutionContext *>(context);
GE_ASSERT_NOTNULL(eager_context);
std::vector<std::vector<int64_t>> input_shapes;
for (size_t i = 0UL; i < node_input_num; i++) {
auto tensor = eager_context->GetInputTensor(i);
GE_ASSERT_NOTNULL(tensor);
auto shape = tensor->GetStorageShape();
std::vector<int64_t> dims;
for (size_t j = 0U; j < shape.GetDimNum(); ++j) {
dims.emplace_back(shape.GetDim(j));
}
input_shapes.emplace_back(dims);
}
auto node_output_num = compute_node_info->GetOutputsNum();
std::vector<std::vector<int64_t>> output_shapes;
for (size_t i = 0UL; i < node_output_num; i++) {
auto tensor = eager_context->GetOutputTensor(i);
GE_ASSERT_NOTNULL(tensor);
auto shape = tensor->GetStorageShape();
std::vector<int64_t> dims;
for (size_t j = 0U; j < shape.GetDimNum(); ++j) {
dims.emplace_back(shape.GetDim(j));
}
output_shapes.emplace_back(dims);
}
GE_ASSERT_SUCCESS(prof_info.FillShapeInfo(input_shapes, output_shapes));
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(FindCustomOp).RunFunc(FindCustomOpFunc);
REGISTER_KERNEL(ExecuteCustomOp).OutputsCreator(CreateWorkspacesMemory)
.RunFunc(ExecuteCustomOpFunc).TracePrinter(CustomOpExecuteKernelTrace)
.ProfilingInfoFiller(CustomOpProfilingDataFill);
REGISTER_KERNEL(ExecuteCustomOpWithInferShape).OutputsCreator(CreateWorkspacesMemory)
.RunFunc(ExecuteCustomOpWithInferShapeFunc).TracePrinter(CustomOpExecuteKernelTrace)
.ProfilingInfoFiller(CustomOpProfilingDataFill);
REGISTER_KERNEL(FreeCustomOpWorkspaces).RunFunc(FreeCustomOpWorkspacesFunc);
}
}