* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/passes/format_optimize/transop_symmetry_elimination_pass.h"
#include "formats/utils/formats_trans_utils.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/util.h"
#include "common/op/transop_util.h"
#include "common/checker.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/type_utils.h"
#include "framework/common/framework_types_internal.h"
namespace {
using FormatSymmFunc = std::function<bool(const ge::NodePtr &, const ge::NodePtr &)>;
const std::set<std::string> white_list_op{ge::TRANSPOSED, ge::RESHAPE, ge::REFORMAT, ge::CAST,
ge::TRANSDATA, ge::SQUEEZEV2, ge::UNSQUEEZEV2};
const std::unordered_map<std::string, std::string> symm_node_pairs{{ge::SQUEEZEV2, ge::UNSQUEEZEV2},
{ge::UNSQUEEZEV2, ge::SQUEEZEV2}};
constexpr size_t NCHW_DIM_NUM = 4U;
constexpr size_t NCHW_C_AXIS = 1U;
constexpr size_t NCHW_H_AXIS = 2U;
constexpr size_t NCHW_W_AXIS = 3U;
bool Nc1hwc02NchwAreSymmetry(const ge::NodePtr &src_node, const ge::NodePtr &dst_node) {
const auto &src_in_desc = src_node->GetOpDesc()->MutableInputDesc(0);
const auto &src_out_desc = src_node->GetOpDesc()->MutableOutputDesc(0);
const auto &dst_in_desc = dst_node->GetOpDesc()->MutableInputDesc(0);
const auto &dst_out_desc = dst_node->GetOpDesc()->MutableOutputDesc(0);
GE_CHECK_NOTNULL(src_in_desc);
GE_CHECK_NOTNULL(src_out_desc);
GE_CHECK_NOTNULL(dst_in_desc);
GE_CHECK_NOTNULL(dst_out_desc);
const auto &src_out_dims = src_out_desc->GetShape().GetDims();
const bool is_src_out_dims_unknown = src_out_desc->GetShape().IsUnknownShape();
const auto &dst_in_dims = dst_in_desc->GetShape().GetDims();
const bool is_dst_in_dims_unknown = dst_in_desc->GetShape().IsUnknownShape();
GE_ASSERT_TRUE(src_out_dims.size() == NCHW_DIM_NUM);
GE_ASSERT_TRUE(dst_in_dims.size() == NCHW_DIM_NUM);
const bool is_hw_size_equal = !is_src_out_dims_unknown && !is_dst_in_dims_unknown &&
((src_out_dims[NCHW_H_AXIS] * src_out_dims[NCHW_W_AXIS]) ==
(dst_in_dims[NCHW_H_AXIS] * dst_in_dims[NCHW_W_AXIS]));
const bool is_c_axis_same = (src_out_dims[NCHW_C_AXIS] == dst_in_dims[NCHW_C_AXIS]);
const bool is_padding = (src_in_desc->GetShape().GetShapeSize() != src_out_desc->GetShape().GetShapeSize());
GELOGI("enter to %s 5hd_nchw %s judge, is_hw_size_equal is %d, is_c_axis_same is %d, is_padding is %d",
src_node->GetNamePtr(), dst_node->GetNamePtr(), is_hw_size_equal, is_c_axis_same, is_padding);
if (is_padding) {
return (is_hw_size_equal && is_c_axis_same);
} else {
return is_hw_size_equal;
}
}
bool Nchw2Nc1hwc0AreSymmetry(const ge::NodePtr &src_node, const ge::NodePtr &dst_node) {
return Nc1hwc02NchwAreSymmetry(dst_node, src_node);
}
const std::unordered_map<ge::Format, std::pair<ge::Format, FormatSymmFunc>> symm_format_pairs_func{
{ge::FORMAT_NCHW, {ge::FORMAT_NC1HWC0, Nchw2Nc1hwc0AreSymmetry}},
{ge::FORMAT_NC1HWC0, {ge::FORMAT_NCHW, Nc1hwc02NchwAreSymmetry}}};
bool AreNodeTypeSymmetry(const ge::NodePtr &src_node, const ge::NodePtr &dst_node) {
const auto &iter = symm_node_pairs.find(src_node->GetType());
if ((iter != symm_node_pairs.cend()) && (iter->second == dst_node->GetType())) {
return true;
}
if (src_node->GetType() == dst_node->GetType()) {
return true;
}
GELOGD("Pre node %s type %s is not equal with node %s type %s. Ignore pass.", src_node->GetName().c_str(),
src_node->GetType().c_str(), dst_node->GetName().c_str(), dst_node->GetType().c_str());
return false;
}
}
namespace ge {
Status TransOpSymmetryEliminationPass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
if (white_list_op.find(node->GetType()) == white_list_op.end()) {
return SUCCESS;
}
GELOGD("Symmetry Elimination Pass in");
for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
GE_CHECK_NOTNULL(out_anchor);
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_CHECK_NOTNULL(peer_in_anchor);
GE_CHECK_NOTNULL(peer_in_anchor->GetOwnerNode());
GE_CHECK_NOTNULL(peer_in_anchor->GetOwnerNode()->GetOpDesc());
if (!CheckCanBeEliminated(node, peer_in_anchor)) { continue; }
auto dst_node = peer_in_anchor->GetOwnerNode();
Status ret = EliminateTransOp(node, out_anchor, dst_node, peer_in_anchor);
if (ret != SUCCESS) {
GELOGW("Eliminate %s and %s failed, ignore current pass.", node->GetName().c_str(),
dst_node->GetName().c_str());
return ret;
}
}
}
GELOGD("Symmetry Elimination Pass end");
return SUCCESS;
}
bool TransOpSymmetryEliminationPass::CheckCanBeEliminated(const ge::NodePtr &src_node,
const InDataAnchorPtr &dst_in_anchor) {
auto dst_node = dst_in_anchor->GetOwnerNode();
if (!AreNodeTypeSymmetry(src_node, dst_node)) {
return false;
}
if (dst_in_anchor->GetIdx() != TransOpUtil::GetTransOpDataIndex(src_node)) {
GELOGD("Next node %s type %s input %d is not for transform. Ignore pass.", dst_node->GetName().c_str(),
dst_node->GetType().c_str(), dst_in_anchor->GetIdx());
return false;
}
if (src_node->GetType() == ge::RESHAPE) {
GE_CHECK_NOTNULL(src_node->GetOpDesc());
auto unknown_dims_num = GetUnknownDimsNum(src_node->GetOpDesc()->GetInputDesc(0));
if (unknown_dims_num != 0 && (unknown_dims_num == UNKNOWN_DIM_NUM || unknown_dims_num > 1)) {
GELOGD("Pre node %s is reshape op which input is dynamic shape and has more than one unknown dimension. "
"Ignore pass.",
src_node->GetName().c_str());
return false;
}
} else if (src_node->GetType() == ge::TRANSPOSED) {
if (!JudgeTransposeDBack2Raw(src_node, dst_node)) {
GELOGD("Two Transpose op src node %s dst node %s will change the raw data. Ignore pass.",
src_node->GetName().c_str(), dst_node->GetName().c_str());
return false;
}
} else if (src_node->GetType() == ge::TRANSDATA) {
auto unknown_dims_num = GetUnknownDimsNum(src_node->GetOpDesc()->GetInputDesc(0));
if (unknown_dims_num == UNKNOWN_DIM_NUM) {
GELOGD("Pre node %s is transdata op which input is dynamic shape and all dimension are unknown(-2). Ignore pass.",
src_node->GetName().c_str());
return false;
}
} else if ((src_node->GetType() == ge::SQUEEZEV2) || (src_node->GetType() == ge::UNSQUEEZEV2)) {
std::vector<int32_t> src_axis;
std::vector<int32_t> dst_axis;
(void)AttrUtils::GetListInt(src_node->GetOpDesc(), ATTR_NAME_AXIS, src_axis);
(void)AttrUtils::GetListInt(dst_node->GetOpDesc(), ATTR_NAME_AXIS, dst_axis);
if (src_axis != dst_axis) {
GELOGD("Src node %s aixs not equal with dst node %s. Ignore pass.", src_node->GetName().c_str(),
dst_node->GetName().c_str());
return false;
}
}
return (!TransOpUtil::IsPrecisionLoss(src_node)) && DescAreSymmetry(src_node, dst_node);
}
bool TransOpSymmetryEliminationPass::DescAreSymmetry(const NodePtr &src_node, const NodePtr &dst_node) {
const auto &src_input_desc = src_node->GetOpDesc()->MutableInputDesc(0);
const auto &src_output_desc = src_node->GetOpDesc()->MutableOutputDesc(0);
const auto &dst_input_desc = dst_node->GetOpDesc()->MutableInputDesc(0);
const auto &dst_output_desc = dst_node->GetOpDesc()->MutableOutputDesc(0);
GE_CHECK_NOTNULL(src_input_desc);
GE_CHECK_NOTNULL(src_output_desc);
GE_CHECK_NOTNULL(dst_input_desc);
GE_CHECK_NOTNULL(dst_output_desc);
const auto &src_input_dtype = src_input_desc->GetDataType();
const auto &src_input_format = src_input_desc->GetFormat();
const auto &src_input_origin_format = src_input_desc->GetOriginFormat();
const auto &src_input_shape = src_input_desc->GetShape().GetDims();
const auto &src_input_origin_shape = src_input_desc->GetOriginShape().GetDims();
const auto &src_output_format = src_output_desc->GetFormat();
const auto &dst_input_format = dst_input_desc->GetFormat();
const auto &dst_output_dtype = dst_output_desc->GetDataType();
const auto &dst_output_format = dst_output_desc->GetFormat();
const auto &dst_output_origin_format = dst_output_desc->GetOriginFormat();
const auto &dst_output_shape = dst_output_desc->GetShape().GetDims();
const auto &dst_output_origin_shape = dst_output_desc->GetOriginShape().GetDims();
bool is_symmetry = true;
if (src_node->GetType() == CAST && dst_node->GetType() == CAST) {
bool is_format_symmetry =
(src_input_format == dst_output_format) || (dst_output_format == FORMAT_ND) || (src_input_format == FORMAT_ND);
is_symmetry = (src_input_dtype == dst_output_dtype) && is_format_symmetry;
} else {
bool is_format_continuously = (src_output_format == dst_input_format);
is_symmetry = (src_input_dtype == dst_output_dtype) && (src_input_format == dst_output_format) &&
(src_input_shape == dst_output_shape);
is_symmetry = is_symmetry && is_format_continuously;
if (src_input_origin_format == dst_output_origin_format) {
is_symmetry = is_symmetry && (src_input_origin_shape == dst_output_origin_shape);
}
}
const bool need_judge_again = !is_symmetry && (src_node->GetType() == TRANSDATA) &&
(dst_node->GetType() == TRANSDATA) && (src_input_shape != dst_output_shape);
if (need_judge_again) {
is_symmetry = IsTransdataMemLayoutSymmetry(src_node, dst_node);
}
GELOGI("Desc check ret is %d."
"Src node %s input type: %s primary_format: %s sub_format: %d shape: [%s], origin_shape: [%s],"
"output primary_format: %s. Dst node %s input primary_format: %s,"
"output type: %s primary_format: %s sub_format: %d shape: [%s], origin_shape: [%s].",
is_symmetry, src_node->GetName().c_str(), TypeUtils::DataTypeToSerialString(src_input_dtype).c_str(),
TypeUtils::FormatToSerialString(src_input_format).c_str(), GetSubFormat(src_input_format),
formats::ShapeToString(src_input_shape).c_str(), formats::ShapeToString(src_input_origin_shape).c_str(),
TypeUtils::FormatToSerialString(src_output_format).c_str(),
dst_node->GetName().c_str(), TypeUtils::FormatToSerialString(dst_input_format).c_str(),
TypeUtils::DataTypeToSerialString(dst_output_dtype).c_str(),
TypeUtils::FormatToSerialString(dst_output_format).c_str(), GetSubFormat(dst_output_format),
formats::ShapeToString(dst_output_shape).c_str(), formats::ShapeToString(dst_output_origin_shape).c_str());
return is_symmetry;
}
bool TransOpSymmetryEliminationPass::IsTransdataMemLayoutSymmetry(const NodePtr &src_node, const NodePtr &dst_node) {
auto src_in_desc = src_node->GetOpDesc()->GetInputDescPtr(0);
auto src_out_desc = src_node->GetOpDesc()->GetOutputDescPtr(0);
auto dst_in_desc = dst_node->GetOpDesc()->GetInputDescPtr(0);
auto dst_out_desc = dst_node->GetOpDesc()->GetOutputDescPtr(0);
GE_CHECK_NOTNULL(src_in_desc);
GE_CHECK_NOTNULL(src_out_desc);
GE_CHECK_NOTNULL(dst_in_desc);
GE_CHECK_NOTNULL(dst_out_desc);
const bool is_format_dt_shapesize_continuously = ((src_in_desc->GetFormat() == dst_out_desc->GetFormat()) &&
(src_out_desc->GetFormat() == dst_in_desc->GetFormat()) &&
(src_in_desc->GetDataType() == dst_out_desc->GetDataType()) &&
(src_out_desc->GetDataType() == dst_in_desc->GetDataType()) &&
(src_in_desc->GetShape().GetShapeSize() == dst_out_desc->GetShape().GetShapeSize()) &&
(src_out_desc->GetShape().GetShapeSize() == dst_in_desc->GetShape().GetShapeSize()));
if (!is_format_dt_shapesize_continuously) {
return false;
}
GELOGI("shape is not equal, try to judge %s and %s can fusion, src_format %d(%s), dst_format %d(%s)",
src_node->GetNamePtr(), dst_node->GetNamePtr(), src_in_desc->GetFormat(),
TypeUtils::FormatToSerialString(src_in_desc->GetFormat()).c_str(), src_out_desc->GetFormat(),
TypeUtils::FormatToSerialString(src_out_desc->GetFormat()).c_str());
const Format src_in_primary_format =
static_cast<Format>(GetPrimaryFormat(static_cast<int32_t>(src_in_desc->GetFormat())));
const Format src_out_primary_format =
static_cast<Format>(GetPrimaryFormat(static_cast<int32_t>(src_out_desc->GetFormat())));
const auto &it = symm_format_pairs_func.find(src_in_primary_format);
if (it != symm_format_pairs_func.end() && it->second.first == src_out_primary_format) {
return it->second.second(src_node, dst_node);
}
GELOGI("shape is not equal, src_format %s, dst_format %s skip judge",
TypeUtils::FormatToSerialString(src_in_desc->GetFormat()).c_str(),
TypeUtils::FormatToSerialString(src_out_desc->GetFormat()).c_str());
return false;
}
int32_t TransOpSymmetryEliminationPass::GetUnknownDimsNum(const GeTensorDesc& node_desc) {
int32_t unknown_dims_num = 0;
auto ge_shape = node_desc.GetShape();
for (const auto dim : ge_shape.GetDims()) {
if (dim == UNKNOWN_DIM_NUM) { return UNKNOWN_DIM_NUM; }
if (dim == UNKNOWN_DIM) { ++unknown_dims_num; }
}
return unknown_dims_num;
}
bool TransOpSymmetryEliminationPass::JudgeTransposeDBack2Raw(const NodePtr &src_node, const NodePtr &dst_node) {
std::vector<int64_t> src_node_perm;
(void)AttrUtils::GetListInt(src_node->GetOpDesc(), ge::PERMUTE_ATTR_PERM, src_node_perm);
std::vector<int64_t> dst_node_perm;
(void)AttrUtils::GetListInt(dst_node->GetOpDesc(), ge::PERMUTE_ATTR_PERM, dst_node_perm);
if (src_node_perm.size() != dst_node_perm.size()) { return false; }
for (size_t src_index = 0; src_index < src_node_perm.size(); ++src_index) {
if (dst_node_perm[src_index] >= static_cast<int64_t>(src_node_perm.size())) { return false; }
if (static_cast<int64_t>(src_index) != src_node_perm[dst_node_perm[src_index]]) { return false; }
}
return true;
}
Status TransOpSymmetryEliminationPass::EliminateTransOp(NodePtr &src_node, const OutDataAnchorPtr &src_out_anchor,
NodePtr &dst_node, const InDataAnchorPtr &dst_in_anchor) {
auto ret = src_out_anchor->Unlink(dst_in_anchor);
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999",
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed",
src_out_anchor->GetOwnerNode()->GetName().c_str(),
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(),
dst_in_anchor->GetOwnerNode()->GetName().c_str(),
dst_in_anchor->GetOwnerNode()->GetType().c_str(), dst_in_anchor->GetIdx());
GELOGE(FAILED, "[Unlink][DataAnchor] from %s(%s)(index:%d) to %s(%s)(index:%d) failed.",
src_out_anchor->GetOwnerNode()->GetName().c_str(),
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(),
dst_in_anchor->GetOwnerNode()->GetName().c_str(),
dst_in_anchor->GetOwnerNode()->GetType().c_str(), dst_in_anchor->GetIdx());
return ret;
}
auto data_idx = TransOpUtil::GetTransOpDataIndex(src_node);
auto in_anchor = src_node->GetInDataAnchor(data_idx);
GE_CHECK_NOTNULL(in_anchor);
GE_CHECK_NOTNULL(in_anchor->GetPeerOutAnchor());
auto pre_normal_node = in_anchor->GetPeerOutAnchor()->GetOwnerNode();
ret = GraphUtils::AddEdge(in_anchor->GetPeerOutAnchor(), dst_in_anchor);
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
pre_normal_node->GetName().c_str(), pre_normal_node->GetType().c_str(),
in_anchor->GetPeerOutAnchor()->GetIdx(),
dst_in_anchor->GetOwnerNode()->GetName().c_str(),
dst_in_anchor->GetOwnerNode()->GetType().c_str(), dst_in_anchor->GetIdx());
GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
pre_normal_node->GetName().c_str(), pre_normal_node->GetType().c_str(),
in_anchor->GetPeerOutAnchor()->GetIdx(),
dst_in_anchor->GetOwnerNode()->GetName().c_str(),
dst_in_anchor->GetOwnerNode()->GetType().c_str(), dst_in_anchor->GetIdx());
return ret;
}
ret = GraphUtils::CopyInCtrlEdges(src_node, dst_node);
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Copy in control edge from node:%s(%s) to node:%s(%s) failed",
src_node->GetName().c_str(), src_node->GetType().c_str(),
dst_node->GetName().c_str(), dst_node->GetType().c_str());
GELOGE(FAILED, "[Copy][InCtrlEdges] from node:%s(%s) to node:%s(%s) failed",
src_node->GetName().c_str(), src_node->GetType().c_str(),
dst_node->GetName().c_str(), dst_node->GetType().c_str());
return ret;
}
for (const auto &in_node : src_node->GetInDataNodes()) {
if (in_node->GetName() == pre_normal_node->GetName()) { continue; }
ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), dst_node->GetInControlAnchor());
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
in_node->GetName().c_str(), in_node->GetType().c_str(),
dst_node->GetName().c_str(), dst_node->GetType().c_str());
GELOGE(FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
in_node->GetName().c_str(), in_node->GetType().c_str(),
dst_node->GetName().c_str(), dst_node->GetType().c_str());
return ret;
}
}
ret = IsolateAndDeleteNode(dst_node, {0});
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Isolate and delete node:%s(%s) failed",
dst_node->GetName().c_str(), dst_node->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[IsolateAndDelete][Node] failed, node name:%s, node type:%s ",
dst_node->GetName().c_str(), dst_node->GetType().c_str());
return ret;
}
GELOGI("Trans op symmetry eliminate successfully. Node %s has been removed.", dst_node->GetName().c_str());
ret = RemoveTransOpWithoutOutput(pre_normal_node, src_node);
if (ret != GRAPH_SUCCESS) {
GELOGE(ret, "[Call][RemoveTransOpWithoutOutput] for node:%s(%s) failed",
src_node->GetName().c_str(), src_node->GetType().c_str());
return ret;
}
return SUCCESS;
}
Status TransOpSymmetryEliminationPass::RemoveTransOpWithoutOutput(NodePtr &pre_node, NodePtr &trans_node) {
if (trans_node->GetOutDataNodesSize() == 0) {
Status ret = GraphUtils::CopyOutCtrlEdges(trans_node, pre_node);
if (ret != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E19999", "Copy out control edge from node:%s(%s) to node:%s(%s) failed",
trans_node->GetName().c_str(), trans_node->GetType().c_str(),
pre_node->GetName().c_str(), pre_node->GetType().c_str());
GELOGE(FAILED, "[Copy][OutCtrlEdges] from %s to %s failed.", trans_node->GetName().c_str(),
pre_node->GetName().c_str());
return ret;
}
ret = IsolateAndDeleteNode(trans_node, {});
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "[IsolateAndDelete][Node] %s(%s) failed", trans_node->GetName().c_str(),
trans_node->GetType().c_str());
return ret;
}
GELOGI("Trans op symmetry eliminate successfully. Node %s has been removed.", trans_node->GetName().c_str());
}
return SUCCESS;
}
REG_PASS_OPTION("TransOpSymmetryEliminationPass").LEVELS(OoLevel::kO3);
}