* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "hccl_sequence_adjust_pass.h"
#include <stack>
namespace ge {
Status HcclSequenceAdjustPass::Run(ComputeGraphPtr graph) {
if (graph->GetParentNode() != nullptr) {
return SUCCESS;
}
GELOGD("Start to run HcclSequenceAdjustPass.");
std::vector<NodePtr> func_nodes;
GE_CHK_STATUS_RET(GetFunctionNodesWithHcclGroup(graph, func_nodes), "Get function nodes with hccl group failed.");
if (func_nodes.empty()) {
GELOGD("Cannot find function nodes with hccl group.");
return SUCCESS;
}
GELOGD("Size of function node with hccl group is %zu.", func_nodes.size());
GE_CHK_STATUS_RET(RebuildHcclControlRelation(graph), "rebuild hccl control relation failed.");
GELOGD("Success to run HcclSequenceAdjustPass.");
return SUCCESS;
}
Status HcclSequenceAdjustPass::GetFunctionNodesWithHcclGroup(const ComputeGraphPtr &graph,
std::vector<NodePtr> &func_nodes) const {
GELOGD("Start to get function nodes with hccl group.");
std::string group_id;
for (const auto &node : graph->GetDirectNode()) {
const auto &op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (!op_desc->GetSubgraphInstanceNames().empty() &&
AttrUtils::GetStr(op_desc, ATTR_NAME_HCCL_FUSED_GROUP, group_id)) {
func_nodes.emplace_back(node);
}
}
GELOGD("Success to get function nodes with hccl group.");
return SUCCESS;
}
bool HcclSequenceAdjustPass::HasRelationPath(const NodePtr &second_hccl, const NodePtr &last_hccl) const {
std::stack<NodePtr> out_nodes;
out_nodes.push(second_hccl);
std::set<NodePtr> node_seen;
while (!out_nodes.empty()) {
const auto one_node = out_nodes.top();
out_nodes.pop();
if (one_node == last_hccl) {
return true;
}
for (const auto &out_node : one_node->GetOutAllNodes()) {
if (node_seen.emplace(out_node).second) {
out_nodes.push(out_node);
}
}
}
return false;
}
Status HcclSequenceAdjustPass::RebuildHcclControlRelation(const ComputeGraphPtr &graph) const {
GELOGD("Start to rebulid hccl control relation of graph: %s.", graph->GetName().c_str());
NodePtr last_hccl = nullptr;
NodePtr second_hccl = nullptr;
for (const auto &node : graph->GetDirectNode()) {
if (node->GetType() == HCOMALLREDUCE) {
const auto &out_ctrl_nodes = node->GetOutControlNodes();
if (std::none_of(out_ctrl_nodes.begin(), out_ctrl_nodes.end(),
[](const NodePtr &n) { return (n->GetType() == HCOMALLREDUCE); })) {
const auto &in_ctrl_nodes = node->GetInControlNodes();
if (std::any_of(in_ctrl_nodes.begin(), in_ctrl_nodes.end(),
[](const NodePtr &n) { return (n->GetType() == HCOMALLREDUCE); })) {
second_hccl = node;
GELOGD("Find second last hccl node: %s.", node->GetName().c_str());
} else {
last_hccl = node;
GELOGD("Find last hccl node: %s.", node->GetName().c_str());
}
}
}
}
if ((last_hccl == nullptr) || (second_hccl == nullptr)) {
GELOGW("Cannot find optimizable HcomAllReduce nodes..");
return SUCCESS;
}
if (HasRelationPath(second_hccl, last_hccl)) {
GELOGW("Exist path from %s to %s, skip link.", second_hccl->GetName().c_str(), last_hccl->GetName().c_str());
return SUCCESS;
}
const auto &out_ctrl_anchor = last_hccl->GetOutControlAnchor();
GE_CHECK_NOTNULL(out_ctrl_anchor);
const auto &in_ctrl_anchor = second_hccl->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor);
GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(in_ctrl_anchor),
"Add link from %s to %s failed.", last_hccl->GetName().c_str(), second_hccl->GetName().c_str());
GELOGD("Add control edge from %s to %s.", last_hccl->GetName().c_str(), second_hccl->GetName().c_str());
GELOGD("Success to rebulid hccl control relation of graph: %s.", graph->GetName().c_str());
return SUCCESS;
}
REG_PASS_OPTION("HcclSequenceAdjustPass").LEVELS(OoLevel::kO3);
}