* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/passes/standard_optimize/remove_same_const_pass.h"
#include <sstream>
#include <string>
#include <set>
#include "common/base64.h"
#include "common/checker.h"
#include "common/compile_profiling/ge_call_wrapper.h"
#include "graph/utils/node_utils.h"
namespace ge {
namespace {
enum class KeyLevel {
kControlInNodes = 0,
kWeightSize,
kWeightDesc,
kAllAttrsSize,
kAllAttrsValue,
};
const AttrNameFilter attr_weights_filter = [](const std::string &attr_name) -> bool {
return (attr_name != ATTR_NAME_WEIGHTS);
};
std::string GetCseKeyByAttrs(const NodePtr node, const KeyLevel key_level,
const std::map<std::string, AnyValue> &attrs) {
std::string ss;
switch (key_level) {
case KeyLevel::kControlInNodes: {
(void)ss.append("type-").append(node->GetType()).append("-");
std::set<std::string> control_in_node_names;
for (const auto &src_node : node->GetInControlNodes()) {
control_in_node_names.insert(src_node->GetName());
}
for (const auto &name : control_in_node_names) {
(void)ss.append(name).append("-");
}
break;
}
case KeyLevel::kAllAttrsSize:
(void)ss.append("attr-size-").append(std::to_string(attrs.size()));
break;
case KeyLevel::kAllAttrsValue:
(void)ss.append("attrs-").append(AttrUtils::GetAllAttrsStr(attrs));
break;
default:;
}
return ss;
}
std::string GetCseKeyByTensor(const KeyLevel key_level, const GeTensorPtr &weight) {
std::string ss;
switch (key_level) {
case KeyLevel::kWeightSize: {
const size_t output_size = weight->GetData().GetSize();
(void)ss.append("weight-size-").append(std::to_string(output_size));
break;
}
case KeyLevel::kWeightDesc: {
std::map<std::string, AnyValue> tensor_desc;
tensor_desc.emplace(std::make_pair("weight_tensor_desc", AnyValue::CreateFrom(weight->GetTensorDesc())));
(void)ss.append("tensor_desc-").append(AttrUtils::GetAllAttrsStr(tensor_desc));
break;
}
default:;
}
return ss;
}
Status FilterSameKeyNodesByAttr(const std::map<std::string, std::vector<NodePtr>> &level_to_nodes,
const KeyLevel &key_level,
std::map<NodePtr, std::map<std::string, AnyValue>> &node_attrs,
std::map<std::string, std::vector<NodePtr>> &new_level_to_nodes) {
for (const auto &pair : level_to_nodes) {
if (pair.second.size() <= 1U) {
continue;
}
for (const auto &node : pair.second) {
if (key_level == KeyLevel::kAllAttrsSize) {
const auto &op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
node_attrs[node] = AttrUtils::GetAllAttrsWithFilter(op_desc, attr_weights_filter);
} else {
GE_ASSERT_TRUE(node_attrs.find(node) != node_attrs.end());
}
const auto &key = GetCseKeyByAttrs(node, key_level, node_attrs[node]);
new_level_to_nodes[pair.first + key].emplace_back(node);
}
}
return SUCCESS;
}
Status FilterSameKeyNodesByTensor(const std::map<std::string, std::vector<NodePtr>> &level_to_nodes,
const KeyLevel &key_level, std::map<NodePtr, GeTensorPtr> &node_tensor,
std::map<std::string, std::vector<NodePtr>> &new_level_to_nodes) {
for (const auto &pair : level_to_nodes) {
if (pair.second.size() <= 1U) {
continue;
}
for (const auto &node : pair.second) {
if (key_level == KeyLevel::kWeightSize) {
const auto &op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
GeTensorPtr weight;
(void)AttrUtils::MutableTensor(op_desc, ATTR_NAME_WEIGHTS, weight);
GE_CHECK_NOTNULL(weight);
node_tensor[node] = weight;
} else {
GE_ASSERT_TRUE(node_tensor.find(node) != node_tensor.end());
}
const auto &key = GetCseKeyByTensor(key_level, node_tensor[node]);
new_level_to_nodes[pair.first + key].emplace_back(node);
}
}
return SUCCESS;
}
Status FilterSameKeyNodesByValue(const std::map<std::string, std::vector<NodePtr>> &level_to_nodes,
std::map<NodePtr, GeTensorPtr> &node_tensor,
std::map<std::string, std::map<void *, std::vector<NodePtr>>> &new_level_to_nodes) {
for (const auto &pair : level_to_nodes) {
if (pair.second.size() <= 1U) {
continue;
}
for (const auto &node : pair.second) {
const auto iter_weight = node_tensor.find(node);
GE_ASSERT_TRUE(iter_weight != node_tensor.end());
const auto &weight = iter_weight->second;
GE_CHECK_NOTNULL(weight);
const auto &values = static_cast<void *>(weight->MutableData().data());
const auto iter = new_level_to_nodes.find(pair.first);
if (iter == new_level_to_nodes.end()) {
std::map<void *, std::vector<NodePtr>> value_nodes;
value_nodes[values].emplace_back(node);
new_level_to_nodes[pair.first] = std::move(value_nodes);
GELOGD("Add const node: %s, key: %s.", node->GetName().c_str(), pair.first.c_str());
continue;
}
auto &value_nodes = iter->second;
bool has_same_value = false;
const size_t weight_size = weight->GetData().size();
for (auto &value_pair : value_nodes) {
const auto &cur_value = value_pair.first;
if (memcmp(values, cur_value, weight_size) == 0) {
GELOGD("Const node: %s, has same key: %s.", node->GetName().c_str(), pair.first.c_str());
value_pair.second.emplace_back(node);
has_same_value = true;
break;
}
}
if (!has_same_value) {
value_nodes[values].emplace_back(node);
}
}
}
return SUCCESS;
}
Status FilterSameKeyNodesByGeTensor(const std::map<std::string, std::vector<NodePtr>> &level_to_nodes,
std::map<std::string, std::vector<NodePtr>> &new_level_to_nodes) {
std::map<std::string, std::vector<NodePtr>> weight_size_to_nodes;
std::map<NodePtr, GeTensorPtr> node_tensor;
GE_ASSERT_SUCCESS(
FilterSameKeyNodesByTensor(level_to_nodes, KeyLevel::kWeightSize, node_tensor, weight_size_to_nodes));
std::map<std::string, std::vector<NodePtr>> tensor_desc_to_nodes;
GE_ASSERT_SUCCESS(
FilterSameKeyNodesByTensor(weight_size_to_nodes, KeyLevel::kWeightDesc, node_tensor, tensor_desc_to_nodes));
std::map<std::string, std::map<void *, std::vector<NodePtr>>> weight_value_to_nodes;
GE_ASSERT_SUCCESS(FilterSameKeyNodesByValue(tensor_desc_to_nodes, node_tensor, weight_value_to_nodes));
for (const auto &pair : weight_value_to_nodes) {
for (const auto &value_nodes : pair.second) {
if (value_nodes.second.size() > 1U) {
const std::string key = pair.first + "weight-ptr-" + std::to_string(PtrToValue(value_nodes.first));
(void)new_level_to_nodes[key].insert(new_level_to_nodes[key].end(), value_nodes.second.begin(),
value_nodes.second.end());
}
}
}
return SUCCESS;
}
Status ReplaceConstNode(const ComputeGraphPtr &graph, const NodePtr &node, const NodePtr &to_be_removed_node) {
if (to_be_removed_node->GetAllOutDataAnchorsSize() != node->GetAllOutDataAnchorsSize()) {
GELOGW("The const node %s and %s have the same CSE key, but different output anchor count, skip to fusion them",
node->GetName().c_str(), to_be_removed_node->GetName().c_str());
return SUCCESS;
}
std::vector<int32_t> output_map(to_be_removed_node->GetAllOutDataAnchorsSize());
for (size_t i = 0U; i < to_be_removed_node->GetAllOutDataAnchorsSize(); ++i) {
output_map[i] = i;
}
if (GraphUtils::ReplaceNodeAnchors(node, to_be_removed_node, {}, output_map) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Replace to_be_removed_node:%s(%s)'s anchor by node:%s(%s) failed",
to_be_removed_node->GetName().c_str(), to_be_removed_node->GetType().c_str(),
node->GetName().c_str(), node->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Replace][Anchors] of to_be_removed_node:%s(%s) by node:%s(%s) failed",
to_be_removed_node->GetName().c_str(), to_be_removed_node->GetType().c_str(), node->GetName().c_str(),
node->GetType().c_str());
return INTERNAL_ERROR;
}
NodeUtils::UnlinkAll(*to_be_removed_node);
if (GraphUtils::RemoveNodeWithoutRelink(graph, to_be_removed_node) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Remove to_be_removed_node:%s(%s) without relink in graph:%s failed",
to_be_removed_node->GetName().c_str(), to_be_removed_node->GetType().c_str(),
graph->GetName().c_str());
GELOGE(INTERNAL_ERROR, "[Remove][Node] %s(%s) without relink in graph:%s failed",
to_be_removed_node->GetName().c_str(), to_be_removed_node->GetType().c_str(), graph->GetName().c_str());
return INTERNAL_ERROR;
}
GELOGI("Remove const to_be_removed_node %s by RemoveSameConstPass, replace it with node %s",
to_be_removed_node->GetName().c_str(), node->GetName().c_str());
return SUCCESS;
}
Status RemoveDuplicateNodes(const ComputeGraphPtr &graph,
const std::map<std::string, std::vector<NodePtr>> &level_last_to_nodes) {
for (const auto &pair : level_last_to_nodes) {
if (pair.second.size() <= 1U) {
continue;
}
auto iter = pair.second.begin();
const auto &saved_node = *iter;
GELOGD("The const node %s cse key %s", saved_node->GetName().c_str(),
ge::base64::EncodeToBase64(pair.first).c_str());
++iter;
for (; iter != pair.second.end(); ++iter) {
GE_CHK_STATUS_RET(ReplaceConstNode(graph, saved_node, *iter), "Replace const %s by node %s failed",
(*iter)->GetName().c_str(), saved_node->GetName().c_str());
}
}
return SUCCESS;
}
bool IsConstType(const NodePtr &node) {
return ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP));
}
}
Status RemoveSameConstPass::Run(ComputeGraphPtr graph) {
GELOGD("Begin to run RemoveSameConstPass on the graph");
GE_CHECK_NOTNULL(graph);
std::map<std::string, std::vector<NodePtr>> level0_to_nodes;
for (const auto &node : graph->GetDirectNode()) {
if (!IsConstType(node)) {
continue;
}
bool is_unknown = false;
auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown);
if (ret != GRAPH_SUCCESS) {
GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(),
node->GetType().c_str());
continue;
}
if (is_unknown) {
GELOGI("Current node %s, type %s is unknown shape which should be skip.", node->GetName().c_str(),
node->GetType().c_str());
continue;
}
const auto &key = GetCseKeyByAttrs(node, KeyLevel::kControlInNodes, {});
level0_to_nodes[key].emplace_back(node);
}
std::map<std::string, std::vector<NodePtr>> level1_to_nodes;
GE_ASSERT_SUCCESS(FilterSameKeyNodesByGeTensor(level0_to_nodes, level1_to_nodes));
std::map<NodePtr, std::map<std::string, AnyValue>> node_attrs;
std::map<std::string, std::vector<NodePtr>> level2_to_nodes;
GE_ASSERT_SUCCESS(FilterSameKeyNodesByAttr(level1_to_nodes, KeyLevel::kAllAttrsSize, node_attrs, level2_to_nodes));
std::map<std::string, std::vector<NodePtr>> level3_to_nodes;
GE_ASSERT_SUCCESS(FilterSameKeyNodesByAttr(level2_to_nodes, KeyLevel::kAllAttrsValue, node_attrs, level3_to_nodes));
GE_ASSERT_SUCCESS(RemoveDuplicateNodes(graph, level3_to_nodes));
GELOGD("Success to run RemoveSameConstPass on the graph %s.", graph->GetName().c_str());
return SUCCESS;
}
REG_PASS_OPTION("RemoveSameConstPass").LEVELS(OoLevel::kO3);
}