* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/passes/feature/parallel_group_pass.h"
#include <queue>
#include "framework/common/debug/ge_log.h"
#include "common/plugin/ge_make_unique_util.h"
#include "common/checker.h"
#include "framework/common/ge_inner_error_codes.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
namespace ge {
namespace {
const int32_t kMaxRecursionDepth = 10;
const int64_t kLoopType = 1;
auto with_subgraph_node_filter = [](const Node &node) {
const auto &subgraph_name = node.GetOpDesc()->GetSubgraphInstanceNames();
return !subgraph_name.empty();
};
}
Status ParallelGroupPass::Run(ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph);
if (graph->GetParentGraph() != nullptr) {
return SUCCESS;
}
GELOGD("ParallelGroupPass running");
if (graph->TopologicalSorting() != GRAPH_SUCCESS) {
GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str());
REPORT_INNER_ERR_MSG("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.",
graph->GetName().c_str());
return FAILED;
}
std::unordered_set<std::string> parallel_groups;
if (ProcessGraphGroupNodes(graph, parallel_groups) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Process][Graph]Process group nodes of graph %s failed.", graph->GetName().c_str());
return INTERNAL_ERROR;
}
if (graph->TopologicalSorting() != GRAPH_SUCCESS) {
GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str());
REPORT_INNER_ERR_MSG("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.",
graph->GetName().c_str());
return FAILED;
}
return SUCCESS;
}
Status ParallelGroupPass::ProcessOnGraph(
const ComputeGraphPtr &graph,
std::map<ge::NodePtr, std::unordered_set<std::string>> &pnodes_2_parallel_groups) const {
auto candidates = graph->GetDirectNode();
std::unordered_set<std::string> parallel_groups;
std::map<std::string, std::vector<NodePtr>> group_nodes;
for (const auto &node : candidates) {
OpDescPtr op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
continue;
}
std::string group_name;
if (AttrUtils::GetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name)) {
group_nodes[group_name].push_back(node);
parallel_groups.insert(group_name);
GELOGD("Find group node:%s, group_name:%s", node->GetName().c_str(), group_name.c_str());
}
const auto &subgraph_name = op_desc->GetSubgraphInstanceNames();
if (subgraph_name.empty()) {
continue;
}
auto iter = pnodes_2_parallel_groups.find(node);
if (iter != pnodes_2_parallel_groups.end()) {
for (const auto &sub_parallel_group : iter->second) {
parallel_groups.insert(sub_parallel_group);
group_nodes[sub_parallel_group].emplace_back(node);
}
}
}
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> node_2_switch_merge;
if (ProcessGroupNodeInSwitch(graph, node_2_switch_merge) != SUCCESS) {
GELOGE(FAILED, "[Process][Node]Process group node in switch failed, graph:%s.", graph->GetName().c_str());
return FAILED;
}
for (const auto &itr : group_nodes) {
const auto &nodes = itr.second;
if (nodes.empty()) {
continue;
}
NodePtr pre_node = nodes[0];
NodePtr cur_node = nullptr;
for (std::size_t i = 1; i < nodes.size(); i++) {
cur_node = nodes[i];
GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) {
GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.",
pre_node->GetName().c_str(), cur_node->GetName().c_str());
return FAILED;
}
pre_node = cur_node;
}
}
pnodes_2_parallel_groups[graph->GetParentNode()] = parallel_groups;
return SUCCESS;
}
Status ParallelGroupPass::ProcessGraphGroupNodes(ComputeGraphPtr graph,
std::unordered_set<std::string> ¶llel_groups) const {
auto pnodes_with_subgraph = graph->GetNodes(false, with_subgraph_node_filter, nullptr);
std::map<ge::NodePtr, std::unordered_set<std::string>> pnodes_2_parallel_groups;
for (auto iter = pnodes_with_subgraph.end() -1; iter != pnodes_with_subgraph.begin() - 1; --iter) {
const auto &pnode = *iter;
const auto &subgraph_name = pnode->GetOpDesc()->GetSubgraphInstanceNames();
for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) {
const auto &sub_graph = graph->GetSubgraph(*name_iter);
GE_CHECK_NOTNULL(sub_graph);
GE_ASSERT_SUCCESS(ProcessOnGraph(sub_graph, pnodes_2_parallel_groups));
}
}
GE_ASSERT_SUCCESS(ProcessOnGraph(graph, pnodes_2_parallel_groups));
parallel_groups = pnodes_2_parallel_groups[nullptr];
return SUCCESS;
}
Status ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) const {
if (pre_node == cur_node) {
GELOGD("Pre_node and cur_node are same, no need add anchor");
return SUCCESS;
}
if (cur_node->GetType() == DATA) {
return SUCCESS;
}
auto in_nodes = cur_node->GetInAllNodes();
for (const auto &node : in_nodes) {
if (pre_node == node) {
GELOGD("Node:%s and %s already linked", pre_node->GetName().c_str(),
cur_node->GetName().c_str());
return SUCCESS;
}
}
GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), cur_node->GetInControlAnchor());
}
Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph,
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) const {
std::string type;
auto direct_nodes = graph->GetDirectNode();
for (const auto &node : direct_nodes) {
type = node->GetType();
if (type != STREAMSWITCH) {
continue;
}
GE_CHECK_NOTNULL(node->GetOpDesc());
if (IsBigSmallLoopStreamSwitch(node->GetOpDesc()) || IsWhileStreamSwitch(node->GetOpDesc())) {
continue;
}
std::vector<NodePtr> merge_nodes;
std::set<NodePtr> group_nodes;
std::set<std::string> stream_labels;
FindGroupNodeAndMerge(node, group_nodes, merge_nodes, stream_labels);
if (merge_nodes.empty() || (!group_nodes.empty() && stream_labels.size() > 1)) {
GELOGE(FAILED, "[Process][Node]Cannot find merge node or exist switch nestification, switch node:%s,"
"merge_vec size:%zu, stream_labels size:%zu, graph:%s.", node->GetName().c_str(),
merge_nodes.size(), stream_labels.size(), graph->GetName().c_str());
REPORT_INNER_ERR_MSG("E19999", "Cannot find merge node or exist switch nest, switch node:%s,"
"merge_vec size: %zu, stream_labels size: %zu, graph:%s.", node->GetName().c_str(),
merge_nodes.size(), stream_labels.size(), graph->GetName().c_str());
return FAILED;
}
std::sort(merge_nodes.begin(), merge_nodes.end(), [] (NodePtr a, NodePtr b) -> bool {
if ((a->GetOpDesc() == nullptr) || (b->GetOpDesc() == nullptr)) {
return false;
}
return (a->GetOpDesc()->GetId() < b->GetOpDesc()->GetId());
});
NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0);
GE_CHECK_NOTNULL(cast_node);
if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, cast_node, node, node_2_switch_merge) != SUCCESS) {
GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str());
return FAILED;
}
}
return SUCCESS;
}
void ParallelGroupPass::FindGroupNodeAndMerge(NodePtr stream_switch_node, std::set<NodePtr> &group_nodes,
std::vector<NodePtr> &merge_nodes, std::set<std::string> &stream_labels) const {
std::string type;
std::deque<NodePtr> candidates;
std::set<NodePtr> visited;
candidates.push_back(stream_switch_node);
while (!candidates.empty()) {
NodePtr tmp_node = candidates.front();
candidates.pop_front();
for (const auto &out_node : tmp_node->GetOutAllNodes()) {
type = out_node->GetType();
if (type == STREAMMERGE) {
merge_nodes.emplace_back(out_node);
continue;
}
const auto &op = out_node->GetOpDesc();
if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) {
group_nodes.emplace(out_node);
}
if (visited.count(out_node) > 0) {
continue;
}
candidates.push_back(out_node);
visited.insert(out_node);
std::string stream_label;
if (ge::AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) {
stream_labels.insert(stream_label);
}
}
}
}
Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set<NodePtr> &group_nodes,
const std::vector<NodePtr> &merge_nodes, const NodePtr &cast_node, const NodePtr &switch_node,
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) const {
for (const auto &group_node : group_nodes) {
auto itr = node_2_switch_merge.find(group_node);
if (itr != node_2_switch_merge.end()) {
auto &tmp = itr->second;
auto &switch_set = tmp.first;
const auto &merge_node = tmp.second;
GELOGD("Find group node: %s in switch %s and merge %s.",
group_node->GetName().c_str(), switch_node->GetName().c_str(), merge_node->GetName().c_str());
if (merge_node != merge_nodes.back()) {
GELOGE(FAILED, "[Mapping][Node]Has two different merge nodes: %s and %s, graph's structure is invalid",
merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str());
REPORT_INNER_ERR_MSG("E19999", "Has two different merge nodes: %s and %s, graph's structure is invalid",
merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str());
return FAILED;
}
switch_set.insert(cast_node);
} else {
node_2_switch_merge.emplace(group_node, std::make_pair(std::set<NodePtr>{cast_node}, merge_nodes.back()));
}
}
return SUCCESS;
}
Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cur_node,
const std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> &node_2_switch_merge) const {
auto pre_itr = node_2_switch_merge.find(pre_node);
auto cur_itr = node_2_switch_merge.find(cur_node);
if (pre_itr != node_2_switch_merge.end()) {
if (cur_itr != node_2_switch_merge.end()) {
const auto &pre_set = pre_itr->second.first;
const auto &cur_set = cur_itr->second.first;
if (!HasSameSwitch(pre_set, cur_set)) {
pre_node = pre_itr->second.second;
for (const auto &switch_node : cur_itr->second.first) {
if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) {
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
REPORT_INNER_ERR_MSG("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
return FAILED;
}
}
}
return SUCCESS;
} else {
pre_node = pre_itr->second.second;
return AddCtrlEdge(pre_node, cur_node);
}
} else {
if (cur_itr != node_2_switch_merge.end()) {
GE_CHECK_NOTNULL(pre_node->GetOpDesc());
for (const auto &switch_node : cur_itr->second.first) {
GE_CHECK_NOTNULL(switch_node->GetOpDesc());
int64_t pre_id = pre_node->GetOpDesc()->GetId();
int64_t switch_id = switch_node->GetOpDesc()->GetId();
NodePtr first_node = pre_node;
NodePtr second_node = switch_node;
if (pre_id > switch_id && IsIndirectConnect(switch_node, pre_node)) {
first_node = cur_itr->second.second;
second_node = pre_node;
}
if (AddCtrlEdge(first_node, second_node) != SUCCESS) {
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
first_node->GetName().c_str(), second_node->GetName().c_str());
REPORT_INNER_ERR_MSG("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
first_node->GetName().c_str(), second_node->GetName().c_str());
return FAILED;
}
}
} else {
return AddCtrlEdge(pre_node, cur_node);
}
}
return SUCCESS;
}
bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1,
const std::set<NodePtr> &switch_set2) const {
for (const auto &node1 : switch_set1) {
const auto itr = switch_set2.find(node1);
if (itr != switch_set2.end()) {
return true;
}
}
return false;
}
bool ParallelGroupPass::IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc) const {
return !AttrUtils::HasAttr(switch_op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG);
}
bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) const {
int64_t stream_switch_type = -1;
return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) &&
stream_switch_type == kLoopType);
}
bool ParallelGroupPass::IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b) const {
if (node_a == nullptr || node_b == nullptr || node_b->GetOpDesc() == nullptr) {
GELOGW("node_a or node_b is nullptr.");
return false;
}
int64_t end_id = node_b->GetOpDesc()->GetId();
std::queue<NodePtr> nodes;
nodes.push(node_a);
while (!nodes.empty()) {
NodePtr tmp_node = nodes.front();
nodes.pop();
if ((tmp_node == nullptr) || (tmp_node->GetOpDesc() == nullptr) || (tmp_node->GetOpDesc()->GetId() > end_id)) {
continue;
}
if (tmp_node == node_b) {
return true;
}
for (const auto &out_node : tmp_node->GetOutAllNodes()) {
nodes.push(out_node);
}
}
return false;
}
REG_PASS_OPTION("ParallelGroupPass").LEVELS(OoLevel::kO0);
}