* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/passes/feature/iterator_op_pass.h"
#include <memory>
#include <sstream>
#include <string>
#include "framework/common/debug/log.h"
#include "framework/common/debug/ge_log.h"
#include "common/plugin/ge_make_unique_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/anchor.h"
#include "common/omg_util/omg_util.h"
#include "graph/graph.h"
#include "graph/node.h"
#include "graph/passes/pass_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/manager/graph_var_manager.h"
#include "graph/ge_context.h"
#include "graph/manager/util/rt_context_util.h"
namespace ge {
const int32_t kMaxIterationsPerLoop = INT32_MAX - 1;
Status IteratorOpPass::Run(ge::ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph);
if (!PassUtils::IsNeedTrainIteFlowCtrl(graph)) {
return SUCCESS;
}
GELOGD("GetNextOpPass begin");
std::string type;
for (ge::NodePtr &node : graph->GetDirectNode()) {
GE_CHK_STATUS_RET(GetOriginalType(node, type));
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
const std::string op_type = op_desc->GetType();
if ((type == ITERATORV2) || (type == ITERATOR) || (op_type == GETNEXT)) {
ge::NodePtr memcpy_async_node = InsertMemcpyAsyncNode(node, graph);
GE_CHECK_NOTNULL(memcpy_async_node);
auto status = SetCycleEvent(memcpy_async_node);
if (status != ge::SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Set cycle event to op:%s(%s) failed",
memcpy_async_node->GetName().c_str(), memcpy_async_node->GetType().c_str());
GELOGE(status, "[Set][CycleEvent] to op:%s(%s) failed",
memcpy_async_node->GetName().c_str(), memcpy_async_node->GetType().c_str());
return status;
}
status = SetStreamLabel(memcpy_async_node, memcpy_async_node->GetName());
if (status != ge::SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Set stream label:%s to op:%s(%s) failed",
memcpy_async_node->GetName().c_str(), memcpy_async_node->GetName().c_str(),
memcpy_async_node->GetType().c_str());
GELOGE(status, "[Set][StreamLabel] %s to op:%s(%s) failed",
memcpy_async_node->GetName().c_str(), memcpy_async_node->GetName().c_str(),
memcpy_async_node->GetType().c_str());
return status;
}
status = SetStreamLabel(node, node->GetName());
if (status != ge::SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Set stream label:%s to op:%s(%s) failed",
node->GetName().c_str(), node->GetName().c_str(), node->GetType().c_str());
GELOGE(status, "[Set][StreamLabel] %s to op:%s(%s) failed",
node->GetName().c_str(), node->GetName().c_str(), node->GetType().c_str());
return status;
}
GELOGI("Set independent loop for iterator node success");
}
}
GELOGD("GetNextOpPass end");
return SUCCESS;
}
ge::NodePtr IteratorOpPass::InsertMemcpyAsyncNode(const ge::NodePtr &pre_node, const ge::ComputeGraphPtr &graph) {
GE_CHK_BOOL_EXEC(pre_node != nullptr, GELOGW("Pre node is null."); return nullptr);
GE_CHK_BOOL_EXEC(graph != nullptr, GELOGW("graph is null."); return nullptr);
ge::OpDescPtr memcpy_async_op_desc = CreateMemcpyAsyncOp(pre_node);
GE_CHK_BOOL_EXEC(memcpy_async_op_desc != nullptr, GELOGW("Create memcpyAsync op fail."); return nullptr);
auto memcpy_async_node = graph->InsertNode(pre_node, memcpy_async_op_desc);
GE_CHK_BOOL_EXEC(memcpy_async_node != nullptr,
REPORT_INNER_ERR_MSG("E19999", "Add node:%s(%s) to graph:%s failed",
memcpy_async_op_desc->GetName().c_str(), memcpy_async_op_desc->GetType().c_str(),
graph->GetName().c_str());
return nullptr,
"[Add][Node] %s(%s) to graph:%s failed", memcpy_async_op_desc->GetName().c_str(),
memcpy_async_op_desc->GetType().c_str(), graph->GetName().c_str());
for (auto &out_anchor : pre_node->GetAllOutDataAnchors()) {
if (out_anchor == nullptr) {
continue;
}
ge::graphStatus status;
GELOGI("Graph add memcpyAsync op in edge, index:%d.", out_anchor->GetIdx());
for (auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_IF_BOOL_EXEC(peer_in_anchor == nullptr, GELOGW("peer_in_anchor is nullptr"); return nullptr);
status = GraphUtils::RemoveEdge(out_anchor, peer_in_anchor);
GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS,
REPORT_INNER_ERR_MSG("E19999",
"Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
pre_node->GetName().c_str(),
pre_node->GetType().c_str(), out_anchor->GetIdx(),
peer_in_anchor->GetOwnerNode()->GetName().c_str(),
peer_in_anchor->GetOwnerNode()->GetType().c_str(),
peer_in_anchor->GetIdx());
return nullptr,
"[Remove][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
pre_node->GetName().c_str(), pre_node->GetType().c_str(), out_anchor->GetIdx(),
peer_in_anchor->GetOwnerNode()->GetName().c_str(),
peer_in_anchor->GetOwnerNode()->GetType().c_str(),
peer_in_anchor->GetIdx());
status = GraphUtils::AddEdge(memcpy_async_node->GetOutDataAnchor(out_anchor->GetIdx()), peer_in_anchor);
GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS,
REPORT_INNER_ERR_MSG("E19999",
"Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
memcpy_async_node->GetName().c_str(), memcpy_async_node->GetType().c_str(),
out_anchor->GetIdx(), peer_in_anchor->GetOwnerNode()->GetName().c_str(),
peer_in_anchor->GetOwnerNode()->GetType().c_str(), peer_in_anchor->GetIdx());
return nullptr,
"Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
memcpy_async_node->GetName().c_str(), memcpy_async_node->GetType().c_str(),
out_anchor->GetIdx(), peer_in_anchor->GetOwnerNode()->GetName().c_str(),
peer_in_anchor->GetOwnerNode()->GetType().c_str(), peer_in_anchor->GetIdx());
GELOGI("Graph add memcpyAsync op out edge, src index:%d, dst index:%d, dst node: %s.", out_anchor->GetIdx(),
peer_in_anchor->GetIdx(), peer_in_anchor->GetOwnerNode()->GetName().c_str());
}
status = GraphUtils::AddEdge(out_anchor, memcpy_async_node->GetInDataAnchor(out_anchor->GetIdx()));
GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS,
REPORT_INNER_ERR_MSG("E19999",
"Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
pre_node->GetName().c_str(), pre_node->GetType().c_str(), out_anchor->GetIdx(),
memcpy_async_node->GetName().c_str(), memcpy_async_node->GetType().c_str(),
out_anchor->GetIdx());
return nullptr,
"[Add][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
pre_node->GetName().c_str(), pre_node->GetType().c_str(), out_anchor->GetIdx(),
memcpy_async_node->GetName().c_str(), memcpy_async_node->GetType().c_str(),
out_anchor->GetIdx());
}
OutControlAnchorPtr out_ctrl_anchor = pre_node->GetOutControlAnchor();
GE_IF_BOOL_EXEC(out_ctrl_anchor != nullptr,
for (auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
ge::graphStatus status = GraphUtils::RemoveEdge(out_ctrl_anchor, peer_in_ctrl_anchor);
GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS,
REPORT_INNER_ERR_MSG("E19999",
"Remove control edge between op:%s(%s) and op:%s(%s) failed",
pre_node->GetName().c_str(), pre_node->GetType().c_str(),
peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
peer_in_ctrl_anchor->GetOwnerNode()->GetType().c_str());
return nullptr,
"[Remove][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
pre_node->GetName().c_str(), pre_node->GetType().c_str(),
peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
peer_in_ctrl_anchor->GetOwnerNode()->GetType().c_str());
status = GraphUtils::AddEdge(memcpy_async_node->GetOutControlAnchor(), peer_in_ctrl_anchor);
GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS,
REPORT_INNER_ERR_MSG("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
memcpy_async_node->GetName().c_str(), memcpy_async_node->GetType().c_str(),
peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
peer_in_ctrl_anchor->GetOwnerNode()->GetType().c_str());
return nullptr,
"[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
memcpy_async_node->GetName().c_str(), memcpy_async_node->GetType().c_str(),
peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
peer_in_ctrl_anchor->GetOwnerNode()->GetType().c_str());
GELOGI("Graph add memcpyAsync op out ctrl edge, dst node: %s.",
peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str());
});
GELOGI("Insert memcpyAsync op success.");
return memcpy_async_node;
}
ge::OpDescPtr IteratorOpPass::CreateMemcpyAsyncOp(const ge::NodePtr &pre_node) const {
GE_CHK_BOOL_EXEC(pre_node != nullptr, return nullptr, "Input param invalid.");
std::string node_name = pre_node->GetName() + "_MemcpyAsync";
ge::OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), MEMCPYASYNC);
if (op_desc == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "New OpDesc failed");
GELOGE(FAILED, "[New][OpDesc] failed");
return op_desc;
}
GELOGI("Create memcpyAsync op:%s.", op_desc->GetName().c_str());
ge::OpDescPtr pre_node_op_desc = pre_node->GetOpDesc();
GE_CHK_BOOL_EXEC(pre_node_op_desc != nullptr,
REPORT_INNER_ERR_MSG("E19999", "OpDesc in node is nullptr, check invalid");
return nullptr, "[Get][OpDesc] failed, OpDesc of pre_node is invalid.");
size_t out_size = pre_node_op_desc->GetOutputsSize();
GELOGI("Create memcpyAsync op, pre_node out_size: %zu.", out_size);
for (size_t i = 0; i < out_size; i++) {
GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(pre_node_op_desc->GetOutputDesc(i)) == GRAPH_SUCCESS,
REPORT_INNER_ERR_MSG("E19999", "Add input desc to op:%s(%s) failed",
pre_node_op_desc->GetName().c_str(), pre_node_op_desc->GetType().c_str());
return nullptr,
"[Add][InputDesc] to op:%s(%s) failed",
pre_node_op_desc->GetName().c_str(), pre_node_op_desc->GetType().c_str());
GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(pre_node_op_desc->GetOutputDesc(i)) == GRAPH_SUCCESS,
REPORT_INNER_ERR_MSG("E19999", "Add output desc to op:%s(%s) failed",
pre_node_op_desc->GetName().c_str(), pre_node_op_desc->GetType().c_str());
return nullptr,
"[Add][OutputDesc] to op:%s(%s) failed",
pre_node_op_desc->GetName().c_str(), pre_node_op_desc->GetType().c_str());
}
return op_desc;
}
REG_PASS_OPTION("IteratorOpPass").LEVELS(OoLevel::kO1);
}