* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/fast_graph/execute_graph.h"
#include "framework/common/debug/ge_log.h"
#include "common/ge_common/ge_types.h"
#include "fast_graph/fast_graph_impl.h"
#include "graph/utils/fast_node_utils.h"
namespace ge {
namespace {
enum class FastTopoSortingMode { kBFS = 0, kDFS, kRDFS };
const std::string kMemoryPriority = "MemoryPriority";
constexpr int32_t kTopoSortingBfs = 0;
constexpr int32_t kTopoSortingDfs = 1;
constexpr int32_t kTopoSortingReverseDfs = 2;
FastTopoSortingMode GetTopoSortingStrategy() {
std::string topo_sorting_mode_str;
if ((ge::GetContext().GetOption(ge::OPTION_TOPOSORTING_MODE, topo_sorting_mode_str) == GRAPH_SUCCESS) &&
(!topo_sorting_mode_str.empty())) {
const int32_t base = 10;
const auto topo_sorting_mode = static_cast<int32_t>(std::strtol(topo_sorting_mode_str.c_str(), nullptr, base));
if (topo_sorting_mode == kTopoSortingBfs) {
return FastTopoSortingMode::kBFS;
} else if (topo_sorting_mode == kTopoSortingDfs) {
return FastTopoSortingMode::kDFS;
} else if (topo_sorting_mode == kTopoSortingReverseDfs) {
return FastTopoSortingMode::kRDFS;
} else {
GELOGW("OPTION_TOPOSORTING_MODE = %s is invalid", topo_sorting_mode_str.c_str());
}
}
if (ge::GetContext().GetTrainGraphFlag()) {
GELOGI("train flag is 1, use BFS.");
return FastTopoSortingMode::kBFS;
}
GELOGI("train flag is 0, use DFS.");
return FastTopoSortingMode::kDFS;
}
bool IsMemoryPriority() {
std::string memory_optimization_policy;
(void)ge::GetContext().GetOption(MEMORY_OPTIMIZATION_POLICY, memory_optimization_policy);
return (memory_optimization_policy == ge::kMemoryPriority);
}
void GetOutNodesFromEdge(std::map<FastNode *, uint32_t> &map_in_edge_num, FastNode *node,
std::vector<FastNode *> &out_nodes) {
const auto iter = map_in_edge_num.find(node);
if (iter != map_in_edge_num.end()) {
--iter->second;
if (iter->second == 0U) {
out_nodes.push_back(node);
}
}
}
bool InputIsLongLifeTimeNode(const FastNode *node, const ExecuteGraph *execute_graph) {
bool match = false;
auto num = node->GetDataInNum();
for (size_t i = 0LL; i < num; ++i) {
const auto &edge = node->GetInDataEdgeByIndex(i);
if (edge == nullptr) {
continue;
}
auto &peer_node = edge->src;
if ((peer_node == nullptr) || (peer_node->GetExtendInfo() == nullptr)) {
continue;
}
const auto type = peer_node->GetType();
static std::unordered_set<std::string> kDataSet = {DATA, REFDATA, AIPPDATA, ANN_DATA};
static const std::unordered_set<std::string> kConstPlaceHolderOpSet = {CONSTPLACEHOLDER};
auto graph = peer_node->GetExtendInfo()->GetOwnerGraphBarePtr();
const bool is_io_data =
(execute_graph == graph) && ((kDataSet.count(type) > 0U) || (kConstPlaceHolderOpSet.count(type) > 0U));
if ((!FastNodeUtils::GetConstOpType(peer_node)) && (type != VARIABLE) && (type != VARIABLEV2) && (!is_io_data)) {
return false;
} else {
match = true;
}
GELOGD("Node:%s peer:%s type :%s", node->GetName().c_str(), peer_node->GetName().c_str(),
peer_node->GetType().c_str());
}
return match;
}
graphStatus GetOutNodeIndex(std::vector<FastNode *> &nodes, size_t &index, size_t &out_count,
const ExecuteGraph *execute_graph) {
if (nodes.empty()) {
return GRAPH_FAILED;
}
if ((nodes.size() == 1UL) && (!InputIsLongLifeTimeNode(nodes.front(), execute_graph))) {
return GRAPH_FAILED;
}
const auto &node = nodes.back();
auto op_desc = node->GetOpDescBarePtr();
GE_CHECK_NOTNULL(op_desc);
if ((nodes.size() != 1UL) && (node->GetDataInNum() != 1UL)) {
return GRAPH_FAILED;
}
int64_t min_index = 0LL;
FastNode *delay_node = nullptr;
for (const auto &out_node : node->GetAllOutNodes()) {
out_count++;
GE_CHECK_NOTNULL(out_node);
auto out_node_desc = out_node->GetOpDescBarePtr();
GE_CHECK_NOTNULL(out_node_desc);
GELOGD("Node:%s id:%ld peer node:%s id:%ld", node->GetName().c_str(), op_desc->GetId(),
out_node_desc->GetName().c_str(), out_node_desc->GetId());
if ((min_index == 0LL) || (out_node_desc->GetId() < min_index)) {
min_index = out_node_desc->GetId();
delay_node = out_node;
}
}
if (delay_node != nullptr) {
index = static_cast<size_t>(min_index);
if (index > (static_cast<size_t>(op_desc->GetId()) + 1UL)) {
GELOGD("Node:%s id:%ld delay to:%s id:%zu", node->GetName().c_str(), op_desc->GetId(),
delay_node->GetName().c_str(), index);
}
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}
void DelayTopoSort(std::vector<FastNode *> &nodes, const ExecuteGraph *execute_graph) {
std::vector<std::pair<bool, std::vector<FastNode *>>> delay_nodes;
delay_nodes.resize(nodes.size());
for (size_t i = 0UL; i < delay_nodes.size(); ++i) {
nodes[i]->GetOpDescBarePtr()->SetId(static_cast<int64_t>(i));
delay_nodes[i].first = true;
delay_nodes[i].second.emplace_back(nodes[i]);
}
size_t delay_node_count = 0UL;
for (size_t i = 0UL; i < delay_nodes.size(); ++i) {
size_t delay_to_index = 0UL;
size_t out_count = 0UL;
if (delay_nodes[i].first &&
(GetOutNodeIndex(delay_nodes[i].second, delay_to_index, out_count, execute_graph) == GRAPH_SUCCESS) &&
(delay_to_index < delay_nodes.size()) && (delay_to_index > (i + 1UL))) {
delay_nodes[delay_to_index].second.insert(delay_nodes[delay_to_index].second.begin(),
delay_nodes[i].second.begin(), delay_nodes[i].second.end());
if (out_count > 1UL) {
delay_nodes[delay_to_index].first = false;
}
delay_nodes[i].second.clear();
delay_node_count++;
}
}
if (delay_node_count > 0UL) {
nodes.clear();
for (size_t i = 0UL; i < delay_nodes.size(); ++i) {
if (!delay_nodes[i].second.empty()) {
nodes.insert(nodes.end(), delay_nodes[i].second.begin(), delay_nodes[i].second.end());
}
}
GELOGI("Delay %zu nodes.", delay_node_count);
}
}
void InitNodeStatus(const ExecuteGraph *compute_graph, std::vector<NodeStatus> &reverse_dfs_nodes_info) {
reverse_dfs_nodes_info.clear();
reverse_dfs_nodes_info.resize(compute_graph->GetDirectNodesSize());
int64_t index = 0;
for (const auto &node : compute_graph->GetDirectNode()) {
reverse_dfs_nodes_info[index].size = 0U;
reverse_dfs_nodes_info[index].status = FastWalkStatus::kNotWalked;
node->GetOpDescBarePtr()->SetId(index);
index++;
}
}
}
ExecuteGraph::ExecuteGraph(const std::string &name) {
graph_shared_ = std::make_shared<FastGraphImpl<FastNode, ExecuteGraph>>(name);
graph_shared_->SetOwnerGraph(this);
}
ExecuteGraph &ExecuteGraph::operator=(ge::ExecuteGraph &exec_graph) {
if (&exec_graph == this) {
return *this;
}
graph_shared_ = exec_graph.graph_shared_;
names_to_subgraph_ = exec_graph.names_to_subgraph_;
inputs_order_ = exec_graph.inputs_order_;
AttrHolder::SwapBase(exec_graph);
return *this;
}
ExecuteGraph &ExecuteGraph::CompleteCopy(ge::ExecuteGraph &exec_graph) {
if (&exec_graph == this) {
return *this;
}
graph_shared_->DeepCopy(*(exec_graph.graph_shared_));
const std::map<string, GeAttrValue> &original_attrs = AttrUtils::GetAllAttrs(exec_graph);
for (auto const &attr_iter : original_attrs) {
if (this->TrySetAttr(attr_iter.first, attr_iter.second) != GRAPH_SUCCESS) {
GELOGW("Set inherit original attr[%s] failed, Please Check.", attr_iter.first.c_str());
}
}
inputs_order_.clear();
for (auto &item : exec_graph.inputs_order_) {
inputs_order_.push_back(item);
}
return *this;
}
FastNode *ExecuteGraph::AddNode(const OpDescPtr &op) {
return graph_shared_->AddNode(op);
}
FastNode *ExecuteGraph::AddNode(const OpDescPtr &op, int64_t id) {
return graph_shared_->AddNode(op, id);
}
void ExecuteGraph::RemoveNodeFromNodesFree(const FastNode *const fast_node) const {
auto quick_node = FastGraphUtils::GetListElementAddr(fast_node);
auto owner = quick_node->owner;
auto mode = quick_node->mode;
if ((owner != nullptr) && (mode == ListMode::kFreeMode)) {
owner->erase(quick_node);
}
}
FastNode *ExecuteGraph::AddNode(FastNode *fast_node) {
if (fast_node == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "The node is nullptr.");
GE_LOGE("[Check][Param] The node is nullptr.");
return nullptr;
}
RemoveNodeFromNodesFree(fast_node);
return graph_shared_->AddNode(fast_node);
}
FastNode *ExecuteGraph::AddNodeFront(const OpDescPtr &op) {
return graph_shared_->AddNodeFront(op);
}
FastNode *ExecuteGraph::AddNodeFront(FastNode *const fast_node) {
if (fast_node == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "The node is nullptr.");
GE_LOGE("[Check][Param] The node is nullptr.");
return nullptr;
}
RemoveNodeFromNodesFree(fast_node);
return graph_shared_->AddNodeFront(fast_node);
}
graphStatus ExecuteGraph::RemoveJustNode(const FastNode *const fast_node) {
if (fast_node == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "The node is nullptr.");
GE_LOGE("[Check][Param] The node is nullptr.");
return GRAPH_FAILED;
}
return graph_shared_->RemoveJustNode(FastGraphUtils::GetListElementAddr(fast_node));
}
FastEdge *ExecuteGraph::AddEdge(FastNode *const src, int32_t src_index, FastNode *const dst, int32_t dst_index) {
if ((src == nullptr) || (dst == nullptr)) {
REPORT_INNER_ERR_MSG("E18888", "The node is nullptr.");
GE_LOGE("[Check][Param] The node is nullptr.");
return nullptr;
}
if (!CheckNodeIsInGraph(src) || !CheckNodeIsInGraph(dst)) {
GELOGW("The src %s or dst %s not belong to graph.", src->GetNamePtr(), dst->GetNamePtr());
}
return graph_shared_->AddEdge(src, src_index, dst, dst_index);
}
graphStatus ExecuteGraph::RemoveEdge(const FastEdge *const edge) {
if (edge == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "The edge is nullptr.");
GE_LOGE("[Check][Param] The edge is nullptr.");
return GRAPH_FAILED;
}
return graph_shared_->RemoveEdge(FastGraphUtils::GetListElementAddr(edge));
}
const FastNode *ExecuteGraph::GetParentNodeBarePtr() const {
return graph_shared_->GetParentNode();
}
FastNode *ExecuteGraph::GetParentNodeBarePtr() {
return graph_shared_->GetParentNode();
}
void ExecuteGraph::SetParentNode(FastNode *const node) {
graph_shared_->SetParentNode(node);
}
ExecuteGraph *ExecuteGraph::AddSubGraph(const std::shared_ptr<ExecuteGraph> &sub_graph) {
if (sub_graph == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "Try to add a null subgraph");
GE_LOGE("[Check][Param] Try to add a null subgraph");
return nullptr;
}
auto ret = graph_shared_->AddSubGraph(sub_graph.get());
if (ret == nullptr) {
return nullptr;
}
names_to_subgraph_[sub_graph->GetName()] = {sub_graph, ret};
return ret->data;
}
graphStatus ExecuteGraph::RemoveSubGraph(const ExecuteGraph *const sub_graph) {
if (sub_graph == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "Try to add a null subgraph");
GE_LOGE("[Check][Param] Try to add a null subgraph");
return GRAPH_PARAM_INVALID;
}
return RemoveSubGraph(sub_graph->GetName());
}
ExecuteGraph *ExecuteGraph::AddSubGraph(const std::shared_ptr<ExecuteGraph> &sub_graph_ptr, const std::string &name) {
if (sub_graph_ptr == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "Try to add a null subgraph, name %s", name.c_str());
GE_LOGE("[Check][Param] Try to add a null subgraph, name %s", name.c_str());
return nullptr;
}
auto sub_graph = sub_graph_ptr.get();
const auto parent_graph = sub_graph->GetParentGraphBarePtr();
if (parent_graph == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "Try to add subgraph without parent graph, name %s", name.c_str());
GE_LOGE("[Get][Graph] Try to add subgraph without parent graph, name %s", name.c_str());
return nullptr;
}
const auto parent_node = sub_graph->GetParentNodeBarePtr();
if ((parent_node == nullptr) || (parent_node->GetExtendInfo() == nullptr)) {
REPORT_INNER_ERR_MSG("E18888", "Try to add a subgraph without parent node, name %s", name.c_str());
GE_LOGE("[Get][Node] Try to add a subgraph without parent node, name %s", name.c_str());
return nullptr;
}
if (parent_node->GetExtendInfo()->GetOwnerGraphBarePtr() != parent_graph) {
REPORT_INNER_ERR_MSG("E18888",
"Try to add a subgraph which parent node's graph is not equal to "
"the subgraph's parent graph, subgraph name %s, parent node name %s",
sub_graph->GetName().c_str(), parent_graph->GetName().c_str());
GE_LOGE(
"[Check][Param] Try to add a subgraph which parent node's graph is not equal to "
"the subgraph's parent graph, subgraph name %s, parent node name %s",
sub_graph->GetName().c_str(), parent_graph->GetName().c_str());
return nullptr;
}
if (name != sub_graph->GetName()) {
GELOGW("[Add][Subgraph] The subgraph name %s is different with input %s", sub_graph->GetName().c_str(),
name.c_str());
}
if (names_to_subgraph_.find(sub_graph->GetName()) != names_to_subgraph_.end()) {
REPORT_INNER_ERR_MSG("E18888", "The subgraph %s existed", GetName().c_str());
GE_LOGE("[Check][Param] The subgraph %s existed", GetName().c_str());
return nullptr;
}
auto ret = graph_shared_->AddSubGraph(sub_graph);
if (ret == nullptr) {
return nullptr;
}
names_to_subgraph_[sub_graph->GetName()] = {sub_graph_ptr, ret};
return ret->data;
}
graphStatus ExecuteGraph::RemoveSubGraph(const std::string &name) {
auto iter = names_to_subgraph_.find(name);
if (iter != names_to_subgraph_.end()) {
auto quick_graph = iter->second.quick_graph;
graph_shared_->RemoveSubGraph(quick_graph);
names_to_subgraph_.erase(iter);
}
return GRAPH_SUCCESS;
}
ExecuteGraph *ExecuteGraph::GetSubGraph(const std::string &name) const {
const ExecuteGraph *exec_graph = graph_shared_->GetParentGraph();
if (exec_graph == nullptr) {
const auto iter = names_to_subgraph_.find(name);
if (iter == names_to_subgraph_.end()) {
return nullptr;
}
auto quick_graph = iter->second.quick_graph;
return quick_graph->data;
} else {
return exec_graph->GetSubGraph(name);
}
}
void ExecuteGraph::ClearAllSubGraph() {
names_to_subgraph_.clear();
return graph_shared_->ClearAllSubGraph();
}
std::vector<FastNode *> ExecuteGraph::GetDirectNode() const {
return graph_shared_->GetDirectNode();
}
size_t ExecuteGraph::GetDirectNodesSize() const {
return graph_shared_->GetDirectNodesSize();
}
std::vector<FastEdge *> ExecuteGraph::GetAllEdges() const {
return graph_shared_->GetAllEdges();
}
std::vector<ExecuteGraph *> ExecuteGraph::GetAllSubgraphs() const {
return graph_shared_->GetAllSubgraphs();
}
FastNode *ExecuteGraph::AddInputNode(FastNode *fast_node) {
if (fast_node == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "The node is nullptr.");
GE_LOGE("[Check][Param] The node is nullptr.");
return nullptr;
}
RemoveNodeFromNodesFree(fast_node);
return graph_shared_->AddInputNode(fast_node);
}
graphStatus ExecuteGraph::RemoveInputNode(FastNode *const fast_node) {
if (fast_node == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "The node is nullptr.");
GE_LOGE("[Check][Param] The node is nullptr.");
return GRAPH_FAILED;
}
return graph_shared_->RemoveInputNode(fast_node);
}
FastNode *ExecuteGraph::AddOutputNodeByIndex(FastNode *const fast_node, int32_t index) {
if (fast_node == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "The node is nullptr.");
GE_LOGE("[Check][Param] The node is nullptr.");
return nullptr;
}
RemoveNodeFromNodesFree(fast_node);
return graph_shared_->AddOutputNodeByIndex(fast_node, index);
}
graphStatus ExecuteGraph::RemoveOutputNode(const FastNode *const fast_node) {
if (fast_node == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "The node is nullptr.");
GE_LOGE("[Check][Param] The node is nullptr.");
return GRAPH_FAILED;
}
return graph_shared_->RemoveOutputNode(fast_node);
}
const FastNode *ExecuteGraph::FindNode(size_t token) const {
auto quick_node = graph_shared_->FindNode(token);
return ((quick_node == nullptr) ? nullptr : &(quick_node->data));
}
graphStatus ExecuteGraph::SortNodes(std::vector<FastNode *> &stack,
std::map<FastNode *, uint32_t> &map_in_edge_num) const {
std::vector<FastNode *> data_nodes_vec;
std::vector<FastNode *> no_data_nodes_vec;
for (const auto &node : graph_shared_->GetDirectNodeToModify()) {
auto fast_node = &FastGraphUtils::GetNode(node);
GE_IF_BOOL_EXEC(fast_node->GetOpDescBarePtr() == nullptr, continue);
map_in_edge_num[fast_node] = static_cast<uint32_t>(fast_node->GetInEdgeSize());
if (map_in_edge_num[fast_node] == 0U) {
if ((strcmp(fast_node->GetOpDescBarePtr()->GetTypePtr(), DATA) != 0)) {
no_data_nodes_vec.emplace_back(fast_node);
continue;
}
data_nodes_vec.emplace_back(fast_node);
}
}
(void)stack.insert(stack.end(), no_data_nodes_vec.rbegin(), no_data_nodes_vec.rend());
(void)stack.insert(stack.end(), data_nodes_vec.rbegin(), data_nodes_vec.rend());
for (size_t i = 0UL; i < stack.size(); ++i) {
const auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName());
GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue);
const auto inx_i = it_i - inputs_order_.begin();
for (size_t j = i + 1UL; j < stack.size(); ++j) {
const auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName());
GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue);
const auto inx_j = it_j - inputs_order_.begin();
GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j]));
}
}
return GRAPH_SUCCESS;
}
void ExecuteGraph::GetOutNodesFromEdgesToMap(std::map<FastNode *, uint32_t> &map_in_edge_num, FastNode *node,
std::map<std::string, FastNode *> &breadth_node_map) const {
auto iter = map_in_edge_num.find(node);
if (iter != map_in_edge_num.end()) {
--iter->second;
if (iter->second == 0U) {
(void)breadth_node_map.emplace(node->GetName(), node);
}
}
}
graphStatus ExecuteGraph::CollectBreadthOutNode(const FastNode *const node,
std::map<FastNode *, uint32_t> &map_in_edge_num,
std::map<std::string, FastNode *> &breadth_node_map) const {
auto &edges = node->GetAllOutDataEdgesRef();
for (size_t i = 0UL; i < edges.size(); ++i) {
std::for_each(edges[i].begin(), edges[i].end(), [&map_in_edge_num, &breadth_node_map, this](FastEdge *edge) {
if ((edge != nullptr) && (edge->dst_input != kControlEdgeIndex)) {
GetOutNodesFromEdgesToMap(map_in_edge_num, edge->dst, breadth_node_map);
}
});
}
auto &control_edges = node->GetAllOutControlEdgesRef();
if (control_edges.empty()) {
return GRAPH_SUCCESS;
}
std::for_each(control_edges.begin(), control_edges.end(),
[&map_in_edge_num, &breadth_node_map, this](FastEdge *edge) {
if (edge != nullptr) {
GetOutNodesFromEdgesToMap(map_in_edge_num, edge->dst, breadth_node_map);
}
});
return GRAPH_SUCCESS;
}
graphStatus ExecuteGraph::BFSTopologicalSorting(std::vector<FastNode *> &node_vec, const bool reverse,
const ExecuteGraph *const compute_graph) const {
GELOGD("Runing_Bfs_Sort: %s", GetName().c_str());
(void)reverse;
const bool is_mem_priority = IsMemoryPriority();
std::vector<NodeStatus> reverse_dfs_nodes_info;
if (is_mem_priority) {
InitNodeStatus(compute_graph, reverse_dfs_nodes_info);
}
TopoSortStack<FastNode> topo_sort_stack(&reverse_dfs_nodes_info, is_mem_priority);
std::vector<FastNode *> stack_input;
std::map<std::string, FastNode *> breadth_node_map;
std::map<FastNode *, uint32_t> map_in_edge_num;
GE_CHK_BOOL_EXEC(SortNodes(stack_input, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed");
while ((!stack_input.empty()) || (!topo_sort_stack.Empty())) {
FastNode *node = nullptr;
if (!topo_sort_stack.Empty()) {
node = topo_sort_stack.Pop();
} else {
node = stack_input.back();
stack_input.pop_back();
}
node_vec.push_back(node);
GE_CHECK_NOTNULL(node->GetOpDescBarePtr());
GELOGD("node_vec.push_back %s", node->GetOpDescBarePtr()->GetName().c_str());
(void)CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map);
for (const auto &name_node : breadth_node_map) {
(void)topo_sort_stack.Push(name_node.second);
}
breadth_node_map.clear();
}
return GRAPH_SUCCESS;
}
graphStatus ExecuteGraph::DFSTopologicalSorting(std::vector<FastNode *> &node_vec, const bool reverse,
const ExecuteGraph *const compute_graph) const {
GELOGD("Runing_Dfs_Sort: %s", GetName().c_str());
std::vector<FastNode *> stack;
std::map<FastNode *, uint32_t> map_in_edge_num;
GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed");
const bool is_mem_priority = IsMemoryPriority();
std::vector<NodeStatus> reverse_dfs_nodes_info;
if (is_mem_priority) {
InitNodeStatus(compute_graph, reverse_dfs_nodes_info);
}
TopoSortStack<FastNode> topo_sort_stack(&reverse_dfs_nodes_info, is_mem_priority, true, reverse);
for (const auto &node : stack) {
topo_sort_stack.Push(node);
}
std::vector<FastNode *> out_nodes;
const auto stack_push = [&reverse, &topo_sort_stack](std::vector<FastNode *> &tmp_out_nodes) {
if (reverse) {
std::reverse(tmp_out_nodes.begin(), tmp_out_nodes.end());
}
for (const auto &node : tmp_out_nodes) {
topo_sort_stack.Push(node);
}
tmp_out_nodes.clear();
};
while (!topo_sort_stack.Empty()) {
FastNode *node = topo_sort_stack.Pop();
node_vec.push_back(node);
GE_CHECK_NOTNULL(node->GetOpDescBarePtr());
auto &edges = node->GetAllOutDataEdgesRef();
for (size_t i = 0UL; i < edges.size(); ++i) {
std::for_each(edges[i].begin(), edges[i].end(), [&map_in_edge_num, &out_nodes](FastEdge *edge) {
if (edge != nullptr) {
GetOutNodesFromEdge(map_in_edge_num, edge->dst, out_nodes);
}
});
stack_push(out_nodes);
}
auto control_edges = node->GetAllOutControlEdgesRef();
std::for_each(control_edges.begin(), control_edges.end(), [&map_in_edge_num, &out_nodes](FastEdge *edge) {
if (edge != nullptr) {
GetOutNodesFromEdge(map_in_edge_num, edge->dst, out_nodes);
}
});
stack_push(out_nodes);
}
return GRAPH_SUCCESS;
}
void ExecuteGraph::GetInNodes(const FastNode *const current, std::vector<FastNode *> &input_nodes) const {
auto &in_data_edges = current->GetAllInDataEdgesRef();
auto &ref = input_nodes;
for (size_t i = 0UL; i < in_data_edges.size(); i++) {
auto edge = in_data_edges[i];
if (edge != nullptr) {
ref.push_back(edge->src);
}
}
auto &in_control_edges = current->GetAllInControlEdgesRef();
std::for_each(in_control_edges.begin(), in_control_edges.end(), [&ref](FastEdge *edge) {
if (edge != nullptr) {
ref.push_back(edge->src);
}
});
}
graphStatus ExecuteGraph::RDFSTopologicalSorting(std::vector<FastNode *> &node_vec, const bool reverse,
const ExecuteGraph *const compute_graph) const {
(void)reverse;
GELOGD("Runing_Reverse_Dfs_Sort: %s", GetName().c_str());
std::vector<NodeStatus> reverse_dfs_nodes_info;
InitNodeStatus(compute_graph, reverse_dfs_nodes_info);
for (const auto quick_node : graph_shared_->GetDirectNodeToModify()) {
auto node = &FastGraphUtils::GetNode(quick_node);
if (!node->OutNodesIsEmpty()) {
continue;
}
std::vector<FastNode *> stack = {node};
while (!stack.empty()) {
const auto current = stack.back();
NodeStatus &reverse_dfs_node_info = reverse_dfs_nodes_info[current->GetOpDescBarePtr()->GetId()];
if (reverse_dfs_node_info.status == FastWalkStatus::kNotWalked) {
reverse_dfs_node_info.status = FastWalkStatus::kWalking;
std::vector<FastNode *> in_all_nodes;
GetInNodes(current, in_all_nodes);
NodeCmp<FastNode> cmp(&reverse_dfs_nodes_info);
std::set<FastNode *, NodeCmp<FastNode>> input_nodes{in_all_nodes.begin(), in_all_nodes.end(), cmp};
stack.insert(stack.end(), input_nodes.cbegin(), input_nodes.cend());
continue;
}
stack.pop_back();
if (reverse_dfs_node_info.status == FastWalkStatus::kWalking) {
reverse_dfs_node_info.status = FastWalkStatus::kWalked;
node_vec.emplace_back(current);
}
}
}
return GRAPH_SUCCESS;
}
graphStatus ExecuteGraph::TopologicalSortingGraph(const ExecuteGraph *const execute_graph, const bool dfs_reverse) {
using TopoSortingStrategy = std::function<graphStatus(ExecuteGraph *, std::vector<FastNode *> &, const bool,
const ExecuteGraph *const compute_graph)>;
static const std::map<FastTopoSortingMode, TopoSortingStrategy> topo_sorting_strategy{
{FastTopoSortingMode::kBFS, &ExecuteGraph::BFSTopologicalSorting},
{FastTopoSortingMode::kDFS, &ExecuteGraph::DFSTopologicalSorting},
{FastTopoSortingMode::kRDFS, &ExecuteGraph::RDFSTopologicalSorting}};
std::vector<FastNode *> node_vec;
const auto use_topo_strategy = GetTopoSortingStrategy();
const auto it = topo_sorting_strategy.find(use_topo_strategy);
if (it == topo_sorting_strategy.end()) {
GELOGE(GRAPH_FAILED, "Cannot find topo sorting strategy of %d.", static_cast<int32_t>(use_topo_strategy));
return GRAPH_FAILED;
}
if (it->second(this, node_vec, dfs_reverse, execute_graph) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
if (node_vec.size() != GetDirectNodesSize()) {
std::set<FastNode *> itered_nodes_set;
for (auto &node : node_vec) {
(void)itered_nodes_set.insert(node);
}
REPORT_INNER_ERR_MSG("E18888", "Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph:%s",
GetDirectNodesSize(), node_vec.size(), GetName().c_str());
GELOGW("[Check][Param] Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.",
GetDirectNodesSize(), node_vec.size());
for (auto node : graph_shared_->GetDirectNodeToModify()) {
if (itered_nodes_set.count(&FastGraphUtils::GetNode(node)) == 0UL) {
GELOGW("[Check][Param] The node %s does not itered when topological sorting",
FastGraphUtils::GetNode(node).GetName().c_str());
}
}
return GRAPH_FAILED;
}
if (IsMemoryPriority() || (use_topo_strategy == FastTopoSortingMode::kRDFS)) {
DelayTopoSort(node_vec, execute_graph);
}
auto ret = graph_shared_->SetNodesAfterSorting(node_vec);
if (ret != GRAPH_SUCCESS) {
return ret;
}
graph_shared_->SetValidFlag(true);
return GRAPH_SUCCESS;
}
void ExecuteGraph::GetAllNodesFromOpdesc(std::vector<std::shared_ptr<ExecuteGraph>> &subgraphs, const OpDesc &op_desc,
std::deque<FastNode *> &candidates) const {
const auto &subgraph_names = op_desc.GetSubgraphInstanceNames();
auto name_iter = subgraph_names.rbegin();
while (name_iter != subgraph_names.rend()) {
auto subgraph = GetSubGraph(*name_iter);
if (subgraph != nullptr) {
subgraphs.emplace_back(subgraph->shared_from_this());
auto subgraph_nodes = subgraph->GetDirectNode();
(void)candidates.insert(candidates.begin(), subgraph_nodes.begin(), subgraph_nodes.end());
}
++name_iter;
}
}
std::vector<FastNode *> ExecuteGraph::AllGraphNodes(std::vector<std::shared_ptr<ExecuteGraph>> &subgraphs,
const FastNodeFilter &fast_node_filter) const {
std::vector<FastNode *> all_nodes;
std::deque<FastNode *> candidates;
auto &ref = graph_shared_->GetDirectNodeToModify();
for (auto iter = ref.begin(); iter != ref.end(); ++iter) {
QuickNode *node = *iter;
candidates.push_back(&(node->data));
}
while (!candidates.empty()) {
FastNode *node = candidates.front();
candidates.pop_front();
if ((fast_node_filter == nullptr) || fast_node_filter(node)) {
all_nodes.emplace_back(node);
}
const auto op_desc = node->GetOpDescBarePtr();
if (op_desc != nullptr) {
GetAllNodesFromOpdesc(subgraphs, *op_desc, candidates);
}
}
return all_nodes;
}
graphStatus ExecuteGraph::TopologicalSorting() {
auto ret = TopologicalSortingGraph(this, false);
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E18888", "Graph [%s] topological sort failed, saved to file black_box", GetName().c_str());
GELOGE(GRAPH_FAILED, "[Sort][Graph] Graph [%s] topological sort failed, saved to file black_box",
GetName().c_str());
return ret;
}
const auto &src_sub_graphs = graph_shared_->sub_graphs_;
if (src_sub_graphs.empty()) {
return GRAPH_SUCCESS;
}
for (auto sub_graph : src_sub_graphs) {
GE_CHECK_NOTNULL(sub_graph);
GE_CHECK_NOTNULL(FastGraphUtils::GetGraph(sub_graph));
ret = FastGraphUtils::GetGraph(sub_graph)->TopologicalSortingGraph(FastGraphUtils::GetGraph(sub_graph), false);
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E18888", "Sub graph[%s] topological sort failed, saved to file black_box",
FastGraphUtils::GetGraph(sub_graph)->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Sort][Graph] Sub graph[%s] topological sort failed, saved to file black_box",
FastGraphUtils::GetGraph(sub_graph)->GetName().c_str());
return ret;
}
}
std::vector<std::shared_ptr<ExecuteGraph>> subgraphs;
auto nodes = AllGraphNodes(subgraphs, nullptr);
int64_t i = 0LL;
for (auto iter = nodes.begin(); iter != nodes.end(); ++iter) {
FastNode *node = *iter;
node->GetOpDescBarePtr()->SetId(i);
++i;
}
if (src_sub_graphs.size() != subgraphs.size()) {
GELOGW("[TopoSort][CheckNodeSize] Keep original subgraph for graph size %zu not equal %zu.", src_sub_graphs.size(),
subgraphs.size());
return GRAPH_SUCCESS;
}
graph_shared_->ClearAllSubGraph();
names_to_subgraph_.clear();
std::for_each(subgraphs.begin(), subgraphs.end(),
[this](std::shared_ptr<ExecuteGraph> &subgraph) { (void) AddSubGraph(subgraph); });
return GRAPH_SUCCESS;
}
void ExecuteGraph::SetName(const std::string &name) {
graph_shared_->SetName(name);
}
std::string ExecuteGraph::GetName() const {
return graph_shared_->GetName();
}
void ExecuteGraph::SetParentGraph(ExecuteGraph *const parent_graph) {
graph_shared_->SetParentGraph(parent_graph);
}
const ExecuteGraph *ExecuteGraph::GetParentGraphBarePtr(void) const {
return graph_shared_->GetParentGraph();
}
ExecuteGraph *ExecuteGraph::GetParentGraphBarePtr(void) {
return graph_shared_->GetParentGraph();
}
graphStatus ExecuteGraph::RecycleQuickEdge(const FastEdge *const fast_edge) {
if (fast_edge == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "The node is nullptr.");
GE_LOGE("[Check][Param] The node is nullptr.");
return GRAPH_FAILED;
}
return graph_shared_->RecycleQuickEdge(FastGraphUtils::GetListElementAddr(fast_edge));
}
graphStatus ExecuteGraph::RecycleQuickNode(const FastNode *const fast_node) {
if (fast_node == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "The node is nullptr.");
GE_LOGE("[Check][Param] The node is nullptr.");
return GRAPH_FAILED;
}
return graph_shared_->RecycleQuickNode(FastGraphUtils::GetListElementAddr(fast_node));
}
std::vector<FastNode *> ExecuteGraph::GetAllNodes() const {
std::vector<std::shared_ptr<ExecuteGraph>> subgraphs;
return AllGraphNodes(subgraphs, nullptr);
}
std::vector<FastNode *> ExecuteGraph::GetAllNodes(const FastNodeFilter &fast_node_filter) const {
std::vector<std::shared_ptr<ExecuteGraph>> subgraphs;
return AllGraphNodes(subgraphs, fast_node_filter);
}
void ExecuteGraph::SetInputsOrder(const std::vector<std::string> &inputs_order) {
inputs_order_ = inputs_order;
}
void ExecuteGraph::ReorderByNodeId() {
graph_shared_->ReorderByNodeId();
}
void ExecuteGraph::SetGraphId(size_t graph_id) {
graph_shared_->SetGraphId(graph_id);
}
size_t ExecuteGraph::GetGraphId() const {
return graph_shared_->GetGraphId();
}
ProtoAttrMap &ExecuteGraph::MutableAttrMap() {
return attrs_;
}
ConstProtoAttrMap &ExecuteGraph::GetAttrMap() const {
return attrs_;
}
bool ExecuteGraph::CheckNodeIsInGraph(const FastNode *const node) const {
return graph_shared_->CheckNodeIsInGraph(node);
}
bool ExecuteGraph::CheckEdgeIsInGraph(const FastEdge *const edge) const {
return graph_shared_->CheckEdgeIsInGraph(edge);
}
graphStatus ExecuteGraph::MoveEdgeToGraph(const FastEdge *const edge) {
if (edge == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "The edge is nullptr.");
GE_LOGE("[Check][Param] The edge is nullptr.");
return GRAPH_FAILED;
}
graph_shared_->MoveEdgeToGraph(edge);
return GRAPH_SUCCESS;
}
}