* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "common/graph/fe_graph_utils.h"
#include "common/fe_type_utils.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "common/fe_inner_attr_define.h"
#include "common/platform_utils.h"
namespace fe {
namespace {
bool HasPeerOutNode(const ge::Node *node, const int &anchor_index,
ge::NodePtr &peer_out_node) {
auto in_anchor = node->GetInDataAnchor(anchor_index);
FE_CHECK(in_anchor == nullptr, FE_LOGW("index:%d in_anchor is nullptr",
anchor_index), return false);
auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
FE_CHECK(peer_out_anchor == nullptr, FE_LOGW("index:%d peer_out_anchor is nullptr",
anchor_index), return false);
peer_out_node = peer_out_anchor->GetOwnerNode();
FE_CHECK(peer_out_node == nullptr, FE_LOGW("index:%d peer_out_anchor is nullptr",
anchor_index), return false);
return true;
}
}
void FeGraphUtils::DumpSubGraphAndOnnx(const ge::ComputeGraph &graph, const std::string &suffix) {
for (auto subgraph : graph.GetAllSubgraphs()) {
DumpGraphAndOnnx(*subgraph, suffix);
}
}
void FeGraphUtils::DumpGraphAndOnnx(const ge::ComputeGraph &graph, const std::string &suffix) {
DumpGraph(graph, suffix);
ge::GraphUtils::DumpGEGraphToOnnx(graph, suffix);
}
void FeGraphUtils::DumpGraph(const ge::ComputeGraph &graph, const std::string &suffix) {
std::shared_ptr<ge::ComputeGraph> compute_graph_ptr = FeComGraphMakeShared<ge::ComputeGraph>(graph);
ge::GraphUtils::DumpGEGraph(compute_graph_ptr, suffix);
}
bool FeGraphUtils::IsMainGraphData(const ge::OpDescPtr &op_desc_ptr) {
if (op_desc_ptr == nullptr) {
return false;
}
return IsRootGraphData(op_desc_ptr->GetType()) && !IsSubGraphData(op_desc_ptr);
}
bool FeGraphUtils::IsMainGraphNetOutput(const ge::OpDescPtr &op_desc_ptr) {
if (op_desc_ptr == nullptr) {
return false;
}
return op_desc_ptr->GetType() == NETOUTPUT && !IsSubGraphNetOutput(op_desc_ptr);
}
bool FeGraphUtils::IsSubGraphDataOrNetOutput(const ge::OpDescPtr &op_desc_ptr) {
return IsSubGraphData(op_desc_ptr) || IsSubGraphNetOutput(op_desc_ptr);
}
bool FeGraphUtils::IsNotSubGraphDataAndNetOutput(const ge::OpDescPtr &op_desc_ptr) {
return !IsSubGraphData(op_desc_ptr) && !IsSubGraphNetOutput(op_desc_ptr);
}
bool FeGraphUtils::IsSubGraphData(const ge::OpDescPtr &op_desc_ptr) {
if (op_desc_ptr == nullptr || op_desc_ptr->GetType() != DATA) {
return false;
}
return op_desc_ptr->HasAttr(ge::ATTR_NAME_PARENT_NODE_INDEX);
}
bool FeGraphUtils::IsSubGraphNetOutput(const ge::OpDescPtr &op_desc) {
if (op_desc == nullptr || op_desc->GetType() != NETOUTPUT) {
return false;
}
for (auto &tensor : op_desc->GetAllInputsDescPtr()) {
if (ge::AttrUtils::HasAttr(tensor, ge::ATTR_NAME_PARENT_NODE_INDEX)) {
return true;
}
}
return false;
}
Status FeGraphUtils::GetPreOutAnchorOfSubData(const ge::NodePtr &data_node_ptr,
ge::OutDataAnchorPtr &pre_out_data_anchor_ptr) {
FE_CHECK_NOTNULL(data_node_ptr);
ge::OpDescPtr data_op_desc_ptr = data_node_ptr->GetOpDesc();
FE_CHECK_NOTNULL(data_op_desc_ptr);
uint32_t parent_node_index = 0;
if (!ge::AttrUtils::GetInt(data_op_desc_ptr, ge::ATTR_NAME_PARENT_NODE_INDEX, parent_node_index)) {
REPORT_FE_ERROR("[GraphOptJdgInst][UpdFmtAndDtype][GetPreOutAncr] attr %s is missing for node %s",
ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), data_op_desc_ptr->GetName().c_str());
return FAILED;
}
auto owner_graph = data_node_ptr->GetOwnerComputeGraph();
FE_CHECK_NOTNULL(owner_graph);
ge::NodePtr parent_node_ptr = owner_graph->GetParentNode();
FE_CHECK_NOTNULL(parent_node_ptr);
ge::InDataAnchorPtr in_data_anchor_ptr = parent_node_ptr->GetInDataAnchor(parent_node_index);
FE_CHECK_NOTNULL(in_data_anchor_ptr);
pre_out_data_anchor_ptr = in_data_anchor_ptr->GetPeerOutAnchor();
return SUCCESS;
}
Status FeGraphUtils::GetPreSubNetoutputInAnchor(std::unordered_set<ge::RefCell, ge::RefCellHash> &reflections,
std::vector<ge::InDataAnchorPtr> &vec_netoutput_in_ahchor) {
for (const auto &cell : reflections) {
if ((cell.in_out != ge::NODE_IN) || (cell.node->GetType() != NETOUTPUT)) {
continue;
}
for (auto &in_anchor : cell.node->GetAllInDataAnchors()) {
FE_CHECK_NOTNULL(in_anchor);
if (in_anchor->GetIdx() == cell.in_out_idx) {
vec_netoutput_in_ahchor.push_back(in_anchor);
break;
}
}
}
if (vec_netoutput_in_ahchor.empty()) {
return FAILED;
}
return SUCCESS;
}
Status FeGraphUtils::GetNextInAnchorsOfSubNetOutput(const ge::NodePtr &net_output_node_ptr, const int &input_index,
std::vector<ge::InDataAnchorPtr> &next_in_data_anchors) {
FE_CHECK_NOTNULL(net_output_node_ptr);
ge::OpDescPtr op_desc_ptr = net_output_node_ptr->GetOpDesc();
FE_CHECK_NOTNULL(op_desc_ptr);
auto input_desc = op_desc_ptr->GetInputDescPtr(input_index);
uint32_t parent_index = -1;
if (!ge::AttrUtils::GetInt(input_desc, ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
FE_LOGE("");
return FAILED;
}
auto owner_graph = net_output_node_ptr->GetOwnerComputeGraph();
FE_CHECK_NOTNULL(owner_graph);
ge::NodePtr parent_node_ptr = owner_graph->GetParentNode();
FE_CHECK_NOTNULL(parent_node_ptr);
ge::OutDataAnchorPtr out_data_anchor_ptr = parent_node_ptr->GetOutDataAnchor(parent_index);
FE_CHECK_NOTNULL(out_data_anchor_ptr);
for (auto it : out_data_anchor_ptr->GetPeerInDataAnchors()) {
next_in_data_anchors.push_back(it);
}
return SUCCESS;
}
Status FeGraphUtils::GetNextSubDatasOutAnchors(std::unordered_set<ge::RefCell, ge::RefCellHash> &reflections,
std::vector<ge::OutDataAnchorPtr> &out_data_anchors) {
for (const auto &cell : reflections) {
if ((cell.in_out != ge::NODE_OUT) || (cell.node->GetType() != DATA)) {
continue;
}
for (auto &out_anchor : cell.node->GetAllOutDataAnchors()) {
FE_CHECK_NOTNULL(out_anchor);
if (out_anchor->GetIdx() == cell.in_out_idx) {
out_data_anchors.push_back(out_anchor);
break;
}
}
}
if (out_data_anchors.empty()) {
return FAILED;
}
return SUCCESS;
}
Status FeGraphUtils::UpdateFormatOfRelatedEdges(const std::unordered_set<ge::RefCell, ge::RefCellHash> &reflections,
const RelationUpdateInfo &relation_update_info_a) {
FE_LOGD("relationUpdateInfo: primary_format=[%s], sub_format=[%d], shape=[%s].",
ge::TypeUtils::FormatToSerialString(relation_update_info_a.primary_format).c_str(),
relation_update_info_a.sub_format, GetShapeDims(relation_update_info_a.shape).c_str());
for (const auto &cell : reflections) {
ge::NodePtr node_ptr = cell.node;
FE_CHECK_NOTNULL(node_ptr);
ge::OpDescPtr op_desc_ptr = node_ptr->GetOpDesc();
FE_CHECK_NOTNULL(op_desc_ptr);
auto owner_graph = node_ptr->GetOwnerComputeGraph();
FE_CHECK_NOTNULL(owner_graph);
string graph_name = owner_graph->GetName();
string node_name = node_ptr->GetName();
FE_LOGD("Graph[%s]Op[type=%s,name=%s]: cell.in_out_idx=[%d], cell.in_out=[%d].", graph_name.c_str(),
node_ptr->GetType().c_str(), node_name.c_str(), cell.in_out_idx, cell.in_out);
auto index = cell.in_out_idx;
auto desc = (cell.in_out == ge::NODE_IN ? op_desc_ptr->GetInputDesc(static_cast<uint32_t>(index))
: op_desc_ptr->GetOutputDesc(static_cast<uint32_t>(index)));
string input_or_output = cell.in_out == ge::NODE_IN ? STR_INPUT_LOWERCASE : STR_OUTPUT_LOWERCASE;
if (relation_update_info_a.primary_format != ge::FORMAT_RESERVED) {
ge::Format cur_format = desc.GetFormat();
ge::GeShape cur_shape = desc.GetShape();
int32_t c0_bit_val = GetC0BitByDataType(desc.GetDataType());
auto new_format = static_cast<ge::Format>(
ge::GetFormatFromSubAndC0(relation_update_info_a.primary_format, relation_update_info_a.sub_format,
c0_bit_val));
desc.SetFormat(new_format);
desc.SetShape(relation_update_info_a.shape);
FE_LOGD(
"Graph[%s]Op[type=%s,name=%s]: update the %s %d desc, cur_format=[%s], cur_shape=[%s], new_format=[%s], "
"newShape=[%s].",
graph_name.c_str(), node_ptr->GetType().c_str(), node_name.c_str(), input_or_output.c_str(), index,
ge::TypeUtils::FormatToSerialString(cur_format).c_str(), GetShapeDims(cur_shape).c_str(),
ge::TypeUtils::FormatToSerialString(new_format).c_str(), GetShapeDims(relation_update_info_a.shape).c_str());
}
if (!relation_update_info_a.attr_name.empty() &&
(op_desc_ptr->GetType() != DATA && op_desc_ptr->GetType() != NETOUTPUT)) {
(void)ge::AttrUtils::SetInt(desc, relation_update_info_a.attr_name, relation_update_info_a.attr_value);
}
if (cell.in_out == ge::NODE_IN) {
(void)op_desc_ptr->UpdateInputDesc(static_cast<uint32_t>(index), desc);
} else {
(void)op_desc_ptr->UpdateOutputDesc(static_cast<uint32_t>(index), desc);
}
}
return SUCCESS;
}
bool FeGraphUtils::CheckRelatedEdgesOriginShape(const std::unordered_set<ge::RefCell, ge::RefCellHash> &reflections) {
int init_flag = 0;
vector<int64_t> ref_origin_shape_dims;
for (const auto &cell : reflections) {
ge::NodePtr node_ptr = cell.node;
FE_CHECK_NOTNULL(node_ptr);
ge::OpDescPtr op_desc_ptr = node_ptr->GetOpDesc();
FE_CHECK_NOTNULL(op_desc_ptr);
auto owner_graph = node_ptr->GetOwnerComputeGraph();
FE_CHECK_NOTNULL(owner_graph);
string graph_name = owner_graph->GetName();
string node_name = node_ptr->GetName();
string input_output = cell.in_out == ge::NODE_IN ? STR_INPUT_LOWERCASE : STR_OUTPUT_LOWERCASE;
FE_LOGD("Relations context: the %s %d of Graph[%s]Op[%s].", input_output.c_str(),
cell.in_out_idx, graph_name.c_str(), node_name.c_str());
auto index = cell.in_out_idx;
auto desc = (cell.in_out == ge::NODE_IN ? op_desc_ptr->GetInputDescPtr(static_cast<uint32_t>(index))
: op_desc_ptr->GetOutputDescPtr(static_cast<uint32_t>(index)));
if (desc == nullptr) {
return false;
}
vector<int64_t> origin_shape_dims = desc->GetOriginShape().GetDims();
if (init_flag == 0) {
ref_origin_shape_dims = origin_shape_dims;
init_flag = 1;
} else {
if (ref_origin_shape_dims != origin_shape_dims) {
FE_LOGD("Relations: the %s %d of Graph[%s]Op[%s], shape is not equal.", input_output.c_str(), cell.in_out_idx,
graph_name.c_str(), node_name.c_str());
return false;
}
}
}
return true;
}
void FeGraphUtils::GetGraphIdFromAttr(const ge::ComputeGraph &graph, string &graph_id) {
string session_graph_id = "";
if (ge::AttrUtils::GetStr(graph, ge::ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) && !session_graph_id.empty()) {
size_t pos = session_graph_id.find('_');
if (pos != string::npos && pos + 1 < session_graph_id.size()) {
graph_id = session_graph_id.substr(pos + 1);
}
}
FE_LOGD("Get session_graph_id=%s graph_id=%s.", session_graph_id.c_str(), graph_id.c_str());
}
bool FeGraphUtils::CheckTypeOnRootGraph(const std::unordered_set<string> &types, ge::NodePtr &parent_node) {
ge::NodePtr really_parent_node = nullptr;
if (ge::NodeUtils::GetInNodeCrossPartionedCallNode(parent_node, 0, really_parent_node) != SUCCESS) {
FE_LOGW(
"[SubGraphOpt][PreCompileOp][SetTensorConstVal] Node[%s, %s]: failed to getInNodeCrossPartionedCallNode.",
parent_node->GetName().c_str(), parent_node->GetType().c_str());
return false;
}
if (really_parent_node != nullptr) {
std::string node_type = really_parent_node->GetType();
FE_LOGD("Parent_node:%s type:%s really_parent_node:%s type:%s", parent_node->GetName().c_str(),
parent_node->GetType().c_str(), really_parent_node->GetName().c_str(),
really_parent_node->GetType().c_str());
parent_node = really_parent_node;
return types.count(node_type);
} else {
FE_LOGD("real parent node for %s is null.", parent_node->GetName().c_str());
}
return false;
}
void FeGraphUtils::ProcessPartitionedCall(const std::string &name, std::string &type, ge::NodePtr &parent_node,
ge::NodePtr &really_parent_node, ge::NodePtr &node) {
if (type != PARTITIONEDCALL) {
return;
}
const auto &func_graph = parent_node->GetOwnerComputeGraph();
FE_CHECK(func_graph == nullptr, FE_LOGW("GetOwnerComputeGraph Failed."), return);
const auto &src_graph = func_graph->TryGetExtAttr(kPartSrcGraph, ge::ComputeGraphPtr());
FE_CHECK(src_graph == nullptr, FE_LOGW("TryGetExtAttr Failed."), return);
const auto &root_graph = ge::GraphUtils::FindRootGraph(src_graph);
FE_CHECK(root_graph == nullptr, FE_LOGW("FindRootGraph Failed"), return);
for (const auto &subgraph : root_graph->GetAllSubgraphs()) {
if (subgraph->GetParentNode() == nullptr) {
continue;
}
if (subgraph->GetParentNode()->GetName() == parent_node->GetName()) {
const auto &net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT);
FE_CHECK(net_output_node == nullptr, FE_LOGW("GetSubgraph Failed"), return);
int32_t parent_node_anchor_index;
if (!ge::AttrUtils::GetInt(node->GetOpDesc(), "anchorIndex", parent_node_anchor_index)) {
FE_LOGW("Node [%s] failed to get anchorIndex.", name.c_str());
return;
}
auto in_node = ge::NodeUtils::GetInDataNodeByIndex(*net_output_node, parent_node_anchor_index);
if (in_node == nullptr) { FE_LOGW("in_node is nullptr!"); return; }
FE_LOGD("[SubGraphOpt][IsNodeSpecType][in_node %s, type %s]", in_node->GetName().c_str(),
in_node->GetType().c_str());
type = in_node->GetType();
parent_node = in_node;
really_parent_node = parent_node;
break;
}
}
}
void FeGraphUtils::IsNodeSpecificType(const std::unordered_set<string> &types,
ge::NodePtr &node, bool &matched) {
auto type = node->GetType();
auto name = node->GetName();
matched = types.count(type) != 0;
if (matched) {
return;
}
if (type == OP_TYPE_PLACE_HOLDER) {
ge::NodePtr parent_node = nullptr;
parent_node = node->GetOpDesc()->TryGetExtAttr(ATTR_NAME_PARENT_NODE, parent_node);
if (parent_node != nullptr) {
type = parent_node->GetType();
FE_LOGD("The parent node of place holder[%s] is [%s, %s].", name.c_str(),
parent_node->GetName().c_str(), parent_node->GetType().c_str());
ge::NodePtr really_parent_node = parent_node;
ProcessPartitionedCall(name, type, parent_node, really_parent_node, node);
bool parent_node_invalid = (types.count(type) == 0 &&
ge::NodeUtils::GetInNodeCrossPartionedCallNode(parent_node, 0,
really_parent_node) != SUCCESS);
if (parent_node_invalid) {
FE_LOGW("[SubGraphOpt][IsNodeSpecType][Op %s, type %s]: Failed to getInNodeCrossPartionedCallNode.",
name.c_str(), type.c_str());
return;
}
if (really_parent_node != nullptr) {
node = really_parent_node;
type = really_parent_node->GetType();
FE_LOGD("Parent node:%s type:%s really parent node:%s type:%s.", parent_node->GetName().c_str(),
parent_node->GetType().c_str(), really_parent_node->GetName().c_str(),
really_parent_node->GetType().c_str());
}
matched = types.count(type);
}
} else if (FeGraphUtils::IsSubGraphData(node->GetOpDesc())) {
matched = FeGraphUtils::CheckTypeOnRootGraph(types, node);
} else {
FE_LOGD("Cannot match any types for node %s and type %s.", node->GetName().c_str(), type.c_str());
}
}
bool FeGraphUtils::IsPeerOutConst(const ge::Node *node, const int &anchor_index,
ge::NodePtr &peer_out_node) {
if (node == nullptr) {
return false;
}
auto op_desc = node->GetOpDesc();
bool has_other_node = HasPeerOutNode(node, anchor_index, peer_out_node);
if (has_other_node) {
bool is_const_node = false;
IsNodeSpecificType(kConstTypes, peer_out_node, is_const_node);
return is_const_node;
} else {
return false;
}
}
bool FeGraphUtils::IsPeerOutWeight(ge::Node *node, const int &anchor_index,
ge::NodePtr &peer_out_node) {
if (node == nullptr) {
return false;
}
auto op_desc = node->GetOpDesc();
bool has_other_node = HasPeerOutNode(node, anchor_index, peer_out_node);
if (has_other_node) {
FE_LOGD("[IsPeerOutWeight] Peer out node is %s.", peer_out_node->GetName().c_str());
bool is_const_node = false;
IsNodeSpecificType(kWeightTypes, peer_out_node, is_const_node);
return is_const_node;
} else {
return false;
}
}
Status FeGraphUtils::GetAoeTypeFromRootGraph(ge::ComputeGraph& graph, std::string &aoe_type) {
auto nodes = graph.GetDirectNode();
ge::ComputeGraphPtr root_graph;
const auto &func_graph = nodes.at(0)->GetOwnerComputeGraph();
FE_CHECK(func_graph == nullptr, FE_LOGW("GetOwnerComputeGraph Failed!"), return FAILED);
const auto &src_graph = func_graph->TryGetExtAttr(kPartSrcGraph, ge::ComputeGraphPtr());
if (src_graph == nullptr) {
root_graph = ge::GraphUtils::FindRootGraph(graph.shared_from_this());
} else {
root_graph = ge::GraphUtils::FindRootGraph(src_graph);
}
FE_CHECK(root_graph == nullptr, FE_LOGW("FindRootGraph Failed!"), return FAILED);
if (!ge::AttrUtils::GetStr(*root_graph, AOE_TYPE, aoe_type)) {
return FAILED;
}
return SUCCESS;
}
void FeGraphUtils::FindPeerOpType(const ge::NodePtr &node, const bool is_input, std::string &peer_op_type) {
if (node == nullptr) {
return;
}
peer_op_type = node->GetType();
if (kGeDeleteOpType.count(peer_op_type) != 0) {
ge::Node::Vistor<ge::NodePtr> peer_nodes = is_input ? node->GetInDataNodes() : node->GetOutDataNodes();
if (!peer_nodes.empty()) {
FindPeerOpType(peer_nodes.at(0), is_input, peer_op_type);
}
}
if (peer_op_type == OP_TYPE_PLACE_HOLDER || peer_op_type == OP_TYPE_END) {
ge::NodePtr peer_node = node->GetOpDesc()->TryGetExtAttr<ge::NodePtr>(ATTR_NAME_PARENT_NODE, nullptr);
if (peer_node != nullptr) {
FindPeerOpType(peer_node, is_input, peer_op_type);
} else {
(void)ge::AttrUtils::GetStr(node->GetOpDesc(), PARENT_OP_TYPE, peer_op_type);
}
}
}
void FeGraphUtils::GetPrecisionModeFromGraph(const ge::ComputeGraph& graph, fe::PrecisionMode &precision_mode) {
int precision_mode_num = -1;
(void)ge::AttrUtils::GetInt(graph, "graph_precision_mode", precision_mode_num);
if (precision_mode_num < 0 || precision_mode_num > static_cast<int>(fe::PrecisionMode::ENUM_UNDEFINED)) {
FEContextUtils::GetPrecisionMode(precision_mode);
FE_LOGD("[GraphOptJdgInst][GetPrecisionModeFromGraph] The precision mode num %d from graph is invalid.",
precision_mode_num);
return;
}
precision_mode = static_cast<fe::PrecisionMode>(precision_mode_num);
FE_LOGD("[GraphOptJdgInst][GetPrecisionModeFromGraph] Get precision mode %d from graph.", precision_mode_num);
}
}