* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "infer_shape_range.h"
#include "infer_shape.h"
#include "graph/ge_error_codes.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/normal_graph/operator_impl.h"
#include "graph/compute_graph.h"
#include "graph/operator_factory_impl.h"
#include "graph/buffer.h"
#include "register/kernel_registry.h"
#include "framework/common/debug/ge_log.h"
#include "common/checker.h"
#include "compatible_utils.h"
#include "common/util/mem_utils.h"
#include "core/debug/kernel_tracing.h"
#include "graph/utils/type_utils.h"
#include "graph_metadef/common/ge_common/util.h"
namespace gert {
namespace kernel {
namespace {
std::vector<std::string> CompatibleInferShapeKernelTrace(const KernelContext *context) {
return {PrintNodeType(context),
PrintInputShapeInfo(context, static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kInputNum)),
PrintOutputShapeInfo(context)};
}
std::vector<std::string> CompatibleInferShapeRangeKernelTrace(const KernelContext *context) {
return {PrintNodeType(context),
PrintInputRangeInfo(context, static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kInputNum)),
PrintOutputRangeInfo(context)};
}
ge::graphStatus UpdateInputShapeToOpDesc(KernelContext *context, ge::OpDescPtr &op_desc) {
auto other_inputs_size = static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kInputNum);
GE_ASSERT_TRUE(context->GetInputNum() >= other_inputs_size);
auto input_shapes_num = context->GetInputNum() - other_inputs_size;
auto all_valid_input_desc = op_desc->GetAllInputsDescPtr();
GE_ASSERT_EQ(all_valid_input_desc.size(), input_shapes_num);
auto infer_context = reinterpret_cast<InferShapeContext *>(context);
auto input_shape_start_pos = other_inputs_size;
for (size_t i = 0U; i < input_shapes_num; ++i) {
auto input_shape = infer_context->GetInputShape(input_shape_start_pos + i);
GE_ASSERT_NOTNULL(input_shape);
std::vector<int64_t> shape_dims;
for (size_t j = 0U; j < input_shape->GetDimNum(); ++j) {
shape_dims.emplace_back(input_shape->GetDim(j));
}
const auto &input_desc = all_valid_input_desc.at(i);
GE_ASSERT_NOTNULL(input_desc);
input_desc->SetShape(ge::GeShape(shape_dims));
input_desc->SetOriginShape(ge::GeShape(shape_dims));
auto input_desc_in_context = infer_context->GetInputDesc(i);
GE_ASSERT_NOTNULL(input_desc_in_context);
input_desc->SetFormat(input_desc_in_context->GetOriginFormat());
input_desc->SetOriginFormat(input_desc_in_context->GetOriginFormat());
}
auto output_shapes_num = context->GetOutputNum();
auto output_num_on_op = op_desc->GetOutputsSize();
GE_ASSERT_EQ(output_shapes_num, output_num_on_op);
for (size_t i = 0U; i < output_shapes_num; ++i) {
const auto &output_desc = op_desc->MutableOutputDesc(i);
GE_ASSERT_NOTNULL(output_desc);
auto output_desc_in_context = infer_context->GetOutputDesc(i);
GE_ASSERT_NOTNULL(output_desc_in_context);
output_desc->SetFormat(output_desc_in_context->GetOriginFormat());
output_desc->SetOriginFormat(output_desc_in_context->GetOriginFormat());
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus UpdateOutputShapeToContext(const ge::OpDescPtr &op_desc, KernelContext *context) {
auto output_num = context->GetOutputNum();
auto output_num_on_op = op_desc->GetOutputsSize();
if (output_num_on_op != output_num) {
GELOGE(ge::PARAM_INVALID, "Output num on op %s is %zu, output num on context is %zu, not match.",
op_desc->GetName().c_str(), output_num_on_op, output_num);
return ge::PARAM_INVALID;
}
auto infer_context = reinterpret_cast<InferShapeContext *>(context);
for (size_t i = 0U; i < output_num_on_op; ++i) {
const auto &output_shape = op_desc->GetOutputDesc(i).GetShape().GetDims();
auto output_shape_on_context = infer_context->GetOutputShape(i);
GE_ASSERT_NOTNULL(output_shape_on_context);
output_shape_on_context->SetDimNum(output_shape.size());
for (size_t j = 0U; j < output_shape.size(); ++j) {
output_shape_on_context->SetDim(j, output_shape[j]);
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus UpdateInputShapeRangeToOpDesc(KernelContext *context, ge::OpDescPtr &op_desc) {
auto other_inputs_size = static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kInputNum);
GE_ASSERT_TRUE(context->GetInputNum() >= other_inputs_size);
auto input_shapes_num = context->GetInputNum() - other_inputs_size;
auto input_num_on_op = op_desc->GetInputsSize();
if (input_num_on_op != input_shapes_num) {
GELOGE(ge::PARAM_INVALID, "Input num on op %s is %zu, input num on context is %zu, not match.",
op_desc->GetName().c_str(), input_num_on_op, input_shapes_num);
return ge::PARAM_INVALID;
}
auto infer_range_context = reinterpret_cast<InferShapeRangeContext *>(context);
auto input_shape_start_pos = other_inputs_size;
for (size_t i = 0U; i < input_shapes_num; ++i) {
auto input_shape_min = infer_range_context->GetInputShapeRange(input_shape_start_pos + i)->GetMin();
auto input_shape_max = infer_range_context->GetInputShapeRange(input_shape_start_pos + i)->GetMax();
GE_ASSERT_EQ(input_shape_min->GetDimNum(), input_shape_max->GetDimNum());
std::vector<int64_t> shape_dims;
std::vector<std::pair<int64_t, int64_t>> range;
for (size_t j = 0U; j < input_shape_min->GetDimNum(); ++j) {
auto dim = (input_shape_min->GetDim(j) == input_shape_max->GetDim(j)) ? input_shape_min->GetDim(j) : -1;
shape_dims.emplace_back(dim);
range.emplace_back(std::pair<int64_t, int64_t>(input_shape_min->GetDim(j), input_shape_max->GetDim(j)));
}
const auto &input_desc = op_desc->MutableInputDesc(i);
input_desc->SetOriginShape(ge::GeShape(shape_dims));
input_desc->SetShape(ge::GeShape(shape_dims));
input_desc->SetOriginShapeRange(range);
input_desc->SetShapeRange(range);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus UpdateOutputShapeRangeToContext(const ge::OpDescPtr &op_desc, KernelContext *context) {
auto output_num = context->GetOutputNum();
size_t size_of_output = 0U;
GE_ASSERT_TRUE(
!ge::MulOverflow(static_cast<size_t>(op_desc->GetOutputsSize()), kShapeRangeOutputOfNode, size_of_output));
if (size_of_output != output_num) {
GELOGE(ge::PARAM_INVALID, "Output num on op %s is %zu, output num on context is %zu, not match.",
op_desc->GetName().c_str(), size_of_output, output_num);
return ge::PARAM_INVALID;
}
auto infer_range_context = reinterpret_cast<InferShapeRangeContext *>(context);
for (size_t i = 0U; i < op_desc->GetOutputsSize(); ++i) {
std::vector<std::pair<int64_t, int64_t>> range;
GE_CHK_STATUS_RET(op_desc->GetOutputDesc(i).GetShapeRange(range), "Fail to get output shape range from op desc.");
GE_ASSERT_NOTNULL(infer_range_context->GetOutputShapeRange(i));
infer_range_context->GetOutputShapeRange(i)->GetMin()->SetDimNum(range.size());
infer_range_context->GetOutputShapeRange(i)->GetMax()->SetDimNum(range.size());
for (size_t j = 0U; j < range.size(); ++j) {
infer_range_context->GetOutputShapeRange(i)->GetMin()->SetDim(j, range[j].first);
infer_range_context->GetOutputShapeRange(i)->GetMax()->SetDim(j, range[j].second);
}
}
return ge::GRAPH_SUCCESS;
}
}
ge::graphStatus BuildCreateOpOutputs(const ge::FastNode *node, KernelContext *context) {
(void) node;
auto compute_node_chain = context->GetOutput(0U);
GE_ASSERT_NOTNULL(compute_node_chain);
auto op = new (std::nothrow) ge::Operator();
GE_ASSERT_NOTNULL(op);
compute_node_chain->SetWithDefaultDeleter(op);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CreateOpFromBuffer(KernelContext *context) {
auto input_num = context->GetInputNum();
if (input_num < static_cast<size_t>(CreateOpFromBufferInputIndex::kInputNum)) {
return ge::GRAPH_FAILED;
}
auto operator_buffer_index = static_cast<size_t>(CreateOpFromBufferInputIndex::kOpBuffer);
auto operator_buffer = context->GetInputPointer<ContinuousVector>(operator_buffer_index);
if (operator_buffer == nullptr) {
return ge::GRAPH_FAILED;
}
auto operator_buffer_size_index = static_cast<size_t>(CreateOpFromBufferInputIndex::kOpBufferSize);
auto operator_buffer_size = context->GetInputValue<size_t>(operator_buffer_size_index);
ge::OpDescPtr op_desc = nullptr;
GE_CHK_STATUS_RET_NOLOG(KernelCompatibleUtils::UnSerializeOpDesc(operator_buffer, operator_buffer_size, op_desc));
GE_ASSERT_NOTNULL(op_desc);
auto dummy_graph = ge::MakeShared<ge::ComputeGraph>("dummy");
GE_ASSERT_NOTNULL(dummy_graph);
auto dummy_node = dummy_graph->AddNode(op_desc);
GE_ASSERT_NOTNULL(dummy_node);
auto out_op = context->GetOutputPointer<ge::Operator>(0U);
GE_ASSERT_NOTNULL(out_op);
*out_op = ge::OpDescUtils::CreateOperatorFromNode(dummy_node);
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(CreateOpFromBuffer).RunFunc(CreateOpFromBuffer).OutputsCreator(BuildCreateOpOutputs);
ge::graphStatus BuildFindCompatibleInferShapeFuncOutputs(const ge::FastNode *node, KernelContext *context) {
(void) node;
auto chain = context->GetOutput(0);
GE_ASSERT_NOTNULL(chain);
auto infer_func = new (std::nothrow)ge::InferShapeFunc();
GE_ASSERT_NOTNULL(infer_func);
chain->SetWithDefaultDeleter(infer_func);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus FindCompatibleInferShapeFunc(KernelContext *context) {
auto node_type = context->GetInputValue<char *>(0);
auto infer_fun_ptr = context->GetOutputPointer<ge::InferShapeFunc>(0);
if (node_type == nullptr || infer_fun_ptr == nullptr) {
GELOGE(ge::FAILED, "Failed to find infer shape kernel, node type %s", node_type);
return ge::GRAPH_FAILED;
}
const auto &op_funcs = ge::OperatorFactoryImpl::GetInferShapeFunc(node_type);
if (op_funcs == nullptr) {
GELOGE(ge::FAILED, "Failed to find v1 infer shape func, node type %s", node_type);
return ge::GRAPH_FAILED;
}
*infer_fun_ptr = op_funcs;
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(FindCompatibleInferShapeFunc)
.RunFunc(FindCompatibleInferShapeFunc)
.OutputsCreator(BuildFindCompatibleInferShapeFuncOutputs);
REGISTER_KERNEL(FindCompatibleInferShapeRangeFunc)
.RunFunc(FindCompatibleInferShapeFunc)
.OutputsCreator(BuildFindCompatibleInferShapeFuncOutputs);
ge::graphStatus CompatibleInferShape(KernelContext *context) {
auto input_num = context->GetInputNum();
if (input_num < static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kInputNum)) {
return ge::GRAPH_FAILED;
}
auto op_index = static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kOperator);
auto op = context->GetInputValue<ge::Operator *>(op_index);
GE_ASSERT_NOTNULL(op);
auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(*op);
GE_ASSERT_NOTNULL(op_desc);
GE_CHK_STATUS_RET(UpdateInputShapeToOpDesc(context, op_desc), "Fail to update input shapes to Operator.");
std::vector<ge::GeTensorPtr> tensor_holder;
if (!op_desc->GetOpInferDepends().empty()) {
auto callback = [&context, &tensor_holder](const ge::ConstNodePtr &node, const size_t index,
ge::GeTensorPtr &tensor) {
(void) node;
auto infer_shape_context = reinterpret_cast<InferShapeContext *>(context);
auto input_start_pos = static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kInputNum);
auto shape_tensor = infer_shape_context->GetInputTensor(index + input_start_pos);
GE_CHECK_NOTNULL(shape_tensor);
return KernelCompatibleUtils::ConvertRTTensorToGeTensor(shape_tensor, tensor, tensor_holder);
};
ge::OpDescUtils::SetCallbackGetConstInputFuncToOperator(*op, callback);
}
auto op_infer_fun_index = static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kInferFunc);
auto op_infer_fun = context->GetInputValue<ge::InferShapeFunc *>(op_infer_fun_index);
GE_ASSERT_NOTNULL(op_infer_fun);
auto ret = (*op_infer_fun)(*op);
if (ret != ge::GRAPH_SUCCESS) {
GELOGE(ret, "Fail to infer shape on op %s.", op_desc->GetName().c_str());
return ret;
}
GE_CHK_STATUS_RET(UpdateOutputShapeToContext(op_desc, context), "Fail to update output shapes to Context.");
auto extend_context = reinterpret_cast<ExtendedKernelContext *>(context);
auto compute_node_info = extend_context->GetComputeNodeInfo();
if (compute_node_info == nullptr) {
return ge::GRAPH_FAILED;
}
GE_CHK_STATUS_RET(TransformAllOutputsShape(compute_node_info, context), "Fail to transfrom output shapes.");
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CompatibleInferShapeRange(KernelContext *context) {
auto input_num = context->GetInputNum();
if (input_num < static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kInputNum)) {
GELOGE(ge::PARAM_INVALID,
"input num only has %zu inputs now, at least need %zu inputs, including infer_func and operator.", input_num,
static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kInputNum));
return ge::GRAPH_FAILED;
}
auto op_index = static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kOperator);
auto op = context->GetInputValue<ge::Operator *>(op_index);
GE_ASSERT_NOTNULL(op);
auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(*op);
GE_ASSERT_NOTNULL(op_desc);
GE_CHK_STATUS_RET(UpdateInputShapeRangeToOpDesc(context, op_desc), "Fail to update input shapes to Operator.");
std::vector<ge::GeTensorPtr> tensor_holder;
if (!op_desc->GetOpInferDepends().empty()) {
auto callback = [&context, &tensor_holder](const ge::ConstNodePtr &node, const size_t index,
ge::GeTensorPtr &tensor) {
(void) node;
auto infer_range_context = reinterpret_cast<InferShapeRangeContext *>(context);
auto input_start_pos = static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kInputNum);
auto tensor_range = infer_range_context->GetInputTensorRange(index + input_start_pos);
GE_CHECK_NOTNULL(tensor_range);
auto shape_range_tensor = tensor_range->GetMin();
GE_CHECK_NOTNULL(shape_range_tensor);
return KernelCompatibleUtils::ConvertRTTensorToGeTensor(shape_range_tensor, tensor, tensor_holder);
};
ge::OpDescUtils::SetCallbackGetConstInputFuncToOperator(*op, callback);
}
auto op_infer_fun_index = static_cast<size_t>(CompatibleInferShapeOrRangeInputIndex::kInferFunc);
auto op_infer_fun = context->GetInputValue<ge::InferShapeFunc *>(op_infer_fun_index);
GE_ASSERT_NOTNULL(op_infer_fun);
auto ret = (*op_infer_fun)(*op);
if (ret != ge::GRAPH_SUCCESS) {
GELOGE(ret, "Fail to infer shape on op %s.", op_desc->GetName().c_str());
return ret;
}
GE_CHK_STATUS_RET(UpdateOutputShapeRangeToContext(op_desc, context),
"Fail to update output shapes range to Context.");
auto extend_context = reinterpret_cast<ExtendedKernelContext *>(context);
auto compute_node_info = extend_context->GetComputeNodeInfo();
if (compute_node_info == nullptr) {
return ge::GRAPH_FAILED;
}
GE_CHK_STATUS_RET(TransformAllOutputsMaxShape(compute_node_info, context), "Fail to transfrom output shapes.");
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(CompatibleInferShape)
.RunFunc(CompatibleInferShape)
.OutputsCreator(BuildInferShapeOutputs)
.TracePrinter(CompatibleInferShapeKernelTrace);
REGISTER_KERNEL(CompatibleInferShapeRange)
.RunFunc(CompatibleInferShapeRange)
.OutputsCreator(BuildInferShapeRangeOutputs)
.TracePrinter(CompatibleInferShapeRangeKernelTrace);
}
}