* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <iostream>
#include <vector>
#include "gtest/gtest.h"
#include "platform/platform_infos_def.h"
#include "platform/platform_info.h"
#include "ge/es_graph_builder.h"
#include "es_math_ops.h"
#include "log/log.h"
#include "../../../op_graph/fusion_pass/castlike_fusion_pass.h"
#include "register/register_custom_pass.h"
using namespace std;
using namespace ge;
using namespace fe;
using namespace fusion;
using namespace ops;
namespace {
const std::string kPassName = "CastlikeFusionPass";
}
class CastlikeFusionPassTest : public testing::Test {
protected:
static void SetUpTestCase()
{
PlatformInfo platformInfo;
OptionalInfo optiCompilationInfo;
platformInfo.soc_info.ai_core_cnt = 64;
platformInfo.str_info.short_soc_version = "Ascend950";
optiCompilationInfo.soc_version = "Ascend950";
PlatformInfoManager::Instance().platform_info_map_["Ascend950"] = platformInfo;
PlatformInfoManager::Instance().SetOptionalCompilationInfo(optiCompilationInfo);
}
void SetUp() override
{
PlatformInfo platformInfo;
OptionalInfo optiCompilationInfo;
platformInfo.soc_info.ai_core_cnt = 64;
platformInfo.str_info.short_soc_version = "Ascend950";
optiCompilationInfo.soc_version = "Ascend950";
PlatformInfoManager::Instance().platform_info_map_["Ascend950"] = platformInfo;
PlatformInfoManager::Instance().SetOptionalCompilationInfo(optiCompilationInfo);
}
void SetPlatform(const std::string& soc)
{
PlatformInfo platformInfo;
OptionalInfo optiCompilationInfo;
platformInfo.soc_info.ai_core_cnt = 64;
platformInfo.str_info.short_soc_version = soc;
optiCompilationInfo.soc_version = soc;
PlatformInfoManager::Instance().platform_info_map_[soc] = platformInfo;
PlatformInfoManager::Instance().SetOptionalCompilationInfo(optiCompilationInfo);
}
es::EsTensorHolder BuildCastLikeNode(es::EsGraphBuilder& builder, const es::EsTensorHolder& x,
const es::EsTensorHolder& y)
{
ge::Graph* graph = builder.GetCGraphBuilder()->GetGraph();
GNode castlike = es::CompliantNodeBuilder(graph)
.OpType("CastLike")
.Name("CastLike")
.IrDefInputs({
{"input", es::CompliantNodeBuilder::kEsIrInputRequired, ""},
{"target_type", es::CompliantNodeBuilder::kEsIrInputRequired, ""}
})
.IrDefOutputs({
{"output", es::CompliantNodeBuilder::kEsIrOutputRequired, ""}
})
.Build();
es::AddEdgeAndUpdatePeerDesc(*graph, *x.GetProducer(), x.GetProducerOutIndex(), castlike, 0);
es::AddEdgeAndUpdatePeerDesc(*graph, *y.GetProducer(), y.GetProducerOutIndex(), castlike, 1);
auto output = es::EsTensorHolder(builder.GetCGraphBuilder()->GetTensorHolderFromNode(castlike, 0));
return output;
}
static void InferShapeForTest(
DataType dtypeX, DataType dtypeY, Shape& shape,
es::EsTensorHolder& x, es::EsTensorHolder& y, es::EsTensorHolder& castlike)
{
TensorDesc xDesc;
x.GetProducer()->GetOutputDesc(0, xDesc);
xDesc.SetDataType(dtypeX);
xDesc.SetShape(shape);
x.GetProducer()->UpdateOutputDesc(0, xDesc);
TensorDesc yDesc;
y.GetProducer()->GetOutputDesc(0, yDesc);
yDesc.SetDataType(dtypeY);
yDesc.SetShape(shape);
y.GetProducer()->UpdateOutputDesc(0, yDesc);
castlike.GetProducer()->UpdateInputDesc(0, xDesc);
castlike.GetProducer()->UpdateInputDesc(1, yDesc);
castlike.GetProducer()->UpdateOutputDesc(0, yDesc);
}
};
TEST_F(CastlikeFusionPassTest, patternTest)
{
CastlikeFusionPass pass;
std::vector<PatternUniqPtr> patterns = pass.Patterns();
EXPECT_GT(patterns.size(), 0);
}
TEST_F(CastlikeFusionPassTest, fusionSuccessFp16ToFp32)
{
std::vector<int64_t> dimsX{2, 32, 128};
Shape shapeX(dimsX);
auto graphBuilder = es::EsGraphBuilder("test");
auto x = graphBuilder.CreateInput(0, "x", DT_FLOAT16, FORMAT_ND, dimsX);
auto y = graphBuilder.CreateInput(1, "y", DT_FLOAT, FORMAT_ND, dimsX);
auto output = BuildCastLikeNode(graphBuilder, x, y);
InferShapeForTest(DT_FLOAT16, DT_FLOAT, shapeX, x, y, output);
std::shared_ptr<Graph> graph = graphBuilder.BuildAndReset({{output}});
CustomPassContext* passContextPtr = new CustomPassContext();
CastlikeFusionPass pass;
Status status = pass.Run(graph, *passContextPtr);
EXPECT_TRUE(status == SUCCESS || status == GRAPH_NOT_CHANGED);
bool findCast = false;
int node_count = 0;
for (auto node : graph->GetAllNodes()) {
node_count++;
AscendString type;
node.GetType(type);
if (type == "Cast") {
findCast = true;
}
}
EXPECT_EQ(findCast, true);
EXPECT_EQ(node_count, 4);
}
TEST_F(CastlikeFusionPassTest, fusionSuccessFp32ToFp16)
{
std::vector<int64_t> dimsX{2, 32, 128};
Shape shapeX(dimsX);
auto graphBuilder = es::EsGraphBuilder("test");
auto x = graphBuilder.CreateInput(0, "x", DT_FLOAT, FORMAT_ND, dimsX);
auto y = graphBuilder.CreateInput(1, "y", DT_FLOAT16, FORMAT_ND, dimsX);
auto output = BuildCastLikeNode(graphBuilder, x, y);
InferShapeForTest(DT_FLOAT, DT_FLOAT16, shapeX, x, y, output);
std::shared_ptr<Graph> graph = graphBuilder.BuildAndReset({output});
CustomPassContext* passContextPtr = new CustomPassContext();
CastlikeFusionPass pass;
Status status = pass.Run(graph, *passContextPtr);
EXPECT_TRUE(status == SUCCESS || status == GRAPH_NOT_CHANGED);
bool findCast = false;
int node_count = 0;
for (auto node : graph->GetAllNodes()) {
node_count++;
AscendString type;
node.GetType(type);
if (type == "Cast") {
findCast = true;
}
}
EXPECT_EQ(findCast, true);
EXPECT_EQ(node_count, 4);
}
TEST_F(CastlikeFusionPassTest, fusionSuccessBF16ToFp32)
{
std::vector<int64_t> dimsX{2, 32, 128};
Shape shapeX(dimsX);
auto graphBuilder = es::EsGraphBuilder("test");
auto x = graphBuilder.CreateInput(0, "x", DT_BF16, FORMAT_ND, dimsX);
auto y = graphBuilder.CreateInput(1, "y", DT_FLOAT, FORMAT_ND, dimsX);
auto output = BuildCastLikeNode(graphBuilder, x, y);
InferShapeForTest(DT_BF16, DT_FLOAT, shapeX, x, y, output);
std::shared_ptr<Graph> graph = graphBuilder.BuildAndReset({output});
CustomPassContext* passContextPtr = new CustomPassContext();
CastlikeFusionPass pass;
Status status = pass.Run(graph, *passContextPtr);
EXPECT_TRUE(status == SUCCESS || status == GRAPH_NOT_CHANGED);
bool findCast = false;
int node_count = 0;
for (auto node : graph->GetAllNodes()) {
node_count++;
AscendString type;
node.GetType(type);
if (type == "Cast") {
findCast = true;
}
}
EXPECT_EQ(findCast, true);
EXPECT_EQ(node_count, 4);
}