* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gtest/gtest.h"
#include "asc_graph_builder.h"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "pre_process/scalar_broadcast_insert.h"
#include "autoschedule/autoschedule.h"
#include "ascgraph_info_complete.h"
#include "platform_context.h"
#include "runtime_stub.h"
using namespace af;
using namespace af::testing;
using namespace af::pre_process;
using namespace optimize::autoschedule;
namespace {
class TestScalarBroadcastInsert : public ::testing::Test {
protected:
void SetUp() override {
auto stub_v1 = std::make_shared<ge::RuntimeStub>();
ge::RuntimeStub::SetInstance(stub_v1);
}
void TearDown() override {}
};
size_t CountNodesByType(AscGraph &graph, const std::string &type) {
size_t count = 0U;
for (const auto &node : AscGraphUtils::GetComputeGraph(graph)->GetAllNodes()) {
if (node->GetType() == type) {
++count;
}
}
return count;
}
bool HasNodeWithName(AscGraph &graph, const std::string &name) {
for (const auto &node : AscGraphUtils::GetComputeGraph(graph)->GetAllNodes()) {
if (node->GetName() == name) {
return true;
}
}
return false;
}
bool IsConnected(AscGraph &graph, const std::string &src_name, const std::string &dst_name) {
for (const auto &node : AscGraphUtils::GetComputeGraph(graph)->GetAllNodes()) {
if (node->GetName() == src_name) {
auto out_anchor = node->GetOutDataAnchor(0);
if (out_anchor == nullptr) return false;
for (const auto &peer : out_anchor->GetPeerInDataAnchors()) {
if (peer != nullptr && peer->GetOwnerNode()->GetName() == dst_name) {
return true;
}
}
}
}
return false;
}
size_t CountOutDataEdges(AscGraph &graph, const std::string &node_name) {
for (const auto &node : AscGraphUtils::GetComputeGraph(graph)->GetAllNodes()) {
if (node->GetName() == node_name) {
auto out_anchor = node->GetOutDataAnchor(0);
if (out_anchor == nullptr) return 0U;
return out_anchor->GetPeerInDataAnchors().size();
}
}
return 0U;
}
}
TEST_F(TestScalarBroadcastInsert, ScalarDirectToCompute_InsertsBroadcast) {
auto graph = AscGraphBuilder("test_scalar_to_add")
.Loops({Sym("s0"), Sym("s1")})
.Data("data0", 0)
.Load("load0", "data0")
.Scalar("scalar0", "1.0")
.Add("add0", "load0", "scalar0")
.Store("store0", "add0")
.Output("output0", "store0")
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
size_t brc_before = CountNodesByType(graph, ascir_op::Broadcast::Type);
ASSERT_EQ(brc_before, 0U);
auto ret = InsertBroadcastAfterScalarForAscGraph(graph);
ASSERT_EQ(ret, ge::SUCCESS);
size_t brc_after = CountNodesByType(graph, ascir_op::Broadcast::Type);
EXPECT_GT(brc_after, 0U);
EXPECT_FALSE(IsConnected(graph, "scalar0", "add0"));
}
TEST_F(TestScalarBroadcastInsert, ScalarAlreadyHasBroadcast_NoInsert) {
auto graph = AscGraphBuilder("test_scalar_with_brc")
.Loops({Sym("s0"), Sym("s1")})
.Data("data0", 0)
.Load("load0", "data0")
.Scalar("scalar0", "1.0")
.Broadcast("brc0", "scalar0", {Sym("s0"), Sym("s1")})
.Add("add0", "load0", "brc0")
.Store("store0", "add0")
.Output("output0", "store0")
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
size_t brc_before = CountNodesByType(graph, ascir_op::Broadcast::Type);
auto ret = InsertBroadcastAfterScalarForAscGraph(graph);
ASSERT_EQ(ret, ge::SUCCESS);
size_t brc_after = CountNodesByType(graph, ascir_op::Broadcast::Type);
EXPECT_EQ(brc_after, brc_before);
}
TEST_F(TestScalarBroadcastInsert, ScalarDirectToStore_InsertsBroadcast) {
auto graph = AscGraphBuilder("test_scalar_to_store")
.Loops({Sym("s0")})
.Scalar("scalar0", "2.0")
.Store("store0", "scalar0")
.Output("output0", "store0")
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto ret = InsertBroadcastAfterScalarForAscGraph(graph);
ASSERT_EQ(ret, ge::SUCCESS);
size_t brc_after = CountNodesByType(graph, ascir_op::Broadcast::Type);
EXPECT_GT(brc_after, 0U);
EXPECT_FALSE(IsConnected(graph, "scalar0", "store0"));
}
TEST_F(TestScalarBroadcastInsert, ScalarNoDownstream_NoInsert) {
auto graph = AscGraphBuilder("test_scalar_no_downstream")
.Loops({Sym("s0")})
.Scalar("scalar0", "1.0")
.Data("data0", 0)
.Load("load0", "data0")
.Store("store0", "load0")
.Output("output0", "store0")
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto ret = InsertBroadcastAfterScalarForAscGraph(graph);
ASSERT_EQ(ret, ge::SUCCESS);
size_t brc_after = CountNodesByType(graph, ascir_op::Broadcast::Type);
EXPECT_EQ(brc_after, 0U);
}
TEST_F(TestScalarBroadcastInsert, ScalarMultipleDownstreams_InsertsOneSharedBroadcast) {
auto graph = AscGraphBuilder("test_scalar_multi_downstream")
.Loops({Sym("s0")})
.Data("data0", 0)
.Data("data1", 1)
.Load("load0", "data0")
.Load("load1", "data1")
.Scalar("scalar0", "1.0")
.Add("add0", "load0", "scalar0")
.Mul("mul0", "load1", "scalar0")
.Store("store0", "add0")
.Store("store1", "mul0")
.Output("output0", "store0")
.Output("output1", "store1", 1)
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto ret = InsertBroadcastAfterScalarForAscGraph(graph);
ASSERT_EQ(ret, ge::SUCCESS);
size_t brc_after = CountNodesByType(graph, ascir_op::Broadcast::Type);
EXPECT_EQ(brc_after, 1U);
for (const auto &node : AscGraphUtils::GetComputeGraph(graph)->GetAllNodes()) {
if (node->GetType() == ascir_op::Broadcast::Type) {
auto out_anchor = node->GetOutDataAnchor(0);
ASSERT_TRUE(out_anchor != nullptr);
EXPECT_EQ(out_anchor->GetPeerInDataAnchors().size(), 2U);
break;
}
}
}
TEST_F(TestScalarBroadcastInsert, ScalarDirectToOutput_NoInsert) {
auto graph = AscGraphBuilder("test_scalar_to_output")
.Loops({Sym("s0")})
.Scalar("scalar0", "1.0")
.Output("output0", "scalar0")
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto ret = InsertBroadcastAfterScalarForAscGraph(graph);
ASSERT_EQ(ret, ge::SUCCESS);
size_t brc_after = CountNodesByType(graph, ascir_op::Broadcast::Type);
EXPECT_EQ(brc_after, 0U);
}
TEST_F(TestScalarBroadcastInsert, NoScalarInGraph_NoChange) {
auto graph = AscGraphBuilder("test_no_scalar")
.Loops({Sym("s0")})
.Data("data0", 0)
.Load("load0", "data0")
.Abs("abs0", "load0")
.Store("store0", "abs0")
.Output("output0", "store0")
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto ret = InsertBroadcastAfterScalarForAscGraph(graph);
ASSERT_EQ(ret, ge::SUCCESS);
size_t brc_after = CountNodesByType(graph, ascir_op::Broadcast::Type);
EXPECT_EQ(brc_after, 0U);
}
TEST_F(TestScalarBroadcastInsert, MixedScalarOnlyInsertsForDirectOnes) {
auto graph = AscGraphBuilder("test_mixed_scalar")
.Loops({Sym("s0"), Sym("s1")})
.Data("data0", 0)
.Data("data1", 1)
.Load("load0", "data0")
.Load("load1", "data1")
.Scalar("scalar0", "1.0")
.Add("add0", "load0", "scalar0")
.Scalar("scalar1", "2.0")
.Broadcast("brc0", "scalar1", {Sym("s0"), Sym("s1")})
.Mul("mul0", "load1", "brc0")
.Store("store0", "add0")
.Store("store1", "mul0")
.Output("output0", "store0")
.Output("output1", "store1", 1)
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
size_t brc_before = CountNodesByType(graph, ascir_op::Broadcast::Type);
ASSERT_EQ(brc_before, 1U);
auto ret = InsertBroadcastAfterScalarForAscGraph(graph);
ASSERT_EQ(ret, ge::SUCCESS);
size_t brc_after = CountNodesByType(graph, ascir_op::Broadcast::Type);
EXPECT_EQ(brc_after, 2U);
EXPECT_FALSE(IsConnected(graph, "scalar0", "add0"));
EXPECT_TRUE(IsConnected(graph, "scalar1", "brc0"));
}
TEST_F(TestScalarBroadcastInsert, BroadcastDtypeMatchesScalar) {
auto graph = AscGraphBuilder("test_brc_dtype")
.Loops({Sym("s0")})
.Scalar("scalar0", "1.0", ge::DT_FLOAT16)
.Abs("abs0", "scalar0")
.Store("store0", "abs0")
.Output("output0", "store0", 0, ge::DT_FLOAT16)
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto ret = InsertBroadcastAfterScalarForAscGraph(graph);
ASSERT_EQ(ret, ge::SUCCESS);
for (const auto &node : AscGraphUtils::GetComputeGraph(graph)->GetAllNodes()) {
if (node->GetType() == ascir_op::Broadcast::Type) {
auto desc = node->GetOpDesc();
ASSERT_TRUE(desc != nullptr);
auto output_desc = desc->MutableOutputDesc(0);
ASSERT_TRUE(output_desc != nullptr);
EXPECT_EQ(output_desc->GetDataType(), ge::DT_FLOAT16);
break;
}
}
}
TEST_F(TestScalarBroadcastInsert, MultipleScalarsAllGetBroadcast) {
auto graph = AscGraphBuilder("test_multi_scalars")
.Loops({Sym("s0")})
.Data("data0", 0)
.Load("load0", "data0")
.Scalar("scalar0", "1.0")
.Scalar("scalar1", "2.0")
.Add("add0", "load0", "scalar0")
.Mul("mul0", "load0", "scalar1")
.Store("store0", "add0")
.Store("store1", "mul0")
.Output("output0", "store0")
.Output("output1", "store1", 1)
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto ret = InsertBroadcastAfterScalarForAscGraph(graph);
ASSERT_EQ(ret, ge::SUCCESS);
size_t brc_after = CountNodesByType(graph, ascir_op::Broadcast::Type);
EXPECT_EQ(brc_after, 2U);
EXPECT_FALSE(IsConnected(graph, "scalar0", "add0"));
EXPECT_FALSE(IsConnected(graph, "scalar1", "mul0"));
}
TEST_F(TestScalarBroadcastInsert, ScalarDirectToData_NoInsert) {
auto graph = AscGraphBuilder("test_scalar_to_data")
.Loops({Sym("s0")})
.Scalar("scalar0", "1.0")
.Data("data0", 0)
.Output("output0", "data0")
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto ret = InsertBroadcastAfterScalarForAscGraph(graph);
ASSERT_EQ(ret, ge::SUCCESS);
size_t brc_after = CountNodesByType(graph, ascir_op::Broadcast::Type);
EXPECT_EQ(brc_after, 0U);
}
TEST_F(TestScalarBroadcastInsert, ScalarToBothComputeAndOutput_OnlyInterceptsCompute) {
auto graph = AscGraphBuilder("test_scalar_compute_output")
.Loops({Sym("s0")})
.Data("data0", 0)
.Load("load0", "data0")
.Scalar("scalar0", "1.0")
.Add("add0", "load0", "scalar0")
.Store("store0", "add0")
.Output("output0", "store0")
.Output("output1", "scalar0", 1)
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto ret = InsertBroadcastAfterScalarForAscGraph(graph);
ASSERT_EQ(ret, ge::SUCCESS);
EXPECT_EQ(CountNodesByType(graph, ascir_op::Broadcast::Type), 1U);
EXPECT_FALSE(IsConnected(graph, "scalar0", "add0"));
EXPECT_TRUE(IsConnected(graph, "scalar0", "output1"));
}
TEST_F(TestScalarBroadcastInsert, BroadcastReorderBasic) {
auto s0 = Sym(4);
auto s1 = Sym(512 * 1024);
auto graph = AscGraphBuilder("broadcast_reorder_basic")
.Loops({s0, s1})
.Data("data0", 0, {af::sym::kSymbolOne, s1}, {af::sym::kSymbolZero, af::sym::kSymbolOne}, ge::DT_FLOAT16)
.Load("load0", "data0", {af::sym::kSymbolOne, s1}, {af::sym::kSymbolZero, af::sym::kSymbolOne})
.Abs("abs0", "load0")
.Store("store0", "abs0")
.Output("y", "store0", 0, ge::DT_FLOAT16)
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto load_node = graph.FindNode("load0");
ASSERT_NE(load_node, nullptr);
auto original_axis = load_node->attr.sched.axis;
ASSERT_EQ(original_axis.size(), 2UL);
std::vector<AutoScheduleOutput> results;
AutoSchedule schedule(graph, results);
EXPECT_EQ(schedule.DoAutoSchedule(), ge::SUCCESS);
std::vector<int64_t> expected_order = {original_axis[1], original_axis[0]};
EXPECT_EQ(load_node->attr.sched.axis, expected_order);
auto abs_node = graph.FindNode("abs0");
ASSERT_NE(abs_node, nullptr);
EXPECT_EQ(abs_node->attr.sched.axis, expected_order);
}
TEST_F(TestScalarBroadcastInsert, BroadcastReorderSkipBrcTooLarge) {
auto s0 = Sym(32);
auto s1 = Sym(512 * 1024);
auto graph = AscGraphBuilder("broadcast_reorder_large_brc")
.Loops({s0, s1})
.Data("data0", 0, {af::sym::kSymbolOne, s1}, {af::sym::kSymbolZero, af::sym::kSymbolOne}, ge::DT_FLOAT16)
.Load("load0", "data0", {af::sym::kSymbolOne, s1}, {af::sym::kSymbolZero, af::sym::kSymbolOne})
.Abs("abs0", "load0")
.Store("store0", "abs0")
.Output("y", "store0", 0, ge::DT_FLOAT16)
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto load_node = graph.FindNode("load0");
ASSERT_NE(load_node, nullptr);
auto original_axis = load_node->attr.sched.axis;
std::vector<AutoScheduleOutput> results;
AutoSchedule schedule(graph, results);
EXPECT_EQ(schedule.DoAutoSchedule(), ge::SUCCESS);
EXPECT_EQ(load_node->attr.sched.axis, original_axis);
}
TEST_F(TestScalarBroadcastInsert, BroadcastReorderSkipDataTooSmall) {
auto s0 = Sym(4);
auto s1 = Sym(8);
auto graph = AscGraphBuilder("broadcast_reorder_small_data")
.Loops({s0, s1})
.Data("data0", 0, {af::sym::kSymbolOne, s1}, {af::sym::kSymbolZero, af::sym::kSymbolOne}, ge::DT_FLOAT16)
.Load("load0", "data0", {af::sym::kSymbolOne, s1}, {af::sym::kSymbolZero, af::sym::kSymbolOne})
.Abs("abs0", "load0")
.Store("store0", "abs0")
.Output("y", "store0", 0, ge::DT_FLOAT16)
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto load_node = graph.FindNode("load0");
ASSERT_NE(load_node, nullptr);
auto original_axis = load_node->attr.sched.axis;
std::vector<AutoScheduleOutput> results;
AutoSchedule schedule(graph, results);
EXPECT_EQ(schedule.DoAutoSchedule(), ge::SUCCESS);
EXPECT_EQ(load_node->attr.sched.axis, original_axis);
}