* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/passes/memory_conflict/hccl_continuous_memcpy_pass.h"
#include <string>
#include "framework/common/debug/log.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/ge_inner_error_codes.h"
#include "common/plugin/ge_make_unique_util.h"
#include "framework/common/framework_types_internal.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/op_type_utils.h"
#include "runtime/rt.h"
#include "checker.h"
namespace {
const int32_t kAnchorAssignRefIndex = 0;
const int32_t kAnchorAssignValueIndex = 1;
const int32_t kAnchorIdentityIndex = 0;
bool IsInputContinuous(const ge::OpDescPtr &op_desc) {
bool is_input_continuous = false;
(void)ge::AttrUtils::GetBool(op_desc, ge::ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous);
return is_input_continuous;
}
bool HasMultiPeerNodeRequireInputContinuous(const ge::OutDataAnchorPtr &out_data_anchor) {
size_t count = 0U;
for (const auto &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
if (peer_in_anchor == nullptr || peer_in_anchor->GetOwnerNode() == nullptr) {
continue;
}
const auto &peer_in_node = peer_in_anchor->GetOwnerNode();
if (IsInputContinuous(peer_in_node->GetOpDesc())) {
count++;
}
}
return (count > 1U);
}
}
namespace ge {
Status HcclContinuousMemcpyPass::Run(ge::ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph);
for (const auto &node : graph->GetDirectNode()) {
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Node with nullptr op_desc exist in Param graph:%s, check invalid",
graph->GetName().c_str());
GELOGE(INTERNAL_ERROR, "[Get][OpDesc] failed, Node with nullptr op_desc exist in Param graph:%s.",
graph->GetName().c_str());
return INTERNAL_ERROR;
}
Status ret = ContinuousInputProcess(graph, node);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Call][ContinuousInputProcess] failed, node_name:%s.", node->GetName().c_str());
return ret;
}
ret = P2pmemInputProcess(graph, node);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Call][P2pmemInputProcess] failed, node_name:%s.", node->GetName().c_str());
return ret;
}
}
return SUCCESS;
}
Status HcclContinuousMemcpyPass::ContinuousInputProcess(const ComputeGraphPtr &graph, const NodePtr node) {
auto op_desc = node->GetOpDesc();
if (IsInputContinuous(op_desc)) {
GELOGI("continuous input op is:%s.", op_desc->GetName().c_str());
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) {
if (hccl_in_anchor == nullptr) {
continue;
}
auto src_out_anchor = hccl_in_anchor->GetPeerOutAnchor();
if (src_out_anchor == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Node:%s(%s) input:%d anchor, peer anchor is nullptr, check invalid",
node->GetName().c_str(), node->GetType().c_str(),
hccl_in_anchor->GetIdx());
GELOGE(INTERNAL_ERROR, "[Get][PeerOutAnchor] failed, Node:%s(%s) input:%d anchor, peer anchor is nullptr",
node->GetName().c_str(), node->GetType().c_str(), hccl_in_anchor->GetIdx());
return INTERNAL_ERROR;
}
if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType()) ||
HasMultiPeerNodeRequireInputContinuous(src_out_anchor)) {
Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Modify][EdgeConnection] between %s and %s failed.",
src_out_anchor->GetOwnerNode()->GetName().c_str(), node->GetName().c_str());
return ret;
}
}
}
}
return SUCCESS;
}
Status HcclContinuousMemcpyPass::P2pmemInputProcess(const ComputeGraphPtr &graph, const NodePtr node) {
auto op_desc = node->GetOpDesc();
std::vector<int64_t> input_memory_types;
(void) ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_INPUT_MEM_TYPE_LIST, input_memory_types);
if (input_memory_types.empty()) {
return SUCCESS;
}
for (uint32_t index = 0; index < input_memory_types.size() && index < op_desc->GetInputsSize(); index++) {
if (input_memory_types[index] != RT_MEMORY_P2P_DDR) {
continue;
}
GELOGD("p2p input op is:%s.", op_desc->GetName().c_str());
auto hccl_in_anchor = node->GetInDataAnchor(index);
if (hccl_in_anchor == nullptr) {
continue;
}
auto src_out_anchor = hccl_in_anchor->GetPeerOutAnchor();
if (src_out_anchor == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Node:%s(%s) input:%u anchor, peer anchor is nullptr, check invalid",
node->GetName().c_str(), node->GetType().c_str(), index);
GELOGE(INTERNAL_ERROR, "[Get][PeerOutAnchor] failed, Node:%s(%s) input:%u anchor, peer anchor is nullptr",
node->GetName().c_str(), node->GetType().c_str(), index);
return INTERNAL_ERROR;
}
if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType())) {
Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Modify][EdgeConnection] between %s and %s failed.",
src_out_anchor->GetOwnerNode()->GetName().c_str(), node->GetName().c_str());
return ret;
}
}
}
return SUCCESS;
}
bool HcclContinuousMemcpyPass::IsDataNode(const std::string& node_type) const {
return (node_type == CONSTANTOP) || OpTypeUtils::IsDataNode(node_type) || (node_type == CONSTANT) ||
OpTypeUtils::IsVarLikeNode(node_type);
}
OpDescPtr HcclContinuousMemcpyPass::CreateIdentityOpDesc(const OutDataAnchorPtr &out_data_anchor) {
NodePtr pre_node = out_data_anchor->GetOwnerNode();
OpDescPtr pre_op_desc = pre_node->GetOpDesc();
if (pre_op_desc == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "OpDesc in node is nullptr, check invalid");
GELOGE(INTERNAL_ERROR, "[Get][OpDesc] failed, OpDesc of pre node is invalid.");
return nullptr;
}
std::string node_name = pre_node->GetName() + "_HcclContinuousMemcpyPass_" + IDENTITY;
node_name = CheckDuplicateName(node_name);
OpDescBuilder op_desc_builder(node_name, IDENTITY);
auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx());
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build();
if (identity_op_desc == nullptr) {
return nullptr;
}
(void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false);
(void)AttrUtils::SetBool(identity_op_desc, ATTR_NAME_CANNOT_BE_DELETED, true);
return identity_op_desc;
}
std::string HcclContinuousMemcpyPass::CheckDuplicateName(const std::string &node_name) {
std::string tmp_name = node_name;
auto iter = node_num_map_.find(tmp_name);
if (iter != node_num_map_.end()) {
tmp_name = tmp_name + "_" + std::to_string(iter->second);
(iter->second)++;
} else {
node_num_map_[tmp_name] = 1;
}
return tmp_name;
}
Status HcclContinuousMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph,
const OutDataAnchorPtr &src_out_anchor,
const InDataAnchorPtr &hccl_in_anchor) {
GE_CHECK_NOTNULL(src_out_anchor->GetOwnerNode());
GE_CHECK_NOTNULL(hccl_in_anchor->GetOwnerNode());
Status ret = InsertIdentityBeforeHccl(src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Add][Identity] failed, var_node:%s, hccl_node:%s.",
src_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
return ret;
}
ret = InsertAssignAfterBroadcastIfNeed(graph, src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Add][Assign] failed, var_node:%s, hccl_node:%s.",
src_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
return ret;
}
return SUCCESS;
}
Status HcclContinuousMemcpyPass::InsertIdentityBeforeHccl(const OutDataAnchorPtr &src_out_anchor,
const InDataAnchorPtr &hccl_in_anchor) {
GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
OpDescPtr identity_op = CreateIdentityOpDesc(src_out_anchor);
GE_CHECK_NOTNULL(identity_op);
auto identity_node = GraphUtils::InsertNodeBefore(hccl_in_anchor,
identity_op, kAnchorIdentityIndex, kAnchorIdentityIndex);
if (identity_node == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.",
identity_op->GetName().c_str(), identity_op->GetType().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetType().c_str(),
hccl_in_anchor->GetIdx());
GELOGE(INTERNAL_ERROR, "[Insert][Node] %s(%s) before %s(%s) on index:%d input anchor failed.",
identity_op->GetName().c_str(), identity_op->GetType().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), hccl_in_anchor->GetOwnerNode()->GetType().c_str(),
hccl_in_anchor->GetIdx());
return FAILED;
}
return SUCCESS;
}
Status HcclContinuousMemcpyPass::InsertAssignAfterBroadcastIfNeed(const ComputeGraphPtr &graph,
const OutDataAnchorPtr &var_out_anchor,
const InDataAnchorPtr &hccl_in_anchor) {
if (hccl_in_anchor->GetOwnerNode()->GetType() != HCOMBROADCAST) {
GELOGD("%s not broadcast, no need to insert assign node", hccl_in_anchor->GetOwnerNode()->GetName().c_str());
return SUCCESS;
}
if (!OpTypeUtils::IsVarLikeNode(var_out_anchor->GetOwnerNode()->GetType())) {
GELOGD("%s not variable, no need to insert assign node", var_out_anchor->GetOwnerNode()->GetName().c_str());
return SUCCESS;
}
GELOGI("after op %s and op %s need insert assign op.", var_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
for (auto peer_in_anchor : var_out_anchor->GetPeerInDataAnchors()) {
if (peer_in_anchor->GetOwnerNode()->GetType() == ASSIGN) {
GELOGD("variable %s out assign node is exist.", var_out_anchor->GetOwnerNode()->GetName().c_str());
return SUCCESS;
}
}
NodePtr assign_node = CreateAssignNode(graph, var_out_anchor);
GE_CHECK_NOTNULL(assign_node);
OutDataAnchorPtr hccl_out_anchor = hccl_in_anchor->GetOwnerNode()->GetOutDataAnchor(hccl_in_anchor->GetIdx());
GE_CHECK_NOTNULL(hccl_out_anchor);
Status ret = hccl_out_anchor->LinkTo(assign_node->GetInDataAnchor(kAnchorAssignValueIndex));
GE_ASSERT_SUCCESS(ret, "Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed",
hccl_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_out_anchor->GetOwnerNode()->GetType().c_str(), hccl_out_anchor->GetIdx(),
assign_node->GetName().c_str(), assign_node->GetType().c_str(), kAnchorAssignValueIndex);
ret = var_out_anchor->LinkTo(assign_node->GetInDataAnchor(kAnchorAssignRefIndex));
GE_ASSERT_SUCCESS(ret, "Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed",
var_out_anchor->GetOwnerNode()->GetName().c_str(),
var_out_anchor->GetOwnerNode()->GetType().c_str(), var_out_anchor->GetIdx(),
assign_node->GetName().c_str(), assign_node->GetType().c_str(), kAnchorAssignRefIndex);
OutControlAnchorPtr assign_out_control_anchor = assign_node->GetOutControlAnchor();
GE_CHECK_NOTNULL(assign_out_control_anchor);
for (auto in_data_anchor : hccl_out_anchor->GetPeerInDataAnchors()) {
if (in_data_anchor->GetOwnerNode()->GetName() == assign_node->GetName()) {
continue;
}
ret = assign_out_control_anchor->LinkTo(in_data_anchor->GetOwnerNode()->GetInControlAnchor());
if (ret != SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Op:%s(%s) link control to op:%s(%s) failed",
assign_out_control_anchor->GetOwnerNode()->GetName().c_str(),
assign_out_control_anchor->GetOwnerNode()->GetType().c_str(),
in_data_anchor->GetOwnerNode()->GetName().c_str(),
in_data_anchor->GetOwnerNode()->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Add][Edge] Op:%s(%s) link control to op:%s(%s) failed",
assign_out_control_anchor->GetOwnerNode()->GetName().c_str(),
assign_out_control_anchor->GetOwnerNode()->GetType().c_str(),
in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetOwnerNode()->GetType().c_str());
return FAILED;
}
}
for (auto in_control_anchor : hccl_out_anchor->GetOwnerNode()->GetOutControlAnchor()->GetPeerInControlAnchors()) {
if (in_control_anchor->GetOwnerNode()->GetName() == assign_node->GetName()) {
continue;
}
ret = assign_out_control_anchor->LinkTo(in_control_anchor);
if (ret != SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Op:%s(%s) link control to op:%s(%s) failed",
assign_out_control_anchor->GetOwnerNode()->GetName().c_str(),
assign_out_control_anchor->GetOwnerNode()->GetType().c_str(),
in_control_anchor->GetOwnerNode()->GetName().c_str(),
in_control_anchor->GetOwnerNode()->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Add][Edge] Op:%s(%s) link control to op:%s(%s) failed",
assign_out_control_anchor->GetOwnerNode()->GetName().c_str(),
assign_out_control_anchor->GetOwnerNode()->GetType().c_str(),
in_control_anchor->GetOwnerNode()->GetName().c_str(),
in_control_anchor->GetOwnerNode()->GetType().c_str());
return FAILED;
}
}
return SUCCESS;
}
NodePtr HcclContinuousMemcpyPass::CreateAssignNode(const ComputeGraphPtr &graph,
const OutDataAnchorPtr &out_data_anchor) {
GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
NodePtr pre_node = out_data_anchor->GetOwnerNode();
OpDescPtr pre_op_desc = pre_node->GetOpDesc();
if (pre_op_desc == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "OpDesc in node is nullptr, check invalid");
GELOGE(INTERNAL_ERROR, "[Get][OpDesc] failed, OpDesc of pre node is invalid.");
return nullptr;
}
std::string node_name = pre_node->GetName() + "_" + ASSIGN;
node_name = CheckDuplicateName(node_name);
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), ASSIGN);
if (op_desc == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "New OpDesc failed");
GELOGE(INTERNAL_ERROR, "[New][OpDesc] failed.");
return nullptr;
}
GELOGI("Create Assign op:%s.", op_desc->GetName().c_str());
if (!AttrUtils::SetBool(op_desc, ATTR_NEED_COMPILE, true)) {
REPORT_INNER_ERR_MSG("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NEED_COMPILE.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NEED_COMPILE.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str());
return nullptr;
}
graphStatus ret = op_desc->AddInputDesc("ref", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Add input desc to op:%s(%s) failed, name:ref",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Add][InputDesc] to op:%s(%s) failed, name:ref",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
return nullptr;
}
ret = op_desc->AddInputDesc("value", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Add input desc to op:%s(%s) failed, name:value",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Add][InputDesc] to op:%s(%s) failed, name:value",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
return nullptr;
}
ret = op_desc->AddOutputDesc("ref", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Add output desc to op:%s(%s) failed, name:ref",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Add][OutputDesc] to op:%s(%s) failed, name:ref",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
return nullptr;
}
NodePtr assign_node = graph->AddNode(op_desc);
if (assign_node == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Add node:%s(%s) to graph:%s failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
GELOGE(INTERNAL_ERROR, "[Add][Node] %s(%s) to graph:%s failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
return nullptr;
}
return assign_node;
}
Status HcclContinuousMemcpyPass::ClearStatus() {
node_num_map_.clear();
return SUCCESS;
}
REG_PASS_OPTION("HcclContinuousMemcpyPass").LEVELS(OoLevel::kO0);
}