* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* @brief Permute fusion pass (Permute --> TransposeD / Transpose)
* @details
* Pattern: matches a Permute node.
* Replacement:
* - 310B/310P/910/910B/910_93 platforms: Permute --> TransposeD (perm as attr)
* - All other platforms (950 and newer): Permute --> Transpose (perm as input)
*/
#include <vector>
#include <string>
#include "ge/es_graph_builder.h"
#include "ge/compliant_node_builder.h"
#include "es_math_ops.h"
#include "platform/platform_info.h"
#include "ge/ge_utils.h"
#include "log/log.h"
#include "permute_fusion_pass.h"
using namespace ge;
using namespace fe;
using namespace ge::fusion;
namespace ops {
const std::string kFusionPassName = "PermuteFusionPass";
const int64_t kCapturePermuteIdx = 0L;
const std::set<std::string> kTransposeDPlatformList = {
"Ascend310B", "Ascend310P", "Ascend910", "Ascend910B", "Ascend910_93"};
static bool IsTransposeDPlatform()
{
PlatformInfo platformInfo;
OptionalInfo optionalInfo;
if (unlikely(PlatformInfoManager::Instance().GetPlatformInfoWithOutSocVersion(
platformInfo, optionalInfo) != SUCCESS)) {
OP_LOGE(kFusionPassName.c_str(), "Get platform_info failed.");
return false;
}
const std::string soc = platformInfo.str_info.short_soc_version;
return kTransposeDPlatformList.count(soc) > 0;
}
static bool GetPermutePermAttr(const GNode& permuteNode, std::vector<int64_t>& permList)
{
if (permuteNode.GetAttr("perm", permList) == GRAPH_SUCCESS) {
return true;
}
if (permuteNode.GetAttr("order", permList) == GRAPH_SUCCESS) {
OP_LOGD(kFusionPassName.c_str(), "Using 'order' attr as perm.");
return true;
}
OP_LOGE(kFusionPassName.c_str(), "Failed to get perm/order attr from Permute node.");
return false;
}
static void GetInputsInfo(
const std::vector<SubgraphInput>& subgraphInputs,
std::vector<Shape>& inputShapes,
std::vector<DataType>& inputDtypes,
std::vector<Format>& inputFormats)
{
for (const auto& subgraphInput : subgraphInputs) {
auto matchNode = subgraphInput.GetAllInputs().at(0);
TensorDesc tensorDesc;
matchNode.node.GetInputDesc(matchNode.index, tensorDesc);
if (tensorDesc.GetShape().GetDims().empty()) {
auto srcInfo = matchNode.node.GetInDataNodesAndPortIndexs(0);
if (srcInfo.first != nullptr) {
GNode srcNode = *srcInfo.first;
srcNode.GetOutputDesc(srcInfo.second, tensorDesc);
}
}
inputShapes.emplace_back(tensorDesc.GetShape());
inputDtypes.emplace_back(tensorDesc.GetDataType());
inputFormats.emplace_back(tensorDesc.GetFormat());
}
}
static Status InferShape(const GraphUniqPtr& replaceGraph, const std::vector<SubgraphInput>& subgraphInputs)
{
OP_LOGD(kFusionPassName.c_str(), "Begin infershape for replacement.");
std::vector<Shape> inputShapes;
for (const auto& subgraphInput : subgraphInputs) {
auto matchNode = subgraphInput.GetAllInputs().at(0);
TensorDesc tensorDesc;
matchNode.node.GetInputDesc(matchNode.index, tensorDesc);
if (tensorDesc.GetShape().GetDims().empty()) {
auto srcInfo = matchNode.node.GetInDataNodesAndPortIndexs(0);
if (srcInfo.first != nullptr) {
GNode srcNode = *srcInfo.first;
srcNode.GetOutputDesc(srcInfo.second, tensorDesc);
}
}
inputShapes.emplace_back(tensorDesc.GetShape());
}
return GeUtils::InferShape(*replaceGraph, inputShapes);
}
static es::EsTensorHolder BuildTransposeDNode(
es::EsGraphBuilder& replaceGraphBuilder,
GNode inputNode,
int32_t inputIndex,
const std::vector<int64_t>& permAttr)
{
auto* graph = replaceGraphBuilder.GetCGraphBuilder()->GetGraph();
auto transposeDNode = es::CompliantNodeBuilder(graph)
.OpType("TransposeD")
.IrDefInputs({{"x", es::CompliantNodeBuilder::kEsIrInputRequired, ""}})
.IrDefOutputs({{"y", es::CompliantNodeBuilder::kEsIrOutputRequired, ""}})
.IrDefAttrs({
{"perm", es::CompliantNodeBuilder::kEsAttrRequired, "ListInt", es::CreateFrom(permAttr)},
})
.Build();
if (es::AddEdgeAndUpdatePeerDesc(
*graph, inputNode, inputIndex, transposeDNode, 0) != GRAPH_SUCCESS) {
OP_LOGE(kFusionPassName.c_str(), "Failed to add edge for TransposeD input.");
return es::EsTensorHolder();
}
auto output = replaceGraphBuilder.GetCGraphBuilder()->GetTensorHolderFromNode(transposeDNode, 0);
return es::EsTensorHolder(output);
}
static es::EsTensorHolder BuildTransposeNode(
es::EsGraphBuilder& replaceGraphBuilder,
GNode inputNode,
int32_t inputIndex,
const std::vector<int64_t>& permAttr)
{
auto* graph = replaceGraphBuilder.GetCGraphBuilder()->GetGraph();
auto permConst = replaceGraphBuilder.CreateConst(
permAttr,
std::vector<int64_t>{static_cast<int64_t>(permAttr.size())});
auto transposeNode = es::CompliantNodeBuilder(graph)
.OpType("Transpose")
.IrDefInputs({
{"x", es::CompliantNodeBuilder::kEsIrInputRequired, ""},
{"perm", es::CompliantNodeBuilder::kEsIrInputRequired, ""},
})
.IrDefOutputs({{"y", es::CompliantNodeBuilder::kEsIrOutputRequired, ""}})
.Build();
if (es::AddEdgeAndUpdatePeerDesc(
*graph, inputNode, inputIndex, transposeNode, 0) != GRAPH_SUCCESS) {
OP_LOGE(kFusionPassName.c_str(), "Failed to add edge for Transpose x input.");
return es::EsTensorHolder();
}
if (es::AddEdgeAndUpdatePeerDesc(
*graph, *permConst.GetProducer(), permConst.GetProducerOutIndex(),
transposeNode, 1) != GRAPH_SUCCESS) {
OP_LOGE(kFusionPassName.c_str(), "Failed to add edge for Transpose perm input.");
return es::EsTensorHolder();
}
auto output = replaceGraphBuilder.GetCGraphBuilder()->GetTensorHolderFromNode(transposeNode, 0);
return es::EsTensorHolder(output);
}
std::vector<PatternUniqPtr> PermuteFusionPass::Patterns()
{
OP_LOGD(kFusionPassName.c_str(), "Enter Patterns for PermuteFusionPass.");
std::vector<PatternUniqPtr> patternGraphs;
auto graphBuilder = es::EsGraphBuilder("PermuteFusionPass");
auto x = graphBuilder.CreateInput(0);
auto* graph = graphBuilder.GetCGraphBuilder()->GetGraph();
auto permuteNode = es::CompliantNodeBuilder(graph)
.OpType("Permute")
.IrDefInputs({{"x", es::CompliantNodeBuilder::kEsIrInputRequired, ""}})
.IrDefOutputs({{"y", es::CompliantNodeBuilder::kEsIrOutputRequired, ""}})
.Build();
if (es::AddEdgeAndUpdatePeerDesc(
*graph, *x.GetProducer(), x.GetProducerOutIndex(), permuteNode, 0) != GRAPH_SUCCESS) {
OP_LOGE(kFusionPassName.c_str(), "Failed to add edge in pattern for Permute.");
return patternGraphs;
}
auto y = graphBuilder.GetCGraphBuilder()->GetTensorHolderFromNode(permuteNode, 0);
auto builtGraph = graphBuilder.BuildAndReset(std::vector<es::EsTensorHolder>{es::EsTensorHolder(y)});
auto pattern = std::make_unique<Pattern>(std::move(*builtGraph));
NodeIo nodeIo = {y->GetProducer(), 0};
pattern->CaptureTensor(nodeIo);
patternGraphs.emplace_back(std::move(pattern));
return patternGraphs;
}
bool PermuteFusionPass::MeetRequirements(const std::unique_ptr<MatchResult>& matchResult)
{
OP_LOGD(kFusionPassName.c_str(), "Enter MeetRequirements for PermuteFusionPass.");
NodeIo permuteNodeIo;
if (unlikely(matchResult->GetCapturedTensor(kCapturePermuteIdx, permuteNodeIo) != SUCCESS)) {
OP_LOGE(kFusionPassName.c_str(), "Failed to get captured tensor.");
return false;
}
AscendString nodeType;
permuteNodeIo.node.GetType(nodeType);
if (nodeType != "Permute") {
OP_LOGE(kFusionPassName.c_str(), "Node type %s is not Permute, skip.", nodeType.GetString());
return false;
}
return true;
}
GraphUniqPtr PermuteFusionPass::Replacement(const std::unique_ptr<MatchResult>& matchResult)
{
OP_LOGD(kFusionPassName.c_str(), "Enter Replacement for PermuteFusionPass.");
NodeIo permuteNodeIo;
if (matchResult->GetCapturedTensor(kCapturePermuteIdx, permuteNodeIo) != SUCCESS) {
OP_LOGE(kFusionPassName.c_str(), "Failed to get captured tensor in Replacement.");
return nullptr;
}
std::vector<int64_t> permList;
if (!GetPermutePermAttr(permuteNodeIo.node, permList)) {
OP_LOGE(kFusionPassName.c_str(), "Cannot extract perm attribute, skip.");
return nullptr;
}
bool isTransposeDPlatform = IsTransposeDPlatform();
std::vector<SubgraphInput> subgraphInputs;
matchResult->ToSubgraphBoundary()->GetAllInputs(subgraphInputs);
std::vector<Shape> inputShapes;
std::vector<DataType> inputDtypes;
std::vector<Format> inputFormats;
GetInputsInfo(subgraphInputs, inputShapes, inputDtypes, inputFormats);
auto inDims = inputShapes[0].GetDims();
auto replaceGraphBuilder = es::EsGraphBuilder("replacement");
auto xTensor = replaceGraphBuilder.CreateInput(
0, "x", inputDtypes[0], inputFormats[0], inDims);
if (isTransposeDPlatform) {
GNode xNode = *xTensor.GetProducer();
int32_t xIndex = xTensor.GetProducerOutIndex();
auto output = BuildTransposeDNode(replaceGraphBuilder, xNode, xIndex, permList);
if (output.GetCTensorHolder() == nullptr) {
OP_LOGE(kFusionPassName.c_str(), "Failed to build TransposeD node.");
return nullptr;
}
auto replaceGraph = replaceGraphBuilder.BuildAndReset(
std::vector<es::EsTensorHolder>{output});
if (InferShape(replaceGraph, subgraphInputs) != SUCCESS) {
OP_LOGE(kFusionPassName.c_str(), "Infershape for replacement (TransposeD) failed.");
return nullptr;
}
return replaceGraph;
} else {
GNode xNode = *xTensor.GetProducer();
int32_t xIndex = xTensor.GetProducerOutIndex();
auto output = BuildTransposeNode(replaceGraphBuilder, xNode, xIndex, permList);
if (output.GetCTensorHolder() == nullptr) {
OP_LOGE(kFusionPassName.c_str(), "Failed to build Transpose node.");
return nullptr;
}
auto replaceGraph = replaceGraphBuilder.BuildAndReset(
std::vector<es::EsTensorHolder>{output});
if (InferShape(replaceGraph, subgraphInputs) != SUCCESS) {
OP_LOGE(kFusionPassName.c_str(), "Infershape for replacement (Transpose) failed.");
return nullptr;
}
return replaceGraph;
}
}
REG_FUSION_PASS(PermuteFusionPass).Stage(CustomPassStage::kCompatibleInherited);
}