* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "hybrid/executor/worker/shape_inference_engine.h"
#include "graph/runtime_inference_context.h"
#include "graph/shape_refiner.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/tensor_utils.h"
#include "common/math/math_util.h"
#include "common/profiling/profiling_properties.h"
#include "common/profiling/profiling_manager.h"
#include "common/profiling_definitions.h"
#include "hybrid/node_executor/node_executor.h"
#include "hybrid/model/infer/shape_utils.h"
namespace ge {
namespace {
constexpr int32_t kDataInputIndex = 0;
}
namespace hybrid {
ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *const execution_context, const bool force_infer_shape)
: force_infer_shape_(force_infer_shape),
execution_context_(execution_context) {}
void ShapeInferenceEngine::Config(SubgraphContext *const subgraph_context) {
this->subgraph_context_ = subgraph_context;
}
Status ShapeInferenceEngine::InferShape(const NodeState &node_state) const {
HYBRID_CHK_STATUS_RET(DoInferShape(node_state),
"[Invoke][DoInferShape] failed for [%s(%s)].",
node_state.GetName().c_str(), node_state.GetType().c_str());
PROFILING_START(node_state.GetProfilingIndex(), profiling::kInferShapePropgate);
HYBRID_CHK_STATUS_RET(PropagateOutputShapes(node_state),
"[Invoke][PropagateOutputShapes] failed for [%s(%s)].",
node_state.GetName().c_str(), node_state.GetType().c_str());
PROFILING_END(node_state.GetProfilingIndex(), profiling::kInferShapePropgate);
return SUCCESS;
}
Status ShapeInferenceEngine::InitInferShapes(const GraphItem *const graph_item, const std::vector<TensorValue> &inputs,
const std::vector<ConstGeTensorDescPtr> &input_desc) const {
const auto input_nodes = graph_item->GetInputNodes();
if (inputs.size() < input_nodes.size()) {
GELOGE(INTERNAL_ERROR,
"[Check][Size][%s] Number of inputs [%zu] is not sufficient for subgraph which needs [%zu] inputs.",
graph_item->GetName().c_str(), inputs.size(), input_nodes.size());
REPORT_INNER_ERR_MSG("E19999",
"[%s] Number of inputs [%zu] is not sufficient for subgraph which needs [%zu] inputs.",
graph_item->GetName().c_str(), inputs.size(), input_nodes.size());
return INTERNAL_ERROR;
}
for (size_t i = 0UL; i < input_nodes.size(); ++i) {
auto &input_node = input_nodes[i];
if (input_node == nullptr) {
GELOGD("[%s] Input[%zu] is not needed by subgraph, skip it.", graph_item->GetName().c_str(), i);
continue;
}
auto &input_tensor = inputs[i];
GELOGD("[%s] Set input tensor[%zu] to inputs with index = %d, tensor = %s",
graph_item->GetName().c_str(),
i,
input_node->input_start,
input_tensor.DebugString().c_str());
const GeTensorDesc *tensor_desc = nullptr;
if (i < input_desc.size()) {
tensor_desc = input_desc[i].get();
}
if (input_node->has_observer || graph_item->HasCtrlFlowOp()) {
GELOGD("[%s] Start to update input[%zu] for subgraph data node.", graph_item->GetName().c_str(), i);
GE_CHK_STATUS_RET(UpdateShapeAndValue(input_node, &input_tensor, tensor_desc),
"[Update][ShapeAndValue] Failed to update shape and value for data node.");
} else {
GELOGD("[%s] Start to propagate input[%zu] value from data node.", graph_item->GetName().c_str(), i);
GE_CHK_STATUS_RET(PropagateOutputs(input_node, &input_tensor, tensor_desc),
"[Propagate][Outputs] Failed to propagate outputs for data node.");
}
}
GELOGD("[%s] Done setting inputs.", graph_item->GetName().c_str());
return SUCCESS;
}
Status ShapeInferenceEngine::PropagateOutputs(const NodeItem *const node_item, const TensorValue *const tensor,
const GeTensorDesc *const tensor_desc) const {
for (int32_t i = 0; i < node_item->num_outputs; ++i) {
auto &output_nodes = node_item->outputs[static_cast<size_t>(i)];
for (auto &dst_input_index_and_node : output_nodes) {
const auto dst_input_idx = dst_input_index_and_node.first;
auto &dst_node_item = dst_input_index_and_node.second;
GE_CHECK_NOTNULL(dst_node_item);
GE_CHK_STATUS_RET(subgraph_context_->SetInput(*dst_node_item, dst_input_idx, *tensor),
"[Invoke][SetInput] failed for node[%s] propagate input value to node[%s].",
node_item->NodeName().c_str(), dst_node_item->NodeName().c_str());
if (node_item->is_output_shape_static) {
continue;
}
GE_CHECK_NOTNULL(tensor_desc);
GE_CHECK_NOTNULL(dst_node_item->op_desc->MutableInputDesc(static_cast<uint32_t>(dst_input_idx)));
GE_CHK_STATUS_RET_NOLOG(ShapeUtils::CopyShapeAndTensorSize(*tensor_desc,
*(dst_node_item->op_desc->MutableInputDesc(static_cast<uint32_t>(dst_input_idx)))));
}
}
const auto node_state = subgraph_context_->GetNodeState(node_item);
GE_CHECK_NOTNULL(node_state);
node_state->SetSkipSchedule(true);
return SUCCESS;
}
Status ShapeInferenceEngine::UpdateShapeAndValue(const NodeItem *const node_item, const TensorValue *const tensor,
const GeTensorDesc *const tensor_desc) const {
GE_CHK_STATUS_RET(subgraph_context_->SetInput(*node_item, kDataInputIndex, *tensor),
"[Invoke][SetInput] failed for node[%s] to set input tensor", node_item->NodeName().c_str());
if (IsForceInferShape() || node_item->is_dynamic) {
GELOGD("Start to update node [%s] for subgraph data node.", node_item->NodeName().c_str());
GE_CHECK_NOTNULL(tensor_desc);
const auto node_state = subgraph_context_->GetNodeState(node_item);
GE_CHECK_NOTNULL(node_state);
const auto op_desc = node_item->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
GE_CHK_STATUS_RET_NOLOG(ShapeUtils::CopyShapeAndTensorSize(*tensor_desc, *op_desc->MutableInputDesc(0U)));
const auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(kDataInputIndex));
GE_CHECK_NOTNULL(output_desc);
output_desc->SetShape(tensor_desc->GetShape());
output_desc->SetOriginShape(tensor_desc->GetOriginShape());
node_state->SetSkipInferShape(true);
}
return SUCCESS;
}
Status ShapeInferenceEngine::DoInferShape(const NodeState &node_state) const {
GE_CHK_STATUS_RET_NOLOG(node_state.AwaitDependShapes(*execution_context_));
GE_CHK_STATUS_RET_NOLOG(node_state.AwaitInputTensors(*execution_context_));
auto &node_item = node_state.GetNodeItem();
if (node_item.is_output_shape_static && (!node_item.is_need_force_infershape)) {
return SUCCESS;
}
const auto &guard = node_item.MutexGuard("DoInferShape");
if (node_item.fused_subgraph != nullptr) {
GE_CHK_STATUS_RET_NOLOG(InferShapeForSubgraph(node_item, *node_item.fused_subgraph));
GE_CHK_STATUS_RET_NOLOG(node_item.CalcOutputTensorSizes());
return SUCCESS;
}
(void)guard;
if (node_item.shape_inference_type == DEPEND_COMPUTE) {
GELOGD("[%s] Skipping node with unknown shape type DEPEND_COMPUTE", node_item.NodeName().c_str());
return SUCCESS;
}
GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str());
if (!node_state.MaySkipShapeInference()) {
Operator *const op = node_state.GetOperator(execution_context_->stage_id);
GE_CHECK_NOTNULL(op);
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start");
PROFILING_START(node_state.GetProfilingIndex(), profiling::kInferShape);
GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, *op, true),
"[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str());
PROFILING_END(node_state.GetProfilingIndex(), profiling::kInferShape);
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End");
}
RECORD_COMPILE_EVENT(execution_context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start");
GE_CHK_STATUS_RET_NOLOG(node_item.CalcOutputTensorSizes(node_item.shape_inference_type == DEPEND_SHAPE_RANGE));
RECORD_COMPILE_EVENT(execution_context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] End");
GELOGD("[%s] [HybridTrace] After shape inference. Node = %s",
node_item.NodeName().c_str(),
node_item.DebugString().c_str());
GELOGD("[%s] InferShapeAndType finished successfully.", node_item.NodeName().c_str());
return SUCCESS;
}
Status ShapeInferenceEngine::PropagateOutputShapes(const NodeState &node_state) const {
const auto &node_item = node_state.GetNodeItem();
if (node_item.is_output_shape_static) {
return SUCCESS;
}
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] Start");
if (!node_item.IsShapeInFuture()) {
GE_CHK_STATUS_RET_NOLOG(const_cast<NodeItem&>(node_item).DoPropagate());
}
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] End");
GELOGD("[%s] Propagating output shapes finished successfully.", node_item.NodeName().c_str());
return SUCCESS;
}
Status ShapeInferenceEngine::InferShapeForSubgraph(const NodeItem &node_item,
const FusedSubgraph &fused_subgraph) const {
GELOGD("[%s] Start to infer shape by fused subgraph", node_item.NodeName().c_str());
for (auto &it : fused_subgraph.input_mapping) {
const auto parent_tensor_desc = node_item.MutableInputDesc(it.first);
GE_CHECK_NOTNULL(parent_tensor_desc);
GELOGD("Start to update shape by input[%d]", it.first);
GELOGD("Update shape to [%s]", parent_tensor_desc->GetShape().ToString().c_str());
GELOGD("Update original shape to [%s]", parent_tensor_desc->GetOriginShape().ToString().c_str());
for (auto &tensor_desc : it.second) {
GE_CHECK_NOTNULL(tensor_desc);
tensor_desc->SetShape(parent_tensor_desc->GetShape());
tensor_desc->SetOriginShape(parent_tensor_desc->GetOriginShape());
}
}
GE_CHK_STATUS_RET(SetDependingTensor(fused_subgraph), "[Set][DependingTensor] Failed for set depending tensor.");
for (size_t i = 0; i < fused_subgraph.nodes.size(); ++i) {
GELOGD("[%s] Start to invoke InferShapeAndType", fused_subgraph.nodes[i]->GetName().c_str());
PROFILING_START(fused_subgraph.name_to_prof_indexes[i], profiling::kInferShape);
GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndType(fused_subgraph.nodes[i]));
PROFILING_END(fused_subgraph.name_to_prof_indexes[i], profiling::kInferShape);
GELOGD("[%s] Done invoking InferShapeAndType", fused_subgraph.nodes[i]->GetName().c_str());
GE_CHK_STATUS_RET(UpdatePeerNodeShape(*fused_subgraph.nodes[i]),
"[Update][PeerNodeShape] failed for [%s].", fused_subgraph.nodes[i]->GetName().c_str());
}
for (auto &it : fused_subgraph.output_mapping) {
const int32_t parent_output_idx = it.first;
GELOGD("Update parent output[%d]", parent_output_idx);
const auto &input_desc = it.second;
GE_CHECK_NOTNULL(input_desc);
const auto parent_output_tensor_desc = node_item.MutableOutputDesc(parent_output_idx);
GE_CHECK_NOTNULL(parent_output_tensor_desc);
GELOGD("Update shape to [%s]", input_desc->GetShape().ToString().c_str());
GELOGD("Update original shape to [%s]", input_desc->GetOriginShape().ToString().c_str());
parent_output_tensor_desc->SetOriginShape(input_desc->GetOriginShape());
parent_output_tensor_desc->SetShape(input_desc->GetShape());
}
GELOGD("[%s] Done shape inference by subgraph successfully.", node_item.NodeName().c_str());
return SUCCESS;
}
Status ShapeInferenceEngine::UpdatePeerNodeShape(const Node &node) const {
const auto op_desc = node.GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (const auto &out_anchor : node.GetAllOutDataAnchors()) {
const auto output_tensor = op_desc->MutableOutputDesc(static_cast<uint32_t>(out_anchor->GetIdx()));
for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
const auto peer_node = peer_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(peer_node);
const auto peer_op_desc = peer_node->GetOpDesc();
GE_CHECK_NOTNULL(peer_op_desc);
const auto peer_input_desc = peer_op_desc->MutableInputDesc(static_cast<uint32_t>(peer_anchor->GetIdx()));
if (peer_input_desc == nullptr) {
GELOGE(GRAPH_FAILED, "[Call][MutableInputDesc] for %s(%s) return nullptr.",
peer_op_desc->GetName().c_str(), peer_op_desc->GetType().c_str());
REPORT_INNER_ERR_MSG("E19999", "%s(%s) call MutableInputDesc return nullptr.",
peer_op_desc->GetName().c_str(), peer_op_desc->GetType().c_str());
continue;
}
GELOGI("Peer input op desc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(),
output_tensor->GetShape().GetDimNum(), output_tensor->GetDataType(),
output_tensor->GetOriginDataType());
peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
peer_input_desc->SetShape(output_tensor->GetShape());
GELOGI("Peer input op desc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(),
peer_input_desc->GetShape().GetDimNum(), peer_input_desc->GetDataType(),
peer_input_desc->GetOriginDataType());
}
}
return SUCCESS;
}
Status ShapeInferenceEngine::SetDependingTensor(const FusedSubgraph &fused_subgraph) const {
if (!fused_subgraph.data_dependencies.empty()) {
for (auto &dep : fused_subgraph.data_dependencies) {
const auto src_node_id = dep.first.first;
const auto src_output_idx = dep.first.second;
const auto dst_op_desc = dep.second.first;
const auto dst_input_idx = dep.second.second;
const auto dst_tensor_desc = dst_op_desc->MutableInputDesc(static_cast<uint32_t>(dst_input_idx));
GE_CHECK_NOTNULL(dst_tensor_desc);
GeTensorPtr tensor = nullptr;
if (execution_context_->runtime_context_.GetTensor(src_node_id, src_output_idx, tensor) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Failed to get tensor, node_id = %" PRId64 ", output_idx = %d.",
src_node_id, src_output_idx);
GELOGE(INTERNAL_ERROR, "[Get][Tensor]Failed to get tensor, node_id = %ld, output_idx = %d.",
src_node_id, src_output_idx);
return INTERNAL_ERROR;
}
(void)AttrUtils::SetTensor(dst_tensor_desc, ATTR_NAME_VALUE, tensor);
}
}
return SUCCESS;
}
}
}