* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/passes/control_flow_and_stream/next_iteration_pass.h"
#include "common/plugin/ge_make_unique_util.h"
#include "common/omg_util/omg_util.h"
#include "graph/utils/node_utils.h"
namespace ge {
namespace {
constexpr int64_t kLoopType = 1;
constexpr uint8_t kMaxTransOp = 3;
constexpr uint8_t kTransOpIoSize = 1;
}
Status NextIterationPass::Run(ComputeGraphPtr graph) {
GELOGD("NextIterationPass Enter");
for (auto &node : graph->GetDirectNode()) {
const std::string type = node->GetType();
if ((type != ENTER) && (type != REFENTER)) {
continue;
}
if (GroupEnterNode(node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Group][EnterNode] %s failed.", node->GetName().c_str());
return INTERNAL_ERROR;
}
}
if (FindWhileGroups() != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Find][WhileGroups] in graph:%s failed.", graph->GetName().c_str());
return INTERNAL_ERROR;
}
if (!VerifyWhileGroup()) {
GELOGE(INTERNAL_ERROR, "[Verify][WhileGroup] in graph:%s failed.", graph->GetName().c_str());
return INTERNAL_ERROR;
}
if (HandleWhileGroup(graph) != SUCCESS) {
GELOGE(FAILED, "[Handle][WhileGroup] in graph:%s failed.", graph->GetName().c_str());
return FAILED;
}
GELOGD("NextIterationPass Leave");
return SUCCESS;
}
Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) {
OpDescPtr enter_desc = enter_node->GetOpDesc();
GE_CHECK_NOTNULL(enter_desc);
std::string frame_name;
if (!ge::AttrUtils::GetStr(enter_desc, ENTER_ATTR_FRAME_NAME, frame_name) || frame_name.empty()) {
REPORT_INNER_ERR_MSG("E19999", "Get Attr:%s from op:%s(%s) failed", ENTER_ATTR_FRAME_NAME.c_str(),
enter_desc->GetName().c_str(), enter_desc->GetType().c_str());
GELOGE(FAILED, "[Get][Attr] %s from op:%s(%s) failed", ENTER_ATTR_FRAME_NAME.c_str(),
enter_desc->GetName().c_str(), enter_desc->GetType().c_str());
return FAILED;
}
std::string batch_label;
if (ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
frame_name += batch_label;
}
auto iter = loop_group_map_.find(frame_name);
if (iter == loop_group_map_.end()) {
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
if (loop_group == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "New LoopCondGroup failed");
GELOGE(FAILED, "[New][LoopCondGroup] failed.");
return FAILED;
}
loop_group->enter_nodes.emplace_back(enter_node);
loop_group_map_[frame_name] = loop_group;
} else {
iter->second->enter_nodes.emplace_back(enter_node);
}
return SUCCESS;
}
Status NextIterationPass::FindWhileGroups() {
for (const auto &loop_group_iter : loop_group_map_) {
const std::string &frame_name = loop_group_iter.first;
for (const auto &enter_node : loop_group_iter.second->enter_nodes) {
for (const auto &out_node : enter_node->GetOutAllNodes()) {
std::string type;
GE_CHK_STATUS_RET(GetOriginalType(out_node, type), "[Get][OriginalType] failed.");
if ((type != MERGE) && (type != REFMERGE)) {
continue;
}
NodePtr next_node = nullptr;
if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Get][NextIterationNode] failed, frame_name:%s", frame_name.c_str());
return INTERNAL_ERROR;
}
loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node));
NodePtr switch_node = nullptr;
if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Get][SwitchNode] failed, frame_name:%s.", frame_name.c_str());
return INTERNAL_ERROR;
}
if (switch_node == nullptr) {
continue;
}
if (!AttrUtils::SetInt(switch_node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_TYPE, kLoopType)) {
REPORT_INNER_ERR_MSG("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_STREAM_SWITCH_TYPE.c_str(),
switch_node->GetName().c_str(), switch_node->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_STREAM_SWITCH_TYPE.c_str(),
switch_node->GetName().c_str(), switch_node->GetType().c_str());
return INTERNAL_ERROR;
}
NodePtr loop_cond = nullptr;
if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Get][LoopCondNode] failed, frame_name:%s.", frame_name.c_str());
return INTERNAL_ERROR;
}
loop_group_iter.second->switch_nodes.emplace_back(switch_node);
if (loop_group_iter.second->loop_cond == nullptr) {
loop_group_iter.second->loop_cond = loop_cond;
} else if (loop_group_iter.second->loop_cond != loop_cond) {
REPORT_INNER_ERR_MSG("E19999", "Multi LoopCond nodes exist, frame_name:%s, check invalid", frame_name.c_str());
GELOGE(FAILED, "[Check][Param] Multi LoopCond nodes exist, frame_name:%s.", frame_name.c_str());
return FAILED;
}
}
}
}
return SUCCESS;
}
bool NextIterationPass::VerifyWhileGroup() const {
for (const auto &loop_group_iter : loop_group_map_) {
const std::string &frame_name = loop_group_iter.first;
if (frame_name.empty()) {
REPORT_INNER_ERR_MSG("E19999", "Verify while group failed, frame_name is empty");
GELOGE(INTERNAL_ERROR, "[Check][Param] Verify while group failed, frame_name is empty.");
return false;
}
if (loop_group_iter.second->loop_cond == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Verify while group failed, LoopCond is null, frame_name:%s.", frame_name.c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] Verify while group failed, LoopCond is null, frame_name:%s.",
frame_name.c_str());
return false;
}
for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) {
if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) {
REPORT_INNER_ERR_MSG("E19999", "Verify while group failed, merge_node/next_node is null, frame_name:%s.",
frame_name.c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] Verify while group failed, merge_node/next_node is null, frame_name:%s.",
frame_name.c_str());
return false;
}
}
}
return true;
}
Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
for (const auto &loop_cond_iter : loop_group_map_) {
const LoopCondGroup &loop_group = *loop_cond_iter.second;
const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName();
const int64_t group_index = loop_group.loop_cond->GetOpDesc()->GetId();
GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str());
NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE);
NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE);
if ((enter_active == nullptr) || (next_active == nullptr)) {
GELOGE(INTERNAL_ERROR, "[Create][ActiveNode] failed, cond_name:%s.", cond_name.c_str());
return INTERNAL_ERROR;
}
for (const auto &enter_node : loop_cond_iter.second->enter_nodes) {
if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
enter_node->GetName().c_str(), enter_node->GetType().c_str(),
enter_active->GetName().c_str(), enter_active->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
enter_node->GetName().c_str(), enter_node->GetType().c_str(),
enter_active->GetName().c_str(), enter_active->GetType().c_str());
return INTERNAL_ERROR;
}
SetControlFlowGroup(enter_node, group_index);
}
for (const auto &pair : loop_cond_iter.second->merge_next_pairs) {
NodePtr merge_node = pair.first;
NodePtr next_node = pair.second;
if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
enter_active->GetName().c_str(), enter_active->GetType().c_str(),
merge_node->GetName().c_str(), merge_node->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
enter_active->GetName().c_str(), enter_active->GetType().c_str(),
merge_node->GetName().c_str(), merge_node->GetType().c_str());
return INTERNAL_ERROR;
}
if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
next_node->GetName().c_str(), next_node->GetType().c_str(),
next_active->GetName().c_str(), next_active->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
next_node->GetName().c_str(), next_node->GetType().c_str(),
next_active->GetName().c_str(), next_active->GetType().c_str());
return INTERNAL_ERROR;
}
if (BreakNextIteration(next_node, merge_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Break][NextIteration] failed, next_node:%s, merge_node:%s",
next_node->GetName().c_str(), merge_node->GetName().c_str());
return INTERNAL_ERROR;
}
SetControlFlowGroup(next_node, group_index);
SetControlFlowGroup(merge_node, group_index);
}
if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
(SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) {
GELOGE(INTERNAL_ERROR, "[Set][ActiveLabelList] failed, cond_name:%s.", cond_name.c_str());
return INTERNAL_ERROR;
}
SetControlFlowGroup(loop_group.loop_cond, group_index);
SetControlFlowGroup(enter_active, group_index);
SetControlFlowGroup(next_active, group_index);
HandleSwitchExitNodes(loop_group, group_index);
}
return SUCCESS;
}
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) const {
std::string node_type;
for (const auto &switch_node : loop_group.switch_nodes) {
SetControlFlowGroup(switch_node, group_index);
for (auto node : switch_node->GetOutDataNodes()) {
for (uint8_t i = 0; i < kMaxTransOp; ++i) {
if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) {
break;
}
if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) {
SetControlFlowGroup(node, group_index);
break;
}
const auto &all_nodes = node->GetOutAllNodes();
if (all_nodes.size() != kTransOpIoSize) {
break;
}
node = all_nodes.at(0);
}
}
}
}
NodePtr NextIterationPass::CreateActiveNode(ComputeGraphPtr &graph, const std::string &name) const {
OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMACTIVE);
if (op_desc == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "New OpDesc failed");
GELOGE(FAILED, "[New][OpDesc] failed");
return nullptr;
}
GELOGI("Create StreamActive op:%s.", op_desc->GetName().c_str());
NodePtr active_node = graph->AddNode(op_desc);
if (active_node == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Add node:%s(%s) to graph:%s failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
GELOGE(INTERNAL_ERROR, "[Add][Node] %s(%s) to graph:%s failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
return nullptr;
}
if (SetSwitchBranchNodeLabel(active_node, name) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Set][SwitchBranchNodeLabel] %s to node:%s(%s) failed",
name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str());
return nullptr;
}
return active_node;
}
Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr &merge_node) const {
if ((merge_node == nullptr) || (next_node == nullptr)) {
GELOGE(PARAM_INVALID, "[Check][Param] merge node or next node is nullptr.");
return PARAM_INVALID;
}
for (const auto &in_anchor : merge_node->GetAllInDataAnchors()) {
OutDataAnchorPtr out_anchor = in_anchor->GetPeerOutAnchor();
if ((out_anchor == nullptr) || (out_anchor->GetOwnerNode() != next_node)) {
continue;
}
if (GraphUtils::RemoveEdge(out_anchor, in_anchor) != SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
out_anchor->GetOwnerNode()->GetName().c_str(), out_anchor->GetOwnerNode()->GetType().c_str(),
out_anchor->GetIdx(),
merge_node->GetName().c_str(), merge_node->GetType().c_str(), in_anchor->GetIdx());
GELOGE(INTERNAL_ERROR, "[Remove][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
out_anchor->GetOwnerNode()->GetName().c_str(), out_anchor->GetOwnerNode()->GetType().c_str(),
out_anchor->GetIdx(), merge_node->GetName().c_str(), merge_node->GetType().c_str(), in_anchor->GetIdx());
return INTERNAL_ERROR;
}
if (SetNextIteration(merge_node, next_node) != SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Set attr NEXT_ITERATION value:%s to node:%s(%s) failed",
next_node->GetName().c_str(), merge_node->GetName().c_str(), merge_node->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Set][Attr] NEXT_ITERATION value:%s to node:%s(%s) failed",
next_node->GetName().c_str(), merge_node->GetName().c_str(), merge_node->GetType().c_str());
return INTERNAL_ERROR;
}
}
return SUCCESS;
}
Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input,
NodePtr &target_node) const {
if (node == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Param node is nullptr, check invalid");
GELOGE(PARAM_INVALID, "[Check][Param] node is nullptr.");
return PARAM_INVALID;
}
std::vector<NodePtr> nodes;
if (is_input) {
for (const auto &tmp_node : node->GetInDataNodes()) {
nodes.emplace_back(tmp_node);
}
} else {
for (const auto &tmp_node : node->GetOutDataNodes()) {
nodes.emplace_back(tmp_node);
}
}
for (const auto &tmp_node : nodes) {
std::string type;
GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "[Get][NodeType] failed.");
if ((target_type == LOOPCOND) && (type == target_type)) {
target_node = tmp_node;
break;
} else if ((type == target_type) || (type == "Ref" + target_type)) {
target_node = tmp_node;
break;
}
}
if ((target_type != SWITCH) && (target_node == nullptr)) {
REPORT_INNER_ERR_MSG("E19999", "Find target_type:%s node around node:%s(%s) failed",
target_type.c_str(), node->GetName().c_str(), node->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] Find target_type:%s node around node:%s(%s) failed",
target_type.c_str(), node->GetName().c_str(), node->GetType().c_str());
return INTERNAL_ERROR;
}
return SUCCESS;
}
Status NextIterationPass::ClearStatus() {
loop_group_map_.clear();
return SUCCESS;
}
REG_PASS_OPTION("NextIterationPass").LEVELS(OoLevel::kO0);
}