* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/passes/variable_optimize/variable_ref_delete_op_pass.h"
#include <string>
#include "graph/utils/node_utils.h"
#include "graph/utils/op_type_utils.h"
#include "common/checker.h"
namespace ge {
Status VariableRefDeleteOpPass::Run(ge::ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph);
std::unordered_set<std::string> all_var_names;
auto root_graph = GraphUtils::FindRootGraph(graph);
GE_CHECK_NOTNULL(root_graph);
for (const auto &n : root_graph->GetAllNodes()) {
if (OpTypeUtils::IsVarLikeNode(n->GetType())) {
all_var_names.insert(n->GetName());
}
}
for (auto &node : graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node->GetOpDesc());
std::string ref_var_src_var_name;
bool is_variable_ref = (OpTypeUtils::IsVarLikeNode(node->GetOpDesc()->GetType())) &&
(ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name));
if (!is_variable_ref) {
continue;
}
if (all_var_names.count(ref_var_src_var_name) == 0) {
REPORT_INNER_ERR_MSG("E19999", "Cannot find source variable[%s] of variable ref[%s], check invalid",
ref_var_src_var_name.c_str(), node->GetName().c_str());
GELOGE(FAILED, "[Check][Param] Cannot find source variable[%s] of variable ref[%s]",
ref_var_src_var_name.c_str(), node->GetName().c_str());
return FAILED;
}
Status ret = DealVariableRef(graph, node, ref_var_src_var_name);
if (ret != SUCCESS) {
GELOGE(ret, "[Deal][VariableRef] [%s] in graph:%s failed", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
}
return SUCCESS;
}
Status VariableRefDeleteOpPass::DealVariableRef(ge::ComputeGraphPtr &graph, ge::NodePtr &variable_ref,
const std::string &ref_var_src_var_name) const {
GE_CHECK_NOTNULL(variable_ref);
auto inAnchor0 = variable_ref->GetInDataAnchor(0);
if (inAnchor0 == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Node:%s(%s) has no input anchor, check invalid",
variable_ref->GetName().c_str(), variable_ref->GetType().c_str());
GELOGE(FAILED, "[Get][InDataAnchor] failed, variable_ref [%s] no input", variable_ref->GetName().c_str());
return FAILED;
}
GE_CHECK_NOTNULL(inAnchor0->GetPeerOutAnchor());
int32_t index = inAnchor0->GetPeerOutAnchor()->GetIdx();
NodePtr peer_node = inAnchor0->GetPeerOutAnchor()->GetOwnerNode();
auto op_desc = peer_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto out_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(index));
if (peer_node->GetType() == PARTITIONEDCALL) {
for (const auto &netoutput_node : ge::NodeUtils::GetSubgraphOutputNodes(*peer_node)) {
GE_CHECK_NOTNULL(netoutput_node->GetOpDesc());
for (size_t i = 0U; i < netoutput_node->GetOpDesc()->GetAllInputsSize(); ++i) {
if (netoutput_node->GetOpDesc()->MutableInputDesc(i) == nullptr) {
continue;
}
int32_t idx = -1;
(void) AttrUtils::GetInt(netoutput_node->GetOpDesc()->MutableInputDesc(i), ATTR_NAME_PARENT_NODE_INDEX, idx);
if (idx == index) {
GE_ASSERT_NOTNULL(netoutput_node->GetInDataAnchor(i));
const auto out_anchor = netoutput_node->GetInDataAnchor(i)->GetPeerOutAnchor();
GE_CHECK_NOTNULL(out_anchor);
GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc());
out_desc = out_anchor->GetOwnerNode()->GetOpDesc()->MutableOutputDesc(out_anchor->GetIdx());
break;
}
}
}
}
if (ref_var_src_var_name != op_desc->GetName()) {
bool is_set_str = ge::AttrUtils::SetStr(out_desc, REF_VAR_SRC_VAR_NAME, ref_var_src_var_name);
GE_ASSERT_TRUE(is_set_str, "[Set][Attr] %s to output:%d desc of op:%s(%s) failed", REF_VAR_SRC_VAR_NAME.c_str(),
index, op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGI("[%s-%d]: add attr [REF_VAR_SRC_VAR_NAME: %s ] ", peer_node->GetName().c_str(), index,
ref_var_src_var_name.c_str());
}
if (GraphUtils::IsolateNode(variable_ref, {0}) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Isolate node:%s(%s) failed",
variable_ref->GetName().c_str(), variable_ref->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Isolate][Node] name:%s, type:%s failed", variable_ref->GetName().c_str(),
variable_ref->GetType().c_str());
return FAILED;
}
if (GraphUtils::RemoveNodeWithoutRelink(graph, variable_ref) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Remove node:%s(%s) without relink in graph:%s failed",
variable_ref->GetName().c_str(), variable_ref->GetType().c_str(), graph->GetName().c_str());
GELOGE(INTERNAL_ERROR, "[Remove][Node] %s(%s) without relink in graph:%s failed",
variable_ref->GetName().c_str(), variable_ref->GetType().c_str(), graph->GetName().c_str());
return FAILED;
}
return SUCCESS;
}
REG_PASS_OPTION("VariableRefDeleteOpPass").LEVELS(OoLevel::kO3);
}