* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gtest/gtest.h"
#include <ascendc_ir.h>
#include "ascir.h"
#include <ascir_ops.h>
#include <ascir_utils.h>
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "graph/debug/ge_attr_define.h"
#define private public
#include "optimize.h"
#include "platform_context.h"
#include "optimize/graph_pass/pass_runner_handler.h"
#undef private
#include "asc_tensor_utils.h"
#include "ascgraph_info_complete.h"
#include "ascir_ops_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/debug/ge_op_types.h"
#include "graph/normal_graph/ge_tensor_impl.h"
#include "codegen.h"
#include "fusion/autofuse_attrs.h"
#include "task_generator/transpose_schedule_case_generator.h"
#include "task_generator/reduce_schedule_case_generator.h"
#include "ascgraph_info_complete.h"
#include "schedule_result.h"
#include "attribute_group/attr_group_shape_env.h"
#include "autoschedule/tiling_group.h"
#include "expression/testcase/source_stub.h"
#include "util/mem_utils.h"
#include "platform/platform_factory.h"
#include "runtime_stub.h"
#include "asc_graph_builder.h"
#include "codegen.h"
#include "optimize/graph_pass/pass_utils.h"
using namespace af;
using namespace af::ops;
using namespace af::ascir_op;
using ge::RuntimeStub;
using ge::GraphInputShapeSourceStub;
using ge::InputValueSumSourceStub;
namespace {
class GraphBuilder {
public:
explicit GraphBuilder(const std::string &name) {
graph_ = std::make_shared<ComputeGraph>(name);
}
GraphBuilder(const std::string &name, const std::string &node_type) {
graph_ = std::make_shared<ComputeGraph>(name);
node_type_ = node_type;
}
NodePtr AddNode(const std::string &name, const std::string &type, const int in_cnt, const int out_cnt,
const std::vector<int64_t> shape = {1, 1, 1, 1}) {
auto tensor_desc = std::make_shared<GeTensorDesc>();
tensor_desc->SetShape(GeShape(std::move(shape)));
tensor_desc->SetFormat(ge::FORMAT_NCHW);
tensor_desc->SetDataType(ge::DT_FLOAT);
auto op_desc = std::make_shared<OpDesc>(name, (node_type_ == "") ? type : "AscGraph");
for (std::int32_t i = 0; i < in_cnt; ++i) {
op_desc->AddInputDesc(tensor_desc->Clone());
}
for (std::int32_t i = 0; i < out_cnt; ++i) {
op_desc->AddOutputDesc(tensor_desc->Clone());
}
op_desc->AddInferFunc([](Operator &op) { return ge::GRAPH_SUCCESS; });
return graph_->AddNode(op_desc);
}
void AddDataEdge(const NodePtr &src_node, const std::int32_t src_idx, const NodePtr &dst_node,
const std::int32_t dst_idx) {
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx));
}
ComputeGraphPtr GetGraph() {
graph_->TopologicalSorting();
return graph_;
}
private:
ComputeGraphPtr graph_;
std::string node_type_;
};
}
class TestOptimizer : public ::testing::Test {
protected:
void SetUp() override {
ge::PlatformContext::GetInstance().Reset();
auto stub_v1 = std::make_shared<RuntimeStub>();
RuntimeStub::SetInstance(stub_v1);
dlog_setlevel(ASCGEN_MODULE_NAME, DLOG_ERROR, 0);
}
void TearDown() override {
ge::PlatformContext::GetInstance().Reset();
dlog_setlevel(ASCGEN_MODULE_NAME, DLOG_ERROR, 0);
}
optimize::Optimizer optimizer;
TestOptimizer() : optimizer(optimize::OptimizerOptions{}) {}
static std::stringstream &SizeExprListStr(std::stringstream &ss, const af::AscGraph &graph,
const std::vector<af::Expression> &size_expr_list) {
for (auto &size_expr : size_expr_list) {
ss << std::string(size_expr.Str().get()) << ", ";
}
return ss;
}
static std::stringstream &AxisListStr(std::stringstream &ss, af::AscGraph &graph,
const std::vector<af::AxisId> &axis_list) {
for (auto axis_id : axis_list) {
ss << graph.FindAxis(axis_id)->name << ", ";
}
return ss;
}
};
TEST_F(TestOptimizer, TwoWorkspace) {
af::AscGraph graph("test_graph");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Data x_op("x", graph);
Store store_op1("store1");
Store store_op2("store2");
Workspace workspace_op1("workspace1");
Workspace workspace_op2("workspace2");
Load load_op1("load1");
Load load_op2("load2");
Output y_op1("y1");
Output y_op2("y2");
x_op.y.dtype = ge::DT_FLOAT16;
store_op1.x = x_op.y;
store_op1.y.dtype = ge::DT_FLOAT16;
*store_op1.y.axis = {z0.id, z1.id};
x_op.ir_attr.SetIndex(0);
workspace_op1.x = store_op1.y;
workspace_op1.y.dtype = ge::DT_FLOAT16;
*workspace_op1.y.axis = {z0.id, z1.id};
store_op2.x = x_op.y;
store_op2.y.dtype = ge::DT_FLOAT16;
*store_op2.y.axis = {z0.id, z1.id};
workspace_op2.x = store_op2.y;
workspace_op2.y.dtype = ge::DT_FLOAT16;
*workspace_op2.y.axis = {z0.id, z1.id};
load_op1.x = workspace_op1.y;
load_op1.y.dtype = ge::DT_FLOAT16;
*load_op1.y.axis = {z0.id, z1.id};
load_op2.x = workspace_op2.y;
load_op2.y.dtype = ge::DT_FLOAT16;
*load_op2.y.axis = {z0.id, z1.id};
y_op1.x = load_op1.y;
y_op2.x = load_op2.y;
y_op1.ir_attr.SetIndex(0);
y_op2.ir_attr.SetIndex(1);
auto x = graph.FindNode("x");
auto load1 = graph.FindNode("load1");
auto load2 = graph.FindNode("load2");
auto workspace1 = graph.FindNode("workspace1");
auto workspace2 = graph.FindNode("workspace2");
auto store1 = graph.FindNode("store1");
auto store2 = graph.FindNode("store2");
auto y1 = graph.FindNode("y1");
auto y2 = graph.FindNode("y2");
optimizer.BufQueAlloc(graph, graph);
EXPECT_EQ(workspace1->outputs[0].attr.mem.tensor_id, store1->outputs[0].attr.mem.tensor_id);
EXPECT_EQ(workspace2->outputs[0].attr.mem.tensor_id, store2->outputs[0].attr.mem.tensor_id);
}
TEST_F(TestOptimizer, ReOrderMergeAxisGraph_scheduler) {
af::AscGraph graph("test_graph");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data data_i("data_i", graph);
data_i.attr.sched.axis = {z0.id, z1.id};
data_i.y.dtype = ge::DT_FLOAT16;
*data_i.y.axis = {z0.id, z1.id};
data_i.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data_i.ir_attr.SetIndex(0);
af::ascir_op::Load load_i("load_i");
load_i.x = data_i.y;
load_i.attr.sched.axis = {z0.id, z1.id};
*load_i.y.axis = {z0.id, z1.id};
*load_i.y.repeats = {s0, s1};
*load_i.y.strides = {s1, One};
load_i.attr.api.compute_type = af::ComputeType::kComputeLoad;
af::ascir_op::Abs abs("abs");
graph.AddNode(abs);
abs.x = load_i.y;
abs.attr.sched.axis = {z0.id, z1.id};
abs.y.dtype = ge::DT_FLOAT16;
*abs.y.axis = {z0.id, z1.id};
*abs.y.repeats = {s0, s1};
*abs.y.strides = {s1, One};
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
af::ascir_op::Store store("store");
store.x = abs.y;
store.attr.sched.axis = {z0.id, z1.id};
store.y.dtype = ge::DT_FLOAT16;
*store.y.axis = {z0.id, z1.id};
*store.y.repeats = {s0, s1};
*store.y.strides = {s1, One};
store.attr.api.compute_type = af::ComputeType::kComputeStore;
af::ascir_op::Output y("y");
y.x = store.y;
y.attr.sched.axis = {z0.id, z1.id};
y.y.dtype = ge::DT_FLOAT16;
*y.y.axis = {z0.id, z1.id};
y.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y.ir_attr.SetIndex(0);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, ge::SUCCESS);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1UL);
auto optimize_graph = fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0];
auto abs_sched = optimize_graph.FindNode("abs");
std::stringstream ss;
SizeExprListStr(ss, optimize_graph, abs_sched->outputs[0].attr.repeats);
std::stringstream ss1;
SizeExprListStr(ss1, optimize_graph, abs_sched->outputs[0].attr.strides);
std::stringstream ss2;
AxisListStr(ss2, optimize_graph, abs_sched->outputs[0].attr.axis);
EXPECT_EQ(ss.str(), "(s0 * s1 / (z0z1Tb_size * z0z1t_size)), z0z1Tb_size, z0z1t_size, ");
EXPECT_EQ(ss1.str(), "(z0z1Tb_size * z0z1t_size), z0z1t_size, 1, ");
EXPECT_EQ(ss2.str(), "z0z1TB, z0z1Tb, z0z1t, ");
}
TEST_F(TestOptimizer, BufQueAlloc_WhenOutputNode_WillUseInputTensorAsOutput) {
af::AscGraph graph("test_graph");
auto s0 = graph.CreateSizeVar("s0");
auto z0 = graph.CreateAxis("z0", s0);
Data x_op("x", graph);
x_op.y.dtype = ge::DT_FLOAT16;
Load load_op("load");
load_op.x = x_op.y;
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.axis = {z0.id};
af::ascir_op::Abs vec_op("vec");
vec_op.x = load_op.y;
vec_op.y.dtype = ge::DT_FLOAT16;
*vec_op.y.axis = {z0.id};
Store store_op("store");
store_op.x = vec_op.y;
store_op.y.dtype = ge::DT_FLOAT16;
*store_op.y.axis = {z0.id};
Workspace y_op("y");
y_op.x = store_op.y;
Load load_op2("load2");
load_op2.x = y_op.y;
auto store = graph.FindNode("store");
auto y = graph.FindNode("y");
optimizer.BufQueAlloc(graph, graph);
EXPECT_EQ(y->outputs[0].attr.mem.tensor_id, store->outputs[0].attr.mem.tensor_id);
}
TEST_F(TestOptimizer, BufQueAlloc_TempBuffer) {
af::AscGraph graph("test_graph");
auto s0 = graph.CreateSizeVar("s0");
auto z0 = graph.CreateAxis("z0", s0);
Data x_op("x", graph);
x_op.y.dtype = ge::DT_FLOAT16;
x_op.ir_attr.SetIndex(0);
Load load_op("load");
load_op.x = x_op.y;
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.axis = {z0.id};
af::ascir_op::Broadcast brc_op("brc");
brc_op.x = load_op.y;
brc_op.y.dtype = ge::DT_FLOAT16;
*brc_op.y.axis = {z0.id};
Store store_op("store");
store_op.x = brc_op.y;
store_op.y.dtype = ge::DT_FLOAT16;
*store_op.y.axis = {z0.id};
Workspace y_op("y");
y_op.x = store_op.y;
Load load_op2("load2");
load_op2.x = y_op.y;
Status status = optimizer.BufQueAlloc(graph, graph);
ASSERT_EQ(status, ge::SUCCESS);
auto brc = graph.FindNode("brc");
ASSERT_NE(brc, nullptr);
ASSERT_EQ(brc->attr.tmp_buffers.size(), 1);
}
TEST_F(TestOptimizer, ConstantToStoreNeedBroadCast) {
af::AscGraph graph("test_graph");
auto s0 = graph.CreateSizeVar(128);
auto z0 = graph.CreateAxis("z0", s0);
Scalar const_op("const", graph);
Store store_op("store");
Output output_op("output");
const_op.attr.sched.axis = {z0.id};
const_op.ir_attr.SetValue("998.998f");
const_op.y.dtype = ge::DT_FLOAT;
*const_op.y.strides = {af::ops::One};
*const_op.y.repeats = {s0};
store_op.attr.sched.axis = {z0.id};
store_op.x = const_op.y;
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id};
*store_op.y.strides = {af::ops::One};
*store_op.y.repeats = {s0};
output_op.attr.sched.axis = {z0.id};
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, ge::SUCCESS);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1UL);
auto optimize_graph = fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0];
auto cg = af::AscGraphUtils::GetComputeGraph(optimize_graph);
auto found_broadcast = cg->FindFirstNodeMatchType(af::ascir_op::Broadcast::Type);
ASSERT_NE(found_broadcast, nullptr);
auto asc_broadcast = af::AscNode(found_broadcast->GetOpDesc(), nullptr);
auto found_store = cg->FindFirstNodeMatchType(af::ascir_op::Store::Type);
ASSERT_NE(found_store, nullptr);
auto asc_store = af::AscNode(found_store->GetOpDesc(), nullptr);
}
TEST_F(TestOptimizer, ScalarConstantToStore) {
af::AscGraph graph("scalar_const_graph");
Scalar const_op("const", graph);
Store store_op("store");
Output output_op("output");
const_op.ir_attr.SetValue("998.998f");
const_op.y.dtype = ge::DT_FLOAT;
const_op.attr.api.compute_type = af::ComputeType::kComputeInvalid;
store_op.x = const_op.y;
store_op.y.dtype = ge::DT_FLOAT;
store_op.attr.api.compute_type = af::ComputeType::kComputeElewise;
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
output_op.attr.api.compute_type = af::ComputeType::kComputeInvalid;
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
ASSERT_EQ(optimizer.GraphPass(graph), ge::SUCCESS);
auto cg = af::AscGraphUtils::GetComputeGraph(graph);
auto found_broadcast = cg->FindFirstNodeMatchType(af::ascir_op::Broadcast::Type);
ASSERT_NE(found_broadcast, nullptr);
auto asc_broadcast = af::AscNode(found_broadcast->GetOpDesc(), cg);
EXPECT_TRUE(!asc_broadcast.attr.sched.axis.empty());
auto found_store = cg->FindFirstNodeMatchType(af::ascir_op::Store::Type);
ASSERT_NE(found_store, nullptr);
auto asc_store = af::AscNode(found_store->GetOpDesc(), cg);
EXPECT_TRUE(!asc_store.attr.sched.axis.empty());
}
TEST_F(TestOptimizer, SplitMultiOutputsData) {
af::AscGraph graph("multi_outputs");
auto s0 = graph.CreateSizeVar(128);
auto s1 = graph.CreateSizeVar(32);
auto s2 = graph.CreateSizeVar(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT;
data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data1.ir_attr.SetIndex(11);
Load a("a");
a.attr.sched.axis = {z0.id, z1.id};
a.x = data1.y;
*a.y.axis = {z0.id, z1.id};
a.y.dtype = ge::DT_FLOAT;
*a.y.strides = {s1, af::ops::One};
*a.y.repeats = {s0, s1};
Load d("d");
d.attr.sched.axis = {z0.id, z1.id};
d.x = data1.y;
*d.y.axis = {z0.id, z1.id};
d.y.dtype = ge::DT_FLOAT;
*d.y.strides = {s2, af::ops::One};
*d.y.repeats = {s0, s2};
Concat c("c");
c.attr.sched.axis = {z0.id, z1.id};
c.attr.api.compute_type = af::ComputeType::kComputeConcat;
c.x = {a.y, d.y};
*c.y.axis = {z0.id, z1.id};
c.y.dtype = ge::DT_FLOAT;
*c.y.strides = {s1 + s2, af::ops::One};
*c.y.repeats = {s0, s1 + s2};
Store f("f");
f.attr.sched.axis = {z0.id, z1.id};
f.x = c.y;
*f.y.axis = {z0.id, z1.id};
f.y.dtype = ge::DT_FLOAT;
*f.y.strides = {s1, af::ops::One};
*f.y.repeats = {s0, s1};
Output output3("output3");
output3.x = f.y;
output3.y.dtype = ge::DT_FLOAT;
output3.attr.api.compute_type = af::ComputeType::kComputeInvalid;
output3.ir_attr.SetIndex(2);
EXPECT_EQ(optimizer.GraphPass(graph), 0);
auto load0_node = graph.FindNode("a");
auto load1_node = graph.FindNode("d");
ASSERT_NE(load0_node, nullptr);
ASSERT_NE(load1_node, nullptr);
auto data0_node = dynamic_cast<af::AscNode *>(af::ascir::AscTensorUtils::GetOwner(load0_node->inputs[0]));
auto data1_node = dynamic_cast<af::AscNode *>(af::ascir::AscTensorUtils::GetOwner(load1_node->inputs[0]));
ASSERT_NE(data0_node, nullptr);
ASSERT_NE(data1_node, nullptr);
auto ir_attr0 = data0_node->attr.ir_attr->DownCastTo<af::AscDataIrAttrDef>();
ASSERT_NE(ir_attr0, nullptr);
auto ir_attr1 = data1_node->attr.ir_attr->DownCastTo<af::AscDataIrAttrDef>();
ASSERT_NE(ir_attr1, nullptr);
int64_t idx0;
int64_t idx1;
ir_attr0->GetIndex(idx0);
ir_attr1->GetIndex(idx1);
EXPECT_EQ(idx0, 11);
EXPECT_EQ(idx1, 11);
}
TEST_F(TestOptimizer, TestSplitConcatAxisPass) {
af::AscGraph graph("concat_axis_graph");
auto s0 = graph.CreateSizeVar(128);
auto s1 = graph.CreateSizeVar(32);
auto s2 = graph.CreateSizeVar(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT;
data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data1.ir_attr.SetIndex(0);
Data data2("data2", graph);
data2.y.dtype = ge::DT_FLOAT;
data2.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data2.ir_attr.SetIndex(1);
Load a("a");
a.attr.sched.axis = {z0.id, z1.id};
a.x = data1.y;
*a.y.axis = {z0.id, z1.id};
a.y.dtype = ge::DT_FLOAT;
*a.y.strides = {s1, af::ops::One};
*a.y.repeats = {s0, s1};
Exp b("b");
b.attr.sched.axis = {z0.id, z1.id};
b.x = a.y;
*b.y.axis = {z0.id, z1.id};
b.y.dtype = ge::DT_FLOAT;
*b.y.strides = {s1, af::ops::One};
*b.y.repeats = {s0, s1};
Load d("d");
d.attr.sched.axis = {z0.id, z1.id};
d.x = data2.y;
*d.y.axis = {z0.id, z1.id};
d.y.dtype = ge::DT_FLOAT;
*d.y.strides = {s2, af::ops::One};
*d.y.repeats = {s0, s2};
Concat c("c");
c.attr.sched.axis = {z0.id, z1.id};
c.attr.api.compute_type = af::ComputeType::kComputeConcat;
c.x = {b.y, d.y};
*c.y.axis = {z0.id, z1.id};
c.y.dtype = ge::DT_FLOAT;
*c.y.strides = {s1 + s2, af::ops::One};
*c.y.repeats = {s0, s1 + s2};
Exp e("e");
e.attr.sched.axis = {z0.id, z1.id};
e.x = b.y;
*e.y.axis = {z0.id, z1.id};
e.y.dtype = ge::DT_FLOAT;
*e.y.strides = {s1, af::ops::One};
*e.y.repeats = {s0, s1};
Store f("f");
f.attr.sched.axis = {z0.id, z1.id};
f.x = e.y;
*f.y.axis = {z0.id, z1.id};
f.y.dtype = ge::DT_FLOAT;
*f.y.strides = {s1, af::ops::One};
*f.y.repeats = {s0, s1};
Store i("i");
i.attr.sched.axis = {z0.id, z1.id};
i.x = d.y;
*i.y.axis = {z0.id, z1.id};
i.y.dtype = ge::DT_FLOAT;
*i.y.strides = {s2, af::ops::One};
*i.y.repeats = {s0, s2};
Exp g("g");
g.attr.sched.axis = {z0.id, z1.id};
g.x = c.y;
*g.y.axis = {z0.id, z1.id};
g.y.dtype = ge::DT_FLOAT;
*g.y.strides = {s1 + s2, af::ops::One};
*g.y.repeats = {s0, s1 + s2};
Store h("h");
h.attr.sched.axis = {z0.id, z1.id};
h.x = g.y;
*h.y.axis = {z0.id, z1.id};
h.y.dtype = ge::DT_FLOAT;
*h.y.strides = {s1 + s2, af::ops::One};
*h.y.repeats = {s0, s1 + s2};
Output output1("output1");
output1.x = f.y;
output1.y.dtype = ge::DT_FLOAT;
output1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
output1.ir_attr.SetIndex(0);
Output output2("output2");
output2.x = h.y;
output2.y.dtype = ge::DT_FLOAT;
output2.attr.api.compute_type = af::ComputeType::kComputeInvalid;
output2.ir_attr.SetIndex(1);
Output output3("output3");
output3.x = i.y;
output3.y.dtype = ge::DT_FLOAT;
output3.attr.api.compute_type = af::ComputeType::kComputeInvalid;
output3.ir_attr.SetIndex(2);
::ascir::utils::DumpGraph(graph, "Before_");
EXPECT_EQ(optimizer.GraphPass(graph), 0);
::ascir::utils::DumpGraph(graph, "After_");
auto a_node = graph.FindNode("a");
auto c_node = graph.FindNode("c");
auto f_node = graph.FindNode("f");
auto h_node = graph.FindNode("h");
auto i_node = graph.FindNode("i");
auto size_c_out = c_node->attr.sched.axis;
auto size_c_in0 = c_node->inputs[0].attr.axis;
auto size_c_in1 = c_node->inputs[1].attr.axis;
EXPECT_EQ(a_node->attr.sched.axis, size_c_in0);
EXPECT_EQ(a_node->outputs[0].attr.axis, size_c_in0);
EXPECT_EQ(f_node->attr.sched.axis, size_c_in0);
EXPECT_EQ(f_node->outputs[0].attr.axis, size_c_in0);
EXPECT_EQ(h_node->attr.sched.axis, size_c_out);
EXPECT_EQ(h_node->outputs[0].attr.axis, size_c_out);
EXPECT_EQ(i_node->attr.sched.axis, size_c_in1);
EXPECT_EQ(i_node->outputs[0].attr.axis, size_c_in1);
}
TEST_F(TestOptimizer, ConcatFirstDim) {
af::AscGraph graph("concat_1st_dim_graph");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto z2 = graph.CreateAxis("z2", s2);
auto z3 = graph.CreateAxis("z3", s0 + s1);
Data x1_op("x1", graph);
x1_op.ir_attr.SetIndex(0);
Data x2_op("x2", graph);
x2_op.ir_attr.SetIndex(1);
Data x3_op("x3", graph);
x3_op.ir_attr.SetIndex(2);
Load load_op1("load1");
Load load_op2("load2");
Load load_op3("load3");
std::vector<Data> all_data{x1_op, x2_op, x3_op};
std::vector<Load> all_load{load_op1, load_op2, load_op3};
for (size_t i = 0U; i < all_data.size(); ++i) {
auto &x_op = all_data[i];
auto &load_op = all_load[i];
x_op.y.dtype = ge::DT_FLOAT16;
load_op.x = x_op.y;
load_op.attr.sched.axis = {z3.id, z2.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.axis = {z3.id, z2.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.strides = {s2, af::ops::One};
*load_op.y.repeats = {s0, s2};
}
load_op3.attr.sched.axis = {z3.id, z2.id};
*load_op3.y.axis = {z3.id, z2.id};
*load_op3.y.repeats = {s1, s2};
af::ascir_op::Add add_op("add");
add_op.attr.sched.axis = {z3.id, z2.id};
add_op.x1 = load_op1.y;
add_op.x2 = load_op2.y;
add_op.y.dtype = ge::DT_FLOAT16;
*add_op.y.axis = {z3.id, z2.id};
*add_op.y.strides = {s2, af::ops::One};
*add_op.y.repeats = {s0, s2};
af::ascir_op::Abs abs_op("abs");
abs_op.attr.sched.axis = {z3.id, z2.id};
abs_op.x = load_op3.y;
abs_op.y.dtype = ge::DT_FLOAT16;
*abs_op.y.axis = {z3.id, z2.id};
*abs_op.y.strides = {s2, af::ops::One};
*abs_op.y.repeats = {s1, s2};
af::ascir_op::Concat concat_op("concat");
concat_op.attr.sched.axis = {z3.id, z2.id};
concat_op.x = {add_op.y, abs_op.y};
concat_op.y.dtype = ge::DT_FLOAT16;
*concat_op.y.axis = {z3.id, z2.id};
*concat_op.y.repeats = {s0 + s1, s2};
*concat_op.y.strides = {s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z3.id, z2.id};
store_op.x = concat_op.y;
store_op.y.dtype = ge::DT_FLOAT16;
*store_op.y.axis = {z3.id, z2.id};
*store_op.y.repeats = {s0 + s1, s2};
*store_op.y.strides = {s2, af::ops::One};
Output y_op("y");
y_op.x = store_op.y;
y_op.ir_attr.SetIndex(0);
auto store_node = graph.FindNode("store");
setenv("AUTOFUSE_DFX_FLAGS", "codegen_compile_debug=true;debug_dir=./TestDump", 1);
::ascir::utils::ResetDumpConfig();
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, 0);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 1);
auto &schedule_result = fused_scheduled_result.node_idx_to_scheduled_results[0][0];
std::vector<af::Expression> offsets;
std::vector<af::Expression> expect = {af::Symbol(0), (s0 * s2), af::Symbol(0), (s0 * s2)};
for (const auto &schedule_group : schedule_result.schedule_groups) {
for (auto &sub_impl_graph : schedule_group.impl_graphs) {
for (const auto &sub_node : sub_impl_graph.GetAllNodes()) {
if (sub_node->GetType() == "Store") {
af::Expression offset;
EXPECT_EQ(sub_node->attr.ir_attr->GetAttrValue("offset", offset), 0);
offsets.emplace_back(offset);
}
}
}
}
unsetenv("AUTOFUSE_DFX_FLAGS");
::ascir::utils::ResetDumpConfig();
for (size_t i = 0; i < offsets.size(); ++i) {
EXPECT_SYMBOL_EQ(offsets[i], expect[i]);
}
EXPECT_EQ(fused_scheduled_result.input_nodes.size(), 3);
EXPECT_EQ(fused_scheduled_result.output_nodes.size(), 1);
EXPECT_EQ(fused_scheduled_result.workspace_nodes.size(), 0);
EXPECT_EQ(fused_scheduled_result.input_nodes[0]->GetName(), "x1");
EXPECT_EQ(fused_scheduled_result.input_nodes[1]->GetName(), "x2");
EXPECT_EQ(fused_scheduled_result.input_nodes[2]->GetName(), "x3");
EXPECT_EQ(fused_scheduled_result.output_nodes[0]->GetName(), "y");
std::set<std::string> axis_names_0;
std::set<std::string> axis_names_1;
for (const auto &axis : schedule_result.schedule_groups[0].impl_graphs[0].GetAllAxis()) {
axis_names_0.emplace(axis->name);
}
for (const auto &axis : schedule_result.schedule_groups[1].impl_graphs[0].GetAllAxis()) {
axis_names_1.emplace(axis->name);
}
std::set<std::string> expected_0{"z3z2_1", "z3z2_1T", "z3z2_1TB", "z3z2_1Tb", "z3z2_1t"};
std::set<std::string> expected_1{"z3z2_0", "z3z2_0T", "z3z2_0TB", "z3z2_0Tb", "z3z2_0t"};
EXPECT_EQ(axis_names_0, expected_0);
EXPECT_EQ(axis_names_1, expected_1);
}
TEST_F(TestOptimizer, ConcatTailDim) {
af::AscGraph graph("concat_last_dim_graph");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1 + s2);
Data x1_op("x1", graph);
x1_op.ir_attr.SetIndex(0);
Data x2_op("x2", graph);
x2_op.ir_attr.SetIndex(1);
Data x3_op("x3", graph);
x3_op.ir_attr.SetIndex(2);
Load load_op1("load1");
Load load_op2("load2");
Load load_op3("load3");
std::vector<Data> all_data{x1_op, x2_op, x3_op};
std::vector<Load> all_load{load_op1, load_op2, load_op3};
for (size_t i = 0U; i < all_data.size(); ++i) {
auto &x_op = all_data[i];
auto &load_op = all_load[i];
x_op.y.dtype = ge::DT_FLOAT16;
load_op.x = x_op.y;
load_op.attr.sched.axis = {z0.id, z1.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.axis = {z0.id, z1.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.repeats = {s0, s1};
*load_op.y.strides = {s1, af::ops::One};
}
load_op3.attr.sched.axis = {z0.id, z1.id};
*load_op3.y.axis = {z0.id, z1.id};
*load_op3.y.repeats = {s0, s2};
*load_op3.y.strides = {s2, af::ops::One};
af::ascir_op::Add add_op("add");
add_op.attr.sched.axis = {z0.id, z1.id};
add_op.x1 = load_op1.y;
add_op.x2 = load_op2.y;
add_op.y.dtype = ge::DT_FLOAT16;
*add_op.y.axis = {z0.id, z1.id};
*add_op.y.strides = {s1, af::ops::One};
*add_op.y.repeats = {s0, s1};
af::ascir_op::Abs abs_op("abs");
abs_op.attr.sched.axis = {z0.id, z1.id};
abs_op.x = load_op3.y;
abs_op.y.dtype = ge::DT_FLOAT16;
*abs_op.y.axis = {z0.id, z1.id};
*abs_op.y.strides = {s2, af::ops::One};
*abs_op.y.repeats = {s0, s2};
af::ascir_op::Concat concat_op("concat");
concat_op.attr.sched.axis = {z0.id, z1.id};
concat_op.x = {add_op.y, abs_op.y};
concat_op.y.dtype = ge::DT_FLOAT16;
*concat_op.y.axis = {z0.id, z1.id};
*concat_op.y.repeats = {s0, s1 + s2};
*concat_op.y.strides = {s1 + s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id};
store_op.x = concat_op.y;
store_op.y.dtype = ge::DT_FLOAT16;
*store_op.y.axis = {z0.id, z1.id};
*store_op.y.repeats = {s0, s1 + s2};
*store_op.y.strides = {s1 + s2, af::ops::One};
store_op.ir_attr.SetOffset(af::Symbol(0));
Output y_op("y");
y_op.x = store_op.y;
y_op.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
ASSERT_EQ(res, 0);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3);
auto &schedule_result = fused_scheduled_result.node_idx_to_scheduled_results[0][1];
std::vector<af::Expression> offsets;
std::vector<af::Expression> expect = {af::Symbol(0), s1};
for (const auto &schedule_group : schedule_result.schedule_groups) {
auto &sub_impl_graph = schedule_group.impl_graphs.front();
for (const auto &sub_node : sub_impl_graph.GetAllNodes()) {
if (sub_node->GetType() == "Store") {
af::Expression offset;
EXPECT_EQ(sub_node->attr.ir_attr->GetAttrValue("offset", offset), 0);
offsets.emplace_back(offset);
}
}
}
for (size_t i = 0; i < offsets.size(); ++i) {
EXPECT_SYMBOL_EQ(offsets[i], expect[i]);
}
EXPECT_EQ(fused_scheduled_result.input_nodes.size(), 3);
EXPECT_EQ(fused_scheduled_result.output_nodes.size(), 1);
EXPECT_EQ(fused_scheduled_result.workspace_nodes.size(), 0);
EXPECT_EQ(fused_scheduled_result.input_nodes[0]->GetName(), "x1");
EXPECT_EQ(fused_scheduled_result.input_nodes[1]->GetName(), "x2");
EXPECT_EQ(fused_scheduled_result.input_nodes[2]->GetName(), "x3");
EXPECT_EQ(fused_scheduled_result.output_nodes[0]->GetName(), "y");
}
TEST_F(TestOptimizer, ConcatTailDim_OutputOrderReversed) {
af::AscGraph graph("concat_last_dim_graph");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1 + s2);
Data x1_op("x1", graph);
x1_op.ir_attr.SetIndex(0);
Data x2_op("x2", graph);
x2_op.ir_attr.SetIndex(1);
Data x3_op("x3", graph);
x3_op.ir_attr.SetIndex(2);
Load load_op1("load1");
Load load_op2("load2");
Load load_op3("load3");
std::vector<Data> all_data{x1_op, x2_op, x3_op};
std::vector<Load> all_load{load_op1, load_op2, load_op3};
for (size_t i = 0U; i < all_data.size(); ++i) {
auto &x_op = all_data[i];
auto &load_op = all_load[i];
x_op.y.dtype = ge::DT_FLOAT16;
load_op.x = x_op.y;
load_op.attr.sched.axis = {z0.id, z1.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.axis = {z0.id, z1.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.repeats = {s0, s1};
*load_op.y.strides = {s1, af::ops::One};
}
load_op3.attr.sched.axis = {z0.id, z1.id};
*load_op3.y.axis = {z0.id, z1.id};
*load_op3.y.repeats = {s0, s2};
*load_op3.y.strides = {s2, af::ops::One};
af::ascir_op::Add add_op("add");
add_op.attr.sched.axis = {z0.id, z1.id};
add_op.x1 = load_op1.y;
add_op.x2 = load_op2.y;
add_op.y.dtype = ge::DT_FLOAT16;
*add_op.y.axis = {z0.id, z1.id};
*add_op.y.strides = {s1, af::ops::One};
*add_op.y.repeats = {s0, s1};
af::ascir_op::Abs abs_op("abs");
abs_op.attr.sched.axis = {z0.id, z1.id};
abs_op.x = load_op3.y;
abs_op.y.dtype = ge::DT_FLOAT16;
*abs_op.y.axis = {z0.id, z1.id};
*abs_op.y.strides = {s2, af::ops::One};
*abs_op.y.repeats = {s0, s2};
af::ascir_op::Concat concat_op("concat");
concat_op.attr.sched.axis = {z0.id, z1.id};
concat_op.x = {abs_op.y, add_op.y};
concat_op.y.dtype = ge::DT_FLOAT16;
*concat_op.y.axis = {z0.id, z1.id};
*concat_op.y.repeats = {s0, s1 + s2};
*concat_op.y.strides = {s1 + s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id};
store_op.x = concat_op.y;
store_op.y.dtype = ge::DT_FLOAT16;
*store_op.y.axis = {z0.id, z1.id};
*store_op.y.repeats = {s0, s1 + s2};
*store_op.y.strides = {s1 + s2, af::ops::One};
store_op.ir_attr.SetOffset(af::Symbol(0));
Output y_op("y");
y_op.x = store_op.y;
y_op.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, 0);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3);
auto &schedule_result = fused_scheduled_result.node_idx_to_scheduled_results[0][1];
std::vector<af::Expression> offsets;
std::vector<af::Expression> expect = {s2, af::Symbol(0)};
for (const auto &schedule_group : schedule_result.schedule_groups) {
auto &sub_impl_graph = schedule_group.impl_graphs.front();
for (const auto &sub_node : sub_impl_graph.GetAllNodes()) {
if (sub_node->GetType() == "Store") {
af::Expression offset;
EXPECT_EQ(sub_node->attr.ir_attr->GetAttrValue("offset", offset), 0);
offsets.emplace_back(offset);
}
}
}
for (size_t i = 0; i < offsets.size(); ++i) {
EXPECT_SYMBOL_EQ(offsets[i], expect[i]);
}
EXPECT_EQ(fused_scheduled_result.input_nodes.size(), 3);
EXPECT_EQ(fused_scheduled_result.output_nodes.size(), 1);
EXPECT_EQ(fused_scheduled_result.workspace_nodes.size(), 0);
EXPECT_EQ(fused_scheduled_result.input_nodes[0]->GetName(), "x1");
EXPECT_EQ(fused_scheduled_result.input_nodes[1]->GetName(), "x2");
EXPECT_EQ(fused_scheduled_result.input_nodes[2]->GetName(), "x3");
EXPECT_EQ(fused_scheduled_result.output_nodes[0]->GetName(), "y");
}
TEST_F(TestOptimizer, ConcatTailDim_sharing_input) {
af::AscGraph graph("concat_last_dim_graph");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto tmp = graph.CreateAxis("tmp", s0);
auto z0 = graph.CreateAxis("z0", s0);
auto tmp1 = graph.CreateAxis("tmp1", s0);
auto z1 = graph.CreateAxis("z1", s1 + s1);
auto tmp2 = graph.CreateAxis("tmp2", s0);
Data x1_op("x1", graph);
x1_op.ir_attr.SetIndex(0);
Data x2_op("x2", graph);
x2_op.ir_attr.SetIndex(1);
Load load_op1("load1");
Load load_op2("load2");
std::vector<Data> all_data{x1_op, x2_op};
std::vector<Load> all_load{load_op1, load_op2};
for (size_t i = 0U; i < all_data.size(); ++i) {
auto &x_op = all_data[i];
auto &load_op = all_load[i];
x_op.y.dtype = ge::DT_FLOAT16;
load_op.x = x_op.y;
load_op.attr.sched.axis = {z0.id, z1.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.axis = {z0.id, z1.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.repeats = {s0, s1};
*load_op.y.strides = {s1, af::ops::One};
}
af::ascir_op::Add add_op("add");
add_op.attr.sched.axis = {z0.id, z1.id};
add_op.x1 = load_op1.y;
add_op.x2 = load_op2.y;
add_op.y.dtype = ge::DT_FLOAT16;
*add_op.y.axis = {z0.id, z1.id};
*add_op.y.strides = {s1, af::ops::One};
*add_op.y.repeats = {s0, s1};
af::ascir_op::Concat concat_op("concat");
concat_op.attr.sched.axis = {z0.id, z1.id};
concat_op.x = {add_op.y, add_op.y};
concat_op.y.dtype = ge::DT_FLOAT16;
*concat_op.y.axis = {z0.id, z1.id};
*concat_op.y.repeats = {s0, s1 + s1};
*concat_op.y.strides = {s1 + s1, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id};
store_op.x = concat_op.y;
store_op.y.dtype = ge::DT_FLOAT16;
*store_op.y.axis = {z0.id, z1.id};
*store_op.y.repeats = {s0, s1 + s1};
*store_op.y.strides = {s1 + s1, af::ops::One};
store_op.ir_attr.SetOffset(af::Symbol(0));
Output y_op("y");
y_op.x = store_op.y;
y_op.ir_attr.SetIndex(0);
auto axis = graph.GetAllAxis();
axis.erase(axis.begin() + 4);
axis.erase(axis.begin() + 2);
axis.erase(axis.begin());
const auto graph_attr = af::AscGraphUtils::GetComputeGraph(graph)->GetOrCreateAttrsGroup<af::AscGraphAttr>();
graph_attr->axis = axis;
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, 0);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3);
auto &schedule_result = fused_scheduled_result.node_idx_to_scheduled_results[0][1];
std::vector<af::Expression> offsets;
std::vector<af::Expression> expect = {af::Symbol(0), s1};
for (const auto &schedule_group : schedule_result.schedule_groups) {
auto &sub_impl_graph = schedule_group.impl_graphs.front();
auto res_axis = sub_impl_graph.GetAllAxis();
for (size_t i = 0; i < res_axis.size(); i++) {
EXPECT_EQ(res_axis[i]->id, i);
}
for (const auto &sub_node : sub_impl_graph.GetAllNodes()) {
if (sub_node->GetType() == "Store") {
af::Expression offset;
EXPECT_EQ(sub_node->attr.ir_attr->GetAttrValue("offset", offset), 0);
offsets.emplace_back(offset);
}
}
}
for (size_t i = 0; i < offsets.size(); ++i) {
EXPECT_SYMBOL_EQ(offsets[i], expect[i]);
}
EXPECT_EQ(fused_scheduled_result.input_nodes.size(), 2);
EXPECT_EQ(fused_scheduled_result.output_nodes.size(), 1);
EXPECT_EQ(fused_scheduled_result.workspace_nodes.size(), 0);
EXPECT_EQ(fused_scheduled_result.input_nodes[0]->GetName(), "x1");
EXPECT_EQ(fused_scheduled_result.input_nodes[1]->GetName(), "x2");
EXPECT_EQ(fused_scheduled_result.output_nodes[0]->GetName(), "y");
}
void CreateAscBackendGraphTwoInTwoOut(std::shared_ptr<af::AscGraph> &graph, const std::string &prefix,
int64_t axis_num = 2) {
auto ONE = af::Symbol(1);
std::vector<int64_t> axis_ids;
std::vector<af::Expression> repeats;
for (int64_t i = 0; i < axis_num; ++i) {
const af::Expression exp = graph->CreateSizeVar("s" + std::to_string(i));
auto axis = graph->CreateAxis("z" + std::to_string(i), exp);
axis_ids.push_back(i);
repeats.push_back(exp);
}
std::vector<af::Expression> strides(repeats.size(), One);
if (axis_num > 1) {
for (int64_t i = axis_num - 2; i >= 0; --i) {
strides[i] = repeats[i + 1] * strides[i + 1];
}
}
af::ascir_op::Data data0(std::string(prefix + "_data0").c_str(), *graph);
data0.attr.sched.axis = axis_ids;
*data0.y.axis = axis_ids;
*data0.y.repeats = repeats;
*data0.y.strides = strides;
data0.ir_attr.SetIndex(0);
data0.y.dtype = ge::DT_INT8;
af::ascir_op::Load load0(std::string(prefix + "_load0").c_str());
load0.x = data0.y;
load0.attr.sched.axis = axis_ids;
*load0.y.axis = axis_ids;
*load0.y.repeats = repeats;
*load0.y.strides = strides;
load0.ir_attr.SetOffset(af::Symbol("s999"));
af::ascir_op::Data data1(std::string(prefix + "_data1").c_str(), *graph);
data1.attr.sched.axis = axis_ids;
*data1.y.axis = axis_ids;
*data1.y.repeats = repeats;
*data1.y.strides = strides;
data1.ir_attr.SetIndex(1);
data1.y.dtype = ge::DT_INT8;
af::ascir_op::Load load1(std::string(prefix + "_load1").c_str());
load1.x = data1.y;
load1.attr.sched.axis = axis_ids;
*load1.y.axis = axis_ids;
*load1.y.repeats = repeats;
*load1.y.strides = strides;
af::ascir_op::Add add(std::string(prefix + "_add").c_str());
add.x1 = load0.y;
add.x2 = load1.y;
add.attr.sched.axis = axis_ids;
*add.y.axis = axis_ids;
*add.y.repeats = repeats;
*add.y.strides = strides;
af::ascir_op::Store store0(std::string(prefix + "_store0").c_str());
store0.x = add.y;
store0.attr.sched.axis = axis_ids;
*store0.y.axis = axis_ids;
*store0.y.repeats = repeats;
*store0.y.strides = strides;
af::ascir_op::Output y0(std::string(prefix + "_out0").c_str());
y0.x = store0.y;
y0.ir_attr.SetIndex(0);
y0.y.dtype = ge::DT_FLOAT16;
af::ascir_op::Store store1(std::string(prefix + "_store1").c_str());
store1.x = add.y;
store1.attr.sched.axis = axis_ids;
*store1.y.axis = axis_ids;
*store1.y.repeats = repeats;
*store1.y.strides = strides;
af::ascir_op::Output y1(std::string(prefix + "_out1").c_str());
y1.x = store1.y;
y1.ir_attr.SetIndex(1);
y1.y.dtype = ge::DT_FLOAT16;
}
void CreateAscBackendGraphTwoInOneOut(std::shared_ptr<af::AscGraph> &graph, const std::string &prefix,
int64_t axis_num = 2) {
auto ONE = af::Symbol(1);
std::vector<int64_t> axis_ids;
std::vector<af::Expression> repeats;
for (int64_t i = 0; i < axis_num; ++i) {
const af::Expression exp = graph->CreateSizeVar("s" + std::to_string(i));
auto axis = graph->CreateAxis("z" + std::to_string(i), exp);
axis_ids.push_back(i);
repeats.push_back(exp);
}
std::vector<af::Expression> strides(repeats.size(), One);
if (axis_num > 1) {
for (int64_t i = axis_num - 2; i >= 0; --i) {
strides[i] = repeats[i + 1] * strides[i + 1];
}
}
af::ascir_op::Data data0(std::string(prefix + "_data0").c_str(), *graph);
data0.attr.sched.axis = axis_ids;
*data0.y.axis = axis_ids;
*data0.y.repeats = repeats;
*data0.y.strides = strides;
data0.ir_attr.SetIndex(0);
data0.y.dtype = ge::DT_INT8;
af::ascir_op::Load load0(std::string(prefix + "_load0").c_str());
load0.x = data0.y;
load0.attr.sched.axis = axis_ids;
*load0.y.axis = axis_ids;
*load0.y.repeats = repeats;
*load0.y.strides = strides;
af::ascir_op::Data data1(std::string(prefix + "_data1").c_str(), *graph);
data1.attr.sched.axis = axis_ids;
*data1.y.axis = axis_ids;
*data1.y.repeats = repeats;
*data1.y.strides = strides;
data1.ir_attr.SetIndex(1);
data1.y.dtype = ge::DT_INT8;
af::ascir_op::Load load1(std::string(prefix + "_load1").c_str());
load1.x = data1.y;
load1.attr.sched.axis = axis_ids;
*load1.y.axis = axis_ids;
*load1.y.repeats = repeats;
*load1.y.strides = strides;
af::ascir_op::Add add(std::string(prefix + "_add").c_str());
add.x1 = load0.y;
add.x2 = load1.y;
add.attr.sched.axis = axis_ids;
*add.y.axis = axis_ids;
*add.y.repeats = repeats;
*add.y.strides = strides;
af::ascir_op::Store store0(std::string(prefix + "_store0").c_str());
store0.x = add.y;
store0.attr.sched.axis = axis_ids;
*store0.y.axis = axis_ids;
*store0.y.repeats = repeats;
*store0.y.strides = strides;
af::ascir_op::Output y0(std::string(prefix + "_out0").c_str());
y0.x = store0.y;
y0.ir_attr.SetIndex(0);
y0.y.dtype = ge::DT_FLOAT16;
}
static NodePtr CreateAscbcToAscGraph(const std::string &name, ComputeGraphPtr &compute_graph, int64_t in_num = 1,
int64_t out_num = 1) {
OpDescBuilder op_desc_builder(name, "AscBackend");
op_desc_builder.AddDynamicInput("x", in_num);
op_desc_builder.AddDynamicOutput("y", out_num);
const auto &op_desc = op_desc_builder.Build();
auto node = compute_graph->AddNode(op_desc);
node->SetOwnerComputeGraph(compute_graph);
return node;
}
TEST_F(TestOptimizer, optimize_with_fused_ascbacked) {
std::shared_ptr<af::AscGraph> g0 = std::make_shared<af::AscGraph>("g0");
CreateAscBackendGraphTwoInTwoOut(g0, "g0", 2);
std::shared_ptr<af::AscGraph> g1 = std::make_shared<af::AscGraph>("g1");
CreateAscBackendGraphTwoInOneOut(g1, "g1", 1);
std::shared_ptr<af::AscGraph> g2 = std::make_shared<af::AscGraph>("g2");
CreateAscBackendGraphTwoInOneOut(g2, "g2", 2);
af::AscGraph fused_asc_graph("fused_graph");
af::ascir_op::Data data0("data0", fused_asc_graph);
auto ir_attr0 = data0.attr.ir_attr->DownCastTo<af::AscDataIrAttrDef>();
ir_attr0->SetIndex(0);
af::ascir_op::Data data1("data1", fused_asc_graph);
auto ir_attr1 = data1.attr.ir_attr->DownCastTo<af::AscDataIrAttrDef>();
ir_attr1->SetIndex(1);
af::ascir_op::Data data2("data2", fused_asc_graph);
auto ir_attr2 = data2.attr.ir_attr->DownCastTo<af::AscDataIrAttrDef>();
ir_attr2->SetIndex(2);
auto fused_graph = af::AscGraphUtils::GetComputeGraph(fused_asc_graph);
auto data0_node = fused_asc_graph.FindNode("data0");
auto data1_node = fused_asc_graph.FindNode("data1");
auto data2_node = fused_asc_graph.FindNode("data2");
auto ascbc1 = CreateAscbcToAscGraph("ascbc1", fused_graph, 2, 2);
auto ascbc2 = CreateAscbcToAscGraph("ascbc2", fused_graph, 2, 1);
auto ascbc3 = CreateAscbcToAscGraph("ascbc3", fused_graph, 2, 1);
af::GraphUtils::AddEdge(data0_node->GetOutDataAnchor(0), ascbc1->GetInDataAnchor(0));
af::GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), ascbc1->GetInDataAnchor(1));
af::GraphUtils::AddEdge(data2_node->GetOutDataAnchor(0), ascbc2->GetInDataAnchor(0));
af::GraphUtils::AddEdge(ascbc1->GetOutDataAnchor(0), ascbc2->GetInDataAnchor(1));
af::GraphUtils::AddEdge(ascbc2->GetOutDataAnchor(0), ascbc3->GetInDataAnchor(0));
af::GraphUtils::AddEdge(ascbc1->GetOutDataAnchor(1), ascbc3->GetInDataAnchor(1));
af::ascir_op::Output output0("output0");
auto out0_ir_attr = output0.attr.ir_attr->DownCastTo<af::AscDataIrAttrDef>();
out0_ir_attr->SetIndex(0);
auto out0_desc = OpDescUtils::GetOpDescFromOperator(output0);
auto output0_node = fused_graph->AddNode(out0_desc);
af::ascir_op::Output output1("output1");
auto out1_ir_attr = output1.attr.ir_attr->DownCastTo<af::AscDataIrAttrDef>();
out1_ir_attr->SetIndex(1);
auto out1_desc = OpDescUtils::GetOpDescFromOperator(output1);
auto output1_node = fused_graph->AddNode(out1_desc);
af::GraphUtils::AddEdge(ascbc3->GetOutDataAnchor(0), output0_node->GetInDataAnchor(0));
af::GraphUtils::AddEdge(ascbc1->GetOutDataAnchor(1), output1_node->GetInDataAnchor(0));
auto fuse1_attrs = ascbc1->GetOpDesc()->GetOrCreateAttrsGroup<ge::AutoFuseAttrs>();
fuse1_attrs->SetAscGraph(g0);
auto fuse2_attrs = ascbc2->GetOpDesc()->GetOrCreateAttrsGroup<ge::AutoFuseAttrs>();
fuse2_attrs->SetAscGraph(g1);
auto fuse3_attrs = ascbc3->GetOpDesc()->GetOrCreateAttrsGroup<ge::AutoFuseAttrs>();
fuse3_attrs->SetAscGraph(g2);
fused_graph->TopologicalSorting();
optimize::Optimizer opt(optimize::OptimizerOptions{.graph_type = optimize::GraphType::kFusedAscBackend});
::ascir::FusedScheduledResult fused_scheduled_result;
ASSERT_EQ(opt.Optimize(fused_graph, fused_scheduled_result), 0);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 3UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
EXPECT_EQ(fused_scheduled_result.fused_graph_name.GetString(), fused_graph->GetName());
ASSERT_EQ(fused_scheduled_result.origin_vars.size(), 3UL);
EXPECT_EQ(std::string(fused_scheduled_result.origin_vars[0].Serialize().get()), "s0");
EXPECT_EQ(std::string(fused_scheduled_result.origin_vars[1].Serialize().get()), "s1");
EXPECT_EQ(std::string(fused_scheduled_result.origin_vars[2].Serialize().get()), "s999");
ASSERT_EQ(fused_scheduled_result.input_nodes.size(), 3UL);
ASSERT_EQ(fused_scheduled_result.output_nodes.size(), 2UL);
ASSERT_EQ(fused_scheduled_result.workspace_nodes.size(), 2UL);
}
static void CreatSingleConcatAscGraph(af::AscGraph &graph) {
auto s0 = af::Symbol("10");
auto s2 = af::Symbol("8");
auto s1 = af::Symbol("24");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data data0("concat_data0", graph);
data0.ir_attr.SetIndex(0);
data0.attr.sched.axis = {z0.id, z1.id};
*data0.y.axis = {z0.id, z1.id};
*data0.y.repeats = {s0, s2};
*data0.y.strides = {s2, af::sym::kSymbolOne};
af::ascir_op::Load load0("concat_load0");
load0.x = data0.y;
load0.attr.sched.axis = {z0.id, z1.id};
*load0.y.axis = {z0.id, z1.id};
*load0.y.repeats = {s0, s2};
*load0.y.strides = {s2, af::sym::kSymbolOne};
af::ascir_op::Data data1("concat_data1", graph);
data1.ir_attr.SetIndex(1);
data1.attr.sched.axis = {z0.id, z1.id};
*data1.y.axis = {z0.id, z1.id};
*data1.y.repeats = {s0, s2};
*data1.y.strides = {s2, af::sym::kSymbolOne};
af::ascir_op::Load load1("concat_load1");
load1.x = data1.y;
load1.attr.sched.axis = {z0.id, z1.id};
*load1.y.axis = {z0.id, z1.id};
*load1.y.repeats = {s0, s2};
*load1.y.strides = {s2, af::sym::kSymbolOne};
af::ascir_op::Data data2("concat_data2", graph);
data2.ir_attr.SetIndex(2);
data2.attr.sched.axis = {z0.id, z1.id};
*data2.y.axis = {z0.id, z1.id};
*data2.y.repeats = {s0, s2};
*data2.y.strides = {s2, af::sym::kSymbolOne};
af::ascir_op::Load load2("concat_load2");
load2.x = data2.y;
load2.attr.sched.axis = {z0.id, z1.id};
*load2.y.axis = {z0.id, z1.id};
*load2.y.repeats = {s0, s2};
*load2.y.strides = {s2, af::sym::kSymbolOne};
af::ascir_op::Concat concat("concat");
concat.x = {load0.y, load1.y, load2.y};
concat.attr.sched.axis = {z0.id, z1.id};
*concat.y.axis = {z0.id, z1.id};
*concat.y.repeats = {s0, s1};
*concat.y.strides = {s1, af::sym::kSymbolOne};
af::ascir_op::Store store("concat_store");
store.x = concat.y;
store.attr.sched.axis = {z0.id, z1.id};
*store.y.axis = {z0.id, z1.id};
*store.y.repeats = {s0, s1};
*store.y.strides = {s1, af::sym::kSymbolOne};
af::ascir_op::Output y("concat_out");
y.ir_attr.SetIndex(0);
y.x = store.y;
y.y.dtype = ge::DT_FLOAT16;
}
TEST_F(TestOptimizer, only_concat_graph_tail_dim1_scene) {
auto builder = GraphBuilder("test");
auto data0 = builder.AddNode("data0", "Data", 0, 1);
af::AttrUtils::SetInt(data0->GetOpDescBarePtr(), "_parent_node_index", 0);
auto data1 = builder.AddNode("data1", "Data", 0, 1);
af::AttrUtils::SetInt(data1->GetOpDescBarePtr(), "_parent_node_index", 1);
auto data2 = builder.AddNode("data2", "Data", 0, 1);
af::AttrUtils::SetInt(data2->GetOpDescBarePtr(), "_parent_node_index", 2);
auto ascg1 = builder.AddNode("ascbc1", "AscGraph", 3, 1);
auto netoutput1 = builder.AddNode("netoutput1", NETOUTPUT, 2, 0);
builder.AddDataEdge(data0, 0, ascg1, 0);
builder.AddDataEdge(data1, 0, ascg1, 1);
builder.AddDataEdge(data2, 0, ascg1, 2);
builder.AddDataEdge(ascg1, 0, netoutput1, 0);
ComputeGraphPtr compute_graph = builder.GetGraph();
ASSERT_NE(compute_graph, nullptr);
auto ascbc1 = compute_graph->FindNode("ascbc1");
af::AscGraph concat_sub_graph("concat");
CreatSingleConcatAscGraph(concat_sub_graph);
std::string add_graph_str1;
af::AscGraphUtils::SerializeToReadable(concat_sub_graph, add_graph_str1);
af::AttrUtils::SetStr(ascbc1->GetOpDescBarePtr(), "ascgraph", add_graph_str1);
::ascir::FusedScheduledResult fused_scheduled_result;
ASSERT_EQ(optimizer.Optimize(compute_graph, fused_scheduled_result), 0);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1UL);
EXPECT_EQ(fused_scheduled_result.input_nodes.size(), 3UL);
EXPECT_EQ(fused_scheduled_result.output_nodes.size(), 1UL);
auto codegen = codegen::Codegen(
codegen::CodegenOptions{.tiling_lib_path = "gen_tiling.so", .tiling_lib_codegen_symbol = "CodegenTiling"});
codegen::CodegenResult result;
codegen.Generate(fused_scheduled_result, result);
std::fstream tiling_func("concat_tiling_func.cpp", std::ios::out);
std::fstream tiling_data("concat_tiling_data.h", std::ios::out);
std::fstream kernel_func("concat_kernel_func.cpp", std::ios::out);
tiling_func << result.tiling;
tiling_data << result.tiling_data;
kernel_func << result.kernel;
}
* NetOutput
* |
* AscBc3
* / \
* AscBc1 AscBc2
* | |
* data0 data1
*/
static ComputeGraphPtr BuildFusedGraph(const std::string node_type = "") {
auto builder = GraphBuilder("test", node_type);
auto data0 = builder.AddNode("data0", "Data", 0, 1);
auto data1 = builder.AddNode("data1", "Data", 0, 1);
af::AttrUtils::SetInt(data0->GetOpDescBarePtr(), "_parent_node_index", 0);
af::AttrUtils::SetInt(data1->GetOpDescBarePtr(), "_parent_node_index", 1);
auto ascbc1 = builder.AddNode("ascbc1", "AscGraph", 1, 1);
auto ascbc2 = builder.AddNode("ascbc2", "AscGraph", 1, 1);
auto ascbc3 = builder.AddNode("ascbc3", "AscGraph", 2, 1);
auto netoutput1 = builder.AddNode("netoutput1", NETOUTPUT, 1, 0);
builder.AddDataEdge(data0, 0, ascbc1, 0);
builder.AddDataEdge(data1, 0, ascbc2, 0);
builder.AddDataEdge(ascbc1, 0, ascbc3, 0);
builder.AddDataEdge(ascbc2, 0, ascbc3, 1);
builder.AddDataEdge(ascbc3, 0, netoutput1, 0);
return builder.GetGraph();
}
static void CreateOneNodeAscGraph(af::AscGraph &graph, const std::string &prefix = "g0") {
const af::Expression s0 = graph.CreateSizeVar("s0");
const af::Expression s1 = graph.CreateSizeVar("s1");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data x1(std::string(prefix + "sub_data0").c_str(), graph);
x1.ir_attr.SetIndex(0);
x1.attr.sched.axis = {z0.id, z1.id};
*x1.y.axis = {z0.id, z1.id};
*x1.y.repeats = {s0, s1};
*x1.y.strides = {s1, af::sym::kSymbolOne};
af::ascir_op::Load load(std::string(prefix + "load0").c_str());
load.x = x1.y;
load.attr.sched.axis = {z0.id, z1.id};
*load.y.axis = {z0.id, z1.id};
*load.y.repeats = {s0, s1};
*load.y.strides = {s1, af::sym::kSymbolOne};
af::ascir_op::Abs abs(std::string(prefix + "abs0").c_str());
abs.x = load.y;
abs.attr.sched.axis = {z0.id, z1.id};
*abs.y.axis = {z0.id, z1.id};
*abs.y.repeats = {s0, s1};
*abs.y.strides = {s1, af::sym::kSymbolOne};
af::ascir_op::Store store(std::string(prefix + "store0").c_str());
store.x = abs.y;
store.attr.sched.axis = {z0.id, z1.id};
*store.y.axis = {z0.id, z1.id};
*store.y.repeats = {s0, s1};
*store.y.strides = {s1, af::sym::kSymbolOne};
af::ascir_op::Output y(std::string(prefix + "out0").c_str());
y.x = store.y;
y.y.dtype = ge::DT_FLOAT16;
y.ir_attr.SetIndex(0);
}
static void CreateTailPackAscGraph(af::AscGraph &graph) {
const af::Expression s0 = graph.CreateSizeVar("s0");
const af::Expression s1 = graph.CreateSizeVar("s1");
const af::Expression s2 = af::Symbol(2);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
af::ascir_op::Data data0("data0", graph);
data0.ir_attr.SetIndex(0);
data0.attr.sched.axis = {z0.id, z1.id, z2.id};
*data0.y.axis = {z0.id, z1.id, z2.id};
*data0.y.repeats = {s0, s1, af::sym::kSymbolOne};
*data0.y.strides = {s1, af::sym::kSymbolOne, af::sym::kSymbolZero};
af::ascir_op::Load load0("load0");
load0.x = data0.y;
load0.attr.sched.axis = {z0.id, z1.id, z2.id};
*load0.y.axis = {z0.id, z1.id, z2.id};
*load0.y.repeats = {s0, s1, af::sym::kSymbolOne};
*load0.y.strides = {s1, af::sym::kSymbolOne, af::sym::kSymbolZero};
af::ascir_op::Data data1("data1", graph);
data1.ir_attr.SetIndex(1);
data1.attr.sched.axis = {z0.id, z1.id, z2.id};
*data1.y.axis = {z0.id, z1.id, z2.id};
*data1.y.repeats = {s0, s1, af::sym::kSymbolOne};
*data1.y.strides = {s1, af::sym::kSymbolOne, af::sym::kSymbolZero};
af::ascir_op::Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, af::sym::kSymbolOne};
*load1.y.strides = {s1, af::sym::kSymbolOne, af::sym::kSymbolZero};
af::ascir_op::Concat concat("concat");
concat.x = {load0.y, load1.y};
concat.attr.sched.axis = {z0.id, z1.id, z2.id};
*concat.y.axis = {z0.id, z1.id, z2.id};
*concat.y.repeats = {s0, s1, s2};
*concat.y.strides = {s1 * s2, s2, af::sym::kSymbolOne};
af::ascir_op::Store store("store");
store.x = concat.y;
store.attr.sched.axis = {z0.id, z1.id, z2.id};
*store.y.axis = {z0.id, z1.id, z2.id};
*store.y.repeats = {s0, s1, s2};
*store.y.strides = {s1 * s2, s2, af::sym::kSymbolOne};
af::ascir_op::Output y("out0");
y.x = store.y;
y.y.dtype = ge::DT_FLOAT16;
y.ir_attr.SetIndex(0);
}
static void CreateMidPackAscGraph(af::AscGraph &graph) {
const af::Expression s0 = graph.CreateSizeVar("s0");
const af::Expression s2 = af::Symbol(2);
const af::Expression s1 = graph.CreateSizeVar("s1");
auto z0 = graph.CreateAxis("z0", s0);
auto z2 = graph.CreateAxis("z2", s2);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data data0("data0", graph);
data0.ir_attr.SetIndex(0);
data0.attr.sched.axis = {z0.id, z2.id, z1.id};
*data0.y.axis = {z0.id, z2.id, z1.id};
*data0.y.repeats = {s0, af::sym::kSymbolOne, s1};
*data0.y.strides = {s1, s1, af::sym::kSymbolOne};
af::ascir_op::Load load0("load0");
load0.x = data0.y;
load0.attr.sched.axis = {z0.id, z2.id, z1.id};
*load0.y.axis = {z0.id, z2.id, z1.id};
*load0.y.repeats = {s0, af::sym::kSymbolOne, s1};
*load0.y.strides = {s1, s1, af::sym::kSymbolOne};
af::ascir_op::Data data1("data1", graph);
data1.ir_attr.SetIndex(1);
data1.attr.sched.axis = {z0.id, z2.id, z1.id};
*data1.y.axis = {z0.id, z2.id, z1.id};
*data1.y.repeats = {s0, af::sym::kSymbolOne, s1};
*data1.y.strides = {s1, s1, af::sym::kSymbolOne};
af::ascir_op::Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z0.id, z2.id, z1.id};
*load1.y.axis = {z0.id, z2.id, z1.id};
*load1.y.repeats = {s0, af::sym::kSymbolOne, s1};
*load1.y.strides = {s1, s1, af::sym::kSymbolOne};
af::ascir_op::Concat concat("concat");
concat.x = {load0.y, load1.y};
concat.attr.sched.axis = {z0.id, z2.id, z1.id};
*concat.y.axis = {z0.id, z2.id, z1.id};
*concat.y.repeats = {s0, s2, s1};
*concat.y.strides = {s1 * s2, s1, af::sym::kSymbolOne};
af::ascir_op::Store store("store");
store.x = concat.y;
store.attr.sched.axis = {z0.id, z2.id, z1.id};
*store.y.axis = {z0.id, z2.id, z1.id};
*store.y.repeats = {s0, s2, s1};
*store.y.strides = {s1 * s2, s1, af::sym::kSymbolOne};
af::ascir_op::Output y("out0");
y.x = store.y;
y.y.dtype = ge::DT_FLOAT16;
y.ir_attr.SetIndex(0);
}
TEST_F(TestOptimizer, OptimizeWithFusedTailPack) {
ComputeGraphPtr compute_graph = BuildFusedGraph();
ASSERT_NE(compute_graph, nullptr);
auto ascbc1 = compute_graph->FindNode("ascbc1");
ASSERT_NE(ascbc1, nullptr);
auto ascbc2 = compute_graph->FindNode("ascbc2");
ASSERT_NE(ascbc2, nullptr);
auto ascbc3 = compute_graph->FindNode("ascbc3");
ASSERT_NE(ascbc3, nullptr);
af::AscGraph subgraph1("sub1");
af::AscGraph subgraph2("sub2");
af::AscGraph subgraph3("sub3");
CreateOneNodeAscGraph(subgraph1, "g1");
CreateOneNodeAscGraph(subgraph2, "g2");
CreateTailPackAscGraph(subgraph3);
std::string add_graph_str1;
af::AscGraphUtils::SerializeToReadable(subgraph1, add_graph_str1);
af::AttrUtils::SetStr(ascbc1->GetOpDescBarePtr(), "ascgraph", add_graph_str1);
std::string add_graph_str2;
af::AscGraphUtils::SerializeToReadable(subgraph2, add_graph_str2);
af::AttrUtils::SetStr(ascbc2->GetOpDescBarePtr(), "ascgraph", add_graph_str2);
std::string add_graph_str3;
af::AscGraphUtils::SerializeToReadable(subgraph3, add_graph_str3);
af::AttrUtils::SetStr(ascbc3->GetOpDescBarePtr(), "ascgraph", add_graph_str3);
::ascir::FusedScheduledResult fused_scheduled_result;
ASSERT_EQ(optimizer.Optimize(compute_graph, fused_scheduled_result), 0);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 1UL);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1UL);
auto impl_graph = fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0];
auto ascbc_1 = impl_graph.FindNode("ascbc1");
EXPECT_EQ(ascbc_1, nullptr);
auto ascbc_2 = impl_graph.FindNode("ascbc2");
EXPECT_EQ(ascbc_2, nullptr);
auto ascbc_3 = impl_graph.FindNode("ascbc3");
EXPECT_EQ(ascbc_3, nullptr);
}
TEST_F(TestOptimizer, OptimizeWithFusedMidPack) {
ComputeGraphPtr compute_graph = BuildFusedGraph();
ASSERT_NE(compute_graph, nullptr);
auto ascbc1 = compute_graph->FindNode("ascbc1");
ASSERT_NE(ascbc1, nullptr);
auto ascbc2 = compute_graph->FindNode("ascbc2");
ASSERT_NE(ascbc2, nullptr);
auto ascbc3 = compute_graph->FindNode("ascbc3");
ASSERT_NE(ascbc3, nullptr);
af::AscGraph subgraph1("sub1");
af::AscGraph subgraph2("sub2");
af::AscGraph subgraph3("sub3");
CreateOneNodeAscGraph(subgraph1, "g1");
CreateOneNodeAscGraph(subgraph2, "g2");
CreateMidPackAscGraph(subgraph3);
std::string add_graph_str1;
af::AscGraphUtils::SerializeToReadable(subgraph1, add_graph_str1);
af::AttrUtils::SetStr(ascbc1->GetOpDescBarePtr(), "ascgraph", add_graph_str1);
std::string add_graph_str2;
af::AscGraphUtils::SerializeToReadable(subgraph2, add_graph_str2);
af::AttrUtils::SetStr(ascbc2->GetOpDescBarePtr(), "ascgraph", add_graph_str2);
std::string add_graph_str3;
af::AscGraphUtils::SerializeToReadable(subgraph3, add_graph_str3);
af::AttrUtils::SetStr(ascbc3->GetOpDescBarePtr(), "ascgraph", add_graph_str3);
::ascir::FusedScheduledResult fused_scheduled_result;
ASSERT_EQ(optimizer.Optimize(compute_graph, fused_scheduled_result), 0);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3UL);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1UL);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups.size(), 2UL);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[0].impl_graphs.size(), 2UL);
auto impl_graph = fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0];
auto ascbc_1 = impl_graph.FindNode("ascbc1");
EXPECT_EQ(ascbc_1, nullptr);
auto ascbc_2 = impl_graph.FindNode("ascbc2");
EXPECT_EQ(ascbc_2, nullptr);
auto ascbc_3 = impl_graph.FindNode("ascbc3");
EXPECT_EQ(ascbc_3, nullptr);
}
void Construct_Mul_Consumer_Struct_UT(af::AscGraph &graph) {
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto z0 = graph.CreateAxis("z0", s0 * s1 * s2);
auto z1 = graph.CreateAxis("z1", s3);
auto axis = {z0.id, z1.id};
Data arg4_1("arg4_1", graph);
arg4_1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
arg4_1.attr.api.type = af::ApiType::kAPITypeBuffer;
arg4_1.y.dtype = ge::DT_FLOAT16;
arg4_1.ir_attr.SetIndex(0);
Load b0_load("b0_load");
b0_load.x = arg4_1.y;
b0_load.attr.sched.axis = axis;
b0_load.attr.api.compute_type = af::ComputeType::kComputeLoad;
b0_load.y.dtype = ge::DT_FLOAT16;
*b0_load.y.axis = axis;
*b0_load.y.repeats = {s0 * s1 * s2, s3};
*b0_load.y.strides = {s3, One};
Exp b1_exp("b1_exp");
b1_exp.x = b0_load.y;
b1_exp.attr.sched.axis = axis;
b1_exp.attr.api.compute_type = af::ComputeType::kComputeElewise;
b1_exp.attr.api.type = af::ApiType::kAPITypeCompute;
b1_exp.y.dtype = ge::DT_FLOAT16;
*b1_exp.y.axis = axis;
*b1_exp.y.repeats = {s0 * s1 * s2, s3};
*b1_exp.y.strides = {s3, One};
Abs b0_abs("b0_abs");
b0_abs.x = b1_exp.y;
b0_abs.attr.sched.axis = axis;
b0_abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
b0_abs.y.dtype = ge::DT_FLOAT16;
*b0_abs.y.axis = axis;
*b0_abs.y.repeats = {s0 * s1 * s2, s3};
*b0_abs.y.strides = {s3, One};
af::ascir_op::Max b0_max("b0_max");
b0_max.x = b0_abs.y;
b0_max.attr.sched.axis = axis;
b0_max.attr.api.compute_type = af::ComputeType::kComputeReduce;
b0_max.y.dtype = ge::DT_FLOAT16;
*b0_max.y.axis = axis;
*b0_max.y.repeats = {s0 * s1 * s2, s3};
*b0_max.y.strides = {One, Zero};
Broadcast b1_broadcast("b1_broadcast");
b1_broadcast.x = b0_max.y;
b1_broadcast.attr.sched.axis = axis;
b1_broadcast.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
b1_broadcast.y.dtype = ge::DT_FLOAT16;
*b1_broadcast.y.axis = axis;
*b1_broadcast.y.repeats = {s0 * s1 * s2, s3};
*b1_broadcast.y.strides = {s3, One};
Store b0_store("b0_store");
b0_store.x = b1_broadcast.y;
b0_store.attr.sched.axis = axis;
b0_store.attr.api.compute_type = af::ComputeType::kComputeStore;
b0_store.y.dtype = ge::DT_FLOAT16;
*b0_store.y.axis = axis;
*b0_store.y.repeats = {s0 * s1 * s2, s3};
*b0_store.y.strides = {s3, One};
Output buf0("buf0");
buf0.x = b0_store.y;
buf0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
buf0.attr.api.type = af::ApiType::kAPITypeBuffer;
buf0.y.dtype = ge::DT_FLOAT;
buf0.ir_attr.SetIndex(1);
af::ascir_op::Relu b0_relu("b0_relu");
b0_relu.x = b1_exp.y;
b0_relu.attr.sched.axis = axis;
b0_relu.attr.api.compute_type = af::ComputeType::kComputeElewise;
b0_relu.y.dtype = ge::DT_FLOAT16;
*b0_relu.y.axis = axis;
*b0_relu.y.repeats = {s0 * s1 * s2, s3};
*b0_relu.y.strides = {s3, One};
Store b1_store("b1_store");
b1_store.x = b0_relu.y;
b1_store.attr.sched.axis = axis;
b1_store.attr.api.compute_type = af::ComputeType::kComputeStore;
b1_store.y.dtype = ge::DT_FLOAT16;
*b1_store.y.axis = axis;
*b1_store.y.repeats = {s0 * s1 * s2, s3};
*b1_store.y.strides = {s3, One};
Output buf1("buf1");
buf1.x = b1_store.y;
buf1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
buf1.attr.api.type = af::ApiType::kAPITypeBuffer;
buf1.y.dtype = ge::DT_FLOAT;
buf1.ir_attr.SetIndex(2);
}
TEST_F(TestOptimizer, REDUCE_MUL_CONSUMER) {
af::AscGraph graph("REDUCE_MUL_CONSUMER");
Construct_Mul_Consumer_Struct_UT(graph);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, ge::SUCCESS);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 2UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups.size(), 3UL);
}
void ConstructReduceGraphWithMultiOutputs(af::AscGraph &graph) {
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto z0 = graph.CreateAxis("z0", s0 * s1 * s2);
auto z1 = graph.CreateAxis("z1", s3);
auto axis = {z0.id, z1.id};
Data arg4_1("arg4_1", graph);
arg4_1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
arg4_1.attr.api.type = af::ApiType::kAPITypeBuffer;
arg4_1.y.dtype = ge::DT_FLOAT16;
arg4_1.ir_attr.SetIndex(0);
Load b0_load("b0_load");
b0_load.x = arg4_1.y;
b0_load.attr.sched.axis = axis;
b0_load.attr.api.compute_type = af::ComputeType::kComputeLoad;
b0_load.y.dtype = ge::DT_FLOAT16;
*b0_load.y.axis = axis;
*b0_load.y.repeats = {s0 * s1 * s2, s3};
*b0_load.y.strides = {s3, One};
Exp b1_exp("b1_exp");
b1_exp.x = b0_load.y;
b1_exp.attr.sched.axis = axis;
b1_exp.attr.api.compute_type = af::ComputeType::kComputeElewise;
b1_exp.attr.api.type = af::ApiType::kAPITypeCompute;
b1_exp.y.dtype = ge::DT_FLOAT16;
*b1_exp.y.axis = axis;
*b1_exp.y.repeats = {s0 * s1 * s2, s3};
*b1_exp.y.strides = {s3, One};
Store b1_store("b1_store");
b1_store.x = b1_exp.y;
b1_store.attr.sched.axis = axis;
b1_store.attr.api.compute_type = af::ComputeType::kComputeStore;
b1_store.y.dtype = ge::DT_FLOAT16;
*b1_store.y.axis = axis;
*b1_store.y.repeats = {s0 * s1 * s2, s3};
*b1_store.y.strides = {s3, One};
Output buf1("buf1");
buf1.x = b1_store.y;
buf1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
buf1.attr.api.type = af::ApiType::kAPITypeBuffer;
buf1.y.dtype = ge::DT_FLOAT;
buf1.ir_attr.SetIndex(0);
af::ascir_op::Max b0_max("b0_max");
b0_max.x = b1_exp.y;
b0_max.attr.sched.axis = axis;
b0_max.attr.api.compute_type = af::ComputeType::kComputeReduce;
b0_max.y.dtype = ge::DT_FLOAT16;
*b0_max.y.axis = axis;
*b0_max.y.repeats = {s0 * s1 * s2, s3};
*b0_max.y.strides = {One, Zero};
Store b0_store("b0_store");
b0_store.x = b0_max.y;
b0_store.attr.sched.axis = axis;
b0_store.attr.api.compute_type = af::ComputeType::kComputeStore;
b0_store.y.dtype = ge::DT_FLOAT16;
*b0_store.y.axis = axis;
*b0_store.y.repeats = {s0 * s1 * s2, s3};
*b0_store.y.strides = {s3, One};
Output buf0("buf0");
buf0.x = b0_store.y;
buf0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
buf0.attr.api.type = af::ApiType::kAPITypeBuffer;
buf0.y.dtype = ge::DT_FLOAT;
buf0.ir_attr.SetIndex(1);
}
TEST_F(TestOptimizer, ReduceTaskGenerate) {
af::AscGraph graph("ReduceGraph");
ConstructReduceGraphWithMultiOutputs(graph);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
std::vector<optimize::ScheduleTask> schedule_tasks;
optimize::OptimizerOptions options{optimize::GraphType::kAscGraph};
Status res = optimize::ScheduleTaskGenerator::GenerateTasks(graph, schedule_tasks, options);
ASSERT_EQ(res, ge::SUCCESS);
ASSERT_EQ(schedule_tasks.size(), 3UL);
int64_t ir_idx = -1;
for (auto &task : schedule_tasks) {
if (task.reduce_type == optimize::ReduceTemplateType::kRCore) {
ASSERT_EQ(task.grouped_graphs.size(), 2UL);
for (const auto &node : task.grouped_graphs[1].GetAllNodes()) {
if (IsOps<Output>(node)) {
auto ir_attr = node->attr.ir_attr->DownCastTo<af::AscDataIrAttrDef>();
ir_attr->GetIndex(ir_idx);
}
}
}
}
EXPECT_EQ(ir_idx, 1);
optimize::ReducePartitionCaseGenerator generator;
std::vector<af::AscGraph> graphs;
std::vector<std::string> score_functions;
EXPECT_EQ(generator.Generate(graph, graphs, score_functions), 0);
}
TEST_F(TestOptimizer, MergeAxesElewiseOnly) {
af::AscGraph graph("LoadAbsStore");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
auto z3 = graph.CreateAxis("z3", s3);
af::ascir_op::Data data("data", graph);
data.y.dtype = ge::DT_FLOAT16;
data.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data.attr.api.type = af::ApiType::kAPITypeBuffer;
data.ir_attr.SetIndex(0);
af::ascir_op::Abs abs("abs");
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.x = data.y;
abs.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
abs.y.dtype = ge::DT_FLOAT16;
*abs.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs.y.repeats = {s0, s1, s2, s3};
*abs.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Store store("store");
store.attr.api.compute_type = af::ComputeType::kComputeElewise;
store.x = abs.y;
store.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
store.y.dtype = ge::DT_FLOAT16;
*store.y.axis = {z0.id, z1.id, z2.id, z3.id};
*store.y.repeats = {s0, s1, s2, s3};
*store.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Output y("y");
y.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.x = store.y;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.y.dtype = ge::DT_FLOAT16;
y.ir_attr.SetIndex(0);
ASSERT_EQ(optimizer.MergeContinuousAxis(graph), 0);
auto new_axis = graph.GetAllAxis();
ASSERT_EQ(new_axis.size(), 1UL);
EXPECT_EQ(new_axis[0]->size, s0 * s1 * s2 * s3);
}
TEST_F(TestOptimizer, TailAxisSliceDoNotMerge) {
af::AscGraph graph("LoadAbsStore");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
auto z3 = graph.CreateAxis("z3", s3);
auto stride = af::Symbol(2);
af::ascir_op::Data data0("data0", graph);
data0.y.dtype = ge::DT_FLOAT;
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data0.attr.api.type = af::ApiType::kAPITypeBuffer;
data0.ir_attr.SetIndex(0);
af::ascir_op::Load load("load");
load.x = data0.y;
load.attr.api.compute_type = af::ComputeType::kComputeLoad;
load.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
load.y.dtype = ge::DT_FLOAT16;
*load.y.axis = {z0.id, z1.id, z2.id, z3.id};
*load.y.repeats = {s0, s1, s2, s3};
*load.y.strides = {s1 * s2 * s3 * stride, s2 * s3 * stride, s3 * stride, stride};
af::ascir_op::Abs abs("abs");
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.x = load.y;
abs.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
abs.y.dtype = ge::DT_FLOAT16;
*abs.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs.y.repeats = {s0, s1, s2, s3};
*abs.y.strides = {s1 * s2 * s3 * stride, s2 * s3 * stride, s3 * stride, stride};
ASSERT_EQ(optimizer.MergeContinuousAxis(graph), ge::SUCCESS);
auto new_axis = graph.GetAllAxis();
ASSERT_EQ(new_axis.size(), 2UL);
EXPECT_EQ(new_axis[0]->size, s0 * s1 * s2);
EXPECT_EQ(new_axis[1]->size, s3);
}
TEST_F(TestOptimizer, MergeAxesGather) {
af::AscGraph graph("LoadAbsStore");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto s4 = graph.CreateSizeVar("s4");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
auto z3 = graph.CreateAxis("z3", s3);
auto z4 = graph.CreateAxis("z4", s4);
af::ascir_op::Data data0("data0", graph);
data0.y.dtype = ge::DT_FLOAT;
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data0.attr.api.type = af::ApiType::kAPITypeBuffer;
data0.ir_attr.SetIndex(0);
data0.attr.sched.axis = {z0.id, z1.id, z2.id};
*data0.y.axis = {z0.id, z1.id, z2.id};
*data0.y.repeats = {s0, s1, s2};
*data0.y.strides = {s1 * s2, s2, One};
af::ascir_op::Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT;
data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data1.attr.api.type = af::ApiType::kAPITypeBuffer;
data1.ir_attr.SetIndex(1);
data1.attr.sched.axis = {z3.id, z4.id};
*data1.y.axis = {z3.id, z4.id};
*data1.y.repeats = {s3, s4};
*data1.y.strides = {s4, One};
af::ascir_op::Gather gather("gather");
gather.attr.api.compute_type = af::ComputeType::kComputeGather;
gather.x1 = data0.y;
gather.x2 = data1.y;
gather.ir_attr.SetAxis(2);
gather.attr.sched.axis = {z0.id, z1.id, z3.id, z4.id};
gather.y.dtype = ge::DT_FLOAT;
*gather.y.axis = {z0.id, z1.id, z3.id, z4.id};
*gather.y.repeats = {s0, s1, s3, s4};
*gather.y.strides = {s1 * s3 * s4, s3 * s4, s4, One};
af::ascir_op::Abs abs("abs");
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.x = gather.y;
abs.attr.sched.axis = {z0.id, z1.id, z3.id, z4.id};
abs.y.dtype = ge::DT_FLOAT;
*abs.y.axis = {z0.id, z1.id, z3.id, z4.id};
*abs.y.repeats = {s0, s1, s3, s4};
*abs.y.strides = {s1 * s3 * s4, s3 * s4, s4, One};
af::ascir_op::Store store("store");
store.attr.api.compute_type = af::ComputeType::kComputeElewise;
store.x = abs.y;
store.attr.sched.axis = {z0.id, z1.id, z3.id, z4.id};
store.y.dtype = ge::DT_FLOAT;
*store.y.axis = {z0.id, z1.id, z3.id, z4.id};
*store.y.repeats = {s0, s1, s3, s4};
*store.y.strides = {s1 * s3 * s4, s3 * s4, s4, One};
af::ascir_op::Output y("y");
y.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.x = store.y;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.y.dtype = ge::DT_FLOAT;
y.ir_attr.SetIndex(0);
ASSERT_EQ(optimizer.MergeContinuousAxis(graph), 0);
auto new_axis = graph.GetAllAxis();
EXPECT_EQ(new_axis[5]->size, s0 * s1);
EXPECT_EQ(new_axis[6]->size, s3 * s4);
}
TEST_F(TestOptimizer, MergeAxesGatherOnlyOneDim) {
af::AscGraph graph("LoadAbsStore");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
af::ascir_op::Data data0("data0", graph);
data0.y.dtype = ge::DT_FLOAT;
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data0.attr.api.type = af::ApiType::kAPITypeBuffer;
data0.ir_attr.SetIndex(0);
data0.attr.sched.axis = {z0.id};
*data0.y.axis = {z0.id};
*data0.y.repeats = {s0};
*data0.y.strides = {One};
af::ascir_op::Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT;
data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data1.attr.api.type = af::ApiType::kAPITypeBuffer;
data1.ir_attr.SetIndex(1);
data1.attr.sched.axis = {z1.id, z2.id};
*data1.y.axis = {z1.id, z2.id};
*data1.y.repeats = {s1, s2};
*data1.y.strides = {s2, One};
af::ascir_op::Gather gather("gather");
gather.attr.api.compute_type = af::ComputeType::kComputeGather;
gather.x1 = data0.y;
gather.x2 = data1.y;
gather.ir_attr.SetAxis(2);
gather.attr.sched.axis = {z1.id, z2.id};
gather.y.dtype = ge::DT_FLOAT;
*gather.y.axis = {z1.id, z2.id};
*gather.y.repeats = {s1, s2};
*gather.y.strides = {s2, One};
af::ascir_op::Abs abs("abs");
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.x = gather.y;
abs.attr.sched.axis = {z1.id, z2.id};
abs.y.dtype = ge::DT_FLOAT;
*abs.y.axis = {z1.id, z2.id};
*abs.y.repeats = {s1, s2};
*abs.y.strides = {s2, One};
af::ascir_op::Store store("store");
store.attr.api.compute_type = af::ComputeType::kComputeElewise;
store.x = abs.y;
store.attr.sched.axis = {z1.id, z2.id};
store.y.dtype = ge::DT_FLOAT;
*store.y.axis = {z1.id, z2.id};
*store.y.repeats = {s1, s2};
*store.y.strides = {s2, One};
af::ascir_op::Output y("y");
y.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.x = store.y;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.y.dtype = ge::DT_FLOAT;
y.ir_attr.SetIndex(0);
ASSERT_EQ(optimizer.MergeContinuousAxis(graph), 0);
auto new_axis = graph.GetAllAxis();
EXPECT_EQ(new_axis[3]->size, s1 * s2);
}
TEST_F(TestOptimizer, MergeAxesReduce) {
af::AscGraph graph("LoadAbsStore");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto s4 = graph.CreateSizeVar("s4");
auto s5 = graph.CreateSizeVar("s5");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
auto z3 = graph.CreateAxis("z3", s3);
auto z4 = graph.CreateAxis("z4", s4);
auto z5 = graph.CreateAxis("z5", s5);
af::ascir_op::Data data("data", graph);
data.y.dtype = ge::DT_FLOAT16;
data.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data.attr.api.type = af::ApiType::kAPITypeBuffer;
data.ir_attr.SetIndex(0);
af::ascir_op::Abs abs("abs");
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.x = data.y;
abs.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id, z4.id, z5.id};
abs.y.dtype = ge::DT_FLOAT16;
*abs.y.axis = {z0.id, z1.id, z2.id, z3.id, z4.id, z5.id};
*abs.y.repeats = {s0, s1, s2, s3, s4, s5};
*abs.y.strides = {s1 * s2 * s3 * s4 * s5, s2 * s3 * s4 * s5, s3 * s4 * s5, s4 * s5, s5, One};
af::ascir_op::Sum sum("sum");
sum.attr.api.compute_type = af::ComputeType::kComputeReduce;
sum.x = abs.y;
sum.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id, z4.id, z5.id};
sum.y.dtype = ge::DT_FLOAT16;
*sum.y.axis = {z0.id, z1.id, z2.id, z3.id, z4.id, z5.id};
*sum.y.repeats = {s0, s1, One, One, s4, s5};
*sum.y.strides = {s1 * s4 * s5, s4 * s5, Zero, Zero, s5, One};
af::ascir_op::Output y("y");
y.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.x = sum.y;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.y.dtype = ge::DT_FLOAT16;
y.ir_attr.SetIndex(0);
ASSERT_EQ(optimizer.MergeContinuousAxis(graph), 0);
auto new_axis = graph.GetAllAxis();
ASSERT_EQ(new_axis.size(), 3UL);
EXPECT_EQ(new_axis[0]->size, s0 * s1);
EXPECT_EQ(new_axis[1]->size, s2 * s3);
EXPECT_EQ(new_axis[2]->size, s4 * s5);
}
TEST_F(TestOptimizer, MergeAxesBrc) {
af::AscGraph graph("LoadAbsStore");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
auto z3 = graph.CreateAxis("z3", s3);
af::ascir_op::Data data("data", graph);
data.y.dtype = ge::DT_FLOAT16;
data.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data.attr.api.type = af::ApiType::kAPITypeBuffer;
data.ir_attr.SetIndex(0);
af::ascir_op::Abs abs("abs");
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.x = data.y;
abs.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
abs.y.dtype = ge::DT_FLOAT16;
*abs.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs.y.repeats = {One, s1, s2, s3};
*abs.y.strides = {Zero, s2 * s3, s3, One};
af::ascir_op::Broadcast brc("brc");
brc.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc.x = abs.y;
brc.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc.y.dtype = ge::DT_FLOAT16;
*brc.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc.y.repeats = {s0, s1, s2, s3};
*brc.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Output y("y");
y.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.x = brc.y;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.y.dtype = ge::DT_FLOAT16;
y.ir_attr.SetIndex(0);
ASSERT_EQ(optimizer.MergeContinuousAxis(graph), 0);
auto new_axis = graph.GetAllAxis();
ASSERT_EQ(new_axis.size(), 2UL);
EXPECT_EQ(new_axis[0]->size, s0);
EXPECT_EQ(new_axis[1]->size, s1 * s2 * s3);
}
TEST_F(TestOptimizer, MergeAxesTransose) {
af::AscGraph graph("LoadAbsStore");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
auto z3 = graph.CreateAxis("z3", s3);
af::ascir_op::Data data("data", graph);
data.y.dtype = ge::DT_FLOAT16;
data.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data.attr.api.type = af::ApiType::kAPITypeBuffer;
data.ir_attr.SetIndex(0);
af::ascir_op::Abs abs("abs");
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.x = data.y;
abs.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
abs.y.dtype = ge::DT_FLOAT16;
*abs.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs.y.repeats = {s0, s1, s2, s3};
*abs.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Broadcast brc("brc");
brc.attr.api.compute_type = af::ComputeType::kComputeTranspose;
brc.x = abs.y;
brc.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc.y.dtype = ge::DT_FLOAT16;
*brc.y.axis = {z2.id, z3.id, z0.id, z1.id};
*brc.y.repeats = {s2, s3, s0, s1};
*brc.y.strides = {s3 * s0 * s1, s0 * s1, s1, One};
af::ascir_op::Output y("y");
y.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.x = brc.y;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.y.dtype = ge::DT_FLOAT16;
y.ir_attr.SetIndex(0);
ASSERT_EQ(optimizer.MergeContinuousAxis(graph), 0);
auto new_axis = graph.GetAllAxis();
ASSERT_EQ(new_axis.size(), 2UL);
EXPECT_EQ(new_axis[0]->size, s0 * s1);
EXPECT_EQ(new_axis[1]->size, s2 * s3);
}
TEST_F(TestOptimizer, MergeAxesConcat) {
af::AscGraph graph("LoadAbsStore");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto s4 = s2 * af::Symbol(2);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s4);
auto z3 = graph.CreateAxis("z3", s3);
af::ascir_op::Data data("data", graph);
data.y.dtype = ge::DT_FLOAT16;
data.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data.attr.api.type = af::ApiType::kAPITypeBuffer;
data.ir_attr.SetIndex(0);
af::ascir_op::Abs abs("abs");
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.x = data.y;
abs.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
abs.y.dtype = ge::DT_FLOAT16;
*abs.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs.y.repeats = {s0, s1, s2, s3};
*abs.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Abs abs1("abs1");
abs1.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs1.x = data.y;
abs1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
abs1.y.dtype = ge::DT_FLOAT16;
*abs1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs1.y.repeats = {s0, s1, s2, s3};
*abs1.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Concat concat("concat");
concat.attr.api.compute_type = af::ComputeType::kComputeConcat;
concat.x = {abs.y, abs1.y};
concat.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
concat.y.dtype = ge::DT_FLOAT16;
*concat.y.axis = {z0.id, z1.id, z2.id, z3.id};
*concat.y.repeats = {s0, s1, s4, s3};
*concat.y.strides = {s1 * s4 * s3, s4 * s3, s3, One};
af::ascir_op::Output y("y");
y.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.x = concat.y;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.y.dtype = ge::DT_FLOAT16;
y.ir_attr.SetIndex(0);
ASSERT_EQ(optimizer.MergeContinuousAxis(graph), 0);
auto new_axis = graph.GetAllAxis();
ASSERT_EQ(new_axis.size(), 2UL);
EXPECT_EQ(new_axis[0]->size, s0 * s1);
EXPECT_EQ(new_axis[1]->size, s4 * s3);
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Size_Not_Equal_Failed) {
af::AscGraph graph("ScalarBroadcastOptimization_Size_Not_Equal_Failed");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1};
*brc2.y.strides = {af::ops::Zero, One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(1);
data2.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data2.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, af::ops::One};
Add add("add");
add.attr.sched.axis = {z0.id, z1.id, z2.id};
add.x1 = brc3.y;
add.x2 = load1.y;
add.y.dtype = ge::DT_FLOAT;
*add.y.axis = {z0.id, z1.id, z2.id};
*add.y.repeats = {s0, s1, s2};
*add.y.strides = {s1 * s2, s2, af::ops::One};
Exp exp("exp");
exp.attr.sched.axis = {z0.id, z1.id, z2.id};
exp.x = brc2.y;
exp.y.dtype = ge::DT_FLOAT;
*exp.y.axis = {z0.id, z1.id, z2.id};
*exp.y.repeats = {af::ops::One, s1, s2};
*exp.y.strides = {af::ops::Zero, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = add.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 11);
EXPECT_NE(compute_graph->FindNode("brc1"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc2"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc3"), nullptr);
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Single) {
af::AscGraph graph("ScalarBroadcastOptimization_Single");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(1);
data2.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data2.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, af::ops::One};
Add add("add");
add.attr.sched.axis = {z0.id, z1.id, z2.id};
add.x1 = brc3.y;
add.x2 = load1.y;
add.y.dtype = ge::DT_FLOAT;
*add.y.axis = {z0.id, z1.id, z2.id};
*add.y.repeats = {s0, s1, s2};
*add.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = add.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 7);
EXPECT_EQ(compute_graph->FindNode("brc1"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc2"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc3"), nullptr);
EXPECT_NE(compute_graph->FindNode("add"), nullptr);
auto add_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("add"));
std::stringstream add0_repeats;
SizeExprListStr(add0_repeats, graph, add_node->inputs[0].attr.repeats);
std::stringstream expected0_repeats;
SizeExprListStr(expected0_repeats, graph, {s0, s1, s2});
EXPECT_EQ(add0_repeats.str(), expected0_repeats.str());
std::stringstream add1_repeats;
SizeExprListStr(add1_repeats, graph, add_node->inputs[1].attr.repeats);
std::stringstream expected1_repeats;
SizeExprListStr(expected1_repeats, graph, {af::ops::One, af::ops::One, af::ops::One});
EXPECT_EQ(add1_repeats.str(), expected1_repeats.str());
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Multi_Out_Success) {
af::AscGraph graph("ScalarBroadcastOptimization_Multi_Out_Success");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(1);
data2.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data2.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, af::ops::One};
Add add("add");
add.attr.sched.axis = {z0.id, z1.id, z2.id};
add.x1 = brc3.y;
add.x2 = load1.y;
add.y.dtype = ge::DT_FLOAT;
*add.y.axis = {z0.id, z1.id, z2.id};
*add.y.repeats = {s0, s1, s2};
*add.y.strides = {s1 * s2, s2, af::ops::One};
Mul mul("mul");
mul.attr.sched.axis = {z0.id, z1.id, z2.id};
mul.x1 = add.y;
mul.x2 = brc3.y;
mul.y.dtype = ge::DT_FLOAT;
*mul.y.axis = {z0.id, z1.id, z2.id};
*mul.y.repeats = {s0, s1, s2};
*mul.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = mul.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 8);
EXPECT_EQ(compute_graph->FindNode("brc1"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc2"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc3"), nullptr);
EXPECT_NE(compute_graph->FindNode("add"), nullptr);
auto add_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("add"));
std::stringstream add0_repeats;
SizeExprListStr(add0_repeats, graph, add_node->inputs[0].attr.repeats);
std::stringstream expected0_repeats;
SizeExprListStr(expected0_repeats, graph, {s0, s1, s2});
EXPECT_EQ(add0_repeats.str(), expected0_repeats.str());
std::stringstream add1_repeats;
SizeExprListStr(add1_repeats, graph, add_node->inputs[1].attr.repeats);
std::stringstream expected1_repeats;
SizeExprListStr(expected1_repeats, graph, {af::ops::One, af::ops::One, af::ops::One});
EXPECT_EQ(add1_repeats.str(), expected1_repeats.str());
EXPECT_NE(compute_graph->FindNode("mul"), nullptr);
auto mul_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("mul"));
std::stringstream mul0_repeats;
SizeExprListStr(mul0_repeats, graph, mul_node->inputs[0].attr.repeats);
EXPECT_EQ(mul0_repeats.str(), expected0_repeats.str());
std::stringstream mul1_repeats;
SizeExprListStr(mul1_repeats, graph, mul_node->inputs[1].attr.repeats);
EXPECT_EQ(mul1_repeats.str(), expected1_repeats.str());
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Multi_Out_Failed) {
af::AscGraph graph("ScalarBroadcastOptimization_Multi_Out_Failed");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(1);
data2.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data2.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, af::ops::One};
Add add("add");
add.attr.sched.axis = {z0.id, z1.id, z2.id};
add.x1 = brc3.y;
add.x2 = load1.y;
add.y.dtype = ge::DT_FLOAT;
*add.y.axis = {z0.id, z1.id, z2.id};
*add.y.repeats = {s0, s1, s2};
*add.y.strides = {s1 * s2, s2, af::ops::One};
Exp exp("exp");
exp.attr.sched.axis = {z0.id, z1.id, z2.id};
exp.x = brc2.y;
exp.y.dtype = ge::DT_FLOAT;
*exp.y.axis = {z0.id, z1.id, z2.id};
*exp.y.repeats = {af::ops::One, s1, s2};
*exp.y.strides = {af::ops::Zero, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = add.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 11);
EXPECT_NE(compute_graph->FindNode("brc1"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc2"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc3"), nullptr);
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Api_Not_Support_Scalar) {
af::AscGraph graph("ScalarBroadcastOptimization_Api_Not_Support_Scalar");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(1);
data2.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data2.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, af::ops::One};
Sub sub("sub");
sub.attr.sched.axis = {z0.id, z1.id, z2.id};
sub.x1 = brc3.y;
sub.x2 = load1.y;
sub.y.dtype = ge::DT_FLOAT;
*sub.y.axis = {z0.id, z1.id, z2.id};
*sub.y.repeats = {s0, s1, s2};
*sub.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = sub.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 7);
EXPECT_EQ(compute_graph->FindNode("brc1"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc2"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc3"), nullptr);
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Not_Load) {
af::AscGraph graph("ScalarBroadcastOptimization_Not_Load");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Exp exp("exp");
exp.attr.sched.axis = {z0.id, z1.id, z2.id};
exp.x = load.y;
exp.y.dtype = ge::DT_FLOAT;
*exp.y.axis = {z0.id, z1.id, z2.id};
*exp.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*exp.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = exp.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(1);
data2.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data2.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, af::ops::One};
Add add("add");
add.attr.sched.axis = {z0.id, z1.id, z2.id};
add.x1 = brc3.y;
add.x2 = load1.y;
add.y.dtype = ge::DT_FLOAT;
*add.y.axis = {z0.id, z1.id, z2.id};
*add.y.repeats = {s0, s1, s2};
*add.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = add.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 8);
EXPECT_EQ(compute_graph->FindNode("brc1"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc2"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc3"), nullptr);
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Two_Scalar) {
af::AscGraph graph("ScalarBroadcastOptimization_Two_Scalar");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(1);
data2.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data2.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc4("brc4");
brc4.attr.sched.axis = {z0.id, z1.id, z2.id};
brc4.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc4.x = load1.y;
*brc4.y.axis = {z0.id, z1.id, z2.id};
brc4.y.dtype = ge::DT_FLOAT;
*brc4.y.repeats = {af::ops::One, af::ops::One, s2};
*brc4.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc5("brc5");
brc5.attr.sched.axis = {z0.id, z1.id, z2.id};
brc5.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc5.x = brc4.y;
*brc5.y.axis = {z0.id, z1.id, z2.id};
brc5.y.dtype = ge::DT_FLOAT;
*brc5.y.repeats = {af::ops::One, s1, s2};
*brc5.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc6("brc6");
brc6.attr.sched.axis = {z0.id, z1.id, z2.id};
brc6.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc6.x = brc5.y;
*brc6.y.axis = {z0.id, z1.id, z2.id};
brc6.y.dtype = ge::DT_FLOAT;
*brc6.y.repeats = {s0, s1, s2};
*brc6.y.strides = {s1 * s2, s2, af::ops::One};
Add add("add");
add.attr.sched.axis = {z0.id, z1.id, z2.id};
add.x1 = brc3.y;
add.x2 = brc6.y;
add.y.dtype = ge::DT_FLOAT;
*add.y.axis = {z0.id, z1.id, z2.id};
*add.y.repeats = {s0, s1, s2};
*add.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = add.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 10);
EXPECT_EQ(compute_graph->FindNode("brc1"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc2"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc3"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc4"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc5"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc6"), nullptr);
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Same_Input) {
af::AscGraph graph("ScalarBroadcastOptimization_Same_Input");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Add add("add");
add.attr.sched.axis = {z0.id, z1.id, z2.id};
add.x1 = brc3.y;
add.x2 = brc3.y;
add.y.dtype = ge::DT_FLOAT;
*add.y.axis = {z0.id, z1.id, z2.id};
*add.y.repeats = {s0, s1, s2};
*add.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = add.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 8);
EXPECT_NE(compute_graph->FindNode("brc1"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc2"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc3"), nullptr);
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Compare_2nd_Scalar) {
af::AscGraph graph("ScalarBroadcastOptimization_Compare_2nd_Scalar");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(1);
data2.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data2.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, af::ops::One};
Ge ge("ge");
ge.attr.sched.axis = {z0.id, z1.id, z2.id};
ge.x1 = load1.y;
ge.x2 = brc3.y;
ge.y.dtype = ge::DT_FLOAT;
*ge.y.axis = {z0.id, z1.id, z2.id};
*ge.y.repeats = {s0, s1, s2};
*ge.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = ge.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 7);
EXPECT_EQ(compute_graph->FindNode("brc1"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc2"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc3"), nullptr);
EXPECT_NE(compute_graph->FindNode("ge"), nullptr);
auto ge_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("ge"));
std::stringstream ge_repeats;
SizeExprListStr(ge_repeats, graph, ge_node->inputs[0].attr.repeats);
std::stringstream expected0_repeats;
SizeExprListStr(expected0_repeats, graph, {s0, s1, s2});
EXPECT_EQ(ge_repeats.str(), expected0_repeats.str());
std::stringstream ge1_repeats;
SizeExprListStr(ge1_repeats, graph, ge_node->inputs[1].attr.repeats);
std::stringstream expected1_repeats;
SizeExprListStr(expected1_repeats, graph, {af::ops::One, af::ops::One, af::ops::One});
EXPECT_EQ(ge1_repeats.str(), expected1_repeats.str());
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Compare_1nd_Scalar) {
af::AscGraph graph("ScalarBroadcastOptimization_Compare_1nd_Scalar");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(1);
data2.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data2.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, af::ops::One};
Gt gt("gt");
gt.attr.sched.axis = {z0.id, z1.id, z2.id};
gt.x1 = brc3.y;
gt.x2 = load1.y;
gt.y.dtype = ge::DT_FLOAT;
*gt.y.axis = {z0.id, z1.id, z2.id};
*gt.y.repeats = {s0, s1, s2};
*gt.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = gt.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 10);
EXPECT_NE(compute_graph->FindNode("brc1"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc2"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc3"), nullptr);
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Add_Eq_Common_Scalar_Success) {
af::AscGraph graph("ScalarBroadcastOptimization_Add_Eq_Common_Scalar_Success");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(1);
data2.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data2.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, af::ops::One};
Add add("add");
add.attr.sched.axis = {z0.id, z1.id, z2.id};
add.x1 = brc3.y;
add.x2 = load1.y;
add.y.dtype = ge::DT_FLOAT;
*add.y.axis = {z0.id, z1.id, z2.id};
*add.y.repeats = {s0, s1, s2};
*add.y.strides = {s1 * s2, s2, af::ops::One};
Eq eq("eq");
eq.attr.sched.axis = {z0.id, z1.id, z2.id};
eq.x1 = add.y;
eq.x2 = brc3.y;
eq.y.dtype = ge::DT_FLOAT;
*eq.y.axis = {z0.id, z1.id, z2.id};
*eq.y.repeats = {s0, s1, s2};
*eq.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = eq.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 8);
EXPECT_EQ(compute_graph->FindNode("brc1"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc2"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc3"), nullptr);
EXPECT_NE(compute_graph->FindNode("add"), nullptr);
EXPECT_NE(compute_graph->FindNode("eq"), nullptr);
auto add_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("add"));
std::stringstream add0_repeats;
SizeExprListStr(add0_repeats, graph, add_node->inputs[0].attr.repeats);
std::stringstream expected0_repeats;
SizeExprListStr(expected0_repeats, graph, {s0, s1, s2});
EXPECT_EQ(add0_repeats.str(), expected0_repeats.str());
std::stringstream add1_repeats;
SizeExprListStr(add1_repeats, graph, add_node->inputs[1].attr.repeats);
std::stringstream expected1_repeats;
SizeExprListStr(expected1_repeats, graph, {af::ops::One, af::ops::One, af::ops::One});
EXPECT_EQ(add1_repeats.str(), expected1_repeats.str());
auto eq_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("eq"));
std::stringstream eq0_repeats;
SizeExprListStr(eq0_repeats, graph, eq_node->inputs[0].attr.repeats);
EXPECT_EQ(add0_repeats.str(), expected0_repeats.str());
std::stringstream eq1_repeats;
SizeExprListStr(eq1_repeats, graph, eq_node->inputs[1].attr.repeats);
EXPECT_EQ(add1_repeats.str(), expected1_repeats.str());
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Add_Lt_Common_Scalar_Failed) {
af::AscGraph graph("ScalarBroadcastOptimization_Add_Lt_Common_Scalar_Failed");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(1);
data2.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data2.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, af::ops::One};
Add add("add");
add.attr.sched.axis = {z0.id, z1.id, z2.id};
add.x1 = brc3.y;
add.x2 = load1.y;
add.y.dtype = ge::DT_FLOAT;
*add.y.axis = {z0.id, z1.id, z2.id};
*add.y.repeats = {s0, s1, s2};
*add.y.strides = {s1 * s2, s2, af::ops::One};
Lt lt("lt");
lt.attr.sched.axis = {z0.id, z1.id, z2.id};
lt.x1 = brc3.y;
lt.x2 = add.y;
lt.y.dtype = ge::DT_FLOAT;
*lt.y.axis = {z0.id, z1.id, z2.id};
*lt.y.repeats = {s0, s1, s2};
*lt.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = lt.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 11);
EXPECT_NE(compute_graph->FindNode("brc1"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc2"), nullptr);
EXPECT_NE(compute_graph->FindNode("brc3"), nullptr);
EXPECT_NE(compute_graph->FindNode("add"), nullptr);
EXPECT_NE(compute_graph->FindNode("lt"), nullptr);
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Sub_Eq_Success) {
af::AscGraph graph("ScalarBroadcastOptimization_Sub_Eq_Success");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data1("data1", graph);
data1.ir_attr.SetIndex(1);
data1.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data1.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, af::ops::One};
Sub sub("sub");
sub.attr.sched.axis = {z0.id, z1.id, z2.id};
sub.x1 = brc3.y;
sub.x2 = load1.y;
sub.y.dtype = ge::DT_FLOAT;
*sub.y.axis = {z0.id, z1.id, z2.id};
*sub.y.repeats = {s0, s1, s2};
*sub.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(2);
data2.y.dtype = ge::DT_FLOAT;
Load load2("load2");
load2.attr.sched.axis = {z0.id, z1.id, z2.id};
load2.x = data2.y;
*load2.y.axis = {z0.id, z1.id, z2.id};
load2.y.dtype = ge::DT_FLOAT;
*load2.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load2.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc21("brc21");
brc21.attr.sched.axis = {z0.id, z1.id, z2.id};
brc21.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc21.x = load2.y;
*brc21.y.axis = {z0.id, z1.id, z2.id};
brc21.y.dtype = ge::DT_FLOAT;
*brc21.y.repeats = {af::ops::One, af::ops::One, s2};
*brc21.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc22("brc22");
brc22.attr.sched.axis = {z0.id, z1.id, z2.id};
brc22.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc22.x = brc21.y;
*brc22.y.axis = {z0.id, z1.id, z2.id};
brc22.y.dtype = ge::DT_FLOAT;
*brc22.y.repeats = {af::ops::One, s1, s2};
*brc22.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc23("brc23");
brc23.attr.sched.axis = {z0.id, z1.id, z2.id};
brc23.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc23.x = brc22.y;
*brc23.y.axis = {z0.id, z1.id, z2.id};
brc23.y.dtype = ge::DT_FLOAT;
*brc23.y.repeats = {s0, s1, s2};
*brc23.y.strides = {s1 * s2, s2, af::ops::One};
Eq eq("eq");
eq.attr.sched.axis = {z0.id, z1.id, z2.id};
eq.x1 = sub.y;
eq.x2 = brc23.y;
eq.y.dtype = ge::DT_FLOAT;
*eq.y.axis = {z0.id, z1.id, z2.id};
*eq.y.repeats = {s0, s1, s2};
*eq.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = eq.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 10);
EXPECT_EQ(compute_graph->FindNode("brc1"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc2"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc3"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc21"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc22"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc23"), nullptr);
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Min_Ne_Success) {
af::AscGraph graph("ScalarBroadcastOptimization_Min_Ne_Success");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data("data", graph);
data.ir_attr.SetIndex(0);
data.y.dtype = ge::DT_FLOAT;
Load load("load");
load.attr.sched.axis = {z0.id, z1.id, z2.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id, z2.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data1("data1", graph);
data1.ir_attr.SetIndex(1);
data1.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data1.y;
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, af::ops::One};
Minimum mininum("min");
mininum.attr.sched.axis = {z0.id, z1.id, z2.id};
mininum.x1 = brc3.y;
mininum.x2 = load1.y;
mininum.y.dtype = ge::DT_FLOAT;
*mininum.y.axis = {z0.id, z1.id, z2.id};
*mininum.y.repeats = {s0, s1, s2};
*mininum.y.strides = {s1 * s2, s2, af::ops::One};
Ne ne("ne");
ne.attr.sched.axis = {z0.id, z1.id, z2.id};
ne.x1 = mininum.y;
ne.x2 = brc3.y;
ne.y.dtype = ge::DT_FLOAT;
*ne.y.axis = {z0.id, z1.id, z2.id};
*ne.y.repeats = {s0, s1, s2};
*ne.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = ne.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 8);
EXPECT_EQ(compute_graph->FindNode("brc1"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc2"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc3"), nullptr);
auto min_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("min"));
std::stringstream min0_repeats;
SizeExprListStr(min0_repeats, graph, min_node->inputs[0].attr.repeats);
std::stringstream expected0_repeats;
SizeExprListStr(expected0_repeats, graph, {s0, s1, s2});
EXPECT_EQ(min0_repeats.str(), expected0_repeats.str());
std::stringstream min1_repeats;
SizeExprListStr(min1_repeats, graph, min_node->inputs[1].attr.repeats);
std::stringstream expected1_repeats;
SizeExprListStr(expected1_repeats, graph, {af::ops::One, af::ops::One, af::ops::One});
EXPECT_EQ(min1_repeats.str(), expected1_repeats.str());
auto ne_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("ne"));
std::stringstream ne0_repeats;
SizeExprListStr(ne0_repeats, graph, ne_node->inputs[0].attr.repeats);
EXPECT_EQ(ne0_repeats.str(), expected0_repeats.str());
std::stringstream ne1_repeats;
SizeExprListStr(ne1_repeats, graph, ne_node->inputs[1].attr.repeats);
EXPECT_EQ(ne1_repeats.str(), expected1_repeats.str());
}
* where
* /1 /0 \2
* / / \
* / not_equal \
* | / \ \
* | / \ \
* | / \ \
* | / brc123 brc456
* |/ | |
* load0 load1 load2
* | |s |s
* data0 data1 data2
*/
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Where_3S_Success) {
af::AscGraph graph("ScalarBroadcastOptimization_Where_3S_Success");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data0("data0", graph);
data0.ir_attr.SetIndex(0);
data0.y.dtype = ge::DT_FLOAT;
Load load0("load0");
load0.attr.sched.axis = {z0.id, z1.id, z2.id};
load0.x = data0.y;
load0.y.dtype = ge::DT_FLOAT;
*load0.y.axis = {z0.id, z1.id, z2.id};
*load0.y.repeats = {s0, s1, s2};
*load0.y.strides = {s1 * s2, s2, af::ops::One};
Data data1("data1", graph);
data1.ir_attr.SetIndex(1);
data1.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data1.y;
*load1.y.axis = {z0.id, z1.id, z2.id};
load1.y.dtype = ge::DT_FLOAT;
*load1.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load1.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(2);
data2.y.dtype = ge::DT_FLOAT;
Load load2("load2");
load2.attr.sched.axis = {z0.id, z1.id, z2.id};
load2.x = data2.y;
*load2.y.axis = {z0.id, z1.id, z2.id};
load2.y.dtype = ge::DT_FLOAT;
*load2.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load2.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc4("brc4");
brc4.attr.sched.axis = {z0.id, z1.id, z2.id};
brc4.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc4.x = load2.y;
*brc4.y.axis = {z0.id, z1.id, z2.id};
brc4.y.dtype = ge::DT_FLOAT;
*brc4.y.repeats = {af::ops::One, af::ops::One, s2};
*brc4.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc5("brc5");
brc5.attr.sched.axis = {z0.id, z1.id, z2.id};
brc5.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc5.x = brc4.y;
*brc5.y.axis = {z0.id, z1.id, z2.id};
brc5.y.dtype = ge::DT_FLOAT;
*brc5.y.repeats = {af::ops::One, s1, s2};
*brc5.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc6("brc6");
brc6.attr.sched.axis = {z0.id, z1.id, z2.id};
brc6.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc6.x = brc5.y;
*brc6.y.axis = {z0.id, z1.id, z2.id};
brc6.y.dtype = ge::DT_FLOAT;
*brc6.y.repeats = {s0, s1, s2};
*brc6.y.strides = {s1 * s2, s2, af::ops::One};
Ne ne("ne");
ne.attr.sched.axis = {z0.id, z1.id, z2.id};
ne.x1 = brc3.y;
ne.x2 = load0.y;
ne.y.dtype = ge::DT_FLOAT;
*ne.y.axis = {z0.id, z1.id, z2.id};
*ne.y.repeats = {s0, s1, s2};
*ne.y.strides = {s1 * s2, s2, af::ops::One};
Where where("where");
where.attr.sched.axis = {z0.id, z1.id, z2.id};
where.x1 = ne.y;
where.x2 = load0.y;
where.x3 = brc6.y;
where.y.dtype = ge::DT_FLOAT;
*where.y.axis = {z0.id, z1.id, z2.id};
*where.y.repeats = {s0, s1, s2};
*where.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = where.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 10);
EXPECT_EQ(compute_graph->FindNode("brc1"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc2"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc3"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc4"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc5"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc6"), nullptr);
auto ne_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("ne"));
std::stringstream ne0_repeats;
SizeExprListStr(ne0_repeats, graph, ne_node->inputs[0].attr.repeats);
std::stringstream expected0_repeats;
SizeExprListStr(expected0_repeats, graph, {s0, s1, s2});
EXPECT_EQ(ne0_repeats.str(), expected0_repeats.str());
std::stringstream ne1_repeats;
SizeExprListStr(ne1_repeats, graph, ne_node->inputs[1].attr.repeats);
std::stringstream expected1_repeats;
SizeExprListStr(expected1_repeats, graph, {af::ops::One, af::ops::One, af::ops::One});
EXPECT_EQ(ne1_repeats.str(), expected1_repeats.str());
auto where_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("where"));
std::stringstream where1_repeats;
SizeExprListStr(where1_repeats, graph, where_node->inputs[1].attr.repeats);
EXPECT_EQ(where1_repeats.str(), expected0_repeats.str());
std::stringstream where2_repeats;
SizeExprListStr(where2_repeats, graph, where_node->inputs[2].attr.repeats);
EXPECT_EQ(where2_repeats.str(), expected1_repeats.str());
}
* select
* /2 /0 \1
* / / \
* / not_equal \
* | / \ \
* | / \ \
* | / \ \
* | / brc123 brc456
* |/ | |
* load0 load1 load2
* | |s |s
* data0 data1 data2
*/
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Select_2S_Success) {
af::AscGraph graph("ScalarBroadcastOptimization_Select_2S_Success");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data0("data0", graph);
data0.ir_attr.SetIndex(0);
data0.y.dtype = ge::DT_FLOAT;
Load load0("load0");
load0.attr.sched.axis = {z0.id, z1.id, z2.id};
load0.x = data0.y;
load0.y.dtype = ge::DT_FLOAT;
*load0.y.axis = {z0.id, z1.id, z2.id};
*load0.y.repeats = {s0, s1, s2};
*load0.y.strides = {s1 * s2, s2, af::ops::One};
Data data1("data1", graph);
data1.ir_attr.SetIndex(1);
data1.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data1.y;
*load1.y.axis = {z0.id, z1.id, z2.id};
load1.y.dtype = ge::DT_FLOAT;
*load1.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load1.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(2);
data2.y.dtype = ge::DT_FLOAT;
Load load2("load2");
load2.attr.sched.axis = {z0.id, z1.id, z2.id};
load2.x = data2.y;
*load2.y.axis = {z0.id, z1.id, z2.id};
load2.y.dtype = ge::DT_FLOAT;
*load2.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load2.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc4("brc4");
brc4.attr.sched.axis = {z0.id, z1.id, z2.id};
brc4.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc4.x = load2.y;
*brc4.y.axis = {z0.id, z1.id, z2.id};
brc4.y.dtype = ge::DT_FLOAT;
*brc4.y.repeats = {af::ops::One, af::ops::One, s2};
*brc4.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc5("brc5");
brc5.attr.sched.axis = {z0.id, z1.id, z2.id};
brc5.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc5.x = brc4.y;
*brc5.y.axis = {z0.id, z1.id, z2.id};
brc5.y.dtype = ge::DT_FLOAT;
*brc5.y.repeats = {af::ops::One, s1, s2};
*brc5.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc6("brc6");
brc6.attr.sched.axis = {z0.id, z1.id, z2.id};
brc6.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc6.x = brc5.y;
*brc6.y.axis = {z0.id, z1.id, z2.id};
brc6.y.dtype = ge::DT_FLOAT;
*brc6.y.repeats = {s0, s1, s2};
*brc6.y.strides = {s1 * s2, s2, af::ops::One};
Ne ne("ne");
ne.attr.sched.axis = {z0.id, z1.id, z2.id};
ne.x1 = brc3.y;
ne.x2 = load0.y;
ne.y.dtype = ge::DT_FLOAT;
*ne.y.axis = {z0.id, z1.id, z2.id};
*ne.y.repeats = {s0, s1, s2};
*ne.y.strides = {s1 * s2, s2, af::ops::One};
Select select("select");
select.attr.sched.axis = {z0.id, z1.id, z2.id};
select.x1 = ne.y;
select.x2 = brc6.y;
select.x3 = load0.y;
select.y.dtype = ge::DT_FLOAT;
*select.y.axis = {z0.id, z1.id, z2.id};
*select.y.repeats = {s0, s1, s2};
*select.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = select.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 10);
EXPECT_EQ(compute_graph->FindNode("brc1"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc2"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc3"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc4"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc5"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc6"), nullptr);
auto ne_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("ne"));
std::stringstream ne0_repeats;
SizeExprListStr(ne0_repeats, graph, ne_node->inputs[0].attr.repeats);
std::stringstream expected0_repeats;
SizeExprListStr(expected0_repeats, graph, {s0, s1, s2});
EXPECT_EQ(ne0_repeats.str(), expected0_repeats.str());
std::stringstream ne1_repeats;
SizeExprListStr(ne1_repeats, graph, ne_node->inputs[1].attr.repeats);
std::stringstream expected1_repeats;
SizeExprListStr(expected1_repeats, graph, {af::ops::One, af::ops::One, af::ops::One});
EXPECT_EQ(ne1_repeats.str(), expected1_repeats.str());
auto select_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("select"));
std::stringstream select1_repeats;
SizeExprListStr(select1_repeats, graph, select_node->inputs[1].attr.repeats);
EXPECT_EQ(select1_repeats.str(), expected1_repeats.str());
std::stringstream select2_repeats;
SizeExprListStr(select2_repeats, graph, select_node->inputs[2].attr.repeats);
EXPECT_EQ(select2_repeats.str(), expected0_repeats.str());
}
* select
* /0 \1 \2
* / \ \
* not_equal \ \
* / \ \ \
* / \ \ \
* / \ \ \
* / brc123 brc456 brc789
* / | | |
* load0 load1 load2 load3
* | |s |s |s
* data0 data1 data2 data3
*/
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Select_2S_3S_Success) {
af::AscGraph graph("ScalarBroadcastOptimization_Select_2S_3S_Success");
auto s0 = graph.CreateSizeVar(4);
auto s1 = graph.CreateSizeVar(5);
auto s2 = graph.CreateSizeVar(6);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data0("data0", graph);
data0.ir_attr.SetIndex(0);
data0.y.dtype = ge::DT_FLOAT;
Load load0("load0");
load0.attr.sched.axis = {z0.id, z1.id, z2.id};
load0.x = data0.y;
load0.y.dtype = ge::DT_FLOAT;
*load0.y.axis = {z0.id, z1.id, z2.id};
*load0.y.repeats = {s0, s1, s2};
*load0.y.strides = {s1 * s2, s2, af::ops::One};
Data data1("data1", graph);
data1.ir_attr.SetIndex(1);
data1.y.dtype = ge::DT_FLOAT;
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data1.y;
*load1.y.axis = {z0.id, z1.id, z2.id};
load1.y.dtype = ge::DT_FLOAT;
*load1.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc1("brc1");
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.x = load1.y;
*brc1.y.axis = {z0.id, z1.id, z2.id};
brc1.y.dtype = ge::DT_FLOAT;
*brc1.y.repeats = {af::ops::One, af::ops::One, s2};
*brc1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc2("brc2");
brc2.attr.sched.axis = {z0.id, z1.id, z2.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.x = brc1.y;
*brc2.y.axis = {z0.id, z1.id, z2.id};
brc2.y.dtype = ge::DT_FLOAT;
*brc2.y.repeats = {af::ops::One, s1, s2};
*brc2.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc3("brc3");
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc3.x = brc2.y;
*brc3.y.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = ge::DT_FLOAT;
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, af::ops::One};
Data data2("data2", graph);
data2.ir_attr.SetIndex(2);
data2.y.dtype = ge::DT_FLOAT;
Load load2("load2");
load2.attr.sched.axis = {z0.id, z1.id, z2.id};
load2.x = data2.y;
*load2.y.axis = {z0.id, z1.id, z2.id};
load2.y.dtype = ge::DT_FLOAT;
*load2.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load2.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc4("brc4");
brc4.attr.sched.axis = {z0.id, z1.id, z2.id};
brc4.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc4.x = load2.y;
*brc4.y.axis = {z0.id, z1.id, z2.id};
brc4.y.dtype = ge::DT_FLOAT;
*brc4.y.repeats = {af::ops::One, af::ops::One, s2};
*brc4.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc5("brc5");
brc5.attr.sched.axis = {z0.id, z1.id, z2.id};
brc5.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc5.x = brc4.y;
*brc5.y.axis = {z0.id, z1.id, z2.id};
brc5.y.dtype = ge::DT_FLOAT;
*brc5.y.repeats = {af::ops::One, s1, s2};
*brc5.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc6("brc6");
brc6.attr.sched.axis = {z0.id, z1.id, z2.id};
brc6.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc6.x = brc5.y;
*brc6.y.axis = {z0.id, z1.id, z2.id};
brc6.y.dtype = ge::DT_FLOAT;
*brc6.y.repeats = {s0, s1, s2};
*brc6.y.strides = {s1 * s2, s2, af::ops::One};
Ne ne("ne");
ne.attr.sched.axis = {z0.id, z1.id, z2.id};
ne.x1 = brc3.y;
ne.x2 = load0.y;
ne.y.dtype = ge::DT_FLOAT;
*ne.y.axis = {z0.id, z1.id, z2.id};
*ne.y.repeats = {s0, s1, s2};
*ne.y.strides = {s1 * s2, s2, af::ops::One};
Data data3("data3", graph);
data3.ir_attr.SetIndex(3);
data3.y.dtype = ge::DT_FLOAT;
Load load3("load3");
load3.attr.sched.axis = {z0.id, z1.id, z2.id};
load3.x = data3.y;
*load3.y.axis = {z0.id, z1.id, z2.id};
load3.y.dtype = ge::DT_FLOAT;
*load3.y.repeats = {af::ops::One, af::ops::One, af::ops::One};
*load3.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero};
Broadcast brc7("brc7");
brc7.attr.sched.axis = {z0.id, z1.id, z2.id};
brc7.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc7.x = load3.y;
*brc7.y.axis = {z0.id, z1.id, z2.id};
brc7.y.dtype = ge::DT_FLOAT;
*brc7.y.repeats = {af::ops::One, af::ops::One, s2};
*brc7.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::One};
Broadcast brc8("brc8");
brc8.attr.sched.axis = {z0.id, z1.id, z2.id};
brc8.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc8.x = brc7.y;
*brc8.y.axis = {z0.id, z1.id, z2.id};
brc8.y.dtype = ge::DT_FLOAT;
*brc8.y.repeats = {af::ops::One, s1, s2};
*brc8.y.strides = {af::ops::Zero, s2, af::ops::One};
Broadcast brc9("brc9");
brc9.attr.sched.axis = {z0.id, z1.id, z2.id};
brc9.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc9.x = brc8.y;
*brc9.y.axis = {z0.id, z1.id, z2.id};
brc9.y.dtype = ge::DT_FLOAT;
*brc9.y.repeats = {s0, s1, s2};
*brc9.y.strides = {s1 * s2, s2, af::ops::One};
Select select("select");
select.attr.sched.axis = {z0.id, z1.id, z2.id};
select.x1 = ne.y;
select.x2 = brc6.y;
select.x3 = brc9.y;
select.y.dtype = ge::DT_FLOAT;
*select.y.axis = {z0.id, z1.id, z2.id};
*select.y.repeats = {s0, s1, s2};
*select.y.strides = {s1 * s2, s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = select.y;
*store_op.y.axis = {z0.id, z1.id, z2.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
Output output_op("output");
output_op.ir_attr.SetIndex(0);
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
Status res = optimizer.GraphPass(graph);
EXPECT_EQ(res, ge::SUCCESS);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(compute_graph->GetAllNodesSize(), 12);
EXPECT_EQ(compute_graph->FindNode("brc1"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc2"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc3"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc4"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc5"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc6"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc7"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc8"), nullptr);
EXPECT_EQ(compute_graph->FindNode("brc9"), nullptr);
auto ne_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("ne"));
std::stringstream ne0_repeats;
SizeExprListStr(ne0_repeats, graph, ne_node->inputs[0].attr.repeats);
std::stringstream expected0_repeats;
SizeExprListStr(expected0_repeats, graph, {s0, s1, s2});
EXPECT_EQ(ne0_repeats.str(), expected0_repeats.str());
std::stringstream ne1_repeats;
SizeExprListStr(ne1_repeats, graph, ne_node->inputs[1].attr.repeats);
std::stringstream expected1_repeats;
SizeExprListStr(expected1_repeats, graph, {af::ops::One, af::ops::One, af::ops::One});
EXPECT_EQ(ne1_repeats.str(), expected1_repeats.str());
auto select_node = std::dynamic_pointer_cast<af::AscNode>(compute_graph->FindNode("select"));
std::stringstream select1_repeats;
SizeExprListStr(select1_repeats, graph, select_node->inputs[1].attr.repeats);
EXPECT_EQ(select1_repeats.str(), expected1_repeats.str());
std::stringstream select2_repeats;
SizeExprListStr(select2_repeats, graph, select_node->inputs[2].attr.repeats);
EXPECT_EQ(select2_repeats.str(), expected1_repeats.str());
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Scalar) {
af::AscGraph graph("ScalarBroadcastOptimization_Scalar");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
af::ascir_op::Scalar scalar0("data0", graph);
scalar0.ir_attr.SetValue("0");
scalar0.attr.sched.axis = {z0.id, z1.id, z2.id};
scalar0.y.dtype = ge::DT_FLOAT16;
*scalar0.y.axis = {z0.id, z1.id, z2.id};
af::ascir_op::Broadcast brc0("brc0");
brc0.x = scalar0.y;
brc0.attr.sched.axis = {z0.id, z1.id, z2.id};
brc0.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc0.y.dtype = ge::DT_FLOAT16;
*brc0.y.axis = {z0.id, z1.id, z2.id};
*brc0.y.repeats = {One, One, s2};
*brc0.y.strides = {Zero, Zero, One};
af::ascir_op::Broadcast brc1("brc1");
brc1.x = brc0.y;
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.y.dtype = ge::DT_FLOAT16;
*brc1.y.axis = {z0.id, z1.id, z2.id};
*brc1.y.repeats = {s0, One, s2};
*brc1.y.strides = {s2, Zero, One};
af::ascir_op::Data data1("data1", graph);
data1.ir_attr.SetIndex(0);
data1.attr.sched.axis = {z0.id, z1.id, z2.id};
data1.y.dtype = ge::DT_FLOAT16;
*data1.y.axis = {z0.id, z1.id, z2.id};
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.x = data1.y;
*load1.y.axis = {z0.id, z1.id, z2.id};
load1.y.dtype = ge::DT_FLOAT;
*load1.y.repeats = {s0, One, s2};
*load1.y.strides = {s2, Zero, One};
Add add("add");
add.attr.sched.axis = {z0.id, z1.id, z2.id};
add.x1 = brc1.y;
add.x2 = load1.y;
*add.y.axis = {z0.id, z1.id, z2.id};
add.y.dtype = ge::DT_FLOAT;
*add.y.repeats = {s0, One, s2};
*add.y.strides = {s2, Zero, One};
af::ascir_op::Store store("store");
store.x = add.y;
store.attr.sched.axis = {z0.id, z1.id, z2.id};
store.attr.api.compute_type = af::ComputeType::kComputeStore;
store.y.dtype = ge::DT_FLOAT16;
*store.y.axis = {z0.id, z1.id, z2.id};
*store.y.repeats = {s0, One, s2};
*store.y.strides = {s2, Zero, One};
af::ascir_op::Output y("y");
y.ir_attr.SetIndex(0);
y.x = store.y;
y.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.y.dtype = ge::DT_FLOAT16;
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, ge::SUCCESS);
ASSERT_TRUE(!fused_scheduled_result.node_idx_to_scheduled_results.empty());
auto &schedule_results = fused_scheduled_result.node_idx_to_scheduled_results[0];
EXPECT_EQ(schedule_results.size(), 1UL);
EXPECT_EQ(schedule_results[0].schedule_groups.size(), 1UL);
ASSERT_EQ(schedule_results[0].schedule_groups[0].impl_graphs.size(), 1UL);
auto const &impl_graphs = schedule_results[0].schedule_groups[0].impl_graphs;
EXPECT_EQ(impl_graphs[0].FindNode("brc0"), nullptr);
EXPECT_EQ(impl_graphs[0].FindNode("brc1"), nullptr);
}
TEST_F(TestOptimizer, ScalarBroadcastOptimization_Not_Output) {
af::AscGraph graph("ScalarBroadcastOptimization_Not_Output");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
af::ascir_op::Scalar scalar0("data0", graph);
scalar0.ir_attr.SetValue("0");
scalar0.attr.sched.axis = {z0.id, z1.id, z2.id};
scalar0.y.dtype = ge::DT_FLOAT16;
*scalar0.y.axis = {z0.id, z1.id, z2.id};
af::ascir_op::Broadcast brc0("brc0");
brc0.x = scalar0.y;
brc0.attr.sched.axis = {z0.id, z1.id, z2.id};
brc0.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc0.y.dtype = ge::DT_FLOAT16;
*brc0.y.axis = {z0.id, z1.id, z2.id};
*brc0.y.repeats = {One, One, s2};
*brc0.y.strides = {Zero, Zero, One};
af::ascir_op::Broadcast brc1("brc1");
brc1.x = brc0.y;
brc1.attr.sched.axis = {z0.id, z1.id, z2.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.y.dtype = ge::DT_FLOAT16;
*brc1.y.axis = {z0.id, z1.id, z2.id};
*brc1.y.repeats = {s0, One, s2};
*brc1.y.strides = {s2, Zero, One};
af::ascir_op::Store store("store");
store.x = brc1.y;
store.attr.sched.axis = {z0.id, z1.id, z2.id};
store.attr.api.compute_type = af::ComputeType::kComputeStore;
store.y.dtype = ge::DT_FLOAT16;
*store.y.axis = {z0.id, z1.id, z2.id};
*store.y.repeats = {s0, One, s2};
*store.y.strides = {s2, Zero, One};
af::ascir_op::Output y("y");
y.ir_attr.SetIndex(0);
y.x = store.y;
y.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.y.dtype = ge::DT_FLOAT16;
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, ge::SUCCESS);
}
TEST_F(TestOptimizer, NodeCacheMarkerConcat) {
af::AscGraph graph("NodeCacheMarkerConcat");
auto s0 = af::Symbol("s0");
auto s1 = af::Symbol("s1");
auto s2 = af::Symbol("s2");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2 * af::Symbol(2));
auto z2_0 = graph.CreateAxis("z2_0", s2);
af::ascir_op::Data data0("data0", graph);
data0.ir_attr.SetIndex(0);
data0.attr.sched.axis = {z0.id, z1.id, z2.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data0.attr.api.type = af::ApiType::kAPITypeBuffer;
data0.y.dtype = ge::DT_INT8;
af::ascir_op::Load load0("load0");
load0.x = data0.y;
load0.attr.sched.axis = {z0.id, z1.id, z2_0.id};
load0.attr.api.compute_type = af::ComputeType::kComputeLoad;
load0.y.dtype = ge::DT_INT8;
*load0.y.axis = {z0.id, z1.id, z2_0.id};
*load0.y.repeats = {s0, s1, s2};
*load0.y.strides = {s1 * s2, s2, One};
af::ascir_op::Scalar scalar("scalar", graph);
scalar.ir_attr.SetValue("0");
scalar.attr.sched.axis = {z0.id, z1.id, z2.id};
scalar.attr.api.compute_type = af::ComputeType::kComputeInvalid;
scalar.attr.api.type = af::ApiType::kAPITypeBuffer;
scalar.y.dtype = ge::DT_INT8;
*scalar.y.axis = {z0.id, z1.id, z2.id};
*scalar.y.repeats = {One, One, One};
*scalar.y.strides = {Zero, Zero, Zero};
af::ascir_op::Broadcast brc0("brc0");
brc0.x = scalar.y;
brc0.attr.sched.axis = {z0.id, z1.id, z2_0.id};
brc0.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc0.y.dtype = ge::DT_INT8;
*brc0.y.axis = {z0.id, z1.id, z2_0.id};
*brc0.y.repeats = {One, One, s2};
*brc0.y.strides = {Zero, Zero, One};
af::ascir_op::Broadcast brc1("brc1");
brc1.x = brc0.y;
brc1.attr.sched.axis = {z0.id, z1.id, z2_0.id};
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.y.dtype = ge::DT_INT8;
*brc1.y.axis = {z0.id, z1.id, z2_0.id};
*brc1.y.repeats = {One, s1, s2};
*brc1.y.strides = {Zero, s2, One};
af::ascir_op::Broadcast brc2("brc2");
brc2.x = brc1.y;
brc2.attr.sched.axis = {z0.id, z1.id, z2_0.id};
brc2.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc2.y.dtype = ge::DT_INT8;
*brc2.y.axis = {z0.id, z1.id, z2_0.id};
*brc2.y.repeats = {s0, s1, s2};
*brc2.y.strides = {s1 * s2, s2, One};
af::ascir_op::Concat concat("concat");
concat.x = {load0.y, brc2.y};
concat.attr.sched.axis = {z0.id, z1.id, z2.id};
concat.y.dtype = ge::DT_INT8;
*concat.y.axis = {z0.id, z1.id, z2.id};
*concat.y.repeats = {s0, s1, s2 * af::Symbol(2)};
*concat.y.strides = {s1 * s2 * af::Symbol(2), s2 * af::Symbol(2), One};
concat.attr.api.compute_type = af::ComputeType::kComputeConcat;
af::ascir_op::Store store0("store0");
store0.x = concat.y;
store0.attr.sched.axis = {z0.id, z1.id, z2.id};
store0.attr.api.compute_type = af::ComputeType::kComputeStore;
store0.y.dtype = ge::DT_INT8;
*store0.y.axis = {z0.id, z1.id, z2.id};
*store0.y.repeats = {s0, s1, s2 * af::Symbol(2)};
*store0.y.strides = {s1 * s2 * af::Symbol(2), s2 * af::Symbol(2), One};
af::ascir_op::Output y0("y0");
y0.ir_attr.SetIndex(1);
y0.x = store0.y;
y0.attr.sched.axis = {z0.id, z1.id, z2.id};
y0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y0.attr.api.type = af::ApiType::kAPITypeBuffer;
y0.y.dtype = ge::DT_INT8;
*y0.y.axis = {z0.id, z1.id, z2.id};
af::ascir_op::Store store1("store1");
store1.x = brc2.y;
store1.attr.sched.axis = {z0.id, z1.id, z2_0.id};
store1.attr.api.compute_type = af::ComputeType::kComputeStore;
store1.y.dtype = ge::DT_INT8;
*store1.y.axis = {z0.id, z1.id, z2_0.id};
*store1.y.repeats = {s0, s1, s2};
*store1.y.strides = {s1 * s2, s2, One};
af::ascir_op::Output y1("y1");
y1.ir_attr.SetIndex(0);
y1.x = store1.y;
y1.attr.sched.axis = {z0.id, z1.id, z2.id};
y1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y1.attr.api.type = af::ApiType::kAPITypeBuffer;
y1.y.dtype = ge::DT_INT8;
*y1.y.axis = {z0.id, z1.id, z2.id};
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, ge::SUCCESS);
ASSERT_TRUE(!fused_scheduled_result.node_idx_to_scheduled_results.empty());
auto &schedule_results = fused_scheduled_result.node_idx_to_scheduled_results[0];
EXPECT_EQ(schedule_results.size(), 2UL);
EXPECT_EQ(schedule_results[1].schedule_groups.size(), 2UL);
ASSERT_EQ(schedule_results[1].schedule_groups[1].impl_graphs.size(), 3UL);
auto const &impl_graphs = schedule_results[1].schedule_groups[1].impl_graphs;
const auto &impl6_brc0 = impl_graphs[2].FindNode("brc0");
EXPECT_NE(impl6_brc0, nullptr);
EXPECT_EQ(impl6_brc0->attr.sched.exec_condition, af::ExecuteCondition::kCacheBlockSplitFusedBroadcastAxis);
}
TEST_F(TestOptimizer, NodeCacheMarkerBroadcast) {
af::AscGraph graph("NodeCacheMarkerBroadcast");
auto s0 = af::Symbol("s0");
auto s1 = af::Symbol("s1");
auto s2 = af::Symbol("s2");
auto s3 = af::Symbol("s3");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
auto z3 = graph.CreateAxis("z3", s3);
Data data0("data0", graph);
data0.ir_attr.SetIndex(0);
data0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data0.attr.api.type = af::ApiType::kAPITypeBuffer;
data0.y.dtype = ge::DT_INT8;
Load load0("load0");
load0.x = data0.y;
load0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
load0.attr.api.compute_type = af::ComputeType::kComputeLoad;
load0.y.dtype = ge::DT_INT8;
*load0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*load0.y.repeats = {One, One, One, s3};
*load0.y.strides = {Zero, Zero, Zero, One};
Cast cast0("cast0");
cast0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
cast0.attr.api.compute_type = af::ComputeType::kComputeElewise;
cast0.x = load0.y;
cast0.y.dtype = ge::DT_FLOAT;
*cast0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*cast0.y.repeats = {One, One, One, s3};
*cast0.y.strides = {Zero, Zero, Zero, One};
Relu relu0("relu0");
relu0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
relu0.attr.api.compute_type = af::ComputeType::kComputeElewise;
relu0.x = cast0.y;
relu0.y.dtype = ge::DT_FLOAT;
*relu0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*relu0.y.repeats = {One, One, One, s3};
*relu0.y.strides = {Zero, Zero, Zero, One};
Abs abs0("abs0");
abs0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
abs0.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs0.x = relu0.y;
abs0.y.dtype = ge::DT_FLOAT;
*abs0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs0.y.repeats = {One, One, One, s3};
*abs0.y.strides = {Zero, Zero, Zero, One};
Sqrt sqrt0("sqrt0");
sqrt0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
sqrt0.attr.api.compute_type = af::ComputeType::kComputeElewise;
sqrt0.x = abs0.y;
sqrt0.y.dtype = ge::DT_FLOAT;
*sqrt0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*sqrt0.y.repeats = {One, One, One, s3};
*sqrt0.y.strides = {Zero, Zero, Zero, One};
Exp exp0("exp0");
exp0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
exp0.attr.api.compute_type = af::ComputeType::kComputeElewise;
exp0.x = sqrt0.y;
exp0.y.dtype = ge::DT_FLOAT;
*exp0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*exp0.y.repeats = {One, One, One, s3};
*exp0.y.strides = {Zero, Zero, Zero, One};
af::ascir_op::Broadcast brc00("brc00");
brc00.x = exp0.y;
brc00.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc00.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc00.y.dtype = ge::DT_INT8;
*brc00.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc00.y.repeats = {One, One, s2, s3};
*brc00.y.strides = {Zero, Zero, s3, One};
af::ascir_op::Broadcast brc01("brc01");
brc01.x = brc00.y;
brc01.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc01.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc01.y.dtype = ge::DT_INT8;
*brc01.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc01.y.repeats = {One, s1, s2, s3};
*brc01.y.strides = {Zero, s2 * s3, s3, One};
af::ascir_op::Broadcast brc02("brc02");
brc02.x = brc01.y;
brc02.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc02.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc02.y.dtype = ge::DT_INT8;
*brc02.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc02.y.repeats = {s0, s1, s2, s3};
*brc02.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Data data1("data1", graph);
data1.ir_attr.SetIndex(1);
data1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data1.attr.api.type = af::ApiType::kAPITypeBuffer;
data1.y.dtype = ge::DT_INT8;
af::ascir_op::Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
load1.attr.api.compute_type = af::ComputeType::kComputeLoad;
load1.y.dtype = ge::DT_INT8;
*load1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*load1.y.repeats = {One, One, s2, One};
*load1.y.strides = {Zero, Zero, One, Zero};
Cast cast1("cast1");
cast1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
cast1.attr.api.compute_type = af::ComputeType::kComputeElewise;
cast1.x = load1.y;
cast1.y.dtype = ge::DT_FLOAT;
*cast1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*cast1.y.repeats = {One, One, s2, One};
*cast1.y.strides = {Zero, Zero, One, Zero};
Sigmoid sigmoid1("sigmoid1");
sigmoid1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
sigmoid1.attr.api.compute_type = af::ComputeType::kComputeElewise;
sigmoid1.x = cast1.y;
sigmoid1.y.dtype = ge::DT_FLOAT;
*sigmoid1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*sigmoid1.y.repeats = {One, One, s2, One};
*sigmoid1.y.strides = {Zero, Zero, One, Zero};
Sign sign1("sign1");
sign1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
sign1.attr.api.compute_type = af::ComputeType::kComputeElewise;
sign1.x = sigmoid1.y;
sign1.y.dtype = ge::DT_FLOAT;
*sign1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*sign1.y.repeats = {One, One, s2, One};
*sign1.y.strides = {Zero, Zero, One, Zero};
Mul mul1("mul1");
mul1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
mul1.attr.api.compute_type = af::ComputeType::kComputeElewise;
mul1.x1 = sign1.y;
mul1.x2 = sign1.y;
mul1.y.dtype = ge::DT_FLOAT;
*mul1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*mul1.y.repeats = {One, One, s2, One};
*mul1.y.strides = {Zero, Zero, One, Zero};
Exp exp1("exp1");
exp1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
exp1.attr.api.compute_type = af::ComputeType::kComputeElewise;
exp1.x = mul1.y;
exp1.y.dtype = ge::DT_FLOAT;
*exp1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*exp1.y.repeats = {One, One, s2, One};
*exp1.y.strides = {Zero, Zero, One, Zero};
af::ascir_op::Broadcast brc10("brc10");
brc10.x = exp1.y;
brc10.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc10.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc10.y.dtype = ge::DT_INT8;
*brc10.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc10.y.repeats = {One, One, s2, s3};
*brc10.y.strides = {Zero, Zero, s3, One};
af::ascir_op::Broadcast brc11("brc11");
brc11.x = brc10.y;
brc11.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc11.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc11.y.dtype = ge::DT_INT8;
*brc11.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc11.y.repeats = {One, s1, s2, s3};
*brc11.y.strides = {Zero, s2 * s3, s3, One};
af::ascir_op::Broadcast brc12("brc12");
brc12.x = brc11.y;
brc12.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc12.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc12.y.dtype = ge::DT_INT8;
*brc12.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc12.y.repeats = {s0, s1, s2, s3};
*brc12.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Add add0("add0");
add0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
add0.attr.api.compute_type = af::ComputeType::kComputeElewise;
add0.x1 = brc12.y;
add0.x2 = brc02.y;
add0.y.dtype = ge::DT_FLOAT;
*add0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*add0.y.repeats = {s0, s1, s2, s3};
*add0.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Data data2("data2", graph);
data2.ir_attr.SetIndex(2);
data2.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
data2.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data2.attr.api.type = af::ApiType::kAPITypeBuffer;
data2.y.dtype = ge::DT_INT8;
af::ascir_op::Load load2("load2");
load2.x = data2.y;
load2.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
load2.attr.api.compute_type = af::ComputeType::kComputeLoad;
load2.y.dtype = ge::DT_INT8;
*load2.y.axis = {z0.id, z1.id, z2.id, z3.id};
*load2.y.repeats = {One, s1, s2, One};
*load2.y.strides = {Zero, s2, One, Zero};
Cast cast2("cast2");
cast2.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
cast2.attr.api.compute_type = af::ComputeType::kComputeElewise;
cast2.x = load2.y;
cast2.y.dtype = ge::DT_FLOAT;
*cast2.y.axis = {z0.id, z1.id, z2.id, z3.id};
*cast2.y.repeats = {One, s1, s2, One};
*cast2.y.strides = {Zero, s2, One, Zero};
Abs abs2("abs2");
abs2.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
abs2.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs2.x = cast2.y;
abs2.y.dtype = ge::DT_FLOAT;
*abs2.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs2.y.repeats = {One, s1, s2, One};
*abs2.y.strides = {Zero, s2, One, Zero};
Sign sign2("sign2");
sign2.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
sign2.attr.api.compute_type = af::ComputeType::kComputeElewise;
sign2.x = abs2.y;
sign2.y.dtype = ge::DT_FLOAT;
*sign2.y.axis = {z0.id, z1.id, z2.id, z3.id};
*sign2.y.repeats = {One, s1, s2, One};
*sign2.y.strides = {Zero, s2, One, Zero};
Exp exp2("exp2");
exp2.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
exp2.attr.api.compute_type = af::ComputeType::kComputeElewise;
exp2.x = sign2.y;
exp2.y.dtype = ge::DT_FLOAT;
*exp2.y.axis = {z0.id, z1.id, z2.id, z3.id};
*exp2.y.repeats = {One, s1, s2, One};
*exp2.y.strides = {Zero, s2, One, Zero};
af::ascir_op::Broadcast brc20("brc20");
brc20.x = exp2.y;
brc20.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc20.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc20.y.dtype = ge::DT_INT8;
*brc20.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc20.y.repeats = {One, s1, s2, s3};
*brc20.y.strides = {Zero, s2 * s3, s3, One};
af::ascir_op::Broadcast brc21("brc21");
brc21.x = brc20.y;
brc21.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc21.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc21.y.dtype = ge::DT_INT8;
*brc21.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc21.y.repeats = {s0, s1, s2, s3};
*brc21.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Add add1("add1");
add1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
add1.attr.api.compute_type = af::ComputeType::kComputeElewise;
add1.x1 = brc21.y;
add1.x2 = add0.y;
add1.y.dtype = ge::DT_FLOAT;
*add1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*add1.y.repeats = {s0, s1, s2, s3};
*add1.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Data data3("data3", graph);
data3.ir_attr.SetIndex(3);
data3.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
data3.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data3.attr.api.type = af::ApiType::kAPITypeBuffer;
data3.y.dtype = ge::DT_INT8;
af::ascir_op::Load load3("load3");
load3.x = data3.y;
load3.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
load3.attr.api.compute_type = af::ComputeType::kComputeLoad;
load3.y.dtype = ge::DT_INT8;
*load3.y.axis = {z0.id, z1.id, z2.id, z3.id};
*load3.y.repeats = {s0, One, One, s3};
*load3.y.strides = {s3, Zero, Zero, One};
Cast cast3("cast3");
cast3.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
cast3.attr.api.compute_type = af::ComputeType::kComputeElewise;
cast3.x = load3.y;
cast3.y.dtype = ge::DT_FLOAT;
*cast3.y.axis = {z0.id, z1.id, z2.id, z3.id};
*cast3.y.repeats = {s0, One, One, s3};
*cast3.y.strides = {s3, Zero, Zero, One};
Abs abs3("abs3");
abs3.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
abs3.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs3.x = cast3.y;
abs3.y.dtype = ge::DT_FLOAT;
*abs3.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs3.y.repeats = {s0, One, One, s3};
*abs3.y.strides = {s3, Zero, Zero, One};
Sign sign3("sign3");
sign3.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
sign3.attr.api.compute_type = af::ComputeType::kComputeElewise;
sign3.x = abs3.y;
sign3.y.dtype = ge::DT_FLOAT;
*sign3.y.axis = {z0.id, z1.id, z2.id, z3.id};
*sign3.y.repeats = {s0, One, One, s3};
*sign3.y.strides = {s3, Zero, Zero, One};
Exp exp3("exp3");
exp3.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
exp3.attr.api.compute_type = af::ComputeType::kComputeElewise;
exp3.x = sign3.y;
exp3.y.dtype = ge::DT_FLOAT;
*exp3.y.axis = {z0.id, z1.id, z2.id, z3.id};
*exp3.y.repeats = {s0, One, One, s3};
*exp3.y.strides = {s3, Zero, Zero, One};
af::ascir_op::Broadcast brc30("brc30");
brc30.x = exp3.y;
brc30.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc30.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc30.y.dtype = ge::DT_INT8;
*brc30.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc30.y.repeats = {s0, One, s2, s3};
*brc30.y.strides = {s2 * s3, Zero, s3, One};
af::ascir_op::Broadcast brc31("brc31");
brc31.x = brc30.y;
brc31.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc31.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc31.y.dtype = ge::DT_INT8;
*brc31.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc31.y.repeats = {s0, s1, s2, s3};
*brc31.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Add add2("add2");
add2.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
add2.attr.api.compute_type = af::ComputeType::kComputeElewise;
add2.x1 = brc31.y;
add2.x2 = add1.y;
add2.y.dtype = ge::DT_FLOAT;
*add2.y.axis = {z0.id, z1.id, z2.id, z3.id};
*add2.y.repeats = {s0, s1, s2, s3};
*add2.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Cast cast4("cast4");
cast4.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
cast4.attr.api.compute_type = af::ComputeType::kComputeElewise;
cast4.x = add2.y;
cast4.y.dtype = ge::DT_FLOAT16;
*cast4.y.axis = {z0.id, z1.id, z2.id, z3.id};
*cast4.y.repeats = {s0, s1, s2, s3};
*cast4.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Store store0("store0");
store0.x = cast4.y;
store0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
store0.attr.api.compute_type = af::ComputeType::kComputeStore;
store0.y.dtype = ge::DT_FLOAT16;
*store0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*store0.y.repeats = {s0, s1, s2, s3};
*store0.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Output y0("y0");
y0.ir_attr.SetIndex(0);
y0.x = store0.y;
y0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
y0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y0.attr.api.type = af::ApiType::kAPITypeBuffer;
y0.y.dtype = ge::DT_FLOAT16;
*y0.y.axis = {z0.id, z1.id, z2.id, z3.id};
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, ge::SUCCESS);
ASSERT_TRUE(!fused_scheduled_result.node_idx_to_scheduled_results.empty());
auto &schedule_results = fused_scheduled_result.node_idx_to_scheduled_results[0];
for (auto const &result : schedule_results) {
for (auto const &group : result.schedule_groups) {
for (auto const &impl_graph : group.impl_graphs) {
for (auto const &node : impl_graph.GetAllNodes()) {
if (node != nullptr && IsOps<Broadcast>(node)) {
bool condition =
node->attr.sched.exec_condition == af::ExecuteCondition::kNoCache ||
node->attr.sched.exec_condition == af::ExecuteCondition::kCacheBlockSplitFusedBroadcastAxis;
if (!condition) {
printf("Graph: %s, Broadcast %s", impl_graph.GetName().c_str(), node->GetNamePtr());
}
EXPECT_EQ(condition, true);
}
}
}
}
}
}
void Construct_Enable_Cache_Max_Struct(af::AscGraph &graph) {
static af::Expression Zero = af::Symbol(0);
static af::Expression One = af::Symbol(1);
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
auto z3 = graph.CreateAxis("z3", s3);
auto axis = {z0.id, z1.id, z2.id, z3.id};
Data data_0("data_0", graph);
data_0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data_0.attr.api.type = af::ApiType::kAPITypeBuffer;
data_0.y.dtype = ge::DT_FLOAT16;
data_0.ir_attr.SetIndex(0);
Load b0_load("b0_load");
b0_load.x = data_0.y;
b0_load.attr.sched.axis = axis;
b0_load.attr.api.compute_type = af::ComputeType::kComputeLoad;
b0_load.y.dtype = ge::DT_FLOAT16;
*b0_load.y.axis = axis;
*b0_load.y.repeats = {s0, s1, s2, s3};
*b0_load.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Broadcast b0_broadcast("b0_broadcast");
b0_broadcast.x = b0_load.y;
b0_broadcast.attr.sched.axis = axis;
b0_broadcast.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
b0_broadcast.y.dtype = ge::DT_FLOAT16;
*b0_broadcast.y.axis = axis;
*b0_broadcast.y.repeats = {s0, s1, s2, s3};
*b0_broadcast.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Data data_1("data_1", graph);
data_1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data_1.attr.api.type = af::ApiType::kAPITypeBuffer;
data_1.y.dtype = ge::DT_FLOAT16;
data_1.ir_attr.SetIndex(1);
Load b1_load("b1_load");
b1_load.x = data_1.y;
b1_load.attr.sched.axis = axis;
b1_load.attr.api.compute_type = af::ComputeType::kComputeLoad;
b1_load.y.dtype = ge::DT_FLOAT16;
*b1_load.y.axis = axis;
*b1_load.y.repeats = {s0, s1, One, s3};
*b1_load.y.strides = {s1 * s3, s3, Zero, One};
Broadcast b1_broadcast("b1_broadcast");
b1_broadcast.x = b1_load.y;
b1_broadcast.attr.sched.axis = axis;
b1_broadcast.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
b1_broadcast.y.dtype = ge::DT_FLOAT16;
*b1_broadcast.y.axis = axis;
*b1_broadcast.y.repeats = {s0, s1, s2, s3};
*b1_broadcast.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Add b0_add("b0_add");
b0_add.x1 = b0_broadcast.y;
b0_add.x2 = b1_broadcast.y;
b0_add.attr.sched.axis = axis;
b0_add.attr.api.compute_type = af::ComputeType::kComputeElewise;
b0_add.y.dtype = ge::DT_FLOAT16;
*b0_add.y.axis = axis;
*b0_add.y.repeats = {s0, s1, s2, s3};
*b0_add.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Max b0_max("b0_max");
b0_max.x = b0_add.y;
b0_max.attr.sched.axis = axis;
b0_max.attr.api.compute_type = af::ComputeType::kComputeReduce;
b0_max.y.dtype = ge::DT_FLOAT16;
*b0_max.y.axis = axis;
*b0_max.y.repeats = {s0, One, One, s3};
*b0_max.y.strides = {s3, Zero, Zero, One};
Store b0_store("b0_store");
b0_store.x = b0_max.y;
b0_store.attr.sched.axis = axis;
b0_store.attr.api.compute_type = af::ComputeType::kComputeStore;
b0_store.y.dtype = ge::DT_FLOAT16;
*b0_store.y.axis = axis;
*b0_store.y.repeats = {s0, One, One, s3};
*b0_store.y.strides = {s3, Zero, Zero, One};
Output output_0("output_0");
output_0.x = b0_store.y;
output_0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
output_0.attr.api.type = af::ApiType::kAPITypeBuffer;
output_0.y.dtype = ge::DT_FLOAT;
output_0.ir_attr.SetIndex(0);
}
void Construct_Enable_Cache_Sum_Struct(af::AscGraph &graph) {
static af::Expression Zero = af::Symbol(0);
static af::Expression One = af::Symbol(1);
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
auto z3 = graph.CreateAxis("z3", s3);
auto axis = {z0.id, z1.id, z2.id, z3.id};
Data data_0("data_0", graph);
data_0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data_0.attr.api.type = af::ApiType::kAPITypeBuffer;
data_0.y.dtype = ge::DT_FLOAT16;
data_0.ir_attr.SetIndex(0);
Load b0_load("b0_load");
b0_load.x = data_0.y;
b0_load.attr.sched.axis = axis;
b0_load.attr.api.compute_type = af::ComputeType::kComputeLoad;
b0_load.y.dtype = ge::DT_FLOAT16;
*b0_load.y.axis = axis;
*b0_load.y.repeats = {s0, s1, s2, s3};
*b0_load.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Broadcast b0_broadcast("b0_broadcast");
b0_broadcast.x = b0_load.y;
b0_broadcast.attr.sched.axis = axis;
b0_broadcast.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
b0_broadcast.y.dtype = ge::DT_FLOAT16;
*b0_broadcast.y.axis = axis;
*b0_broadcast.y.repeats = {s0, s1, s2, s3};
*b0_broadcast.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Data data_1("data_1", graph);
data_1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
data_1.attr.api.type = af::ApiType::kAPITypeBuffer;
data_1.y.dtype = ge::DT_FLOAT16;
data_1.ir_attr.SetIndex(1);
Load b1_load("b1_load");
b1_load.x = data_1.y;
b1_load.attr.sched.axis = axis;
b1_load.attr.api.compute_type = af::ComputeType::kComputeLoad;
b1_load.y.dtype = ge::DT_FLOAT16;
*b1_load.y.axis = axis;
*b1_load.y.repeats = {s0, s1, One, s3};
*b1_load.y.strides = {s1 * s3, s3, Zero, One};
Broadcast b1_broadcast("b1_broadcast");
b1_broadcast.x = b1_load.y;
b1_broadcast.attr.sched.axis = axis;
b1_broadcast.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
b1_broadcast.y.dtype = ge::DT_FLOAT16;
*b1_broadcast.y.axis = axis;
*b1_broadcast.y.repeats = {s0, s1, s2, s3};
*b1_broadcast.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Add b0_add("b0_add");
b0_add.x1 = b0_broadcast.y;
b0_add.x2 = b1_broadcast.y;
b0_add.attr.sched.axis = axis;
b0_add.attr.api.compute_type = af::ComputeType::kComputeElewise;
b0_add.y.dtype = ge::DT_FLOAT16;
*b0_add.y.axis = axis;
*b0_add.y.repeats = {s0, s1, s2, s3};
*b0_add.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Sum b0_sum("b0_sum");
b0_sum.x = b0_add.y;
b0_sum.attr.sched.axis = axis;
b0_sum.attr.api.compute_type = af::ComputeType::kComputeReduce;
b0_sum.y.dtype = ge::DT_FLOAT16;
*b0_sum.y.axis = axis;
*b0_sum.y.repeats = {s0, One, One, s3};
*b0_sum.y.strides = {s3, Zero, Zero, One};
Store b0_store("b0_store");
b0_store.x = b0_sum.y;
b0_store.attr.sched.axis = axis;
b0_store.attr.api.compute_type = af::ComputeType::kComputeStore;
b0_store.y.dtype = ge::DT_FLOAT16;
*b0_store.y.axis = axis;
*b0_store.y.repeats = {s0, One, One, s3};
*b0_store.y.strides = {s3, Zero, Zero, One};
Output output_0("output_0");
output_0.x = b0_store.y;
output_0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
output_0.attr.api.type = af::ApiType::kAPITypeBuffer;
output_0.y.dtype = ge::DT_FLOAT;
output_0.ir_attr.SetIndex(0);
}
TEST_F(TestOptimizer, EnableCacheMax) {
bool gen_success = true;
af::AscGraph test_graph("enable_cache_max");
Construct_Enable_Cache_Max_Struct(test_graph);
try {
auto codegen = codegen::Codegen(
codegen::CodegenOptions{.tiling_lib_path = "gen_tiling.so", .tiling_lib_codegen_symbol = "CodegenTiling"});
::ascir::FusedScheduledResult fused_schedule_result;
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
optimizer.Optimize(test_graph, fused_schedule_result);
codegen::CodegenResult result;
codegen.Generate(fused_schedule_result, result);
} catch (...) {
gen_success = false;
}
EXPECT_EQ(gen_success, true);
}
TEST_F(TestOptimizer, EnableCacheSum) {
bool gen_success = true;
af::AscGraph test_graph("enable_cache_sum");
Construct_Enable_Cache_Sum_Struct(test_graph);
try {
auto codegen = codegen::Codegen(
codegen::CodegenOptions{.tiling_lib_path = "gen_tiling.so", .tiling_lib_codegen_symbol = "CodegenTiling"});
::ascir::FusedScheduledResult fused_schedule_result;
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
optimizer.Optimize(test_graph, fused_schedule_result);
codegen::CodegenResult result;
codegen.Generate(fused_schedule_result, result);
} catch (...) {
gen_success = false;
}
EXPECT_EQ(gen_success, true);
}
TEST_F(TestOptimizer, TransposeLongTailWithoutUB) {
af::AscGraph graph("transpose_long_tail_without_UB");
auto s0 = graph.CreateSizeVar(16);
auto s1 = graph.CreateSizeVar(86);
auto s2 = graph.CreateSizeVar(1536);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data x1_op("x1", graph);
x1_op.ir_attr.SetIndex(0);
Data x2_op("x2", graph);
x2_op.ir_attr.SetIndex(1);
Data x3_op("x3", graph);
x3_op.ir_attr.SetIndex(2);
Load load_op1("load1");
Load load_op2("load2");
Load load_op3("load3");
std::vector<Data> all_data{x1_op, x2_op, x3_op};
std::vector<Load> all_load{load_op1, load_op2, load_op3};
for (size_t i = 0U; i < all_data.size(); ++i) {
auto &x_op = all_data[i];
auto &load_op = all_load[i];
x_op.y.dtype = ge::DT_FLOAT16;
x_op.attr.api.type = af::ApiType::kAPITypeBuffer;
load_op.x = x_op.y;
load_op.attr.sched.axis = {z0.id, z1.id, z2.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.axis = {z0.id, z1.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.repeats = {s0, s1, s2};
*load_op.y.strides = {s1 * s2, s2, af::ops::One};
}
load_op1.attr.sched.axis = {z1.id, z0.id, z2.id};
*load_op1.y.axis = {z1.id, z0.id, z2.id};
*load_op1.y.repeats = {s1, s0, s2};
*load_op1.y.strides = {s0 * s2, s2, af::ops::One};
af::ascir_op::Transpose transpose_op("transpose");
transpose_op.attr.sched.axis = {z0.id, z1.id, z2.id};
transpose_op.x = load_op1.y;
transpose_op.y.dtype = ge::DT_FLOAT16;
*transpose_op.y.axis = {z0.id, z1.id, z2.id};
*transpose_op.y.strides = {s1 * s2, s2, af::ops::One};
*transpose_op.y.repeats = {s0, s1, s2};
af::ascir_op::Add add_op("add");
add_op.attr.sched.axis = {z0.id, z1.id, z2.id};
add_op.x1 = load_op2.y;
add_op.x2 = load_op3.y;
add_op.y.dtype = ge::DT_FLOAT16;
*add_op.y.axis = {z0.id, z1.id, z2.id};
*add_op.y.strides = {s1 * s2, s2, af::ops::One};
*add_op.y.repeats = {s0, s1, s2};
af::ascir_op::Mul mul_op("mul");
mul_op.attr.sched.axis = {z0.id, z1.id, z2.id};
mul_op.x1 = transpose_op.y;
mul_op.x2 = add_op.y;
mul_op.y.dtype = ge::DT_FLOAT16;
*mul_op.y.axis = {z0.id, z1.id, z2.id};
*mul_op.y.strides = {s1 * s2, s2, af::ops::One};
*mul_op.y.repeats = {s0, s1, s2};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = mul_op.y;
store_op.y.dtype = ge::DT_FLOAT16;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
store_op.ir_attr.SetOffset(af::Symbol(0));
Output y_op("y");
y_op.x = store_op.y;
y_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
std::vector<optimize::ScheduleTask> schedule_tasks;
optimize::OptimizerOptions options{optimize::GraphType::kAscGraph};
int res = optimize::ScheduleTaskGenerator::GenerateTasks(graph, schedule_tasks, options);
ASSERT_EQ(res, 0);
ASSERT_EQ(schedule_tasks.size(), 2);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(schedule_tasks[1].grouped_graphs[0]);
ASSERT_EQ(compute_graph->FindNode("transpose"), nullptr);
ASSERT_EQ(schedule_tasks[0].score_func,
"int32_t CalcScore(const AutofuseTilingData &tiling_data) {\n"
" return -1;\n"
"}\n");
}
TEST_F(TestOptimizer, PowEqiv) {
auto graph = af::testing::AscGraphBuilder("PowBrc")
.Loops({af::testing::Sym(32)})
.Data("data0", 0)
.Load("load0", "data0")
.Scalar("scalar0", "-1")
.template Op<af::ascir_op::Pow>("pow0", {"load0", "scalar0"})
.Load("load1", "data0")
.Scalar("scalar1", "-2")
.template Op<af::ascir_op::Pow>("pow1", {"load1", "scalar1"})
.Load("load2", "data0")
.Scalar("scalar2", "-0.5")
.template Op<af::ascir_op::Pow>("pow2", {"load2", "scalar2"})
.Scalar("scalar3", "3")
.template Op<af::ascir_op::Pow>("pow3", {"scalar3", "scalar3"})
.Load("load4", "data0")
.Scalar("scalar4", "4")
.template Op<af::ascir_op::Pow>("pow4", {"load4", "scalar4"})
.Add("add0", "pow0", "pow1")
.Add("add1", "pow2", "pow3")
.Add("add2", "add0", "add1")
.Add("add3", "add2", "pow4")
.Store("store", "add3")
.Output("y", "store", 0)
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
::ascir::utils::DumpGraph(graph, "BEFORE");
Status res = optimize::autoschedule::PassRunnerHandler().RunPasses(graph);
EXPECT_EQ(res, ge::SUCCESS);
::ascir::utils::DumpGraph(graph, "AFTER");
auto pow0_node = graph.FindNode("pow0");
EXPECT_EQ(pow0_node, nullptr);
auto pow1_node = graph.FindNode("pow1");
EXPECT_EQ(pow1_node, nullptr);
auto pow2_node = graph.FindNode("pow2");
EXPECT_EQ(pow2_node, nullptr);
auto pow3_node = graph.FindNode("pow3");
EXPECT_EQ(pow3_node, nullptr);
auto pow4_node = graph.FindNode("pow4");
EXPECT_EQ(pow4_node, nullptr);
auto cg = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(cg->GetAllNodesSize(), 25UL);
}
TEST_F(TestOptimizer, PowEqivCase2) {
auto graph = af::testing::AscGraphBuilder("PowBrc")
.Loops({af::testing::Sym(32)})
.Data("data0", 0)
.Load("load0", "data0")
.template Op<af::ascir_op::Pow>("pow0", {"load0", "load0"})
.Data("data1", 1)
.Load("load1", "data1")
.Scalar("scalar1", "0")
.template Op<af::ascir_op::Pow>("pow1", {"load1", "scalar1"})
.Add("add0", "pow0", "pow1")
.Store("store", "add0")
.Output("y", "store", 0)
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
::ascir::utils::DumpGraph(graph, "BEFORE");
Status res = optimize::autoschedule::PassRunnerHandler().RunPasses(graph);
EXPECT_EQ(res, ge::SUCCESS);
::ascir::utils::DumpGraph(graph, "AFTER");
auto pow0_node = graph.FindNode("pow0");
EXPECT_NE(pow0_node, nullptr);
auto pow1_node = graph.FindNode("pow1");
EXPECT_EQ(pow1_node, nullptr);
auto data1_node = graph.FindNode("data1");
EXPECT_NE(data1_node, nullptr);
auto cg = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(cg->GetAllNodesSize(), 8UL);
}
TEST_F(TestOptimizer, SkipPruneGraph) {
auto graph = af::testing::AscGraphBuilder("PowBrc")
.Loops({af::testing::Sym(32)})
.Data("data0", 0)
.Load("load0", "data0")
.template Op<af::ascir_op::Pow>("pow0", {"load0", "load0"})
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
::ascir::utils::DumpGraph(graph, "BEFORE");
Status res = optimize::PassUtils::PruneGraph(graph);
EXPECT_EQ(res, ge::SUCCESS);
::ascir::utils::DumpGraph(graph, "After");
auto cg = af::AscGraphUtils::GetComputeGraph(graph);
EXPECT_EQ(cg->GetAllNodesSize(), 3UL);
}
TEST_F(TestOptimizer, OptimizeRemoveDanglingNodes) {
auto graph = af::testing::AscGraphBuilder("RemoveDanglingNodes")
.Loops({af::testing::Sym(32)})
.Data("data0", 0)
.Load("load0", "data0")
.Abs("abs0", "load0")
.Store("store", "abs0")
.Output("output", "store", 0)
.Data("data1", 1)
.Load("load1", "data1")
.Abs("dangling_abs", "load1")
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto cg = af::AscGraphUtils::GetComputeGraph(graph);
ASSERT_NE(cg, nullptr);
ASSERT_NE(graph.FindNode("data1"), nullptr);
ASSERT_NE(graph.FindNode("load1"), nullptr);
ASSERT_NE(graph.FindNode("dangling_abs"), nullptr);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, ge::SUCCESS);
auto data1_node = graph.FindNode("data1");
auto load1_node = graph.FindNode("load1");
auto dangling_abs_node = graph.FindNode("dangling_abs");
auto output_node = graph.FindNode("output");
ASSERT_NE(data1_node, nullptr);
ASSERT_NE(output_node, nullptr);
EXPECT_EQ(load1_node, nullptr);
EXPECT_EQ(dangling_abs_node, nullptr);
EXPECT_EQ(data1_node->GetOutDataNodesSize(), 0U);
EXPECT_EQ(data1_node->GetOutControlNodesSize(), 1U);
EXPECT_EQ(output_node->GetInControlNodesSize(), 1U);
}
TEST_F(TestOptimizer, OptimizeRemoveDanglingNodesWithWorkspaceOutput) {
auto graph = af::testing::AscGraphBuilder("RemoveDanglingNodesWithWorkspaceOutput")
.Loops({af::testing::Sym(32)})
.Data("data0", 0)
.Load("load0", "data0")
.Abs("abs0", "load0")
.Workspace("workspace", "abs0")
.ScalarData("scalar_data", 1)
.Build();
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto workspace_node = graph.FindNode("workspace");
auto scalar_data_node = graph.FindNode("scalar_data");
ASSERT_NE(workspace_node, nullptr);
ASSERT_NE(scalar_data_node, nullptr);
EXPECT_EQ(workspace_node->GetOutDataNodesSize(), 0U);
EXPECT_EQ(scalar_data_node->GetOutControlNodesSize(), 0U);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, ge::SUCCESS);
workspace_node = graph.FindNode("workspace");
scalar_data_node = graph.FindNode("scalar_data");
ASSERT_NE(workspace_node, nullptr);
ASSERT_NE(scalar_data_node, nullptr);
EXPECT_EQ(scalar_data_node->GetOutDataNodesSize(), 0U);
EXPECT_EQ(scalar_data_node->GetOutControlNodesSize(), 1U);
EXPECT_EQ(workspace_node->GetInControlNodesSize(), 1U);
}
TEST_F(TestOptimizer, TransposeWithUB) {
af::AscGraph graph("transpose_with_ub");
auto s0 = graph.CreateSizeVar(16);
auto s1 = graph.CreateSizeVar(86);
auto s2 = graph.CreateSizeVar(200);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data x1_op("x1", graph);
x1_op.ir_attr.SetIndex(0);
Data x2_op("x2", graph);
x2_op.ir_attr.SetIndex(1);
Data x3_op("x3", graph);
x3_op.ir_attr.SetIndex(2);
Load load_op1("load1");
Load load_op2("load2");
Load load_op3("load3");
std::vector<Data> all_data{x1_op, x2_op, x3_op};
std::vector<Load> all_load{load_op1, load_op2, load_op3};
for (size_t i = 0U; i < all_data.size(); ++i) {
auto &x_op = all_data[i];
auto &load_op = all_load[i];
x_op.y.dtype = ge::DT_FLOAT16;
x_op.attr.api.type = af::ApiType::kAPITypeBuffer;
load_op.x = x_op.y;
load_op.attr.sched.axis = {z0.id, z1.id, z2.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.axis = {z0.id, z1.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.repeats = {s0, s1, s2};
*load_op.y.strides = {s1 * s2, s2, af::ops::One};
}
load_op1.attr.sched.axis = {z1.id, z0.id, z2.id};
*load_op1.y.axis = {z1.id, z0.id, z2.id};
*load_op1.y.repeats = {s1, s0, s2};
*load_op1.y.strides = {s0 * s2, s2, af::ops::One};
af::ascir_op::Transpose transpose_op("transpose");
transpose_op.attr.sched.axis = {z0.id, z1.id, z2.id};
transpose_op.x = load_op1.y;
transpose_op.y.dtype = ge::DT_FLOAT16;
*transpose_op.y.axis = {z0.id, z1.id, z2.id};
*transpose_op.y.strides = {s1 * s2, s2, af::ops::One};
*transpose_op.y.repeats = {s0, s1, s2};
af::ascir_op::Add add_op("add");
add_op.attr.sched.axis = {z0.id, z1.id, z2.id};
add_op.x1 = load_op2.y;
add_op.x2 = load_op3.y;
add_op.y.dtype = ge::DT_FLOAT16;
*add_op.y.axis = {z0.id, z1.id, z2.id};
*add_op.y.strides = {s1 * s2, s2, af::ops::One};
*add_op.y.repeats = {s0, s1, s2};
af::ascir_op::Mul mul_op("mul");
mul_op.attr.sched.axis = {z0.id, z1.id, z2.id};
mul_op.x1 = transpose_op.y;
mul_op.x2 = add_op.y;
mul_op.y.dtype = ge::DT_FLOAT16;
*mul_op.y.axis = {z0.id, z1.id, z2.id};
*mul_op.y.strides = {s1 * s2, s2, af::ops::One};
*mul_op.y.repeats = {s0, s1, s2};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = mul_op.y;
store_op.y.dtype = ge::DT_FLOAT16;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
store_op.ir_attr.SetOffset(af::Symbol(0));
Output y_op("y");
y_op.x = store_op.y;
y_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
std::vector<optimize::ScheduleTask> schedule_tasks;
optimize::OptimizerOptions options{optimize::GraphType::kAscGraph};
int res = optimize::ScheduleTaskGenerator::GenerateTasks(graph, schedule_tasks, options);
ASSERT_EQ(res, 0);
ASSERT_EQ(schedule_tasks.size(), 2);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(schedule_tasks[1].grouped_graphs[0]);
ASSERT_EQ(compute_graph->FindNode("transpose"), nullptr);
ASSERT_EQ(schedule_tasks[0].score_func,
"int32_t CalcScore(const AutofuseTilingData &tiling_data) {\n"
" return 1;\n"
"}\n");
SetCurShapeEnvContext(nullptr);
}
TEST_F(TestOptimizer, TransposeWithDynamicTail) {
af::AscGraph graph("transpose_with_dynamic_tail");
auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic));
SetCurShapeEnvContext(&shape_env);
auto s0 = graph.CreateSizeVar(16);
auto s1 = graph.CreateSizeVar(86);
auto s2 = shape_env.CreateSymbol(200, MakeShared<GraphInputShapeSourceStub>(0, 0));
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data x1_op("x1", graph);
x1_op.ir_attr.SetIndex(0);
Data x2_op("x2", graph);
x2_op.ir_attr.SetIndex(1);
Data x3_op("x3", graph);
x3_op.ir_attr.SetIndex(2);
Load load_op1("load1");
Load load_op2("load2");
Load load_op3("load3");
std::vector<Data> all_data{x1_op, x2_op, x3_op};
std::vector<Load> all_load{load_op1, load_op2, load_op3};
for (size_t i = 0U; i < all_data.size(); ++i) {
auto &x_op = all_data[i];
auto &load_op = all_load[i];
x_op.y.dtype = ge::DT_FLOAT16;
x_op.attr.api.type = af::ApiType::kAPITypeBuffer;
load_op.x = x_op.y;
load_op.attr.sched.axis = {z0.id, z1.id, z2.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.axis = {z0.id, z1.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.repeats = {s0, s1, s2};
*load_op.y.strides = {s1 * s2, s2, af::ops::One};
}
load_op1.attr.sched.axis = {z1.id, z0.id, z2.id};
*load_op1.y.axis = {z1.id, z0.id, z2.id};
*load_op1.y.repeats = {s1, s0, s2};
*load_op1.y.strides = {s0 * s2, s2, af::ops::One};
af::ascir_op::Transpose transpose_op("transpose");
transpose_op.attr.sched.axis = {z0.id, z1.id, z2.id};
transpose_op.x = load_op1.y;
transpose_op.y.dtype = ge::DT_FLOAT16;
*transpose_op.y.axis = {z0.id, z1.id, z2.id};
*transpose_op.y.strides = {s1 * s2, s2, af::ops::One};
*transpose_op.y.repeats = {s0, s1, s2};
af::ascir_op::Add add_op("add");
add_op.attr.sched.axis = {z0.id, z1.id, z2.id};
add_op.x1 = load_op2.y;
add_op.x2 = load_op3.y;
add_op.y.dtype = ge::DT_FLOAT16;
*add_op.y.axis = {z0.id, z1.id, z2.id};
*add_op.y.strides = {s1 * s2, s2, af::ops::One};
*add_op.y.repeats = {s0, s1, s2};
af::ascir_op::Mul mul_op("mul");
mul_op.attr.sched.axis = {z0.id, z1.id, z2.id};
mul_op.x1 = transpose_op.y;
mul_op.x2 = add_op.y;
mul_op.y.dtype = ge::DT_FLOAT16;
*mul_op.y.axis = {z0.id, z1.id, z2.id};
*mul_op.y.strides = {s1 * s2, s2, af::ops::One};
*mul_op.y.repeats = {s0, s1, s2};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id, z2.id};
store_op.x = mul_op.y;
store_op.y.dtype = ge::DT_FLOAT16;
*store_op.y.axis = {z0.id, z1.id, z2.id};
*store_op.y.repeats = {s0, s1, s2};
*store_op.y.strides = {s1 * s2, s2, af::ops::One};
store_op.ir_attr.SetOffset(af::Symbol(0));
Output y_op("y");
y_op.x = store_op.y;
y_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
std::vector<optimize::ScheduleTask> schedule_tasks;
optimize::OptimizerOptions options{optimize::GraphType::kAscGraph};
int res = optimize::ScheduleTaskGenerator::GenerateTasks(graph, schedule_tasks, options);
ASSERT_EQ(res, 0);
ASSERT_EQ(schedule_tasks.size(), 2);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(schedule_tasks[1].grouped_graphs[0]);
ASSERT_EQ(compute_graph->FindNode("transpose"), nullptr);
ASSERT_EQ(schedule_tasks[0].score_func,
"int32_t CalcScore(const AutofuseTilingData &tiling_data) {\n"
" return 1;\n"
"}\n");
SetCurShapeEnvContext(nullptr);
}
TEST_F(TestOptimizer, AllReduce) {
af::AscGraph graph("all_reduce");
auto s0 = graph.CreateSizeVar(128);
auto s1 = graph.CreateSizeVar(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Data data("data", graph);
data.y.dtype = ge::DT_FLOAT;
data.ir_attr.SetIndex(0);
Load load("load");
load.attr.sched.axis = {z0.id, z1.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.strides = {s1, af::ops::One};
*load.y.repeats = {s0, s1};
Sum sum("sum");
sum.attr.sched.axis = {z0.id, z1.id};
sum.attr.api.compute_type = af::ComputeType::kComputeReduce;
sum.x = load.y;
*sum.y.axis = {z0.id, z1.id};
sum.y.dtype = ge::DT_FLOAT;
*sum.y.repeats = {af::ops::One, af::ops::One};
*sum.y.strides = {af::ops::Zero, af::ops::Zero};
Store store_op1("store1");
store_op1.attr.sched.axis = {z0.id, z1.id};
store_op1.x = sum.y;
*store_op1.y.axis = {z0.id, z1.id};
store_op1.y.dtype = ge::DT_FLOAT;
*store_op1.y.axis = {z0.id, z1.id};
*store_op1.y.repeats = {af::ops::One, af::ops::One};
*store_op1.y.strides = {af::ops::Zero, af::ops::Zero};
Output output_op("output");
output_op.x = store_op1.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, 0);
}
TEST_F(TestOptimizer, BufQueAllocator_RemovePad_MemUnique) {
af::AscGraph graph("BufQueAllocator_RemovePad_MemUnique");
const af::Expression s0 = graph.CreateSizeVar(320);
const af::Expression s1 = graph.CreateSizeVar(2889);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data x0("x", graph);
x0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
x0.attr.api.type = af::ApiType::kAPITypeBuffer;
x0.ir_attr.SetIndex(0);
x0.y.dtype = ge::DataType::DT_FLOAT;
af::ascir_op::Load load0("load0");
load0.x = x0.y;
load0.attr.api.compute_type = af::ComputeType::kComputeLoad;
load0.attr.api.type = af::ApiType::kAPITypeCompute;
load0.attr.sched.axis = {z0.id, z1.id};
*load0.y.axis = {z0.id, z1.id};
*load0.y.repeats = {s0, s1};
*load0.y.strides = {s1, One};
load0.y.dtype = ge::DataType::DT_FLOAT;
load0.attr.api.unit = af::ComputeUnit::kUnitMTE2;
af::ascir_op::Data x1("x1", graph);
x1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
x1.attr.api.type = af::ApiType::kAPITypeBuffer;
x1.y.dtype = ge::DataType::DT_FLOAT;
x1.ir_attr.SetIndex(1);
af::ascir_op::Load load1("load1");
load1.x = x1.y;
load1.attr.api.compute_type = af::ComputeType::kComputeLoad;
load1.attr.api.type = af::ApiType::kAPITypeCompute;
load1.attr.sched.axis = {z0.id, z1.id};
*load1.y.axis = {z0.id, z1.id};
*load1.y.repeats = {One, s1};
*load1.y.strides = {Zero, One};
load1.y.dtype = ge::DataType::DT_FLOAT;
load1.attr.api.unit = af::ComputeUnit::kUnitMTE2;
af::ascir_op::Broadcast broadcast1("broadcast1");
broadcast1.x = load1.y;
broadcast1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
broadcast1.attr.api.type = af::ApiType::kAPITypeCompute;
broadcast1.attr.sched.axis = {z0.id, z1.id};
*broadcast1.y.axis = {z0.id, z1.id};
*broadcast1.y.repeats = {s0, s1};
*broadcast1.y.strides = {s1, One};
broadcast1.y.dtype = ge::DataType::DT_FLOAT;
broadcast1.attr.api.unit = af::ComputeUnit::kUnitVector;
af::ascir_op::Add add0("add0");
add0.x1 = load0.y;
add0.x2 = broadcast1.y;
add0.attr.api.compute_type = af::ComputeType::kComputeElewise;
add0.attr.api.type = af::ApiType::kAPITypeCompute;
add0.attr.sched.axis = {z0.id, z1.id};
*add0.y.axis = {z0.id, z1.id};
*add0.y.repeats = {s0, s1};
*add0.y.strides = {s1, One};
add0.y.dtype = ge::DataType::DT_FLOAT;
add0.attr.api.unit = af::ComputeUnit::kUnitVector;
af::ascir_op::Data x2("x2", graph);
x2.attr.api.compute_type = af::ComputeType::kComputeInvalid;
x2.attr.api.type = af::ApiType::kAPITypeBuffer;
x2.ir_attr.SetIndex(2);
x2.y.dtype = ge::DataType::DT_FLOAT;
af::ascir_op::Load load2("load2");
load2.x = x2.y;
load2.attr.api.compute_type = af::ComputeType::kComputeLoad;
load2.attr.api.type = af::ApiType::kAPITypeCompute;
load2.attr.sched.axis = {z0.id, z1.id};
*load2.y.axis = {z0.id, z1.id};
*load2.y.repeats = {s0, s1};
*load2.y.strides = {s1, One};
load2.y.dtype = ge::DataType::DT_FLOAT;
load2.attr.api.unit = af::ComputeUnit::kUnitMTE2;
af::ascir_op::Mul mul0("mul0");
mul0.x1 = load2.y;
mul0.x2 = add0.y;
mul0.attr.api.compute_type = af::ComputeType::kComputeElewise;
mul0.attr.api.type = af::ApiType::kAPITypeCompute;
mul0.attr.sched.axis = {z0.id, z1.id};
*mul0.y.axis = {z0.id, z1.id};
*mul0.y.repeats = {s0, s1};
*mul0.y.strides = {s1, One};
mul0.y.dtype = ge::DataType::DT_FLOAT;
mul0.attr.api.unit = af::ComputeUnit::kUnitVector;
af::ascir_op::Store store("store");
store.x = mul0.y;
store.attr.api.compute_type = af::ComputeType::kComputeStore;
store.attr.api.type = af::ApiType::kAPITypeCompute;
store.attr.sched.axis = {z0.id, z1.id};
*store.y.axis = {z0.id, z1.id};
*store.y.repeats = {s0, s1};
*store.y.strides = {s1, One};
store.y.dtype = ge::DataType::DT_FLOAT;
store.attr.api.unit = af::ComputeUnit::kUnitMTE3;
af::ascir_op::Output y("y");
y.x = store.y;
y.attr.api.compute_type = af::ComputeType::kComputeInvalid;
y.attr.api.type = af::ApiType::kAPITypeBuffer;
y.y.dtype = ge::DataType::DT_FLOAT;
y.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, ge::SUCCESS);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 1UL);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 3UL);
auto impl_graph2 = af::AscGraphUtils::GetComputeGraph(
fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[1]);
EXPECT_EQ(impl_graph2->GetAllNodesSize(), 12);
EXPECT_NE(impl_graph2->FindNode("broadcast1"), nullptr);
EXPECT_NE(impl_graph2->FindNode("broadcast1_remove_pad_0"), nullptr);
EXPECT_NE(impl_graph2->FindNode("add0"), nullptr);
const auto &impl_graph2_mul0 = std::dynamic_pointer_cast<af::AscNode>(impl_graph2->FindNode("mul0"));
const auto &impl_graph2_add0 = std::dynamic_pointer_cast<af::AscNode>(impl_graph2->FindNode("add0"));
EXPECT_EQ(impl_graph2_add0->outputs[0].attr.que.id, impl_graph2_mul0->outputs[0].attr.que.id);
}
TEST_F(TestOptimizer, GatherLastAxisTest) {
af::AscGraph graph("gather_last_axis");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = One;
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data x1("x1", graph);
x1.attr.sched.axis = {z0.id};
x1.y.dtype = ge::DT_FLOAT;
*x1.y.repeats = {s0};
*(x1.y.axis) = {z0.id};
Data x2("x2", graph);
x2.attr.sched.axis = {z1.id, z2.id};
x2.y.dtype = ge::DT_INT64;
*x2.y.repeats = {s1, s2};
*(x2.y.axis) = {z1.id, z2.id};
Gather gather("gather");
gather.attr.api.compute_type = af::ComputeType::kComputeGather;
gather.x1 = x1.y;
gather.x2 = x2.y;
gather.attr.sched.axis = {z1.id, z2.id};
gather.ir_attr.SetAxis(0);
gather.y.dtype = ge::DT_FLOAT;
*gather.y.axis = {z1.id, z2.id};
*gather.y.repeats = {s1, s2};
*gather.y.strides = {One, Zero};
gather.attr.tmp_buffers = {{{af::Symbol(8192), -1}, af::MemAttr(), 0}};
af::ascir_op::Abs abs("abs");
graph.AddNode(abs);
abs.x = gather.y;
abs.attr.sched.axis = {z1.id, z2.id};
abs.y.dtype = ge::DT_FLOAT;
*abs.y.repeats = {s1, s2};
*abs.y.strides = {One, Zero};
Store store("store");
graph.AddNode(store);
store.x = abs.y;
store.attr.sched.axis = {z1.id, z2.id};
store.y.dtype = ge::DT_FLOAT;
*store.y.axis = {z1.id, z2.id};
*store.y.repeats = {s1, s2};
*store.y.strides = {One, Zero};
Output y("y");
graph.AddNode(y);
y.x = store.y;
y.attr.sched.axis = {z1.id, z2.id};
y.y.dtype = ge::DT_FLOAT;
*y.y.axis = {z1.id, z2.id};
ASSERT_EQ(optimize::Optimizer::RemoveAllZeroStrideLoopAxis(graph), SUCCESS);
auto gather_node = graph.FindNode("gather");
ASSERT_NE(gather_node, nullptr);
EXPECT_EQ(gather_node->attr.sched.axis.size(), 2UL);
EXPECT_EQ(gather_node->outputs[0].attr.axis.size(), 2UL);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
std::vector<optimize::ScheduleTask> schedule_tasks;
optimize::OptimizerOptions options{optimize::GraphType::kAscGraph};
int res = optimize::ScheduleTaskGenerator::GenerateTasks(graph, schedule_tasks, options);
ASSERT_EQ(res, 0);
ASSERT_EQ(schedule_tasks.size(), 1);
}
TEST_F(TestOptimizer, MergeGroupYAndR) {
af::AscGraph graph("merge_group_y_and_r");
auto s0 = graph.CreateSizeVar(128);
auto s1 = graph.CreateSizeVar(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Data data("data", graph);
data.y.dtype = ge::DT_FLOAT;
data.ir_attr.SetIndex(0);
Load load("load");
load.attr.sched.axis = {z0.id, z1.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.strides = {s1, af::ops::One};
*load.y.repeats = {s0, s1};
Sum sum("sum");
sum.attr.sched.axis = {z0.id, z1.id};
sum.x = load.y;
*sum.y.axis = {z0.id, z1.id};
sum.y.dtype = ge::DT_FLOAT;
*sum.y.strides = {af::ops::Zero, af::ops::Zero};
*sum.y.repeats = {af::ops::One, af::ops::One};
af::ascir_op::Abs abs("abs");
abs.x = sum.y;
abs.attr.sched.axis = {z0.id, z1.id};
abs.y.dtype = ge::DT_FLOAT;
*abs.y.axis = {z0.id, z1.id};
*abs.y.strides = {af::ops::Zero, af::ops::Zero};
*abs.y.repeats = {af::ops::One, af::ops::One};
Store store_op1("store1");
store_op1.attr.sched.axis = {z0.id, z1.id};
store_op1.x = sum.y;
*store_op1.y.axis = {z0.id, z1.id};
store_op1.y.dtype = ge::DT_FLOAT;
*store_op1.y.axis = {z0.id, z1.id};
*store_op1.y.strides = {af::ops::Zero, af::ops::Zero};
*store_op1.y.repeats = {af::ops::One, af::ops::One};
Output output_op("output");
output_op.x = store_op1.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
optimize::autoschedule::AxisGroup axes_group;
EXPECT_EQ(optimize::autoschedule::TilingGroup::GenTilingGroup(graph, axes_group), 0);
optimize::autoschedule::AxisGroup cur_y_group;
cur_y_group.y_group = {z0.id, z1.id};
optimize::autoschedule::AxisGroup cur_r_group;
cur_r_group.r_group = {z0.id, z1.id};
EXPECT_TRUE(optimize::autoschedule::TilingGroup::MergeAxesGroup(cur_y_group, cur_r_group, true));
EXPECT_FALSE(optimize::autoschedule::TilingGroup::MergeAxesGroup(cur_r_group, cur_y_group, true));
}
TEST_F(TestOptimizer, ConcatTailDimStatic) {
af::AscGraph graph("concat_last_dim_graph");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar(412);
auto s2 = graph.CreateSizeVar(16);
auto s3 = graph.CreateSizeVar(32);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1 + s2 + s3);
std::vector<af::Expression> input_dims{s1, s2, s3};
Data x1_op("x1", graph);
x1_op.ir_attr.SetIndex(0);
Data x2_op("x2", graph);
x2_op.ir_attr.SetIndex(1);
Data x3_op("x3", graph);
x3_op.ir_attr.SetIndex(2);
Load load_op1("load1");
Load load_op2("load2");
Load load_op3("load3");
std::vector<Data> all_data{x1_op, x2_op, x3_op};
std::vector<Load> all_load{load_op1, load_op2, load_op3};
for (size_t i = 0U; i < all_data.size(); ++i) {
auto &x_op = all_data[i];
auto &load_op = all_load[i];
x_op.y.dtype = ge::DT_FLOAT16;
load_op.x = x_op.y;
load_op.attr.sched.axis = {z0.id, z1.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.axis = {z0.id, z1.id};
load_op.y.dtype = ge::DT_FLOAT16;
*load_op.y.repeats = {s0, input_dims[i]};
*load_op.y.strides = {input_dims[i], af::ops::One};
}
af::ascir_op::Concat concat_op("concat");
concat_op.attr.sched.axis = {z0.id, z1.id};
concat_op.x = {load_op1.y, load_op2.y, load_op3.y};
concat_op.y.dtype = ge::DT_FLOAT16;
*concat_op.y.axis = {z0.id, z1.id};
*concat_op.y.repeats = {s0, s1 + s2};
*concat_op.y.strides = {s1 + s2, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id};
store_op.x = concat_op.y;
store_op.y.dtype = ge::DT_FLOAT16;
*store_op.y.axis = {z0.id, z1.id};
*store_op.y.repeats = {s0, s1 + s2};
*store_op.y.strides = {s1 + s2, af::ops::One};
store_op.ir_attr.SetOffset(af::Symbol(0));
Output y_op("y");
y_op.x = store_op.y;
y_op.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
ASSERT_EQ(res, 0);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 1);
auto &schedule_result = fused_scheduled_result.node_idx_to_scheduled_results[0][0];
std::vector<af::Expression> offsets;
std::vector<af::Expression> expect = {af::Symbol(0), s1};
for (const auto &schedule_group : schedule_result.schedule_groups) {
auto &sub_impl_graph = schedule_group.impl_graphs.front();
for (const auto &sub_node : sub_impl_graph.GetAllNodes()) {
if (sub_node->GetType() == "Store") {
af::Expression offset;
EXPECT_EQ(sub_node->attr.ir_attr->GetAttrValue("offset", offset), 0);
offsets.emplace_back(offset);
}
}
}
for (size_t i = 0; i < offsets.size(); ++i) {
EXPECT_SYMBOL_EQ(offsets[i], expect[i]);
}
EXPECT_EQ(fused_scheduled_result.input_nodes.size(), 3);
EXPECT_EQ(fused_scheduled_result.output_nodes.size(), 1);
EXPECT_EQ(fused_scheduled_result.workspace_nodes.size(), 0);
EXPECT_EQ(fused_scheduled_result.input_nodes[0]->GetName(), "x1");
EXPECT_EQ(fused_scheduled_result.input_nodes[1]->GetName(), "x2");
EXPECT_EQ(fused_scheduled_result.input_nodes[2]->GetName(), "x3");
EXPECT_EQ(fused_scheduled_result.output_nodes[0]->GetName(), "y");
}
TEST_F(TestOptimizer, ReducePartition) {
af::AscGraph graph("reduce_partition");
auto s0 = graph.CreateSizeVar(128);
auto s1 = graph.CreateSizeVar(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Data data("data", graph);
data.y.dtype = ge::DT_FLOAT;
data.ir_attr.SetIndex(0);
Load load("load");
load.attr.sched.axis = {z0.id, z1.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.strides = {s1 ,af::ops::One};
*load.y.repeats = {s0, s1};
Exp exp("exp");
exp.x = load.y;
exp.attr.sched.axis = {z0.id, z1.id};
*exp.y.axis = {z0.id, z1.id};
exp.y.dtype = ge::DT_FLOAT;
*exp.y.strides = {s1 ,af::ops::One};
*exp.y.repeats = {s0, s1};
Sum sum("sum");
sum.attr.sched.axis = {z0.id, z1.id};
sum.x = exp.y;
*sum.y.axis = {z0.id, z1.id};
sum.y.dtype = ge::DT_FLOAT;
*sum.y.strides = {af::ops::One, af::ops::One};
*sum.y.repeats = {af::ops::Zero, af::ops::Zero};
Broadcast broadcast("broadcast");
broadcast.x = sum.y;
broadcast.attr.sched.axis = {z0.id, z1.id};
*broadcast.y.axis = {z0.id, z1.id};
broadcast.y.dtype = ge::DT_FLOAT;
*broadcast.y.strides = {s1 ,af::ops::One};
*broadcast.y.repeats = {s0, s1};
Sub sub("sub");
sub.x1 = broadcast.y;
sub.x2 = exp.y;
sub.attr.sched.axis = {z0.id, z1.id};
*sub.y.axis = {z0.id, z1.id};
sub.y.dtype = ge::DT_FLOAT;
*sub.y.strides = {s1 ,af::ops::One};
*sub.y.repeats = {s0, s1};
Store store_op1("store1");
store_op1.attr.sched.axis = {z0.id, z1.id};
store_op1.x = sub.y;
*store_op1.y.axis = {z0.id, z1.id};
store_op1.y.dtype = ge::DT_FLOAT;
*store_op1.y.strides = {s1 ,af::ops::One};
*store_op1.y.repeats = {s0, s1};
Output output_op("output");
output_op.x = store_op1.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), 0);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 2);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 2);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[1].impl_graphs.size(), 2);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups.size(), 3);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[0].impl_graphs.size(), 1);
auto impl_graph_sum_phase1 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[0].impl_graphs[0];
auto impl_graph_sum_phase2 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[1].impl_graphs[0];
auto impl_graph_sub = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[2].impl_graphs[0];
auto phase1_workspace = impl_graph_sum_phase1.FindNode("reduce_partition_0_r_multicore_phase_2_graph_workspace");
auto phase2_workspace1 = impl_graph_sum_phase2.FindNode("reduce_partition_0_r_multicore_phase_2_graph_workspace");
auto phase2_workspace2 = impl_graph_sum_phase2.FindNode("sum_Workspace");
auto sub_workspace = impl_graph_sub.FindNode("sum_Workspace");
ASSERT_NE(phase1_workspace, nullptr);
ASSERT_NE(phase2_workspace1, nullptr);
ASSERT_NE(phase2_workspace2, nullptr);
ASSERT_NE(sub_workspace, nullptr);
}
TEST_F(TestOptimizer, ReducePartitionLoad) {
af::AscGraph graph("Reduce_partition_load");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto s3 = graph.CreateSizeVar("s3");
auto z0 = graph.CreateAxis("z0", s0 * s1 * s2);
auto z1 = graph.CreateAxis("z1", s3);
auto axis = {z0.id, z1.id};
Data arg4_1("arg4_1", graph);
arg4_1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
arg4_1.y.dtype = ge::DT_FLOAT16;
arg4_1.ir_attr.SetIndex(0);
Load b0_load("b0_load");
b0_load.x = arg4_1.y;
b0_load.attr.sched.axis = axis;
b0_load.y.dtype = ge::DT_FLOAT16;
*b0_load.y.axis = axis;
*b0_load.y.repeats = {s0 * s1 * s2, s3};
*b0_load.y.strides = {s3, One};
af::ascir_op::Max b0_max("b0_max");
b0_max.x = b0_load.y;
b0_max.attr.sched.axis = axis;
b0_max.y.dtype = ge::DT_FLOAT16;
*b0_max.y.axis = axis;
*b0_max.y.repeats = {s0 * s1 * s2, One};
*b0_max.y.strides = {One, Zero};
Broadcast b1_broadcast("b1_broadcast");
b1_broadcast.x = b0_max.y;
b1_broadcast.attr.sched.axis = axis;
b1_broadcast.y.dtype = ge::DT_FLOAT16;
*b1_broadcast.y.axis = axis;
*b1_broadcast.y.repeats = {s0 * s1 * s2, s3};
*b1_broadcast.y.strides = {s3, One};
af::ascir_op::Sub b1_sub("b1_sub");
b1_sub.x1 = b0_load.y;
b1_sub.x2 = b1_broadcast.y;
b1_sub.attr.sched.axis = axis;
b1_sub.y.dtype = ge::DT_FLOAT16;
*b1_sub.y.axis = axis;
*b1_sub.y.repeats = {s0 * s1 * s2, s3};
*b1_sub.y.strides = {s3, One};
Exp b1_exp("b1_exp");
b1_exp.x = b1_sub.y;
b1_exp.attr.sched.axis = axis;
b1_exp.y.dtype = ge::DT_FLOAT16;
*b1_exp.y.axis = axis;
*b1_exp.y.repeats = {s0 * s1 * s2, s3};
*b1_exp.y.strides = {s3, One};
Sum b2_sum("b2_sum");
b2_sum.x = b1_exp.y;
b2_sum.attr.sched.axis = axis;
b2_sum.y.dtype = ge::DT_FLOAT16;
*b2_sum.y.axis = axis;
*b2_sum.y.repeats = {s0 * s1 * s2, One};
*b2_sum.y.strides = {One, Zero};
Output buf3("buf3");
buf3.ir_attr.SetIndex(2);
Broadcast b3_broadcast("b3_broadcast");
b3_broadcast.x = b2_sum.y;
b3_broadcast.attr.sched.axis = axis;
b3_broadcast.y.dtype = ge::DT_FLOAT16;
*b3_broadcast.y.axis = axis;
*b3_broadcast.y.repeats = {s0 * s1 * s2, s3};
*b3_broadcast.y.strides = {s3, One};
af::ascir_op::Div b3_div("b3_div");
b3_div.x1 = b1_exp.y;
b3_div.x2 = b3_broadcast.y;
b3_div.attr.sched.axis = axis;
b3_div.y.dtype = ge::DT_FLOAT16;
*b3_div.y.axis = axis;
*b3_div.y.repeats = {s0 * s1 * s2, s3};
*b3_div.y.strides = {s3, One};
Store b3_store("b3_store");
b3_store.x = b3_div.y;
b3_store.attr.sched.axis = axis;
b3_store.y.dtype = ge::DT_FLOAT16;
*b3_store.y.axis = axis;
*b3_store.y.repeats = {s0 * s1 * s2, s3};
*b3_store.y.strides = {s3, One};
buf3.x = b3_store.y;
buf3.y.dtype = ge::DT_FLOAT16;
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), 0);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 3);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[1].impl_graphs.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[2].impl_graphs.size(), 2);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups.size(), 5);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[0].impl_graphs.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[1].impl_graphs.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[2].impl_graphs.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[3].impl_graphs.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[4].impl_graphs.size(), 2);
auto impl_graph_max_phase1 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[0].impl_graphs[0];
auto impl_graph_max_phase2 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[1].impl_graphs[0];
auto impl_graph_sum_phase1 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[2].impl_graphs[0];
auto impl_graph_sum_phase2 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[3].impl_graphs[0];
auto impl_graph_div = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[4].impl_graphs[0];
auto max_phase1_workspace = impl_graph_max_phase1.FindNode("Reduce_partition_load_0_r_multicore_phase_2_graph_workspace");
auto max_phase2_workspace1 = impl_graph_max_phase2.FindNode("Reduce_partition_load_0_r_multicore_phase_2_graph_workspace");
auto max_phase2_workspace2 = impl_graph_max_phase2.FindNode("b0_max_Workspace");
auto sum_phase1_workspace1 = impl_graph_sum_phase1.FindNode("b0_max_Workspace");
auto sum_phase1_workspace2 = impl_graph_sum_phase1.FindNode("Reduce_partition_load_1_r_multicore_phase_2_graph_workspace");
auto sum_phase2_workspace1 = impl_graph_sum_phase2.FindNode("Reduce_partition_load_1_r_multicore_phase_2_graph_workspace");
auto sum_phase2_workspace2 = impl_graph_sum_phase2.FindNode("b2_sum_Workspace");
auto div_workspace = impl_graph_div.FindNode("b2_sum_Workspace");
ASSERT_NE(max_phase1_workspace, nullptr);
ASSERT_NE(max_phase2_workspace1, nullptr);
ASSERT_NE(max_phase2_workspace2, nullptr);
ASSERT_NE(sum_phase1_workspace1, nullptr);
ASSERT_NE(sum_phase1_workspace2, nullptr);
ASSERT_NE(sum_phase2_workspace1, nullptr);
ASSERT_NE(sum_phase2_workspace2, nullptr);
ASSERT_NE(div_workspace, nullptr);
}
TEST_F(TestOptimizer, ReducePartition3) {
af::AscGraph graph("reduce_partition");
auto s0 = graph.CreateSizeVar(128);
auto s1 = graph.CreateSizeVar(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Data data("data", graph);
data.y.dtype = ge::DT_FLOAT;
data.ir_attr.SetIndex(0);
Load load("load");
load.attr.sched.axis = {z0.id, z1.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.strides = {s1 ,af::ops::One};
*load.y.repeats = {s0, s1};
Exp exp("exp");
exp.x = load.y;
exp.attr.sched.axis = {z0.id, z1.id};
*exp.y.axis = {z0.id, z1.id};
exp.y.dtype = ge::DT_FLOAT;
*exp.y.strides = {s1 ,af::ops::One};
*exp.y.repeats = {s0, s1};
Abs abs("abs");
abs.attr.sched.axis = {z0.id, z1.id};
abs.x = exp.y;
*abs.y.axis = {z0.id, z1.id};
abs.y.dtype = ge::DT_FLOAT;
*abs.y.strides = {s1 ,af::ops::One};
*abs.y.repeats = {s0, s1};
Sub sub("sub");
sub.x1 = abs.y;
sub.x2 = exp.y;
sub.attr.sched.axis = {z0.id, z1.id};
*sub.y.axis = {z0.id, z1.id};
sub.y.dtype = ge::DT_FLOAT;
*sub.y.strides = {s1 ,af::ops::One};
*sub.y.repeats = {s0, s1};
Sum sum("b2_sum");
sum.x = sub.y;
sum.attr.sched.axis = {z0.id, z1.id};
sum.y.dtype = ge::DT_FLOAT16;
*sum.y.axis = {z0.id, z1.id};
*sum.y.strides = {af::ops::One, af::ops::One};
*sum.y.repeats = {af::ops::Zero, af::ops::Zero};
Cast cast("cast");
cast.x = sum.y;
cast.attr.sched.axis = {z0.id, z1.id};
cast.y.dtype = ge::DT_FLOAT;
*cast.y.axis = {z0.id, z1.id};
*cast.y.strides = {af::ops::One, af::ops::One};
*cast.y.repeats = {af::ops::Zero, af::ops::Zero};
Store store_op1("store1");
store_op1.attr.sched.axis = {z0.id, z1.id};
store_op1.x = cast.y;
*store_op1.y.axis = {z0.id, z1.id};
store_op1.y.dtype = ge::DT_FLOAT;
*store_op1.y.strides = {af::ops::One, af::ops::One};
*store_op1.y.repeats = {af::ops::Zero, af::ops::Zero};
Output output_op("output");
output_op.x = store_op1.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), 0);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups.size(), 2);
auto impl_graph_phase1 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[0].impl_graphs[0];
auto impl_graph_phase2 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[1].impl_graphs[0];
auto phase1_workspace1 = impl_graph_phase1.FindNode("reduce_partition_0_r_multicore_phase_2_graph_workspace");
auto phase2_workspace1 = impl_graph_phase2.FindNode("reduce_partition_0_r_multicore_phase_2_graph_workspace");
ASSERT_NE(phase1_workspace1, nullptr);
ASSERT_NE(phase2_workspace1, nullptr);
}
TEST_F(TestOptimizer, ReducePartition4) {
af::AscGraph graph("reduce_partition");
auto s0 = graph.CreateSizeVar(128);
auto s1 = graph.CreateSizeVar(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Data data("data", graph);
data.y.dtype = ge::DT_FLOAT;
data.ir_attr.SetIndex(0);
Load load("load");
load.attr.sched.axis = {z0.id, z1.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.strides = {s1 ,af::ops::One};
*load.y.repeats = {s0, s1};
Exp exp("exp");
exp.x = load.y;
exp.attr.sched.axis = {z0.id, z1.id};
*exp.y.axis = {z0.id, z1.id};
exp.y.dtype = ge::DT_FLOAT;
*exp.y.strides = {s1 ,af::ops::One};
*exp.y.repeats = {s0, s1};
Abs abs("abs");
abs.attr.sched.axis = {z0.id, z1.id};
abs.x = exp.y;
*abs.y.axis = {z0.id, z1.id};
abs.y.dtype = ge::DT_FLOAT;
*abs.y.strides = {s1 ,af::ops::One};
*abs.y.repeats = {s0, s1};
Sub sub("sub");
sub.x1 = abs.y;
sub.x2 = exp.y;
sub.attr.sched.axis = {z0.id, z1.id};
*sub.y.axis = {z0.id, z1.id};
sub.y.dtype = ge::DT_FLOAT;
*sub.y.strides = {s1 ,af::ops::One};
*sub.y.repeats = {s0, s1};
Sum sum("b2_sum");
sum.x = sub.y;
sum.attr.sched.axis = {z0.id, z1.id};
sum.y.dtype = ge::DT_FLOAT16;
*sum.y.axis = {z0.id, z1.id};
*sum.y.strides = {af::ops::One, af::ops::One};
*sum.y.repeats = {af::ops::Zero, af::ops::Zero};
Cast cast("cast");
cast.x = sum.y;
cast.attr.sched.axis = {z0.id, z1.id};
cast.y.dtype = ge::DT_FLOAT;
*cast.y.axis = {z0.id, z1.id};
*cast.y.strides = {af::ops::One, af::ops::One};
*cast.y.repeats = {af::ops::Zero, af::ops::Zero};
Broadcast broadcast("broadcast");
broadcast.x = cast.y;
broadcast.attr.sched.axis = {z0.id, z1.id};
*broadcast.y.axis = {z0.id, z1.id};
broadcast.y.dtype = ge::DT_FLOAT;
*broadcast.y.strides = {s1 ,af::ops::One};
*broadcast.y.repeats = {s0, s1};
Store store_op1("store1");
store_op1.attr.sched.axis = {z0.id, z1.id};
store_op1.x = broadcast.y;
*store_op1.y.axis = {z0.id, z1.id};
store_op1.y.dtype = ge::DT_FLOAT;
*store_op1.y.strides = {s1 ,af::ops::One};
*store_op1.y.repeats = {s0, s1};
Output output_op("output");
output_op.x = store_op1.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), 0);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 2);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups.size(), 3);
auto impl_graph_phase1 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[0].impl_graphs[0];
auto impl_graph_phase2 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[1].impl_graphs[0];
auto impl_graph_brc = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[2].impl_graphs[0];
auto phase1_workspace = impl_graph_phase1.FindNode("reduce_partition_0_r_multicore_phase_2_graph_workspace");
auto phase2_workspace1 = impl_graph_phase2.FindNode("reduce_partition_0_r_multicore_phase_2_graph_workspace");
auto phase2_workspace2 = impl_graph_phase2.FindNode("b2_sum_Workspace");
auto brc_workspace = impl_graph_brc.FindNode("b2_sum_Workspace");
ASSERT_NE(phase1_workspace, nullptr);
ASSERT_NE(phase2_workspace1, nullptr);
ASSERT_NE(phase2_workspace2, nullptr);
ASSERT_NE(brc_workspace, nullptr);
}
TEST_F(TestOptimizer, ReduceRMulticore) {
af::AscGraph graph("reduce_r_multicore");
auto s0 = graph.CreateSizeVar(128);
auto s1 = graph.CreateSizeVar(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Data data("data", graph);
data.y.dtype = ge::DT_FLOAT;
data.ir_attr.SetIndex(0);
Load load("load");
load.attr.sched.axis = {z0.id, z1.id};
load.x = data.y;
*load.y.axis = {z0.id, z1.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.strides = {s1 ,af::ops::One};
*load.y.repeats = {s0, s1};
Exp exp("exp");
exp.x = load.y;
exp.attr.sched.axis = {z0.id, z1.id};
*exp.y.axis = {z0.id, z1.id};
exp.y.dtype = ge::DT_FLOAT;
*exp.y.strides = {s1 ,af::ops::One};
*exp.y.repeats = {s0, s1};
Abs abs("abs");
abs.attr.sched.axis = {z0.id, z1.id};
abs.x = exp.y;
*abs.y.axis = {z0.id, z1.id};
abs.y.dtype = ge::DT_FLOAT;
*abs.y.strides = {s1 ,af::ops::One};
*abs.y.repeats = {s0, s1};
Sum sum("b2_sum");
sum.x = abs.y;
sum.attr.sched.axis = {z0.id, z1.id};
sum.y.dtype = ge::DT_FLOAT16;
*sum.y.axis = {z0.id, z1.id};
*sum.y.strides = {af::ops::One, af::ops::One};
*sum.y.repeats = {af::ops::Zero, af::ops::Zero};
Cast cast("cast");
cast.x = sum.y;
cast.attr.sched.axis = {z0.id, z1.id};
cast.y.dtype = ge::DT_FLOAT;
*cast.y.axis = {z0.id, z1.id};
*cast.y.strides = {af::ops::One, af::ops::One};
*cast.y.repeats = {af::ops::Zero, af::ops::Zero};
Store store_op1("store1");
store_op1.attr.sched.axis = {z0.id, z1.id};
store_op1.x = cast.y;
*store_op1.y.axis = {z0.id, z1.id};
store_op1.y.dtype = ge::DT_FLOAT;
*store_op1.y.strides = {af::ops::One, af::ops::One};
*store_op1.y.repeats = {af::ops::Zero, af::ops::Zero};
Output output_op("output");
output_op.x = store_op1.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), 0);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups.size(), 2);
auto impl_graph_phase1 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[0].impl_graphs[0];
auto impl_graph_phase2 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[1].impl_graphs[0];
auto phase1_workspace = impl_graph_phase1.FindNode("reduce_r_multicore_0_r_multicore_phase_2_graph_workspace");
auto phase2_workspace1 = impl_graph_phase2.FindNode("reduce_r_multicore_0_r_multicore_phase_2_graph_workspace");
ASSERT_NE(phase1_workspace, nullptr);
ASSERT_NE(phase2_workspace1, nullptr);
}
TEST_F(TestOptimizer, ReducePartitionScalar) {
af::AscGraph graph("reduce_partition_scalar");
auto s0 = graph.CreateSizeVar(128);
auto s1 = graph.CreateSizeVar(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Scalar scalar("scalar", graph);
scalar.y.dtype = ge::DT_FLOAT;
Workspace workspace("workspace");
graph.AddNode(workspace);
Load load("load");
load.attr.sched.axis = {z0.id, z1.id};
load.attr.sched.axis = {z0.id, z1.id};
load.x = workspace.y;
*load.y.axis = {z0.id, z1.id};
load.y.dtype = ge::DT_FLOAT;
*load.y.strides = {s1 ,af::ops::One};
*load.y.repeats = {s0, s1};
Add add("add");
add.x1 = scalar.y;
add.x2 = load.y;
add.attr.sched.axis = {z0.id, z1.id};
*add.y.axis = {z0.id, z1.id};
add.y.dtype = ge::DT_FLOAT;
*add.y.strides = {s1 ,af::ops::One};
*add.y.repeats = {s0, s1};
Mean mean("b2_mean");
mean.x = add.y;
mean.attr.sched.axis = {z0.id, z1.id};
mean.y.dtype = ge::DT_FLOAT;
*mean.y.axis = {z0.id, z1.id};
*mean.y.strides = {af::ops::One, af::ops::One};
*mean.y.repeats = {af::ops::Zero, af::ops::Zero};
Broadcast broadcast("broadcast");
broadcast.x = mean.y;
broadcast.attr.sched.axis = {z0.id, z1.id};
*broadcast.y.axis = {z0.id, z1.id};
broadcast.y.dtype = ge::DT_FLOAT;
*broadcast.y.strides = {s1 ,af::ops::One};
*broadcast.y.repeats = {s0, s1};
Sub sub("sub");
sub.x1 = load.y;
sub.x2 = broadcast.y;
sub.attr.sched.axis = {z0.id, z1.id};
*sub.y.axis = {z0.id, z1.id};
sub.y.dtype = ge::DT_FLOAT;
*sub.y.strides = {s1 ,af::ops::One};
*sub.y.repeats = {s0, s1};
Add add1("add1");
add1.x1 = sub.y;
add1.x2 = scalar.y;
add1.attr.sched.axis = {z0.id, z1.id};
*add1.y.axis = {z0.id, z1.id};
add1.y.dtype = ge::DT_FLOAT;
*add1.y.strides = {s1 ,af::ops::One};
*add1.y.repeats = {s0, s1};
Store store_op1("store1");
store_op1.attr.sched.axis = {z0.id, z1.id};
store_op1.x = add1.y;
*store_op1.y.axis = {z0.id, z1.id};
store_op1.y.dtype = ge::DT_FLOAT;
*store_op1.y.strides = {af::ops::One, af::ops::One};
*store_op1.y.repeats = {af::ops::Zero, af::ops::Zero};
Output output_op("output");
output_op.x = store_op1.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), 0);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 2);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups.size(), 3);
auto sum_graph = fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0];
auto sub_add_graph = fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[1].impl_graphs[0];
auto copy_scalar = sub_add_graph.FindNode("copy_from_scalar");
auto copy_load = sub_add_graph.FindNode("copy_from_load");
auto copy_workspace = sub_add_graph.FindNode("copy_from_workspace");
ASSERT_NE(copy_scalar, nullptr);
ASSERT_NE(copy_load, nullptr);
ASSERT_NE(copy_workspace, nullptr);
auto impl_graph_phase1 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[0].impl_graphs[0];
auto impl_graph_phase2 = fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups[1].impl_graphs[0];
auto phase1_workspace = impl_graph_phase1.FindNode("reduce_partition_scalar_0_r_multicore_phase_2_graph_workspace");
auto phase2_workspace1 = impl_graph_phase2.FindNode("reduce_partition_scalar_0_r_multicore_phase_2_graph_workspace");
ASSERT_NE(phase1_workspace, nullptr);
ASSERT_NE(phase2_workspace1, nullptr);
}
TEST_F(TestOptimizer, ReduceAllLoad) {
af::AscGraph graph("reduce_all_load");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto s2 = graph.CreateSizeVar("s2");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data arg4_1("arg4_1", graph);
arg4_1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
arg4_1.attr.api.type = af::ApiType::kAPITypeBuffer;
arg4_1.y.dtype = ge::DT_FLOAT;
arg4_1.ir_attr.SetIndex(0);
Load b0_load("b0_load");
b0_load.x = arg4_1.y;
b0_load.attr.sched.axis = {z0.id, z1.id, z2.id};
b0_load.attr.api.compute_type = af::ComputeType::kComputeLoad;
b0_load.y.dtype = ge::DT_FLOAT;
*b0_load.y.axis = {z0.id, z1.id, z2.id};
*b0_load.y.repeats = {s0, s1, s2};
*b0_load.y.strides = {s1 * s2, s2, One};
Abs abs("abs");
abs.x = b0_load.y;
abs.attr.sched.axis = {z0.id, z1.id, z2.id};
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.y.dtype = ge::DT_FLOAT;
*abs.y.axis = {z0.id, z1.id, z2.id};
*abs.y.repeats = {s0, s1, s2};
*abs.y.strides = {s1 * s2, s2, One};
Mul mul("mul");
mul.attr.sched.axis = {z0.id, z1.id, z2.id};
mul.x1 = abs.y;
mul.x2 = abs.y;
mul.y.dtype = ge::DT_FLOAT;
*mul.y.axis = {z0.id, z1.id, z2.id};
*mul.y.repeats = {s0, s1, s2};
*mul.y.strides = {s1 * s2, s2, af::ops::One};
af::ascir_op::Max b0_max("b0_max");
b0_max.x = mul.y;
b0_max.attr.sched.axis = {z0.id, z1.id, z2.id};
b0_max.attr.api.compute_type = af::ComputeType::kComputeReduce;
b0_max.y.dtype = ge::DT_FLOAT;
*b0_max.y.axis = {z0.id, z1.id, z2.id};
*b0_max.y.repeats = {s0, One, s2};
*b0_max.y.strides = {s2, Zero, One};
Store b3_store("b3_store");
b3_store.x = b0_max.y;
b3_store.attr.sched.axis = {z0.id, z1.id, z2.id};
b3_store.attr.api.compute_type = af::ComputeType::kComputeStore;
b3_store.y.dtype = ge::DT_FLOAT;
*b3_store.y.axis = {z0.id, z1.id, z2.id};
*b3_store.y.repeats = {s0, One, s2};
*b3_store.y.strides = {s2, Zero, One};
Output buf3("buf3");
buf3.x = b3_store.y;
buf3.attr.api.compute_type = af::ComputeType::kComputeInvalid;
buf3.attr.api.type = af::ApiType::kAPITypeBuffer;
buf3.y.dtype = ge::DT_FLOAT;
buf3.ir_attr.SetIndex(0);
af::ascir_op::Sum b0_sum("b0_sum");
b0_sum.x = abs.y;
b0_sum.attr.sched.axis = {z0.id, z1.id, z2.id};
b0_sum.attr.api.compute_type = af::ComputeType::kComputeReduce;
b0_sum.y.dtype = ge::DT_FLOAT;
*b0_sum.y.axis = {z0.id, z1.id, z2.id};
*b0_sum.y.repeats = {s0, One, s2};
*b0_sum.y.strides = {s2, Zero, One};
Store b4_store("b4_store");
b4_store.x = b0_sum.y;
b4_store.attr.sched.axis = {z0.id, z1.id, z2.id};
b4_store.attr.api.compute_type = af::ComputeType::kComputeStore;
b4_store.y.dtype = ge::DT_FLOAT;
*b4_store.y.axis = {z0.id, z1.id, z2.id};
*b4_store.y.repeats = {s0, One, s2};
*b4_store.y.strides = {s2, Zero, One};
Output buf4("buf4");
buf4.x = b4_store.y;
buf4.attr.api.compute_type = af::ComputeType::kComputeInvalid;
buf4.attr.api.type = af::ApiType::kAPITypeBuffer;
buf4.y.dtype = ge::DT_FLOAT;
buf4.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), 0);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 6);
}
* data0 data1
* | |
* load0 load1
* | |
* abs broadcast
* \ /
* \ /
* add
* |
* store
* |
* output
*/
TEST_F(TestOptimizer, LoadOpSequenceAdjustCase1) {
af::AscGraph graph("reorder_load_op");
auto s0 = graph.CreateSizeVar(64);
auto s1 = graph.CreateSizeVar(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Data data0("data0", graph);
data0.y.dtype = ge::DT_FLOAT;
data0.ir_attr.SetIndex(0);
Load load0("load0");
load0.attr.sched.axis = {z0.id, z1.id};
load0.x = data0.y;
*load0.y.axis = {z0.id, z1.id};
load0.y.dtype = ge::DT_FLOAT;
*load0.y.strides = {s1 ,af::ops::One};
*load0.y.repeats = {s0, s1};
af::ascir_op::Abs abs("abs");
graph.AddNode(abs);
abs.x = load0.y;
abs.attr.sched.axis = {z0.id, z1.id};
abs.y.dtype = ge::DT_FLOAT16;
*abs.y.axis = {z0.id, z1.id};
*abs.y.repeats = {s0, s1};
*abs.y.strides = {s1, One};
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT;
data1.ir_attr.SetIndex(3);
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1.id};
load1.x = data1.y;
*load1.y.axis = {z0.id, z1.id};
load1.y.dtype = ge::DT_FLOAT;
*load1.y.strides = {af::ops::One ,af::ops::One};
*load1.y.repeats = {af::ops::One, af::ops::One};
Broadcast broadcast("broadcast");
broadcast.x = load1.y;
broadcast.attr.sched.axis = {z0.id, z1.id};
*broadcast.y.axis = {z0.id, z1.id};
broadcast.y.dtype = ge::DT_FLOAT;
*broadcast.y.strides = {s1 ,af::ops::One};
*broadcast.y.repeats = {s0, s1};
af::ascir_op::Add add_op("add");
add_op.attr.sched.axis = {z0.id, z1.id};
add_op.x1 = abs.y;
add_op.x2 = broadcast.y;
add_op.y.dtype = ge::DT_FLOAT16;
*add_op.y.axis = {z0.id, z1.id};
*add_op.y.strides = {s1, af::ops::One};
*add_op.y.repeats = {s0, s1};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id};
store_op.x = add_op.y;
*store_op.y.axis = {z0.id, z1.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {s1 ,af::ops::One};
*store_op.y.repeats = {s0, s1};
Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(8);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
for (const auto &node : graph.GetAllNodes()) {
if (node->GetOpDesc()->GetId() == 3) {
EXPECT_EQ(node->GetOpDesc()->GetType(), "Data");
}
}
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), 0);
for (const auto &node : fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0].GetAllNodes()) {
if (node->GetOpDesc()->GetId() == 2) {
EXPECT_EQ(node->GetOpDesc()->GetType(), "Data");
}
if (node->GetOpDesc()->GetId() == 3) {
EXPECT_EQ(node->GetOpDesc()->GetType(), "Abs");
}
if (node->GetOpDesc()->GetId() == 4) {
EXPECT_EQ(node->GetOpDesc()->GetType(), "Load");
}
if (node->GetOpDesc()->GetId() == 5) {
EXPECT_EQ(node->GetOpDesc()->GetType(), "Add");
}
}
}
TEST_F(TestOptimizer, LoadOpSequenceAdjustCase2) {
af::AscGraph graph("reorder_load_op");
auto s0 = graph.CreateSizeVar(64);
auto s1 = graph.CreateSizeVar(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Data data0("data0", graph);
data0.attr.sched.axis = {z0.id, z1.id};
data0.y.dtype = ge::DT_FLOAT;
*data0.y.axis = {z0.id, z1.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0.y.strides = {s1 ,af::ops::One};
*data0.y.repeats = {s0, s1};
data0.ir_attr.SetIndex(0);
Load load0("load0");
load0.attr.sched.axis = {z0.id, z1.id};
load0.x = data0.y;
*load0.y.axis = {z0.id, z1.id};
load0.y.dtype = ge::DT_FLOAT;
*load0.y.strides = {s1 ,af::ops::One};
*load0.y.repeats = {s0, s1};
af::ascir_op::Abs abs("abs");
graph.AddNode(abs);
abs.x = load0.y;
abs.attr.sched.axis = {z0.id, z1.id};
abs.y.dtype = ge::DT_FLOAT16;
*abs.y.axis = {z0.id, z1.id};
*abs.y.repeats = {s0, s1};
*abs.y.strides = {s1, One};
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
af::ascir_op::Scalar scalar0("scalar0", graph);
scalar0.attr.sched.axis = {z0.id, z1.id};
scalar0.ir_attr.SetValue("0");
scalar0.y.dtype = ge::DT_FLOAT;
*scalar0.y.axis = {z0.id, z1.id};
*scalar0.y.repeats = {One, One};
*scalar0.y.strides = {Zero, Zero};
Broadcast broadcast0("broadcast0");
broadcast0.x = scalar0.y;
broadcast0.attr.sched.axis = {z0.id, z1.id};
*broadcast0.y.axis = {z0.id, z1.id};
broadcast0.y.dtype = ge::DT_FLOAT;
*broadcast0.y.repeats = {One, s1};
*broadcast0.y.strides = {Zero, One};
Broadcast broadcast1("broadcast1");
broadcast1.x = broadcast0.y;
broadcast1.attr.sched.axis = {z0.id, z1.id};
*broadcast1.y.axis = {z0.id, z1.id};
broadcast1.y.dtype = ge::DT_FLOAT;
*broadcast1.y.repeats = {s0, s1};
*broadcast1.y.strides = {s1, af::ops::One};
af::ascir_op::Add add_op("add");
add_op.attr.sched.axis = {z0.id, z1.id};
add_op.x1 = abs.y;
add_op.x2 = broadcast1.y;
add_op.y.dtype = ge::DT_FLOAT;
*add_op.y.axis = {z0.id, z1.id};
*add_op.y.repeats = {s0, s1};
*add_op.y.strides = {s1, af::ops::One};
Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT;
data1.attr.sched.axis = {z0.id, z1.id};
*data1.y.axis = {z0.id, z1.id};
data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data1.y.repeats = {One, One};
*data1.y.strides = {Zero, Zero};
data1.ir_attr.SetIndex(1);
Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z0.id, z1.id};
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id};
*load1.y.strides = {Zero, Zero};
*load1.y.repeats = {One, One};
Broadcast broadcast2("broadcast2");
broadcast2.x = load1.y;
broadcast2.attr.sched.axis = {z0.id, z1.id};
*broadcast2.y.axis = {z0.id, z1.id};
broadcast2.y.dtype = ge::DT_FLOAT;
*broadcast2.y.repeats = {One, s1};
*broadcast2.y.strides = {Zero, One};
Broadcast broadcast3("broadcast3");
broadcast3.x = broadcast2.y;
broadcast3.attr.sched.axis = {z0.id, z1.id};
*broadcast3.y.axis = {z0.id, z1.id};
broadcast3.y.dtype = ge::DT_FLOAT;
*broadcast3.y.repeats = {s0, s1};
*broadcast3.y.strides = {s1, One};
Mul mul("mul");
mul.attr.sched.axis = {z0.id, z1.id};
mul.x1 = add_op.y;
mul.x2 = broadcast3.y;
mul.y.dtype = ge::DT_FLOAT;
*mul.y.axis = {z0.id, z1.id};
*mul.y.repeats = {s0, s1};
*mul.y.strides = {s1, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id};
store_op.x = mul.y;
*store_op.y.axis = {z0.id, z1.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {s1 ,af::ops::One};
*store_op.y.repeats = {s0, s1};
Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(8);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), 0);
for (const auto &node : fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0].GetAllNodes()) {
if (node->GetOpDesc()->GetId() == 2) {
EXPECT_EQ(node->GetOpDesc()->GetType(), "Data");
}
if (node->GetOpDesc()->GetId() == 3) {
EXPECT_EQ(node->GetOpDesc()->GetType(), "Abs");
}
if (node->GetOpDesc()->GetId() == 4) {
EXPECT_EQ(node->GetOpDesc()->GetType(), "Scalar");
}
if (node->GetOpDesc()->GetId() == 5) {
EXPECT_EQ(node->GetOpDesc()->GetType(), "Add");
}
if (node->GetOpDesc()->GetId() == 6) {
EXPECT_EQ(node->GetOpDesc()->GetType(), "Load");
}
if (node->GetOpDesc()->GetId() == 7) {
EXPECT_EQ(node->GetOpDesc()->GetType(), "Mul");
}
}
}
TEST_F(TestOptimizer, platform_reg_test) {
af::AscGraph graph("tmp");
std::string platform_str;
ge::PlatformContext::GetInstance().GetCurrentPlatformString(platform_str);
EXPECT_EQ(platform_str, "2201");
auto platform_v1 = optimize::PlatformFactory::GetInstance().GetPlatform();
EXPECT_NE(platform_v1, nullptr);
auto platform_v1_new = optimize::PlatformFactory::GetInstance().GetPlatform();
EXPECT_EQ(platform_v1, platform_v1_new);
EXPECT_EQ(platform_v1->PartitionSubFunctions(graph), ge::SUCCESS);
ge::PlatformContext::GetInstance().SetPlatform("fake");
auto platform_fake = optimize::PlatformFactory::GetInstance().GetPlatform();
EXPECT_EQ(platform_fake, nullptr);
}
TEST_F(TestOptimizer, BackendSpec) {
auto spec = optimize::BackendSpec::GetInstance();
ASSERT_TRUE(spec != nullptr);
ASSERT_EQ(spec->concat_max_input_num, 63);
}
TEST_F(TestOptimizer, BrcCacheReuseOtherMem) {
af::AscGraph graph("BrcCacheReuseOtherMem");
const af::Expression s0 = af::Symbol(12);
const af::Expression s1 = af::Symbol(32);
const af::Expression s2 = af::Symbol(64);
const af::Expression s3 = af::Symbol(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
auto z3 = graph.CreateAxis("z3", s3);
af::ascir_op::Data x0("x", graph);
x0.ir_attr.SetIndex(0);
af::ascir_op::Load load0("load0");
load0.x = x0.y;
load0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*load0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*load0.y.repeats = {s0, s1, One, s3};
*load0.y.strides = {s1 * s3, s3, Zero, One};
af::ascir_op::Abs abs0("abs0");
abs0.x = load0.y;
abs0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*abs0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs0.y.repeats = {s0, s1, One, s3};
*abs0.y.strides = {s1 * s3, s3, Zero, One};
af::ascir_op::Broadcast brc0("brc0");
brc0.x = abs0.y;
brc0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*brc0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc0.y.repeats = {s0, s1, s2, s3};
*brc0.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Data x1("x1", graph);
x1.ir_attr.SetIndex(1);
af::ascir_op::Load load1("load1");
load1.x = x1.y;
load1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*load1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*load1.y.repeats = {s0, s1, s2, One};
*load1.y.strides = {s1 * s2, s2, One, Zero};
af::ascir_op::Abs abs1("abs1");
abs1.x = load1.y;
abs1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*abs1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs1.y.repeats = {s0, s1, s2, One};
*abs1.y.strides = {s1 * s2, s2, One, Zero};
af::ascir_op::Broadcast brc1("brc1");
brc1.x = abs1.y;
brc1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*brc1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc1.y.repeats = {s0, s1, s2, s3};
*brc1.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Add add0("add0");
add0.x1 = brc0.y;
add0.x2 = brc1.y;
add0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*add0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*add0.y.repeats = {s0, s1, s2, s3};
*add0.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Data x2("x2", graph);
x2.ir_attr.SetIndex(2);
af::ascir_op::Load load2("load2");
load2.x = x2.y;
load2.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*load2.y.axis = {z0.id, z1.id, z2.id, z3.id};
*load2.y.repeats = {One, One, s2, s3};
*load2.y.strides = {Zero, Zero, s3, One};
af::ascir_op::Abs abs2("abs2");
abs2.x = load2.y;
abs2.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*abs2.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs2.y.repeats = {One, One, s2, s3};
*abs2.y.strides = {Zero, Zero, s3, One};
af::ascir_op::Exp exp2("exp2");
exp2.x = abs2.y;
exp2.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*exp2.y.axis = {z0.id, z1.id, z2.id, z3.id};
*exp2.y.repeats = {One, One, s2, s3};
*exp2.y.strides = {Zero, Zero, s3, One};
af::ascir_op::Broadcast brc3("brc3");
brc3.x = exp2.y;
brc3.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*brc3.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc3.y.repeats = {s0, s1, s2, s3};
*brc3.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Data x3("x3", graph);
x3.ir_attr.SetIndex(3);
af::ascir_op::Load load3("load3");
load3.x = x3.y;
load3.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*load3.y.axis = {z0.id, z1.id, z2.id, z3.id};
*load3.y.repeats = {s0, s1, s2, s3};
*load3.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Abs abs3("abs3");
abs3.x = load3.y;
abs3.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*abs3.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs3.y.repeats = {s0, s1, s2, s3};
*abs3.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Exp exp3("exp3");
exp3.x = abs3.y;
exp3.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*exp3.y.axis = {z0.id, z1.id, z2.id, z3.id};
*exp3.y.repeats = {s0, s1, s2, s3};
*exp3.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Add add1("add1");
add1.x1 = brc3.y;
add1.x2 = exp3.y;
add1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*add1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*add1.y.repeats = {s0, s1, s2, s3};
*add1.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Add add2("add2");
add2.x1 = add0.y;
add2.x2 = add1.y;
add2.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*add2.y.axis = {z0.id, z1.id, z2.id, z3.id};
*add2.y.repeats = {s0, s1, s2, s3};
*add2.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Store store("store");
store.x = add2.y;
store.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
*store.y.axis = {z0.id, z1.id, z2.id, z3.id};
*store.y.repeats = {s0, s1, s2, s3};
*store.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Output y("y");
y.ir_attr.SetIndex(0);
y.x = store.y;
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), ge::SUCCESS);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 3);
const auto &impl_graphs = fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs;
const auto &impl_graph = impl_graphs[2];
auto abs2_result = impl_graph.FindNode("abs2");
auto exp2_result = impl_graph.FindNode("exp2");
EXPECT_NE(abs2_result->outputs[0].attr.buf.id, exp2_result->outputs[0].attr.buf.id);
}
TEST_F(TestOptimizer, SliceSliceConcatD) {
af::AscGraph graph("slice_concat");
auto s0 = graph.CreateSizeVar(128);
auto s1 = graph.CreateSizeVar(90);
auto s2 = graph.CreateSizeVar(1);
auto s1_0 = graph.CreateSizeVar(60);
auto s1_1 = graph.CreateSizeVar(30);
auto s3 = graph.CreateSizeVar(97);
auto s4 = graph.CreateSizeVar(65);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
auto z1_1 = graph.CreateAxis("z1_1", s1_1);
auto z1_0 = graph.CreateAxis("z1_0", s1_0);
Data data0("data0", graph);
data0.y.dtype = ge::DT_FLOAT;
data0.ir_attr.SetIndex(0);
data0.attr.sched.axis = {z0.id, z1.id, z2.id};
*data0.y.axis = {z0.id, z1.id, z2.id};
*data0.y.repeats = {s0, s1_1, One};
*data0.y.strides = {s1_1, One, One};
Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT;
data1.ir_attr.SetIndex(1);
data1.attr.sched.axis = {z0.id, z1.id, z2.id};
*data1.y.axis = {z0.id, z1.id, z2.id};
*data1.y.repeats = {s0, s1_0, One};
*data1.y.strides = {s1_0, One, One};
Load load0("load0");
load0.attr.sched.axis = {z0.id, z1_0.id};
load0.x = data1.y;
*load0.y.axis = {z0.id, z1_0.id};
load0.y.dtype = ge::DT_FLOAT;
*load0.y.repeats = {s0, s1_0};
*load0.y.strides = {s3 * s1_0, s3};
Load load1("load1");
load1.attr.sched.axis = {z0.id, z1_1.id};
load1.x = data0.y;
*load1.y.axis = {z0.id, z1_1.id};
load1.y.dtype = ge::DT_FLOAT;
*load1.y.repeats = {s0, s1_1};
*load1.y.strides = {s4 * s1_1, s4};
af::ascir_op::Concat concat_op("concat");
concat_op.attr.sched.axis = {z0.id, z1.id};
concat_op.x = {load0.y, load1.y};
concat_op.y.dtype = ge::DT_FLOAT;
*concat_op.y.axis = {z0.id, z1.id};
*concat_op.y.repeats = {s0, s1};
*concat_op.y.strides = {s1, af::ops::One};
Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id};
store_op.x = concat_op.y;
*store_op.y.axis = {z0.id, z1.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.repeats = {s0, s1};
*store_op.y.strides = {s1, af::ops::One};
Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), ge::SUCCESS);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 2UL);
for (auto impl_graph : fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs) {
auto load0_remove_pad_0 = impl_graph.FindNode("load0_remove_pad_0");
EXPECT_NE(load0_remove_pad_0, nullptr);
auto load1_remove_pad_0 = impl_graph.FindNode("load1_remove_pad_0");
EXPECT_NE(load1_remove_pad_0, nullptr);
}
}
TEST_F(TestOptimizer, LastAxisStoreWithStride) {
af::AscGraph graph("matmul");
auto s0 = graph.CreateSizeVar(256);
auto z0 = graph.CreateAxis("z0", s0);
Data data0("data0", graph);
data0.ir_attr.SetIndex(0);
Load load0("load0");
load0.attr.sched.axis = {z0.id};
load0.x = data0.y;
*load0.y.axis = {z0.id};
*load0.y.repeats = {s0};
*load0.y.strides = {af::ops::One};
Store store0("store0");
store0.attr.sched.axis = {z0.id};
store0.x = load0.y;
*store0.y.axis = {z0.id};
*store0.y.repeats = {s0};
*store0.y.strides = {af::Symbol(134)};
Output out0("output0");
out0.x = store0.y;
out0.ir_attr.SetIndex(0);
::optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), 0);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1UL);
auto impl_graph = fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0];
std::vector<af::Expression> golden_stride = {af::Symbol(8)};
auto load_node = impl_graph.FindNode("load0");
EXPECT_NE(load_node, nullptr);
EXPECT_EQ(load_node->outputs[0].attr.vectorized_strides, golden_stride);
auto store_node = impl_graph.FindNode("store0");
EXPECT_NE(store_node, nullptr);
EXPECT_EQ(store_node->outputs[0].attr.vectorized_strides, golden_stride);
}
TEST_F(TestOptimizer, AbsAbsAbsAbsTransposeCast) {
af::AscGraph graph("trans_int64");
auto s0 = af::Symbol(3);
auto s1 = af::Symbol(10);
auto s2 = af::Symbol(4);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
std::vector<int64_t> axis_ids = {z2.id, z1.id, z0.id};
Data data0("data0", graph);
data0.y.dtype = ge::DT_FLOAT;
data0.ir_attr.SetIndex(0);
Load load0("load0");
load0.attr.sched.axis = axis_ids;
load0.x = data0.y;
load0.y.dtype = ge::DT_FLOAT;
*load0.y.axis = axis_ids;
*load0.y.repeats = {s2, s1, s0};
*load0.y.strides = {s1 * s0, s0, af::ops::One};
Transpose transpose0("transpose0");
transpose0.attr.sched.axis = axis_ids;
transpose0.x = load0.y;
transpose0.y.dtype = ge::DT_FLOAT;
*transpose0.y.axis = {z0.id, z1.id, z2.id};
*transpose0.y.repeats = {s0, s1, s2};
*transpose0.y.strides = {s1 * s2, s2, af::ops::One};
Cast cast0("cast0");
cast0.attr.sched.axis = axis_ids;
cast0.x = transpose0.y;
cast0.y.dtype = ge::DT_INT64;
*cast0.y.axis = {z0.id, z1.id, z2.id};
*cast0.y.repeats = {s0, s1, s2};
*cast0.y.strides = {s1 * s2, s2, af::ops::One};
Store store0("store0");
store0.attr.sched.axis = axis_ids;
store0.x = cast0.y;
store0.y.dtype = ge::DT_INT64;
*store0.y.axis = {z0.id, z1.id, z2.id};
*store0.y.repeats = {s0, s1, s2};
*store0.y.strides = {s1 * s2, s2, af::ops::One};
Output out0("out0");
out0.x = store0.y;
out0.y.dtype = ge::DT_INT64;
out0.ir_attr.SetIndex(0);
::optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
::ascir::FusedScheduledResult fused_scheduled_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), 0);
EXPECT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
}