* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <iostream>
#include <vector>
#include "gtest/gtest.h"
#include "platform/platform_infos_def.h"
#include "platform/platform_info.h"
#include "ge/es_graph_builder.h"
#include "ge/compliant_node_builder.h"
#include "es_math_ops.h"
#include "log/log.h"
#include "../../../op_graph/fusion_pass/permute_fusion_pass.h"
using namespace std;
using namespace ge;
using namespace fe;
using namespace fusion;
using namespace ops;
namespace {
const std::string kPassName = "PermuteFusionPass";
}
class PermuteFusionPassTest : public testing::Test {
protected:
static void SetUpTestCase()
{
PlatformInfo platformInfo;
OptionalInfo optiCompilationInfo;
platformInfo.soc_info.ai_core_cnt = 64;
platformInfo.str_info.short_soc_version = "Ascend910_93";
optiCompilationInfo.soc_version = "Ascend910_93";
PlatformInfoManager::Instance().platform_info_map_["Ascend910_93"] = platformInfo;
PlatformInfoManager::Instance().SetOptionalCompilationInfo(optiCompilationInfo);
}
void SetUp() override
{
PlatformInfo platformInfo;
OptionalInfo optiCompilationInfo;
platformInfo.soc_info.ai_core_cnt = 64;
platformInfo.str_info.short_soc_version = "Ascend910_93";
optiCompilationInfo.soc_version = "Ascend910_93";
PlatformInfoManager::Instance().platform_info_map_["Ascend910_93"] = platformInfo;
PlatformInfoManager::Instance().SetOptionalCompilationInfo(optiCompilationInfo);
}
void SetPlatform(const std::string& soc)
{
PlatformInfo platformInfo;
OptionalInfo optiCompilationInfo;
platformInfo.soc_info.ai_core_cnt = 64;
platformInfo.str_info.short_soc_version = soc;
optiCompilationInfo.soc_version = soc;
PlatformInfoManager::Instance().platform_info_map_[soc] = platformInfo;
PlatformInfoManager::Instance().SetOptionalCompilationInfo(optiCompilationInfo);
}
std::shared_ptr<Graph> BuildPermuteGraphAndRunPass(
const std::vector<int64_t>& inputDims,
const std::vector<int64_t>& permAttr,
DataType dtype = DT_FLOAT,
Format format = FORMAT_ND)
{
Shape shapeX(inputDims);
auto graphBuilder = es::EsGraphBuilder("test");
auto x = graphBuilder.CreateInput(0, "x", dtype, format, inputDims);
auto* graph = graphBuilder.GetCGraphBuilder()->GetGraph();
auto permuteNode = es::CompliantNodeBuilder(graph)
.OpType("Permute")
.IrDefInputs({{"x", es::CompliantNodeBuilder::kEsIrInputRequired, ""}})
.IrDefOutputs({{"y", es::CompliantNodeBuilder::kEsIrOutputRequired, ""}})
.IrDefAttrs({
{"perm", es::CompliantNodeBuilder::kEsAttrOptional, "ListInt", es::CreateFrom(permAttr)},
})
.Build();
es::AddEdgeAndUpdatePeerDesc(*graph, *x.GetProducer(), x.GetProducerOutIndex(), permuteNode, 0);
TensorDesc xDesc;
x.GetProducer()->GetOutputDesc(0, xDesc);
xDesc.SetDataType(dtype);
xDesc.SetShape(shapeX);
xDesc.SetFormat(format);
x.GetProducer()->UpdateOutputDesc(0, xDesc);
auto y = graphBuilder.GetCGraphBuilder()->GetTensorHolderFromNode(permuteNode, 0);
std::shared_ptr<Graph> resultGraph = graphBuilder.BuildAndReset(
std::vector<es::EsTensorHolder>{es::EsTensorHolder(y)});
CustomPassContext passContext;
PermuteFusionPass pass;
pass.Run(resultGraph, passContext);
return resultGraph;
}
int CountNodeType(const std::shared_ptr<Graph>& graph, const char* typeName)
{
int count = 0;
for (auto node : graph->GetAllNodes()) {
AscendString type;
node.GetType(type);
if (type == typeName) {
count++;
}
}
return count;
}
bool HasNodeType(const std::shared_ptr<Graph>& graph, const char* typeName)
{
return CountNodeType(graph, typeName) > 0;
}
};
TEST_F(PermuteFusionPassTest, patternTest)
{
PermuteFusionPass pass;
std::vector<PatternUniqPtr> patterns = pass.Patterns();
EXPECT_GT(patterns.size(), 0);
}
TEST_F(PermuteFusionPassTest, unsupportedDtypeFail)
{
SetPlatform("Ascend910_93");
auto resultGraph = BuildPermuteGraphAndRunPass(
std::vector<int64_t>{4, 3, 2, 1},
std::vector<int64_t>{0, 2, 3, 1},
DT_INT32);
EXPECT_TRUE(HasNodeType(resultGraph, "TransposeD"));
}
TEST_F(PermuteFusionPassTest, unsupportedPlatformFail)
{
SetPlatform("UnknownPlatform");
auto resultGraph = BuildPermuteGraphAndRunPass(
std::vector<int64_t>{4, 3, 2, 1},
std::vector<int64_t>{0, 2, 3, 1});
EXPECT_TRUE(HasNodeType(resultGraph, "Transpose"));
EXPECT_FALSE(HasNodeType(resultGraph, "TransposeD"));
}
TEST_F(PermuteFusionPassTest, fusionSuccess910Normal)
{
SetPlatform("Ascend910_93");
auto resultGraph = BuildPermuteGraphAndRunPass(
std::vector<int64_t>{4, 3, 2, 1},
std::vector<int64_t>{0, 2, 3, 1});
EXPECT_TRUE(HasNodeType(resultGraph, "TransposeD"));
EXPECT_FALSE(HasNodeType(resultGraph, "Transpose"));
}
TEST_F(PermuteFusionPassTest, fusionSuccess910Special)
{
SetPlatform("Ascend910_93");
auto resultGraph = BuildPermuteGraphAndRunPass(
std::vector<int64_t>{2, 3, 4, 5},
std::vector<int64_t>{0, 3, 2, 1});
EXPECT_EQ(CountNodeType(resultGraph, "TransposeD"), 1);
EXPECT_EQ(CountNodeType(resultGraph, "Transpose"), 0);
}
TEST_F(PermuteFusionPassTest, fusionSuccess950Normal)
{
SetPlatform("Ascend950");
auto resultGraph = BuildPermuteGraphAndRunPass(
std::vector<int64_t>{4, 3, 2, 1},
std::vector<int64_t>{0, 2, 3, 1});
EXPECT_TRUE(HasNodeType(resultGraph, "Transpose"));
EXPECT_FALSE(HasNodeType(resultGraph, "TransposeD"));
}
TEST_F(PermuteFusionPassTest, fusionSuccess950Special)
{
SetPlatform("Ascend950");
auto resultGraph = BuildPermuteGraphAndRunPass(
std::vector<int64_t>{2, 3, 4, 5},
std::vector<int64_t>{0, 3, 2, 1});
EXPECT_EQ(CountNodeType(resultGraph, "Transpose"), 1);
EXPECT_EQ(CountNodeType(resultGraph, "TransposeD"), 0);
}
TEST_F(PermuteFusionPassTest, fusionSuccess3dShape)
{
SetPlatform("Ascend910_93");
auto resultGraph = BuildPermuteGraphAndRunPass(
std::vector<int64_t>{4, 3, 2},
std::vector<int64_t>{0, 2, 1});
EXPECT_TRUE(HasNodeType(resultGraph, "TransposeD"));
}
TEST_F(PermuteFusionPassTest, fusionSuccessFp16)
{
SetPlatform("Ascend910_93");
auto resultGraph = BuildPermuteGraphAndRunPass(
std::vector<int64_t>{4, 3, 2, 1},
std::vector<int64_t>{0, 2, 3, 1},
DT_FLOAT16);
EXPECT_TRUE(HasNodeType(resultGraph, "TransposeD"));
}
TEST_F(PermuteFusionPassTest, nonTransposeDPlatformUsesTranspose)
{
SetPlatform("Ascend310");
auto resultGraph = BuildPermuteGraphAndRunPass(
std::vector<int64_t>{4, 3, 2, 1},
std::vector<int64_t>{0, 2, 3, 1});
EXPECT_TRUE(HasNodeType(resultGraph, "Transpose"));
EXPECT_FALSE(HasNodeType(resultGraph, "TransposeD"));
}