* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gtest/gtest.h"
#include "asc_graph_builder.h"
#include "task_generator/cast_optimization_pass.h"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "platform_context.h"
#include "runtime_stub.h"
using namespace ge;
using namespace af::testing;
using namespace af::ascir_op;
namespace af {
class TestCastOptimizationPass : public ::testing::Test {
protected:
void SetUp() override {
}
void TearDown() override {
}
};
namespace {
size_t CountNodesByType(AscGraph &graph, const std::string &type) {
size_t count = 0U;
for (const auto &node : AscGraphUtils::GetComputeGraph(graph)->GetAllNodes()) {
if (node->GetType() == type) {
++count;
}
}
return count;
}
DataType GetNodeOutputDtype(AscGraph &graph, const std::string &name) {
for (const auto &node : AscGraphUtils::GetComputeGraph(graph)->GetAllNodes()) {
if (node->GetName() == name) {
auto asc_node = std::dynamic_pointer_cast<AscNode>(node);
if (asc_node != nullptr && !asc_node->outputs().empty()) {
return (*asc_node->outputs().begin())->attr.dtype;
}
}
}
return DT_UNDEFINED;
}
void AssertNodesExist(AscGraph &graph, const std::vector<std::string> &names, bool should_exist) {
for (const auto &name : names) {
bool found = false;
for (const auto &node : AscGraphUtils::GetComputeGraph(graph)->GetAllNodes()) {
if (node->GetName() == name) {
found = true;
break;
}
}
if (should_exist) {
EXPECT_TRUE(found) << "node " << name << " should exist but not found";
} else {
EXPECT_FALSE(found) << "node " << name << " should not exist but found";
}
}
}
AscGraph BuildTwoInputCastConcatGraph(const std::string &name,
DataType data_dtype,
DataType cast_input_dtype,
DataType cast_out_dtype) {
return AscGraphBuilder(name)
.Loops({Sym(128), Sym(64)})
.Data("data0", 0, data_dtype)
.Data("data1", 1, data_dtype)
.Load("load0", "data0")
.Load("load1", "data1")
.Cast("cast_input0", "load0", cast_input_dtype)
.Cast("cast_input1", "load1", cast_input_dtype)
.Concat("concat0", {"cast_input0", "cast_input1"}, 0)
.Cast("cast_out0", "concat0", cast_out_dtype)
.Store("store0", "cast_out0")
.Output("output0", "store0", 0)
.Build();
}
AscGraph BuildSymbolicAxisConcatCastGraph(const std::string &name, int64_t s1_val, int64_t s2_val) {
auto s0 = af::Symbol("s0");
auto s1 = af::Symbol(s1_val);
auto s2 = af::Symbol(s2_val);
auto s3 = s1 + s2;
return AscGraphBuilder(name)
.Loops({s0, s3})
.Data("data0", 0, DT_FLOAT)
.Data("data1", 1, DT_FLOAT)
.Load("load0", "data0", {s0, s1}, {s1, af::sym::kSymbolOne})
.Load("load1", "data1", {s0, s2}, {s2, af::sym::kSymbolOne})
.Concat("concat0", {"load0", "load1"}, 1)
.Cast("cast_out0", "concat0", DT_FLOAT16)
.Store("store0", "cast_out0")
.Output("output0", "store0", 0)
.Build();
}
AscGraph BuildDowncastDiscontinuityTestGraph(const std::string &name, bool discontinuous) {
auto s0 = af::Symbol("s0");
auto s1 = af::Symbol(4);
auto s2 = af::Symbol(8);
auto s3 = s1 + s2;
auto k4 = af::Symbol(4);
auto big_stride = s0 * s1 * k4;
auto inner_stride = s1 * k4;
return AscGraphBuilder(name)
.Loops({s0, s3})
.Data("data0", 0, DT_FLOAT)
.Data("data1", 1, DT_FLOAT)
.Load("load0",
"data0",
{s0, s1},
discontinuous
? std::vector<Expression>{big_stride, inner_stride}
: std::vector<Expression>{s1, af::sym::kSymbolOne})
.Load("load1", "data1", {s0, s2}, {s2, af::sym::kSymbolOne})
.Add("add0", "load0", "load1")
.Concat("concat0", {"add0"}, 1)
.Cast("cast_out0", "concat0", DT_FLOAT16)
.Store("store0", "cast_out0")
.Output("output0", "store0", 0)
.Build();
}
AscGraph BuildUpcastDiscontinuityTestGraph(const std::string &name, bool discontinuous) {
auto s0 = af::Symbol("s0");
auto s1 = af::Symbol(discontinuous ? 3 : 4);
auto s2 = af::Symbol(discontinuous ? 5 : 8);
auto s3 = s1 + s2;
auto k4 = af::Symbol(4);
auto big_stride = s0 * s1 * k4;
auto inner_stride = s1 * k4;
return AscGraphBuilder(name)
.Loops({s0, s3})
.Data("data0", 0, DT_FLOAT16)
.Data("data1", 1, DT_FLOAT16)
.Load("load0",
"data0",
{s0, s1},
discontinuous
? std::vector<Expression>{big_stride, inner_stride}
: std::vector<Expression>{s1, af::sym::kSymbolOne})
.Cast("cast_input0", "load0", DT_FLOAT)
.Load("load1", "data1", {s0, s2}, {s2, af::sym::kSymbolOne})
.Cast("cast_input1", "load1", DT_FLOAT)
.Concat("concat0", {"cast_input0", "cast_input1"}, 1)
.Cast("cast_out0", "concat0", DT_FLOAT16)
.Store("store0", "cast_out0")
.Output("output0", "store0", 0)
.Build();
}
}
TEST_F(TestCastOptimizationPass, NoConcatInGraph_NoChange) {
auto graph = AscGraphBuilder("test_no_concat")
.Loops({Sym(128)})
.Data("data0", 0, DT_FLOAT)
.Load("load0", "data0")
.Abs("abs0", "load0")
.Store("store0", "abs0")
.Output("output0", "store0")
.Build();
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph), SUCCESS);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 0U);
}
TEST_F(TestCastOptimizationPass, ConcatWithoutCastOutput_NoOptimize) {
auto graph = AscGraphBuilder("test_concat_no_cast")
.Loops({Sym(128)})
.Data("data0", 0, DT_FLOAT)
.Data("data1", 1, DT_FLOAT)
.Load("load0", "data0")
.Load("load1", "data1")
.Concat("concat0", {"load0", "load1"}, 0)
.Store("store0", "concat0")
.Output("output0", "store0")
.Build();
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph), SUCCESS);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 0U);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT);
}
TEST_F(TestCastOptimizationPass, ConcatWithMultipleOutputs_NoOptimize) {
auto graph = AscGraphBuilder("test_concat_multi_out")
.Loops({Sym(128)})
.Data("data0", 0, DT_FLOAT)
.Load("load0", "data0")
.Concat("concat0", {"load0"}, 0)
.Cast("cast0", "concat0", DT_FLOAT16)
.Store("store0", "cast0")
.Abs("abs0", "concat0")
.Store("store1", "abs0")
.Output("output0", "store0")
.Output("output1", "store1", 1)
.Build();
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph), SUCCESS);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 1U);
AssertNodesExist(graph, {"cast0"}, true);
}
TEST_F(TestCastOptimizationPass, ConcatWithNonCastOutput_NoOptimize) {
auto graph = AscGraphBuilder("test_concat_non_cast_out")
.Loops({Sym(128)})
.Data("data0", 0, DT_FLOAT)
.Load("load0", "data0")
.Concat("concat0", {"load0"}, 0)
.Abs("abs0", "concat0")
.Store("store0", "abs0")
.Output("output0", "store0")
.Build();
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph), SUCCESS);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 0U);
}
TEST_F(TestCastOptimizationPass, Downcast_PullUpCastBeforeInputs) {
auto graph = AscGraphBuilder("test_downcast_pullup")
.Loops({Sym(128), Sym(64)})
.Data("data0", 0, DT_FLOAT)
.Data("data1", 1, DT_FLOAT)
.Load("load0", "data0")
.Load("load1", "data1")
.Concat("concat0", {"load0", "load1"}, 0)
.Cast("cast_out0", "concat0", DT_FLOAT16)
.Store("store0", "cast_out0")
.Output("output0", "store0", 0)
.Build();
ASSERT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 1U);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph), SUCCESS);
AssertNodesExist(graph, {"cast_out0"}, false);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT16);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 2U);
}
TEST_F(TestCastOptimizationPass, Downcast_ReverseInputCastRemoved) {
auto graph = AscGraphBuilder("test_downcast_reverse_input")
.Loops({Sym(128), Sym(64)})
.Data("data0", 0, DT_FLOAT16)
.Data("data1", 1, DT_FLOAT)
.Load("load0", "data0")
.Load("load1", "data1")
.Cast("cast_input0", "load0", DT_FLOAT)
.Concat("concat0", {"cast_input0", "load1"}, 0)
.Cast("cast_out0", "concat0", DT_FLOAT16)
.Store("store0", "cast_out0")
.Output("output0", "store0", 0)
.Build();
ASSERT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 2U);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph), SUCCESS);
AssertNodesExist(graph, {"cast_out0", "cast_input0"}, false);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT16);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 1U);
}
TEST_F(TestCastOptimizationPass, Downcast_AllReverseInputCastsEliminated) {
auto graph = BuildTwoInputCastConcatGraph("test_downcast_all_reverse", DT_FLOAT16, DT_FLOAT, DT_FLOAT16);
ASSERT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 3U);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph), SUCCESS);
AssertNodesExist(graph, {"cast_out0", "cast_input0", "cast_input1"}, false);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT16);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 0U);
}
TEST_F(TestCastOptimizationPass, Upcast_NoReverseInputCast_NotOptimized) {
auto graph = AscGraphBuilder("test_upcast_no_reverse")
.Loops({Sym(128)})
.Data("data0", 0, DT_FLOAT16)
.Data("data1", 1, DT_FLOAT16)
.Load("load0", "data0")
.Load("load1", "data1")
.Concat("concat0", {"load0", "load1"}, 0)
.Cast("cast_out0", "concat0", DT_FLOAT)
.Store("store0", "cast_out0")
.Output("output0", "store0", 0)
.Build();
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph), SUCCESS);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 1U);
AssertNodesExist(graph, {"cast_out0"}, true);
}
TEST_F(TestCastOptimizationPass, Upcast_TransposeAlg_NoOptimize) {
auto graph = BuildTwoInputCastConcatGraph("test_upcast_transpose", DT_FLOAT, DT_FLOAT16, DT_FLOAT);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph, 0), SUCCESS);
AssertNodesExist(graph, {"cast_out0"}, true);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT16);
}
TEST_F(TestCastOptimizationPass, Upcast_GatherAlg_AllReverseCastsEliminated) {
auto graph = BuildTwoInputCastConcatGraph("test_upcast_gather", DT_FLOAT, DT_FLOAT16, DT_FLOAT);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph, 1), SUCCESS);
AssertNodesExist(graph, {"cast_out0", "cast_input0", "cast_input1"}, false);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 0U);
}
TEST_F(TestCastOptimizationPass, MixedReverseAndNonReverseInputCast) {
auto graph = AscGraphBuilder("test_mixed_input_cast")
.Loops({Sym(128), Sym(64)})
.Data("data0", 0, DT_FLOAT16)
.Data("data1", 1, DT_FLOAT16)
.Data("data2", 2, DT_FLOAT)
.Load("load0", "data0")
.Load("load1", "data1")
.Load("load2", "data2")
.Cast("cast_input0", "load0", DT_FLOAT)
.Cast("cast_input1", "load1", DT_FLOAT)
.Concat("concat0", {"cast_input0", "cast_input1", "load2"}, 0)
.Cast("cast_out0", "concat0", DT_FLOAT16)
.Store("store0", "cast_out0")
.Output("output0", "store0", 0)
.Build();
ASSERT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 3U);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph), SUCCESS);
AssertNodesExist(graph, {"cast_out0", "cast_input0", "cast_input1"}, false);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT16);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 1U);
}
TEST_F(TestCastOptimizationPass, Downcast_TransposeAlg_IgnoreAlignment_Optimize) {
auto graph = BuildSymbolicAxisConcatCastGraph("test_transpose_ignore_alignment", 3, 5);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph, 0), SUCCESS);
AssertNodesExist(graph, {"cast_out0"}, false);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT16);
}
TEST_F(TestCastOptimizationPass, Downcast_GatherAlg_AlignmentDegradation_NoOptimize) {
auto graph = BuildSymbolicAxisConcatCastGraph("test_gather_alignment_degradation", 3, 5);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph, 1), SUCCESS);
AssertNodesExist(graph, {"cast_out0"}, true);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT);
}
TEST_F(TestCastOptimizationPass, Downcast_GatherAlg_NoAlignmentDegradation_Optimize) {
auto graph = BuildSymbolicAxisConcatCastGraph("test_no_degradation", 4, 8);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph, 1), SUCCESS);
AssertNodesExist(graph, {"cast_out0"}, false);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT16);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 2U);
}
TEST_F(TestCastOptimizationPass, Downcast_SharedNonCastInput_OnlyOneCastInserted) {
auto graph = AscGraphBuilder("test_shared_noncast_downcast")
.Loops({Sym(128), Sym(64)})
.Data("data0", 0, DT_FLOAT)
.Load("load0", "data0")
.Relu("relu0", "load0")
.Concat("concat0", {"relu0", "relu0"}, 0)
.Cast("cast_out0", "concat0", DT_FLOAT16)
.Store("store0", "cast_out0")
.Output("output0", "store0", 0)
.Build();
ASSERT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 1U);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph), SUCCESS);
AssertNodesExist(graph, {"cast_out0"}, false);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT16);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 1U);
}
TEST_F(TestCastOptimizationPass, Downcast_GatherAlg_MultipleDiscontinuitiesViaMultiInputNode_Optimize) {
auto graph = BuildDowncastDiscontinuityTestGraph("test_multi_input_discontinuity", true);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph, 1), SUCCESS);
AssertNodesExist(graph, {"cast_out0"}, false);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT16);
}
TEST_F(TestCastOptimizationPass, Downcast_GatherAlg_NoDiscontinuity_Optimize) {
auto graph = BuildDowncastDiscontinuityTestGraph("test_no_discontinuity", false);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph, 1), SUCCESS);
AssertNodesExist(graph, {"cast_out0"}, false);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT16);
}
TEST_F(TestCastOptimizationPass, Upcast_GatherAlg_MultipleDiscontinuitiesViaMultiInputNode_NoOptimize) {
auto graph = BuildUpcastDiscontinuityTestGraph("test_upcast_multi_input_discontinuity", true);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph, 1), SUCCESS);
AssertNodesExist(graph, {"cast_out0", "cast_input0", "cast_input1"}, true);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT);
}
TEST_F(TestCastOptimizationPass, Upcast_GatherAlg_NoDiscontinuity_AllReverseCastsEliminated) {
auto graph = BuildUpcastDiscontinuityTestGraph("test_upcast_no_discontinuity", false);
ASSERT_EQ(af::optimize::CastOptimizationPass::Run(graph, 1), SUCCESS);
AssertNodesExist(graph, {"cast_out0", "cast_input0", "cast_input1"}, false);
EXPECT_EQ(GetNodeOutputDtype(graph, "concat0"), DT_FLOAT16);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Cast::Type), 0U);
}
}