* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gtest/gtest.h"
#include "ascir.h"
#include "ascir_ops.h"
#include "schedule_utils.h"
#define private public
#include "optimize/pre_process/improve_precision.h"
#include "optimize/pre_process/pre_process_config.h"
#include "optimize/pre_process/pre_process.h"
#undef private
#include "common/platform_context.h"
#include "ascgraph_info_complete.h"
#include "tests/framework/easy_asc_graph/asc_graph_builder.h"
#include "runtime_stub.h"
using namespace af;
using namespace af::ascir_op;
using af::ops::IsOps;
using af::ops::One;
using af::testing::Sym;
using af::testing::AscGraphBuilder;
using namespace af::pre_process;
namespace {
size_t CountNodesByType(AscGraph &graph, const std::string &type) {
size_t count = 0U;
for (const auto &node: AscGraphUtils::GetComputeGraph(graph)->GetAllNodes()) {
if (node->GetType() == type) {
++count;
}
}
return count;
}
bool CheckNodeOutputDtype(AscGraph &graph, const std::string &node_name, ge::DataType expected_dtype) {
for (const auto &node: AscGraphUtils::GetComputeGraph(graph)->GetAllNodes()) {
if (node->GetName() == node_name) {
auto desc = node->GetOpDesc();
if (desc != nullptr && desc->GetOutputDesc(0).GetDataType() == expected_dtype) {
return true;
}
}
}
return false;
}
class TestImprovePrecisionST : public ::testing::Test {
protected:
void SetUp() override {
ge::PlatformContext::GetInstance().Reset();
PreProcessConfig::Instance().Reset();
auto stub_v1 = std::make_shared<ge::RuntimeStub>();
ge::RuntimeStub::SetInstance(stub_v1);
dlog_setlevel(ASCGEN_MODULE_NAME, DLOG_ERROR, 0);
}
void TearDown() override {
ge::PlatformContext::GetInstance().Reset();
unsetenv("AUTOFUSE_FLAGS");
PreProcessConfig::Instance().Reset();
dlog_setlevel(ASCGEN_MODULE_NAME, DLOG_ERROR, 0);
}
};
}
TEST_F(TestImprovePrecisionST, ComplexFp16Chain_AllPromotedToFp32) {
auto graph = AscGraphBuilder("st_complex_fp16_chain")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Abs("abs0", "load0")
.Add("add0", "abs0", "abs0")
.Mul("mul0", "add0", "add0")
.Sub("sub0", "mul0", "mul0")
.Store("store0", "sub0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "abs0", ge::DT_FLOAT));
EXPECT_TRUE(CheckNodeOutputDtype(graph, "add0", ge::DT_FLOAT));
EXPECT_TRUE(CheckNodeOutputDtype(graph, "mul0", ge::DT_FLOAT));
EXPECT_TRUE(CheckNodeOutputDtype(graph, "sub0", ge::DT_FLOAT));
size_t cast_count = CountNodesByType(graph, Cast::Type);
EXPECT_GE(cast_count, 1U);
}
TEST_F(TestImprovePrecisionST, TwoInputsBothFp16_BothPromoted) {
auto graph = AscGraphBuilder("st_two_inputs_fp16")
.Loops({Sym("s0")})
.Data("data1", 0, ge::DT_FLOAT16)
.Data("data2", 1, ge::DT_FLOAT16)
.Load("load1", "data1")
.Load("load2", "data2")
.Add("add0", "load1", "load2")
.Store("store0", "add0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "add0", ge::DT_FLOAT));
size_t cast_count = CountNodesByType(graph, Cast::Type);
EXPECT_GE(cast_count, 2U);
}
TEST_F(TestImprovePrecisionST, MixedFp16Bf16Inputs_BothPromotedToFp32) {
auto graph = AscGraphBuilder("st_mixed_fp16_bf16")
.Loops({Sym("s0")})
.Data("data1", 0, ge::DT_FLOAT16)
.Data("data2", 1, ge::DT_BF16)
.Load("load1", "data1")
.Load("load2", "data2")
.Add("add0", "load1", "load2")
.Store("store0", "add0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "add0", ge::DT_FLOAT));
EXPECT_GE(CountNodesByType(graph, Cast::Type), 2U);
}
TEST_F(TestImprovePrecisionST, IdentityFp16Cast_RemovedAndPromoted) {
auto graph = AscGraphBuilder("st_identity_cast")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Cast("cast0", "load0", ge::DT_FLOAT16)
.Abs("abs0", "cast0")
.Store("store0", "abs0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "abs0", ge::DT_FLOAT));
}
TEST_F(TestImprovePrecisionST, Fp32ToFp16CastBeforeStore_RemovedAndAbsPromoted) {
auto graph = AscGraphBuilder("st_fp32_to_fp16_before_store")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Abs("abs0", "load0")
.Cast("cast0", "abs0", ge::DT_FLOAT16)
.Store("store0", "cast0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "abs0", ge::DT_FLOAT));
}
TEST_F(TestImprovePrecisionST, ScalarFp16Promoted_DownstreamAllFp32) {
auto graph = AscGraphBuilder("st_scalar_fp16_downstream")
.Loops({Sym("s0")})
.Scalar("scalar0", "2.0", ge::DT_FLOAT16)
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Mul("mul0", "scalar0", "load0")
.Add("add0", "mul0", "mul0")
.Store("store0", "add0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "scalar0", ge::DT_FLOAT));
EXPECT_TRUE(CheckNodeOutputDtype(graph, "mul0", ge::DT_FLOAT));
EXPECT_TRUE(CheckNodeOutputDtype(graph, "add0", ge::DT_FLOAT));
}
TEST_F(TestImprovePrecisionST, AllFp32Graph_NoModification) {
auto graph = AscGraphBuilder("st_all_fp32")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT)
.Load("load0", "data0")
.Abs("abs0", "load0")
.Mul("mul0", "abs0", "abs0")
.Store("store0", "mul0")
.Output("output0", "store0", 0, ge::DT_FLOAT)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_EQ(CountNodesByType(graph, Cast::Type), 0U);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "abs0", ge::DT_FLOAT));
EXPECT_TRUE(CheckNodeOutputDtype(graph, "mul0", ge::DT_FLOAT));
}
TEST_F(TestImprovePrecisionST, DeepOtherChain_AllPromotedToFp32) {
auto graph = AscGraphBuilder("st_deep_other_chain")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Abs("abs0", "load0")
.Neg("neg0", "abs0")
.Exp("exp0", "neg0")
.Sqrt("sqrt0", "exp0")
.Store("store0", "sqrt0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "abs0", ge::DT_FLOAT));
EXPECT_TRUE(CheckNodeOutputDtype(graph, "neg0", ge::DT_FLOAT));
EXPECT_TRUE(CheckNodeOutputDtype(graph, "exp0", ge::DT_FLOAT));
EXPECT_TRUE(CheckNodeOutputDtype(graph, "sqrt0", ge::DT_FLOAT));
EXPECT_GE(CountNodesByType(graph, Cast::Type), 1U);
}
TEST_F(TestImprovePrecisionST, Bf16FullPipeline_AllPromotedToFp32) {
auto graph = AscGraphBuilder("st_bf16_full")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_BF16)
.Load("load0", "data0")
.Abs("abs0", "load0")
.Mul("mul0", "abs0", "abs0")
.Store("store0", "mul0")
.Output("output0", "store0", 0, ge::DT_BF16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "abs0", ge::DT_FLOAT));
EXPECT_TRUE(CheckNodeOutputDtype(graph, "mul0", ge::DT_FLOAT));
EXPECT_GE(CountNodesByType(graph, Cast::Type), 1U);
}
TEST_F(TestImprovePrecisionST, AllBlacklist_AllNodesSupportFp16_Skip) {
setenv("AUTOFUSE_FLAGS", "--autofuse_enhance_precision_blacklist=all", 1);
PreProcessConfig::Instance().Reset();
auto graph = AscGraphBuilder("st_blacklist_all_skip")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Store("store0", "load0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_EQ(CountNodesByType(graph, Cast::Type), 0U);
}
TEST_F(TestImprovePrecisionST, PartialBlacklist_SpecificOpSkipped) {
setenv("AUTOFUSE_FLAGS", "--autofuse_enhance_precision_blacklist=Abs", 1);
PreProcessConfig::Instance().Reset();
auto graph = AscGraphBuilder("st_partial_blacklist")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Abs("abs0", "load0")
.Mul("mul0", "abs0", "abs0")
.Store("store0", "mul0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "mul0", ge::DT_FLOAT));
}
TEST_F(TestImprovePrecisionST, LoadDirectToStore_NoImprovement) {
auto graph = AscGraphBuilder("st_load_direct_store")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Store("store0", "load0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_EQ(CountNodesByType(graph, Cast::Type), 0U);
}
TEST_F(TestImprovePrecisionST, LoadWithExistingCastPeer_NoDuplicateCast) {
auto graph = AscGraphBuilder("st_load_with_cast_peer")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Cast("cast0", "load0", ge::DT_FLOAT16)
.Abs("abs0", "cast0")
.Store("store0", "abs0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "abs0", ge::DT_FLOAT));
}
TEST_F(TestImprovePrecisionST, PreProcessEntryPoint_Succeeds) {
auto graph = AscGraphBuilder("st_preprocess_entry")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Abs("abs0", "load0")
.Store("store0", "abs0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(PreProcess::Run(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "abs0", ge::DT_FLOAT));
}
TEST_F(TestImprovePrecisionST, PreProcessConfig_TrailingPunctuation) {
setenv("AUTOFUSE_FLAGS", "--autofuse_enhance_precision_blacklist=Abs,Mul;", 1);
PreProcessConfig::Instance().Reset();
const auto &bl = PreProcessConfig::Instance().GetImprovePrecisionBlacklist();
EXPECT_TRUE(bl.find("Abs") != bl.end());
EXPECT_TRUE(bl.find("Mul") != bl.end());
}
TEST_F(TestImprovePrecisionST, PreProcessConfig_MultipleFlagsWithSemicolon) {
setenv("AUTOFUSE_FLAGS", "--other_flag=xyz;--autofuse_enhance_precision_blacklist=Abs,Mul;--yet_another=123", 1);
PreProcessConfig::Instance().Reset();
const auto &bl = PreProcessConfig::Instance().GetImprovePrecisionBlacklist();
EXPECT_TRUE(bl.find("Abs") != bl.end());
EXPECT_TRUE(bl.find("Mul") != bl.end());
}
TEST_F(TestImprovePrecisionST, StoreDtypeMismatch_CastInsertedBeforeStore) {
auto graph = AscGraphBuilder("st_store_dtype_mismatch")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Abs("abs0", "load0")
.Store("store0", "abs0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "abs0", ge::DT_FLOAT));
}
TEST_F(TestImprovePrecisionST, CastBeforeStorePeer_DtypeNotChanged) {
auto graph = AscGraphBuilder("st_cast_store_peer")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Cast("cast0", "load0", ge::DT_FLOAT16)
.Store("store0", "cast0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_EQ(CountNodesByType(graph, Cast::Type), 0U);
}
TEST_F(TestImprovePrecisionST, ScalarBf16PromotedToFp32) {
auto graph = AscGraphBuilder("st_scalar_bf16")
.Loops({Sym("s0")})
.Scalar("scalar0", "1.5", ge::DT_BF16)
.Mul("mul0", "scalar0", "scalar0")
.Store("store0", "mul0")
.Output("output0", "store0", 0, ge::DT_BF16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "scalar0", ge::DT_FLOAT));
EXPECT_TRUE(CheckNodeOutputDtype(graph, "mul0", ge::DT_FLOAT));
}
TEST_F(TestImprovePrecisionST, MultiOutputLoad_CastOnOneBranch) {
auto graph = AscGraphBuilder("st_multi_output_load")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Cast("cast0", "load0", ge::DT_FLOAT16)
.Abs("abs0", "cast0")
.Mul("mul0", "load0", "load0")
.Store("store0", "abs0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "abs0", ge::DT_FLOAT));
EXPECT_TRUE(CheckNodeOutputDtype(graph, "mul0", ge::DT_FLOAT));
}
TEST_F(TestImprovePrecisionST, OnlyLoadAndStore_NoChange) {
auto graph = AscGraphBuilder("st_only_load_store")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Store("store0", "load0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_EQ(CountNodesByType(graph, Cast::Type), 0U);
}
TEST_F(TestImprovePrecisionST, SingleOtherNode_CastInsertedBeforeOther) {
auto graph = AscGraphBuilder("st_single_other")
.Loops({Sym("s0")})
.Data("data0", 0, ge::DT_FLOAT16)
.Load("load0", "data0")
.Relu("relu0", "load0")
.Store("store0", "relu0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
ASSERT_EQ(ImprovePrecisionForAscGraph(graph), ge::SUCCESS);
EXPECT_TRUE(CheckNodeOutputDtype(graph, "relu0", ge::DT_FLOAT));
}