* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstring>
#include "node_matcher.h"
#include "graph/utils/constant_utils.h"
#include "graph/utils/op_type_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "framework/common/debug/ge_log.h"
#include "common/checker.h"
#include "common/math/math_util.h"
#include "graph/ir_definitions_recover.h"
#include "exe_graph/lowering/bg_ir_attrs.h"
namespace ge {
namespace fusion{
namespace {
template<typename T>
bool IsTensorDataEqualWith(ConstGeTensorPtr &a_tensor, ConstGeTensorPtr &b_tensor, size_t shape_size) {
const auto *a_value = reinterpret_cast<const T *>(a_tensor->GetData().data());
const auto *b_value = reinterpret_cast<const T *>(b_tensor->GetData().data());
for (size_t i = 0U; i < shape_size; ++i) {
if(!IsEqualWith(a_value[i], b_value[i])) {
return false;
}
}
return true;
}
bool IsTensorEqualWith(ConstGeTensorPtr &a_tensor, ConstGeTensorPtr &b_tensor) {
const auto &a_tensor_desc = a_tensor->GetTensorDesc();
const auto &b_tensor_desc = b_tensor->GetTensorDesc();
if (!TensorUtils::IsShapeEqual(a_tensor_desc.GetShape(), b_tensor_desc.GetShape())) {
return false;
}
if (a_tensor_desc.GetDataType() != b_tensor_desc.GetDataType()) {
return false;
}
if (a_tensor->GetData().GetSize() != b_tensor->GetData().GetSize()) {
return false;
}
return memcmp(a_tensor->GetData().data(), b_tensor->GetData().data(), a_tensor->GetData().GetSize()) == 0;
}
NodePtr GetRealConst(const NodePtr &node, bool enable_cross_subgraph) {
if (ConstantUtils::IsRealConst(node->GetOpDesc())) {
return node;
} else if (OpTypeUtils::IsSubgraphInnerData(node->GetOpDesc())){
if (!enable_cross_subgraph) {
return nullptr;
}
auto real_node = ge::NodeUtils::GetInNodeCrossSubgraph(node);
GE_ASSERT_NOTNULL(real_node);
if (!ConstantUtils::IsRealConst(real_node->GetOpDesc())) {
GELOGI("[Matching][Const]The type of %s is not constant.", real_node->GetTypePtr());
return nullptr;
}
return real_node;
}
return nullptr;
}
std::string AttrNamesToString(const std::vector<std::string> &attr_names) {
std::stringstream ss;
ss << "[";
for (const auto &attr_name : attr_names) {
ss << "{" << attr_name << "}";
}
ss << "]";
return ss.str();
}
bool IsAttrNamesMatch(const NodePtr &p_node, const NodePtr &t_node, const std::vector<std::string> &p_attr_names) {
const auto &t_attr_names = t_node->GetOpDesc()->GetIrAttrNames();
GE_WARN_ASSERT(
p_attr_names.size() == t_attr_names.size(),
"[AttrMiss] Ir attr num is not match. P node[%s][%s] attr names: %s attr num is %zu, T node [%s][%s] attr names"
":%s attr num is %zu.",
p_node->GetNamePtr(), p_node->GetTypePtr(), AttrNamesToString(p_attr_names).c_str(), p_attr_names.size(),
t_node->GetNamePtr(), t_node->GetTypePtr(), AttrNamesToString(t_attr_names).c_str(), t_attr_names.size());
GE_WARN_ASSERT(
std::is_permutation(p_attr_names.begin(), p_attr_names.end(), t_attr_names.begin()),
"[AttrMiss] Ir attr names is not match. P node[%s][%s] attr names: %s, T node [%s][%s] attr names: %s.",
p_node->GetNamePtr(), p_node->GetTypePtr(), AttrNamesToString(p_attr_names).c_str(), t_node->GetNamePtr(),
t_node->GetTypePtr(), AttrNamesToString(t_attr_names).c_str());
return true;
}
bool IsAttrValuesMatch(const NodePtr &p_node, const NodePtr &t_node, const std::vector<std::string> &ir_attr_names) {
size_t attr_num = ir_attr_names.size();
std::vector<std::vector<uint8_t>> p_attr_values, t_attr_values;
GE_WARN_ASSERT(gert::bg::GetAllIrAttrs(p_node, p_attr_values));
GE_WARN_ASSERT(gert::bg::GetAllIrAttrs(t_node, t_attr_values));
GE_WARN_ASSERT(p_attr_values.size() == attr_num,
"P node[%][%s] ir attr value num:[%zu] is not equal with ir attr def num:[%zu]", p_node->GetNamePtr(),
p_node->GetTypePtr(), p_attr_values.size(), attr_num);
GE_WARN_ASSERT(t_attr_values.size() == attr_num,
"T node[%][%s] ir attr value num:[%zu] is not equal with ir attr def num:[%zu]", t_node->GetNamePtr(),
t_node->GetTypePtr(), t_attr_values.size(), attr_num);
for (size_t i = 0U; i < attr_num; ++i) {
const auto &attr_name = ir_attr_names[i];
const auto &p_attr_value_buf = p_attr_values[i];
const auto &t_attr_value_buf = t_attr_values[i];
if (p_attr_value_buf.size() != t_attr_value_buf.size()) {
GELOGD(
"[AttrMiss] Ir attr value size is not match. Attr name [%s], P node[%s][%s] attr value size: %zu, T node "
"[%s][%s] attr value size: %zu.",
attr_name.c_str(), p_node->GetNamePtr(), p_node->GetTypePtr(), p_attr_value_buf.size(), t_node->GetNamePtr(),
t_node->GetTypePtr(), t_attr_value_buf.size());
return false;
}
if (memcmp(p_attr_value_buf.data(), t_attr_value_buf.data(), p_attr_value_buf.size()) != 0) {
GELOGD("[AttrMiss] Ir attr value is not match. Attr name [%s], P node[%s][%s], T node [%s][%s].",
attr_name.c_str(), p_node->GetNamePtr(), p_node->GetTypePtr(), t_node->GetNamePtr(), t_node->GetTypePtr());
return false;
}
}
return true;
}
bool IsIrAttrMatch(const NodePtr &p_node, const NodePtr &t_node) {
const auto &p_attr_names = p_node->GetOpDesc()->GetIrAttrNames();
if (!IsAttrNamesMatch(p_node, t_node, p_attr_names)) {
return false;
}
return IsAttrValuesMatch(p_node, t_node, p_attr_names);
}
}
bool NormalNodeMatcher::IsMatch(const NodePtr &p_node, const NodePtr &t_node) const {
if (!enable_ir_attr_match_) {
return strcmp(p_node->GetTypePtr(), t_node->GetTypePtr()) == 0;
}
GE_ASSERT_GRAPH_SUCCESS(RecoverOpDescIrDefinition(p_node->GetOpDesc()));
GE_ASSERT_GRAPH_SUCCESS(RecoverOpDescIrDefinition(t_node->GetOpDesc()));
return IsIrAttrMatch(p_node, t_node);
}
bool ConstantMatcher::IsMatch(const NodePtr &p_node, const NodePtr &t_node) const {
auto real_const = GetRealConst(t_node, enable_match_cross_subgraph_);
if (real_const == nullptr) {
return false;
}
if (!enable_value_match_) {
return true;
}
ConstGeTensorPtr p_tensor;
if (!ConstantUtils::GetWeight(p_node->GetOpDesc(), 0, p_tensor)) {
return true;
}
ConstGeTensorPtr t_tensor;
if (!ConstantUtils::GetWeight(real_const->GetOpDesc(), 0, t_tensor)) {
GELOGW("[Matching][Const]%s has no weight which is invalid const.", t_node->GetNamePtr());
return false;
}
return IsTensorEqualWith(p_tensor, t_tensor);
}
}
}