* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "node_priority_calculator.h"
#include <map>
#include <vector>
#include <queue>
#include "common/checker.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/fast_node_utils.h"
#include "exe_graph/lowering/exe_graph_attrs.h"
#include "exe_graph/runtime/continuous_buffer.h"
#include "exe_graph/runtime/compute_node_info.h"
#include "exe_graph/runtime/continuous_vector.h"
#include "core/builder/node_types.h"
namespace gert {
namespace bg {
namespace {
constexpr static int64_t kInitPriority = std::numeric_limits<int64_t>::max();
const int64_t kPriorityExpansion = 10;
const int64_t kPriorityDecrease = 5;
struct NodePriorities {
explicit NodePriorities(size_t node_num) {
node_ids_to_priority.resize(node_num, kInitPriority);
seen_nodes.resize(node_num, false);
}
void Set(int64_t node_id, int64_t priority) {
node_ids_to_priority[node_id] = priority;
seen_nodes[node_id] = true;
}
int64_t Get(int64_t node_id) const {
return node_ids_to_priority[node_id];
}
bool HasSet(int64_t node_id) {
return seen_nodes[node_id];
}
std::vector<int64_t> node_ids_to_priority;
std::vector<bool> seen_nodes;
};
void PushQueue(ge::FastNode *const node, std::queue<ge::FastNode *> &queue, std::set<ge::FastNode *> &seen) {
if (seen.insert(node).second) {
queue.push(node);
}
}
void PushQueueForAncestors(const ge::FastNode *const priority_node, std::queue<ge::FastNode *> &nodes,
std::set<ge::FastNode *> &seen) {
for (const auto in_data_edge : priority_node->GetAllInDataEdgesRef()) {
if (in_data_edge != nullptr) {
PushQueue(in_data_edge->src, nodes, seen);
}
}
for (const auto in_ctrl_edge : priority_node->GetAllInControlEdgesRef()) {
if (in_ctrl_edge != nullptr) {
PushQueue(in_ctrl_edge->src, nodes, seen);
}
}
}
void MarkAncestorsPriorities(const ge::FastNode *const priority_node, int64_t priority,
NodePriorities &node_ids_to_priority) {
std::queue<ge::FastNode *> nodes;
std::set<ge::FastNode *> seen;
PushQueueForAncestors(priority_node, nodes, seen);
GELOGD("Set priority %" PRId64 " to priority node %s", priority, priority_node->GetNamePtr());
node_ids_to_priority.Set(priority_node->GetOpDescBarePtr()->GetId(), priority);
while (!nodes.empty()) {
auto node = nodes.front();
nodes.pop();
const int64_t node_id = node->GetOpDescBarePtr()->GetId();
if (node_ids_to_priority.HasSet(node_id)) {
continue;
}
GELOGD("Set priority %" PRId64 " to node %s for priority node %s", priority, node->GetNamePtr(),
priority_node->GetNamePtr());
node_ids_to_priority.Set(node_id, priority);
PushQueueForAncestors(node, nodes, seen);
}
}
void MarkPriorityByPreNode(ge::FastNode *const priority_node, NodePriorities &node_ids_to_priority) {
std::queue<ge::FastNode *> nodes;
std::set<ge::FastNode *> seen;
PushQueueForAncestors(priority_node, nodes, seen);
int64_t lowest_priority = kInitPriority;
std::map<int64_t, ge::FastNode *> unset_ancestors_to_node;
unset_ancestors_to_node[priority_node->GetOpDescBarePtr()->GetId()] = priority_node;
while (!nodes.empty()) {
auto node = nodes.front();
nodes.pop();
const int64_t node_id = node->GetOpDescBarePtr()->GetId();
if (node_ids_to_priority.HasSet(node_id)) {
if (lowest_priority == kInitPriority) {
lowest_priority = node_ids_to_priority.Get(node_id);
} else {
lowest_priority = std::max(node_ids_to_priority.Get(node_id), lowest_priority);
}
} else {
PushQueueForAncestors(node, nodes, seen);
unset_ancestors_to_node[node_id] = node;
}
}
for (auto &node_id_to_node : unset_ancestors_to_node) {
int64_t priority = lowest_priority;
if (IsSendEventsNode(node_id_to_node.second->GetTypePtr())) {
priority = lowest_priority + 1;
}
GELOGD("Set priority %" PRId64 " to node %s for priority node %s", priority, node_id_to_node.second->GetNamePtr(),
priority_node->GetNamePtr());
node_ids_to_priority.Set(node_id_to_node.first, priority);
}
}
void MarkIfSubGraphPriorities(const ge::FastNode *parent_node, NodePriorities &node_ids_to_priority) {
const auto op_desc = parent_node->GetOpDescBarePtr();
int64_t priority = node_ids_to_priority.Get(op_desc->GetId());
int64_t min_priority = priority;
int64_t max_priority = priority;
const auto &sub_graph_indexes = op_desc->GetSubgraphNameIndexes();
std::vector<const ge::FastNode *> before_branch_execute_nodes;
std::vector<const ge::FastNode *> after_branch_execute_nodes;
std::vector<ge::FastNode *> other_unset_nodes;
for (const auto &index: sub_graph_indexes) {
const auto &graph = ge::FastNodeUtils::GetSubgraphFromNode(parent_node, index.second);
for (const auto node : graph->GetAllNodes()) {
const auto &node_type = node->GetTypePtr();
auto current_priority = node_ids_to_priority.Get(node->GetOpDescBarePtr()->GetId());
if (IsBranchDone(node_type) || IsWaitAnyone(node_type)) {
after_branch_execute_nodes.push_back(node);
} else if (IsSwitchNotifyNode(node_type)) {
before_branch_execute_nodes.push_back(node);
} else {
if (current_priority == kInitPriority) {
other_unset_nodes.emplace_back(node);
}
}
if (current_priority != kInitPriority) {
if (current_priority > min_priority) {
min_priority = current_priority;
}
if (current_priority < max_priority) {
max_priority = current_priority;
}
}
}
}
for (const auto node : before_branch_execute_nodes) {
node_ids_to_priority.Set(node->GetOpDescBarePtr()->GetId(), max_priority);
GELOGD("control node %s set priority %lld", node->GetNamePtr(), max_priority);
}
for (const auto node : after_branch_execute_nodes) {
node_ids_to_priority.Set(node->GetOpDescBarePtr()->GetId(), min_priority);
GELOGD("control node %s set priority %lld", node->GetNamePtr(), min_priority);
}
for (const auto node : other_unset_nodes) {
MarkPriorityByPreNode(node, node_ids_to_priority);
}
}
bool CheckNodeIdValid(ge::FastNode *const node, size_t count) {
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
GE_ASSERT_TRUE(static_cast<size_t>(op_desc->GetId()) < count, "The node %s topo id %" PRId64 " exceeds max count %zu",
op_desc->GetNamePtr(), op_desc->GetId(), count);
return true;
}
ge::graphStatus SetPriorityToParentNode(ge::FastNode *node, NodePriorities &priorities) {
std::set<int64_t> unset;
int64_t priority = kInitPriority;
while (node != nullptr) {
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
priority = priorities.Get(op_desc->GetId());
if (priority != kInitPriority) {
break;
}
GE_ASSERT_TRUE(unset.insert(op_desc->GetId()).second, "Circle found when iter parent nodes for node %s",
node->GetNamePtr());
const auto extend_info = node->GetExtendInfo();
GE_ASSERT_NOTNULL(extend_info);
const auto graph = extend_info->GetOwnerGraphBarePtr();
GE_ASSERT_NOTNULL(graph);
node = graph->GetParentNodeBarePtr();
}
if (priority != kInitPriority) {
for (const auto &node_id : unset) {
priorities.Set(node_id, priority);
}
}
return ge::GRAPH_SUCCESS;
}
}
NodePriorityCalculator::NodePriorityCalculator(const GraphFrame &frame) : frame_(frame) {}
ge::graphStatus NodePriorityCalculator::CalcNodeExecutionPriorities(const std::vector<ge::FastNode *> &main_graph_nodes,
const size_t root_all_nodes_cnt) {
std::vector<int64_t> index_2_compute_node_id;
for (const auto &compute_node : frame_.GetIndexesToNode()) {
const auto op_desc = compute_node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
index_2_compute_node_id.emplace_back(op_desc->GetId());
GELOGD("Get compute node info, node name: %s, id: %" PRId64 "", compute_node->GetNamePtr(), op_desc->GetId());
}
NodePriorities node_ids_to_priority(root_all_nodes_cnt);
vector<ge::FastNode *> pending_mark_nodes;
std::vector<ge::FastNode *> if_sub_graph_nodes;
std::multimap<int64_t, ge::FastNode *> launch_priority_to_launch_node;
int64_t launch_priority = 0;
for (const auto node : main_graph_nodes) {
GE_ASSERT_TRUE(CheckNodeIdValid(node, root_all_nodes_cnt));
if (IsLaunchOrHasSubGraphNode(node)) {
if (IsIfOrCaseType(node->GetTypePtr())) {
if_sub_graph_nodes.push_back(node);
}
int64_t compute_node_index = 0;
if (ge::AttrUtils::GetInt(node->GetOpDescBarePtr(), kComputeNodeIndex, compute_node_index)) {
GE_ASSERT_TRUE(static_cast<size_t>(compute_node_index) < index_2_compute_node_id.size());
auto compute_node_id = index_2_compute_node_id[compute_node_index];
launch_priority = (compute_node_id * kPriorityExpansion) + (kPriorityExpansion - 1);
if (IsAtomicLaunchNode(node->GetTypePtr())) {
launch_priority -= kPriorityDecrease;
}
}
(void)launch_priority_to_launch_node.emplace(launch_priority++, node);
} else if (IsFreeNode(node->GetTypePtr()) || IsSendEventsNode(node->GetTypePtr())) {
pending_mark_nodes.emplace_back(node);
} else {
}
}
for (const auto &node_info : launch_priority_to_launch_node) {
MarkAncestorsPriorities(node_info.second, node_info.first, node_ids_to_priority);
}
for (const auto node : pending_mark_nodes) {
MarkPriorityByPreNode(node, node_ids_to_priority);
}
for (auto node : if_sub_graph_nodes) {
MarkIfSubGraphPriorities(node, node_ids_to_priority);
}
for (const auto node : main_graph_nodes) {
if (node_ids_to_priority.Get(node->GetOpDescBarePtr()->GetId()) == kInitPriority) {
GE_ASSERT_SUCCESS(SetPriorityToParentNode(node, node_ids_to_priority));
}
}
for (const auto node : main_graph_nodes) {
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(op_desc, "priority", node_ids_to_priority.Get(op_desc->GetId())));
}
return ge::GRAPH_SUCCESS;
}
}
}