* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/utils/execute_graph_adapter.h"
#include "common/checker.h"
#include "fast_graph/fast_graph_impl.h"
#include "graph/compute_graph.h"
#include "graph/normal_graph/compute_graph_impl.h"
#include "graph_metadef/graph/debug/ge_util.h"
#include "graph/utils/execute_graph_utils.h"
#include "graph/utils/graph_utils.h"
#include "mmpa/mmpa_api.h"
namespace af {
namespace {
constexpr int32_t kHybridSubgraphRecursion = 32;
}
ComputeGraphPtr ExecuteGraphAdapter::ConvertExecuteGraphToComputeGraph(ExecuteGraph *src_graph) {
GE_ASSERT_NOTNULL(src_graph);
const auto dst_graph = ComGraphMakeShared<ComputeGraph>(src_graph->GetName());
GE_ASSERT_NOTNULL(dst_graph);
const int32_t depth = 0;
GE_ASSERT_GRAPH_SUCCESS(ConvertExecuteGraphToComputeGraph(src_graph, dst_graph, depth),
"Convert execute graph:%s to compute graph failed.", src_graph->GetName().c_str());
return dst_graph;
}
graphStatus ExecuteGraphAdapter::ConvertExecuteGraphToComputeGraph(ExecuteGraph *src_graph,
const ComputeGraphPtr &dst_graph,
const int32_t depth) {
GE_ASSERT_TRUE(depth <= kHybridSubgraphRecursion, "param depth:%d larger than %d(allow max subgraphs).", depth,
kHybridSubgraphRecursion);
std::unordered_map<FastNode *, Node *> all_new_nodes;
GE_ASSERT_GRAPH_SUCCESS(CopyOpAndSubgraph(src_graph, dst_graph, all_new_nodes, depth),
"Copy op and subgraph from %s to %s failed.", src_graph->GetName().c_str(),
dst_graph->GetName().c_str());
for (const auto &n : src_graph->graph_shared_->nodes_) {
GE_ASSERT_NOTNULL(n);
GE_ASSERT_GRAPH_SUCCESS(RelinkGraphEdges(&FastGraphUtils::GetNode(n), all_new_nodes),
"Relink edge for node %s failed.", FastGraphUtils::GetNode(n).GetNamePtr());
}
std::vector<ComputeGraphPtr> new_subgraphs;
const auto &old_subgraphs = src_graph->GetAllSubgraphs();
for (const auto &sub_graph : old_subgraphs) {
const auto new_subgraph = dst_graph->GetSubgraph(sub_graph->GetName());
GE_CHECK_NOTNULL(new_subgraph);
new_subgraphs.emplace_back(new_subgraph);
}
dst_graph->SetAllSubgraphs(new_subgraphs);
GE_ASSERT_GRAPH_SUCCESS(CopyMembers(src_graph, dst_graph, all_new_nodes));
InheritOriginalAttr(src_graph, dst_graph);
return GRAPH_SUCCESS;
}
graphStatus ExecuteGraphAdapter::CopyOpAndSubgraph(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph,
std::unordered_map<FastNode *, Node *> &all_new_nodes,
const int32_t depth) {
const auto src_root_graph = ExecuteGraphUtils::FindRootGraph(src_graph);
GE_ASSERT_NOTNULL(src_root_graph);
const auto dst_root_graph = GraphUtils::FindRootGraph(dst_graph);
GE_ASSERT_NOTNULL(dst_root_graph);
for (const auto &src_node : src_graph->graph_shared_->nodes_) {
GE_ASSERT_NOTNULL(src_node);
const auto &op_desc = FastGraphUtils::GetNode(src_node).GetOpDescPtr();
GE_ASSERT_NOTNULL(op_desc);
const auto &dst_node = dst_graph->AddNode(op_desc, op_desc->GetId());
GE_ASSERT_NOTNULL(dst_node, "Add node:%s for dst graph failed.", op_desc->GetName().c_str());
all_new_nodes[&FastGraphUtils::GetNode(src_node)] = dst_node.get();
const auto &subgraph_names = op_desc->GetSubgraphInstanceNames();
const auto subgraph_num = subgraph_names.size();
for (size_t subgrah_idx = 0U; subgrah_idx < subgraph_num; ++subgrah_idx) {
const auto &subgraph_name = subgraph_names[subgraph_num - 1U - subgrah_idx];
const auto src_subgraph = src_root_graph->GetSubGraph(subgraph_name);
if ((src_subgraph == nullptr) && subgraph_name.empty()) {
continue;
}
GE_ASSERT_NOTNULL(src_subgraph);
const auto dst_subgraph = ComGraphMakeShared<ComputeGraph>(src_subgraph->GetName());
GE_ASSERT_NOTNULL(dst_subgraph);
dst_subgraph->SetParentGraph(dst_root_graph);
GE_ASSERT_GRAPH_SUCCESS(ConvertExecuteGraphToComputeGraph(src_subgraph, dst_subgraph, depth + 1),
"Copy subgraph from %s to %s failed.", src_subgraph->GetName().c_str(),
dst_subgraph->GetName().c_str());
(void) dst_root_graph->AddSubGraph(dst_subgraph);
dst_subgraph->SetParentNode(dst_node);
}
}
return GRAPH_SUCCESS;
}
graphStatus ExecuteGraphAdapter::RelinkGraphEdges(FastNode *old_node,
const std::unordered_map<FastNode *, Node *> &all_new_nodes) {
const auto &iter = all_new_nodes.find(old_node);
GE_ASSERT_TRUE(iter != all_new_nodes.end(), "all_new_nodes not contain %s", old_node->GetNamePtr());
const auto &new_node = iter->second;
GE_ASSERT_NOTNULL(new_node);
const auto &old_out_edges = old_node->GetAllOutDataEdgesRef();
for (size_t out_i = 0; out_i < old_out_edges.size(); ++out_i) {
for (const auto old_edge : old_out_edges[out_i]) {
if (old_edge == nullptr) {
continue;
}
const auto old_dst_node = old_edge->dst;
GE_ASSERT_NOTNULL(old_dst_node);
const auto dst_index = old_edge->dst_input;
const auto &dst_iter = all_new_nodes.find(old_dst_node);
if (dst_iter != all_new_nodes.end()) {
const auto &new_dst_node = dst_iter->second;
GE_ASSERT_NOTNULL(new_dst_node);
GE_ASSERT_GRAPH_SUCCESS(
GraphUtils::AddEdge(new_node->GetOutDataAnchor(out_i), new_dst_node->GetInDataAnchor(dst_index)),
"Add edge %s:%d -> %s:%d failed.", new_node->GetName().c_str(), out_i, new_dst_node->GetName().c_str(),
dst_index);
}
}
}
for (const auto old_control_out_edge : old_node->GetAllOutControlEdgesRef()) {
if (old_control_out_edge == nullptr) {
continue;
}
const auto old_dst_node = old_control_out_edge->dst;
GE_ASSERT_NOTNULL(old_dst_node);
auto dst_iter = all_new_nodes.find(old_dst_node);
if (dst_iter != all_new_nodes.end()) {
const auto &new_dst_node = dst_iter->second;
GE_ASSERT_NOTNULL(new_dst_node);
GE_ASSERT_GRAPH_SUCCESS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), new_dst_node->GetInControlAnchor()),
"Add control edge %s -> %s failed.", new_node->GetName().c_str(),
new_dst_node->GetName().c_str());
}
}
return GRAPH_SUCCESS;
}
graphStatus ExecuteGraphAdapter::CopyMembers(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph,
const std::unordered_map<FastNode *, Node *> &all_new_nodes) {
GE_ASSERT_NOTNULL(src_graph);
GE_ASSERT_NOTNULL(dst_graph);
GE_ASSERT_NOTNULL(src_graph->graph_shared_);
GE_ASSERT_NOTNULL(dst_graph->impl_);
const auto &out_nodes_info = src_graph->graph_shared_->GetAllOutNodeInfo();
std::vector<std::pair<NodePtr, int32_t>> new_out_nodes_info;
for (const auto &info : out_nodes_info) {
GE_ASSERT_NOTNULL(info.first);
const auto it = all_new_nodes.find(info.first);
if (it != all_new_nodes.end()) {
new_out_nodes_info.emplace_back(std::shared_ptr<af::Node>(it->second), info.second);
}
}
GE_ASSERT_SUCCESS(dst_graph->SetGraphOutNodesInfo(new_out_nodes_info));
const auto &input_nodes = src_graph->graph_shared_->GetAllInputNodeInfo();
for (const auto &node : input_nodes) {
GE_ASSERT_NOTNULL(node);
const auto &it = all_new_nodes.find(node);
if (it != all_new_nodes.end()) {
(void) dst_graph->AddInputNode(it->second->shared_from_this());
}
}
dst_graph->impl_->attrs_ = src_graph->attrs_;
return GRAPH_SUCCESS;
}
void ExecuteGraphAdapter::InheritOriginalAttr(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph) {
const auto &original_attrs = AttrUtils::GetAllAttrs(src_graph);
for (const auto &attr_iter : original_attrs) {
if (dst_graph->TrySetAttr(attr_iter.first, attr_iter.second) != GRAPH_SUCCESS) {
GELOGW("Set inherit original attr[%s] failed, Please Check.", attr_iter.first.c_str());
}
}
dst_graph->CopyFrom(*src_graph);
}
}