* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/ref_relation.h"
#include <unordered_set>
#include <set>
#include <unordered_map>
#include "common/util/mem_utils.h"
#include "common/checker.h"
#include "framework/common/debug/ge_log.h"
#include "debug/ge_op_types.h"
#include "graph/utils/graph_utils.h"
#include "graph/def_types.h"
#include "graph_metadef/common/ge_common/util.h"
namespace ge {
namespace {
const char_t *kRefIdx = "_parent_node_index";
const char_t *kWhile = "While";
const char_t *kIf = "If";
const char_t *kCase = "Case";
const char_t *kStatelessWhile = "StatelessWhile";
std::set<std::string> function_op = {kWhile, kIf, kCase};
}
class RefRelations::Impl {
public:
graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set<RefCell, RefCellHash> &result) {
const auto iter = look_up_table_.find(key.hash_key);
if (iter != look_up_table_.end()) {
for (auto &c : iter->second) {
(void)result.insert(c);
}
return GRAPH_SUCCESS;
}
GELOGD("[RefRelations][Check] cannot find any relations! key value of dest relation is %s", key.hash_key.c_str());
return GRAPH_SUCCESS;
};
graphStatus BuildRefRelations(ge::ComputeGraph &graph);
graphStatus Clear() {
GELOGD("Start clear boundary reflections between main graph and sub graph!");
look_up_table_.clear();
values_.clear();
return GRAPH_SUCCESS;
};
private:
friend class RefRelations;
graphStatus BuildLookUpTables();
graphStatus BuildRefRelationsForBranch(
const NodePtr &root_node,
const std::vector<std::vector<NodePtr>> &classed_data_nodes,
const std::vector<std::vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
std::vector<std::vector<RefCell>> &node_refs) const;
graphStatus BuildRefRelationsForWhile(
const NodePtr &root_node,
const std::vector<std::vector<NodePtr>> &classed_data_nodes,
const std::vector<std::vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
std::vector<std::vector<RefCell>> &node_refs) const;
graphStatus BuildRelationsWithFuncNodeType(
const NodePtr &root_node,
const std::vector<std::vector<NodePtr>> &classed_data_nodes,
const std::vector<std::vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
std::vector<std::vector<RefCell>> &node_refs) const;
void GetDataAndNetoutputOfSubGraph(
const ge::ComputeGraph &root_graph,
std::vector<NodePtr> &graph_data_nodes,
std::vector<NodePtr> &netoutput_nodes,
const std::vector<std::string> &sub_graph_names,
const std::string &node_type) const;
graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) const;
graphStatus ProcessSubgraphDataNodes(std::vector<NodePtr> &graph_data_nodes,
std::vector<std::vector<NodePtr>> &classed_data_nodes) const;
graphStatus ProcessSubgraphNetoutput(
const std::vector<NodePtr> &netoutput_nodes,
std::vector<std::vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes) const;
void BuildRelationsForVariables(const ge::ComputeGraph &root_graph);
std::unordered_map<std::string, std::vector<RefCell>> look_up_table_;
std::vector<std::vector<std::vector<RefCell>>> values_;
};
graphStatus RefRelations::Impl::BuildRefRelationsForBranch(
const NodePtr &root_node,
const std::vector<std::vector<NodePtr>> &classed_data_nodes,
const std::vector<std::vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
std::vector<std::vector<RefCell>> &node_refs) const {
GELOGD("Enter BuildRefRelationsForBranch!");
size_t ref_i = 0UL;
for (const auto &ref_i_data_nodes : classed_data_nodes) {
std::vector<RefCell> in_ref_i_all_refs;
RefCell cell_root(root_node->GetName(), root_node, NODE_IN, static_cast<int32_t>(ref_i));
in_ref_i_all_refs.emplace_back(cell_root);
for (const auto &data : ref_i_data_nodes) {
RefCell cell_in(data->GetName(), data, NODE_IN, 0);
RefCell cell_out(data->GetName(), data, NODE_OUT, 0);
in_ref_i_all_refs.emplace_back(cell_in);
in_ref_i_all_refs.emplace_back(cell_out);
}
node_refs.emplace_back(in_ref_i_all_refs);
ref_i++;
}
size_t ref_o = 0UL;
for (const auto &ref_o_net_nodes : classed_netoutput_nodes) {
std::vector<RefCell> out_ref_i_all_refs;
RefCell cell_root(root_node->GetName(), root_node, NODE_OUT, static_cast<int32_t>(ref_o));
out_ref_i_all_refs.emplace_back(cell_root);
for (const auto &ele : ref_o_net_nodes) {
RefCell cell_netoutput_in((ele.first)->GetName(), ele.first, NODE_IN, static_cast<int32_t>(ele.second));
out_ref_i_all_refs.emplace_back(cell_netoutput_in);
}
node_refs.emplace_back(out_ref_i_all_refs);
ref_o++;
}
return GRAPH_SUCCESS;
}
graphStatus RefRelations::Impl::BuildLookUpTables() {
GELOGD("start to build look up table!");
for (size_t i = 0UL; i < values_.size(); i++) {
std::vector<std::vector<RefCell>> &val = values_[i];
for (const auto &ele : val) {
for (const auto &ref_cell : ele) {
look_up_table_.emplace(ref_cell.hash_key, ele);
}
}
}
return GRAPH_SUCCESS;
}
graphStatus RefRelations::Impl::BuildRefRelationsForWhile(
const NodePtr &root_node,
const std::vector<std::vector<NodePtr>> &classed_data_nodes,
const std::vector<std::vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
std::vector<std::vector<RefCell>> &node_refs) const {
GELOGD("Enter BuildRefRelations for while op!");
const auto input_num = root_node->GetAllInDataAnchorsSize();
NodePtr netoutput = nullptr;
GE_ASSERT(classed_data_nodes.size() >= input_num, "classed_data_nodes size %zu must be >= input_num %zu",
classed_data_nodes.size(), input_num);
GE_ASSERT(classed_netoutput_nodes.size() >= input_num, "classed_netoutput_nodes size %zu must be >= input_num %zu",
classed_netoutput_nodes.size(), input_num);
size_t ref_i = 0UL;
while (ref_i < input_num) {
auto &ref_i_data_nodes = classed_data_nodes[ref_i];
auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i];
std::vector<RefCell> ref_i_all_refs;
RefCell cell_root_i(root_node->GetName(), root_node, NODE_IN, static_cast<int32_t>(ref_i));
RefCell cell_root_o(root_node->GetName(), root_node, NODE_OUT, static_cast<int32_t>(ref_i));
ref_i_all_refs.emplace_back(cell_root_i);
ref_i_all_refs.emplace_back(cell_root_o);
for (const auto &data : ref_i_data_nodes) {
RefCell cell_in(data->GetName(), data, NODE_IN, 0);
RefCell cell_out(data->GetName(), data, NODE_OUT, 0);
ref_i_all_refs.emplace_back(cell_in);
ref_i_all_refs.emplace_back(cell_out);
}
for (const auto &ele : ref_i_net_nodes) {
RefCell cell_netoutput_in((ele.first)->GetName(), ele.first, NODE_IN, static_cast<int32_t>(ele.second));
ref_i_all_refs.emplace_back(cell_netoutput_in);
netoutput = ele.first;
}
node_refs.emplace_back(ref_i_all_refs);
ref_i++;
}
* and 1'th tensor should be the same addr.
* Data0 Data1
* \/
* /\
* netoutput
*/
if (netoutput == nullptr) {
return GRAPH_SUCCESS;
}
for (const auto &in_anchor : netoutput->GetAllInDataAnchorsPtr()) {
const auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
if (peer_out_data_anchor == nullptr) {
continue;
}
const auto peer_out_data_node = peer_out_data_anchor->GetOwnerNodeBarePtr();
if ((peer_out_data_node == nullptr) || (peer_out_data_node->GetOpDesc() == nullptr)) {
GELOGW("[RefRelations][Check] Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null",
netoutput->GetName().c_str());
continue;
}
if (peer_out_data_node->GetType() != DATA) {
continue;
}
const auto in_data_anchor_idx = in_anchor->GetIdx();
const auto net_in_desc = netoutput->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx));
int32_t ref_d = 0;
int32_t ref_n = 0;
(void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIdx, ref_d);
(void)AttrUtils::GetInt(net_in_desc, kRefIdx, ref_n);
const size_t ref_desc = static_cast<size_t>(ref_d);
const size_t ref_in = static_cast<size_t>(ref_n);
const size_t idx1 = node_refs[ref_in].size();
for (size_t i = 0U; i < idx1; ++i) {
node_refs[ref_desc].emplace_back(node_refs[ref_in][i]);
}
const size_t idx2 = node_refs[ref_desc].size();
for (size_t i = 0U; i < idx2; ++i) {
node_refs[ref_in].emplace_back(node_refs[ref_desc][i]);
}
}
return GRAPH_SUCCESS;
}
graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType(
const NodePtr &root_node,
const std::vector<std::vector<NodePtr>> &classed_data_nodes,
const std::vector<std::vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
std::vector<std::vector<RefCell>> &node_refs) const {
const auto &node_type = root_node->GetType();
auto status = GRAPH_SUCCESS;
if ((node_type != kWhile) && (node_type != kStatelessWhile)) {
status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs);
} else {
status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs);
}
return status;
}
void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph,
std::vector<NodePtr> &graph_data_nodes,
std::vector<NodePtr> &netoutput_nodes,
const std::vector<std::string> &sub_graph_names,
const std::string &node_type) const {
int32_t sub_graph_idx = 0;
for (const auto &name : sub_graph_names) {
const auto &sub_graph = root_graph.GetSubgraph(name);
if (sub_graph == nullptr) {
GELOGW("[RefRelations][Check] Cannot find sub graph %s, root graph: %s.", name.c_str(),
root_graph.GetName().c_str());
continue;
}
for (const auto &sub_graph_node : sub_graph->GetDirectNode()) {
const auto &sub_graph_node_type = sub_graph_node->GetType();
if (sub_graph_node_type == DATA) {
graph_data_nodes.emplace_back(sub_graph_node);
}
if (sub_graph_node_type == NETOUTPUT) {
if (((node_type == kWhile) || (node_type == kStatelessWhile)) && (sub_graph_idx == 0)) {
continue;
}
netoutput_nodes.emplace_back(sub_graph_node);
}
}
sub_graph_idx++;
}
}
graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) const {
const auto parent_graph_ptr = graph.GetParentGraph();
if (parent_graph_ptr == nullptr) {
root_graph = graph;
return GRAPH_SUCCESS;
}
const auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr);
if (root_graph_ptr == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "Get null root graph, graph:%s", parent_graph_ptr->GetName().c_str());
GE_LOGE("[Find][Graph] Get null root graph");
return GRAPH_PARAM_INVALID;
}
root_graph = *root_graph_ptr;
return GRAPH_SUCCESS;
}
graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(std::vector<NodePtr> &graph_data_nodes,
std::vector<std::vector<NodePtr>> &classed_data_nodes) const {
GELOGD("start to process subgraph data nodes!");
int32_t max_ref_idx = 0;
for (const auto &e : graph_data_nodes) {
int32_t i;
bool is_exist = true;
is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIdx, i);
if (!is_exist) {
REPORT_INNER_ERR_MSG("E18888", "Invalid SubGraph NetOutput node[%s].no attr %s", e->GetName().c_str(), kRefIdx);
GELOGE(GRAPH_FAILED, "[Get][Int] Invalid SubGraph NetOutput node[%s].no attr %s",
e->GetName().c_str(), kRefIdx);
return GRAPH_FAILED;
}
max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx;
}
classed_data_nodes.resize(static_cast<size_t>(max_ref_idx) + 1UL);
while (!graph_data_nodes.empty()) {
auto data = graph_data_nodes.back();
graph_data_nodes.pop_back();
int32_t ref_idx = 0;
(void)AttrUtils::GetInt(data->GetOpDesc(), kRefIdx, ref_idx);
if (ref_idx >= static_cast<int32_t>(classed_data_nodes.size())) {
return GRAPH_FAILED;
}
classed_data_nodes[static_cast<size_t>(ref_idx)].emplace_back(data);
}
return GRAPH_SUCCESS;
}
graphStatus RefRelations::Impl::ProcessSubgraphNetoutput(
const std::vector<NodePtr> &netoutput_nodes,
std::vector<std::vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes) const {
GELOGD("[RefRelations]Start to process subgraph netoutput!");
int32_t max_ref_idx = 0;
for (const auto &sub_netoutput_node : netoutput_nodes) {
const auto op_desc = sub_netoutput_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchorsPtr()) {
const auto in_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(in_data_anchor->GetIdx()));
if (in_desc == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "Invalid NetOutput node [%s] idx [%d], no tensor on it",
sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx());
GELOGE(GRAPH_FAILED, "[Get][Tensor] Invalid NetOutput node [%s] idx [%d], no tensor on it",
sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx());
return GRAPH_FAILED;
}
int32_t ref_o;
if (AttrUtils::GetInt(in_desc, kRefIdx, ref_o)) {
max_ref_idx = (ref_o > max_ref_idx) ? ref_o : max_ref_idx;
} else {
REPORT_INNER_ERR_MSG("E18888", "Invalid NetOutput node [%s] idx [%d], no attr[_parent_node_index] on it",
sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx());
GELOGE(GRAPH_FAILED, "[Get][Int] Invalid NetOutput node [%s] idx [%d], no attr[_parent_node_index] on it",
sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx());
return GRAPH_FAILED;
}
}
}
classed_netoutput_nodes.resize(static_cast<size_t>(max_ref_idx) + 1UL);
for (const auto &sub_netoutput_node : netoutput_nodes) {
const auto op_desc = sub_netoutput_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchorsPtr()) {
const auto in_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(in_data_anchor->GetIdx()));
int32_t ref_o;
if (AttrUtils::GetInt(in_desc, kRefIdx, ref_o)) {
if (ref_o >= static_cast<int32_t>(classed_netoutput_nodes.size())) {
return GRAPH_FAILED;
}
classed_netoutput_nodes[static_cast<size_t>(ref_o)].emplace_back(std::pair<NodePtr, size_t>(
{sub_netoutput_node, static_cast<size_t>(in_data_anchor->GetIdx())}
));
}
}
}
return GRAPH_SUCCESS;
}
void RefRelations::Impl::BuildRelationsForVariables(const ge::ComputeGraph &root_graph) {
if (root_graph.GetAllSubgraphs().empty()) {
return;
}
std::map<std::string, std::vector<NodePtr>> variables;
for (const auto &node : root_graph.GetAllNodes()) {
if (node->GetType() == VARIABLE) {
variables[node->GetName()].emplace_back(node);
}
}
for (const auto &it : variables) {
const auto &instances = it.second;
if (instances.size() <= 1UL) {
continue;
}
GELOGD("Variable [%s] has %zu instances", it.first.c_str(), instances.size());
std::vector<RefCell> variable_all_refs;
for (const auto &variable : instances) {
RefCell variable_ref(it.first, variable, NODE_OUT, 0);
variable_all_refs.emplace_back(std::move(variable_ref));
}
std::vector<std::vector<RefCell>> refs {variable_all_refs};
values_.emplace_back(std::move(refs));
}
}
graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) {
GELOGD("Start to build ref relations!");
ge::ComputeGraph &root_graph = graph;
auto status = GetRootGraph(graph, root_graph);
if (status != GRAPH_SUCCESS) {
return status;
}
for (const auto &node : graph.GetAllNodes()) {
const auto &node_type = node->GetType();
const auto &op_desc = node->GetOpDesc();
const auto &sub_graph_names = op_desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
continue;
}
std::vector<NodePtr> graph_data_nodes;
std::vector<NodePtr> netoutput_nodes;
GetDataAndNetoutputOfSubGraph(root_graph, graph_data_nodes, netoutput_nodes, sub_graph_names, node_type);
std::vector<std::vector<NodePtr>> classed_data_nodes;
std::vector<std::vector<std::pair<NodePtr, size_t>>> classed_netoutput_nodes;
status = ProcessSubgraphDataNodes(graph_data_nodes, classed_data_nodes);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "[Process][SubgraphDataNodes] failed! ret:%d", status);
return status;
}
status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "[Process][SubgraphNetoutput] failed! ret:%d", status);
return status;
}
std::vector<std::vector<RefCell>> node_refs;
status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs);
if (status != GRAPH_SUCCESS) {
GELOGE(status, "[Build][Relations] WithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str());
return status;
}
if (!node_refs.empty()) {
values_.push_back(node_refs);
}
}
BuildRelationsForVariables(root_graph);
status = BuildLookUpTables();
if (status != GRAPH_SUCCESS) {
GELOGE(status, "[Build][LookUpTables] failed! ret:%d", status);
return status;
}
return GRAPH_SUCCESS;
}
RefRelations::RefRelations() {
impl_ = MakeShared<Impl>();
if (impl_ == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "new impl failed.");
GELOGE(GRAPH_FAILED, "[New][Impl] MakeShared failed!");
return;
}
}
graphStatus RefRelations::LookUpRefRelations(const RefCell &key, std::unordered_set<RefCell, RefCellHash> &result) {
GE_CHECK_NOTNULL(impl_);
return impl_->LookUpRefRelations(key, result);
}
graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &graph) {
GE_CHECK_NOTNULL(impl_);
return impl_->BuildRefRelations(graph);
}
graphStatus RefRelations::Clear() {
GE_CHECK_NOTNULL(impl_);
return impl_->Clear();
}
}