* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gtest/gtest.h"
#include "ascendc_ir.h"
#include "ascendc_ir_def.h"
#include "ascir_ops.h"
#define private public
#include "optimize.h"
#include "platform_context.h"
#undef private
#include "ascir_ops_utils.h"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "graph/utils/graph_utils.h"
#include "attribute_group/attr_group_shape_env.h"
#include "fused_graph/fused_graph_unfolder.h"
#include "graph/debug/ge_attr_define.h"
#include "util/mem_utils.h"
#include "runtime_stub.h"
using namespace std;
using namespace ge;
using namespace af::ops;
using namespace af::ascir_op;
namespace {
class VectorFuncSt : public ::testing::Test {
protected:
void SetUp() override {
setenv("DUMP_GE_GRAPH", "2", 1);
ge::PlatformContext::GetInstance().Reset();
auto stub_v2 = std::make_shared<af::RuntimeStubV2>();
RuntimeStub::SetInstance(stub_v2);
}
void TearDown() override {
setenv("DUMP_GE_GRAPH", "0", 1);
RuntimeStub::Reset();
ge::PlatformContext::GetInstance().Reset();
}
optimize::Optimizer optimizer;
VectorFuncSt() : optimizer(optimize::OptimizerOptions{}) {}
};
}
namespace optimize {
TEST_F(VectorFuncSt, vf_partition) {
af::AscGraph graph("brc_abs");
auto s0 = af::Symbol(999);
auto s1 = af::Symbol(10);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data data0("data0", graph);
data0.y.dtype = af::DT_FLOAT;
data0.ir_attr.SetIndex(1);
data0.attr.api.type = af::ApiType::kAPITypeBuffer;
af::ascir_op::Load load0("load0");
load0.x = data0.y;
load0.attr.api.compute_type = af::ComputeType::kComputeLoad;
load0.attr.sched.axis = {z0.id, z1.id};
load0.y.dtype = af::DT_FLOAT;
*load0.y.axis = {z0.id, z1.id};
*load0.y.repeats = {s0, One};
*load0.y.strides = {One, Zero};
*load0.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Broadcast brc("brc");
brc.x = load0.y;
brc.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc.attr.sched.axis = {z0.id, z1.id};
brc.y.dtype = af::DT_FLOAT;
*brc.y.axis = {z0.id, z1.id};
*brc.y.repeats = {s0, s1};
*brc.y.strides = {s1, One};
*brc.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Abs abs("abs");
abs.x = brc.y;
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.attr.sched.axis = {z0.id, z1.id};
abs.y.dtype = af::DT_FLOAT;
*abs.y.axis = {z0.id, z1.id};
*abs.y.repeats = {s0, s1};
*abs.y.strides = {s1, One};
*abs.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Abs abs1("abs1");
abs1.x = brc.y;
abs1.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs1.attr.sched.axis = {z0.id, z1.id};
abs1.y.dtype = af::DT_FLOAT;
*abs1.y.axis = {z0.id, z1.id};
*abs1.y.repeats = {s0, s1};
*abs1.y.strides = {s1, One};
*abs1.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Add add("add");
add.x1 = abs.y;
add.x2 = abs1.y;
add.attr.api.compute_type = af::ComputeType::kComputeElewise;
add.attr.sched.axis = {z0.id, z1.id};
add.y.dtype = af::DT_FLOAT;
*add.y.axis = {z0.id, z1.id};
*add.y.repeats = {s0, s1};
*add.y.strides = {s1, One};
*add.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Store store("store");
store.x = add.y;
store.attr.api.compute_type = af::ComputeType::kComputeStore;
store.attr.sched.axis = {z0.id, z1.id};
store.y.dtype = af::DT_FLOAT;
*store.y.axis = {z0.id, z1.id};
*store.y.repeats = {s0, s1};
*store.y.strides = {s1, One};
*store.y.vectorized_axis = {z0.id, z1.id};
Output out("out");
out.x = store.y;
out.attr.api.compute_type = ComputeType::kComputeInvalid;
out.attr.api.type = af::ApiType::kAPITypeBuffer;
out.y.dtype = af::DT_FLOAT;
out.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, af::SUCCESS);
}
TEST_F(VectorFuncSt, skip_fuse_for_cycle) {
af::AscGraph graph("brc_abs");
auto s0 = af::Symbol(999);
auto s1 = af::Symbol(10);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data data0("data0", graph);
data0.y.dtype = af::DT_FLOAT;
data0.ir_attr.SetIndex(1);
data0.attr.api.type = af::ApiType::kAPITypeBuffer;
af::ascir_op::Load load0("load0");
load0.x = data0.y;
load0.attr.api.compute_type = af::ComputeType::kComputeLoad;
load0.attr.sched.axis = {z0.id, z1.id};
load0.y.dtype = af::DT_FLOAT;
*load0.y.axis = {z0.id, z1.id};
*load0.y.repeats = {s0, One};
*load0.y.strides = {One, Zero};
*load0.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Abs abs("abs");
abs.x = load0.y;
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.attr.sched.axis = {z0.id, z1.id};
abs.y.dtype = af::DT_FLOAT;
*abs.y.axis = {z0.id, z1.id};
*abs.y.repeats = {s0, Zero};
*abs.y.strides = {One, Zero};
*abs.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Broadcast cat("cat");
cat.x = abs.y;
cat.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
cat.attr.sched.axis = {z0.id, z1.id};
cat.y.dtype = af::DT_FLOAT;
*cat.y.axis = {z0.id, z1.id};
*cat.y.repeats = {s0, s1};
*cat.y.strides = {s1, One};
*cat.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Add add("add");
add.x1 = abs.y;
add.x2 = cat.y;
add.attr.api.compute_type = af::ComputeType::kComputeElewise;
add.attr.sched.axis = {z0.id, z1.id};
add.y.dtype = af::DT_FLOAT;
*add.y.axis = {z0.id, z1.id};
*add.y.repeats = {s0, s1};
*add.y.strides = {s1, One};
*add.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Store store("store");
store.x = add.y;
store.attr.api.compute_type = af::ComputeType::kComputeStore;
store.attr.sched.axis = {z0.id, z1.id};
store.y.dtype = af::DT_FLOAT;
*store.y.axis = {z0.id, z1.id};
*store.y.repeats = {s0, s1};
*store.y.strides = {s1, One};
*store.y.vectorized_axis = {z0.id, z1.id};
Output out("out");
out.x = store.y;
out.attr.api.compute_type = ComputeType::kComputeInvalid;
out.attr.api.type = af::ApiType::kAPITypeBuffer;
out.y.dtype = af::DT_FLOAT;
out.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, af::SUCCESS);
}
TEST_F(VectorFuncSt, ResetIOLimit) {
af::AscGraph graph("brc_abs");
auto s0 = af::Symbol(999);
auto s1 = af::Symbol(10);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data data0("data0", graph);
data0.y.dtype = af::DT_FLOAT;
data0.ir_attr.SetIndex(1);
data0.attr.api.type = af::ApiType::kAPITypeBuffer;
af::ascir_op::Load load("load0");
load.x = data0.y;
load.attr.api.compute_type = af::ComputeType::kComputeLoad;
load.attr.sched.axis = {z0.id, z1.id};
load.y.dtype = af::DT_FLOAT;
*load.y.axis = {z0.id, z1.id};
*load.y.repeats = {s0, s1};
*load.y.strides = {s1, One};
*load.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Load load1("load1");
load1.x = data0.y;
load1.attr.api.compute_type = af::ComputeType::kComputeLoad;
load1.attr.sched.axis = {z0.id, z1.id};
load1.y.dtype = af::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id};
*load1.y.repeats = {s0, s1};
*load1.y.strides = {s1, One};
*load1.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Load load2("load2");
load2.x = data0.y;
load2.attr.api.compute_type = af::ComputeType::kComputeLoad;
load2.attr.sched.axis = {z0.id, z1.id};
load2.y.dtype = af::DT_FLOAT;
*load2.y.axis = {z0.id, z1.id};
*load2.y.repeats = {s0, s1};
*load2.y.strides = {s1, One};
*load2.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Load load3("load3");
load3.x = data0.y;
load3.attr.api.compute_type = af::ComputeType::kComputeLoad;
load3.attr.sched.axis = {z0.id, z1.id};
load3.y.dtype = af::DT_FLOAT;
*load3.y.axis = {z0.id, z1.id};
*load3.y.repeats = {s0, s1};
*load3.y.strides = {s1, One};
*load3.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Abs abs("abs");
abs.x = load.y;
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.attr.sched.axis = {z0.id, z1.id};
abs.y.dtype = af::DT_FLOAT;
*abs.y.axis = {z0.id, z1.id};
*abs.y.repeats = {s0, s1};
*abs.y.strides = {s1, One};
*abs.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Abs abs1("abs1");
abs1.x = load1.y;
abs1.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs1.attr.sched.axis = {z0.id, z1.id};
abs1.y.dtype = af::DT_FLOAT;
*abs1.y.axis = {z0.id, z1.id};
*abs1.y.repeats = {s0, s1};
*abs1.y.strides = {s1, One};
*abs1.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Abs abs2("abs2");
abs2.x = load2.y;
abs2.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs2.attr.sched.axis = {z0.id, z1.id};
abs2.y.dtype = af::DT_FLOAT;
*abs2.y.axis = {z0.id, z1.id};
*abs2.y.repeats = {s0, s1};
*abs2.y.strides = {s1, One};
*abs2.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Abs abs3("abs3");
abs3.x = load3.y;
abs3.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs3.attr.sched.axis = {z0.id, z1.id};
abs3.y.dtype = af::DT_FLOAT;
*abs3.y.axis = {z0.id, z1.id};
*abs3.y.repeats = {s0, s1};
*abs3.y.strides = {s1, One};
*abs3.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Add add("add");
add.x1 = abs.y;
add.x2 = abs1.y;
add.attr.api.compute_type = af::ComputeType::kComputeElewise;
add.attr.sched.axis = {z0.id, z1.id};
add.y.dtype = af::DT_FLOAT;
*add.y.axis = {z0.id, z1.id};
*add.y.repeats = {s0, s1};
*add.y.strides = {s1, One};
*add.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Add add1("add1");
add1.x1 = abs2.y;
add1.x2 = abs3.y;
add1.attr.api.compute_type = af::ComputeType::kComputeElewise;
add1.attr.sched.axis = {z0.id, z1.id};
add1.y.dtype = af::DT_FLOAT;
*add1.y.axis = {z0.id, z1.id};
*add1.y.repeats = {s0, s1};
*add1.y.strides = {s1, One};
*add1.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Add add2("add2");
add2.x1 = add1.y;
add2.x2 = add.y;
add2.attr.api.compute_type = af::ComputeType::kComputeElewise;
add2.attr.sched.axis = {z0.id, z1.id};
add2.y.dtype = af::DT_FLOAT;
*add2.y.axis = {z0.id, z1.id};
*add2.y.repeats = {s0, s1};
*add2.y.strides = {s1, One};
*add2.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Store store("store");
store.x = add2.y;
store.attr.api.compute_type = af::ComputeType::kComputeStore;
store.attr.sched.axis = {z0.id, z1.id};
store.y.dtype = af::DT_FLOAT;
*store.y.axis = {z0.id, z1.id};
*store.y.repeats = {s0, s1};
*store.y.strides = {s1, One};
*store.y.vectorized_axis = {z0.id, z1.id};
Output out("out");
out.x = store.y;
out.attr.api.compute_type = ComputeType::kComputeInvalid;
out.attr.api.type = af::ApiType::kAPITypeBuffer;
out.y.dtype = af::DT_FLOAT;
out.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, af::SUCCESS);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1UL);
std::vector<af::AscGraph> asc_graphs;
fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0].GetAllSubGraphs(
asc_graphs);
EXPECT_EQ(asc_graphs.size(), 1UL);
}
TEST_F(VectorFuncSt, cast_bit_with) {
af::AscGraph graph("brc_abs");
auto s0 = af::Symbol(999);
auto s1 = af::Symbol(10);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data data0("data0", graph);
data0.y.dtype = af::DT_FLOAT16;
data0.ir_attr.SetIndex(0);
data0.attr.api.type = af::ApiType::kAPITypeBuffer;
af::ascir_op::Load load0("load0");
load0.x = data0.y;
load0.attr.api.compute_type = af::ComputeType::kComputeLoad;
load0.attr.sched.axis = {z0.id, z1.id};
load0.y.dtype = af::DT_FLOAT16;
*load0.y.axis = {z0.id, z1.id};
*load0.y.repeats = {s0, s1};
*load0.y.strides = {s1, One};
*load0.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Abs abs0("abs0");
abs0.x = load0.y;
abs0.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs0.attr.sched.axis = {z0.id, z1.id};
abs0.y.dtype = af::DT_FLOAT16;
*abs0.y.axis = {z0.id, z1.id};
*abs0.y.repeats = {s0, s1};
*abs0.y.strides = {s1, One};
*abs0.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Cast cast0("cast0");
cast0.x = abs0.y;
cast0.attr.api.compute_type = af::ComputeType::kComputeElewise;
cast0.attr.sched.axis = {z0.id, z1.id};
cast0.y.dtype = af::DT_FLOAT;
*cast0.y.axis = {z0.id, z1.id};
*cast0.y.repeats = {s0, s1};
*cast0.y.strides = {s1, One};
*cast0.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Data data1("data1", graph);
data1.y.dtype = af::DT_FLOAT;
data1.ir_attr.SetIndex(0);
data1.attr.api.type = af::ApiType::kAPITypeBuffer;
af::ascir_op::Load load1("load1");
load1.x = data1.y;
load1.attr.api.compute_type = af::ComputeType::kComputeLoad;
load1.attr.sched.axis = {z0.id, z1.id};
load1.y.dtype = af::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id};
*load1.y.repeats = {s0, s1};
*load1.y.strides = {s1, One};
*load1.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Abs abs1("abs1");
abs1.x = load1.y;
abs1.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs1.attr.sched.axis = {z0.id, z1.id};
abs1.y.dtype = af::DT_FLOAT;
*abs1.y.axis = {z0.id, z1.id};
*abs1.y.repeats = {s0, s1};
*abs1.y.strides = {s1, One};
*abs1.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Cast cast1("cast1");
cast1.x = abs1.y;
cast1.attr.api.compute_type = af::ComputeType::kComputeElewise;
cast1.attr.sched.axis = {z0.id, z1.id};
cast1.y.dtype = af::DT_INT64;
*cast1.y.axis = {z0.id, z1.id};
*cast1.y.repeats = {s0, s1};
*cast1.y.strides = {s1, One};
*cast1.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Cast cast2("cast2");
cast2.x = cast1.y;
cast2.attr.api.compute_type = af::ComputeType::kComputeElewise;
cast2.attr.sched.axis = {z0.id, z1.id};
cast2.y.dtype = af::DT_FLOAT;
*cast2.y.axis = {z0.id, z1.id};
*cast2.y.repeats = {s0, s1};
*cast2.y.strides = {s1, One};
*cast2.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Add add("add");
add.x1 = cast0.y;
add.x2 = cast2.y;
add.attr.api.compute_type = af::ComputeType::kComputeElewise;
add.attr.sched.axis = {z0.id, z1.id};
add.y.dtype = af::DT_FLOAT;
*add.y.axis = {z0.id, z1.id};
*add.y.repeats = {s0, s1};
*add.y.strides = {s1, One};
*add.y.vectorized_axis = {z0.id, z1.id};
af::ascir_op::Store store("store");
store.x = add.y;
store.attr.api.compute_type = af::ComputeType::kComputeStore;
store.attr.sched.axis = {z0.id, z1.id};
store.y.dtype = af::DT_FLOAT;
*store.y.axis = {z0.id, z1.id};
*store.y.repeats = {s0, s1};
*store.y.strides = {s1, One};
*store.y.vectorized_axis = {z0.id, z1.id};
Output out("out");
out.x = store.y;
out.attr.api.compute_type = ComputeType::kComputeInvalid;
out.attr.api.type = af::ApiType::kAPITypeBuffer;
out.y.dtype = af::DT_FLOAT;
out.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, af::SUCCESS);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1UL);
std::vector<af::AscGraph> asc_graphs;
fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0].GetAllSubGraphs(
asc_graphs);
EXPECT_EQ(asc_graphs.size(), 2UL);
}
TEST_F(VectorFuncSt, cycle_bugfix) {
af::AscGraph graph("brc_abs");
auto s0 = af::Symbol(999);
auto z0 = graph.CreateAxis("z0", s0);
af::ascir_op::Data data0("data0", graph);
data0.y.dtype = af::DT_FLOAT;
data0.ir_attr.SetIndex(0);
data0.attr.api.type = af::ApiType::kAPITypeBuffer;
af::ascir_op::Data data1("data1", graph);
data1.y.dtype = af::DT_FLOAT;
data1.ir_attr.SetIndex(1);
data1.attr.api.type = af::ApiType::kAPITypeBuffer;
af::ascir_op::Load load0("load0");
load0.x = data0.y;
load0.attr.api.compute_type = af::ComputeType::kComputeLoad;
load0.attr.sched.axis = {z0.id};
load0.y.dtype = af::DT_FLOAT;
*load0.y.axis = {z0.id};
*load0.y.repeats = {s0};
*load0.y.strides = {One};
af::ascir_op::Load load1("load1");
load1.x = data1.y;
load1.attr.api.compute_type = af::ComputeType::kComputeLoad;
load1.attr.sched.axis = {z0.id};
load1.y.dtype = af::DT_FLOAT;
*load1.y.axis = {z0.id};
*load1.y.repeats = {s0};
*load1.y.strides = {One};
af::ascir_op::Exp exp("exp");
exp.x = load1.y;
exp.attr.api.compute_type = af::ComputeType::kComputeElewise;
exp.attr.sched.axis = {z0.id};
exp.y.dtype = af::DT_FLOAT;
*exp.y.axis = {z0.id};
*exp.y.repeats = {s0};
*exp.y.strides = {One};
af::ascir_op::Sub sub("sub");
sub.x1 = load0.y;
sub.x2 = exp.y;
sub.attr.api.compute_type = af::ComputeType::kComputeElewise;
sub.attr.sched.axis = {z0.id};
sub.y.dtype = af::DT_FLOAT;
*sub.y.axis = {z0.id};
*sub.y.repeats = {s0};
*sub.y.strides = {One};
af::ascir_op::Relu relu("relu");
relu.x = sub.y;
relu.attr.api.compute_type = af::ComputeType::kComputeElewise;
relu.attr.sched.axis = {z0.id};
relu.y.dtype = af::DT_FLOAT;
*relu.y.axis = {z0.id};
*relu.y.repeats = {s0};
*relu.y.strides = {One};
af::ascir_op::Mul mul("mul");
mul.x1 = sub.y;
mul.x2 = exp.y;
mul.attr.api.compute_type = af::ComputeType::kComputeElewise;
mul.attr.sched.axis = {z0.id};
mul.y.dtype = af::DT_FLOAT;
*mul.y.axis = {z0.id};
*mul.y.repeats = {s0};
*mul.y.strides = {One};
af::ascir_op::Abs abs("abs");
abs.x = mul.y;
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.attr.sched.axis = {z0.id};
abs.y.dtype = af::DT_FLOAT;
*abs.y.axis = {z0.id};
*abs.y.repeats = {s0};
*abs.y.strides = {One};
af::ascir_op::Add add("add");
add.x1 = relu.y;
add.x2 = abs.y;
add.attr.api.compute_type = af::ComputeType::kComputeElewise;
add.attr.sched.axis = {z0.id};
add.y.dtype = af::DT_FLOAT;
*add.y.axis = {z0.id};
*add.y.repeats = {s0};
*add.y.strides = {One};
af::ascir_op::Store store("store");
store.x = add.y;
store.attr.api.compute_type = af::ComputeType::kComputeStore;
store.attr.sched.axis = {z0.id};
store.y.dtype = af::DT_FLOAT;
*store.y.axis = {z0.id};
*store.y.repeats = {s0};
*store.y.strides = {One};
Output out("out");
out.x = store.y;
out.attr.api.compute_type = ComputeType::kComputeInvalid;
out.attr.api.type = af::ApiType::kAPITypeBuffer;
out.y.dtype = af::DT_FLOAT;
out.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, af::SUCCESS);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1UL);
std::vector<af::AscGraph> asc_graphs;
fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0].GetAllSubGraphs(
asc_graphs);
EXPECT_EQ(asc_graphs.size(), 1UL);
}
TEST_F(VectorFuncSt, CastHorizontolFusion) {
af::AscGraph graph("cast_graph");
auto s0 = af::Symbol(128);
auto s1 = af::Symbol(64);
auto s2 = af::Symbol(160);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
Data data0("data0", graph);
data0.y.dtype = af::DT_FLOAT16;
data0.ir_attr.SetIndex(0);
Load load0("load0");
load0.x = data0.y;
load0.attr.sched.axis = {z0.id, z1.id, z2.id};
load0.y.dtype = af::DT_FLOAT16;
*load0.y.axis = {z0.id, z1.id, z2.id};
*load0.y.repeats = {One, One, s2};
*load0.y.strides = {Zero, Zero, One};
Cast cast0("cast0");
cast0.x = load0.y;
cast0.attr.sched.axis = {z0.id, z1.id, z2.id};
cast0.y.dtype = af::DT_FLOAT;
*cast0.y.axis = {z0.id, z1.id, z2.id};
*cast0.y.repeats = {One, One, s2};
*cast0.y.strides = {Zero, Zero, One};
Broadcast brc0("brc0");
brc0.x = cast0.y;
brc0.attr.sched.axis = {z0.id, z1.id, z2.id};
brc0.y.dtype = af::DT_FLOAT;
*brc0.y.axis = {z0.id, z1.id, z2.id};
*brc0.y.repeats = {s0, s1, s2};
*brc0.y.strides = {s1 * s2, s2, One};
af::ascir_op::Data data1("data1", graph);
data1.y.dtype = af::DT_FLOAT16;
data1.ir_attr.SetIndex(1);
Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z0.id, z1.id, z2.id};
load1.y.dtype = af::DT_FLOAT16;
*load1.y.axis = {z0.id, z1.id, z2.id};
*load1.y.repeats = {s0, s1, s2};
*load1.y.strides = {s1 * s2, s2, One};
Cast cast1("cast1");
cast1.x = load1.y;
cast1.attr.sched.axis = {z0.id, z1.id, z2.id};
cast1.y.dtype = af::DT_FLOAT;
*cast1.y.axis = {z0.id, z1.id, z2.id};
*cast1.y.repeats = {s0, s1, s2};
*cast1.y.strides = {s1 * s2, s2, One};
Mul mul0("mul0");
mul0.x1 = brc0.y;
mul0.x2 = cast1.y;
mul0.attr.sched.axis = {z0.id, z1.id, z2.id};
mul0.y.dtype = af::DT_FLOAT;
*mul0.y.axis = {z0.id, z1.id, z2.id};
*mul0.y.repeats = {s0, s1, s2};
*mul0.y.strides = {s1 * s2, s2, One};
af::ascir_op::Data data2("data2", graph);
data2.y.dtype = af::DT_FLOAT16;
data2.ir_attr.SetIndex(2);
Load load2("load2");
load2.x = data2.y;
load2.attr.sched.axis = {z0.id, z1.id, z2.id};
load2.y.dtype = af::DT_FLOAT16;
*load2.y.axis = {z0.id, z1.id, z2.id};
*load2.y.repeats = {s0, s1, s2};
*load2.y.strides = {s1 * s2, s2, One};
Cast cast2("cast2");
cast2.x = load2.y;
cast2.attr.sched.axis = {z0.id, z1.id, z2.id};
cast2.y.dtype = af::DT_FLOAT;
*cast2.y.axis = {z0.id, z1.id, z2.id};
*cast2.y.repeats = {s0, s1, s2};
*cast2.y.strides = {s1 * s2, s2, One};
Sub sub0("sub0");
sub0.x1 = mul0.y;
sub0.x2 = cast2.y;
sub0.attr.sched.axis = {z0.id, z1.id, z2.id};
sub0.y.dtype = af::DT_FLOAT;
*sub0.y.axis = {z0.id, z1.id, z2.id};
*sub0.y.repeats = {s0, s1, s2};
*sub0.y.strides = {s1 * s2, s2, One};
af::ascir_op::Data data3("data3", graph);
data3.y.dtype = af::DT_FLOAT16;
data3.ir_attr.SetIndex(3);
Load load3("load3");
load3.x = data3.y;
load3.attr.sched.axis = {z0.id, z1.id, z2.id};
load3.y.dtype = af::DT_FLOAT16;
*load3.y.axis = {z0.id, z1.id, z2.id};
*load3.y.repeats = {One, One, One};
*load3.y.strides = {Zero, Zero, Zero};
Cast cast3("cast3");
cast3.x = load3.y;
cast3.attr.sched.axis = {z0.id, z1.id, z2.id};
cast3.y.dtype = af::DT_FLOAT;
*cast3.y.axis = {z0.id, z1.id, z2.id};
*cast3.y.repeats = {One, One, One};
*cast3.y.strides = {Zero, Zero, Zero};
Broadcast brc3("brc3");
brc3.x = cast3.y;
brc3.attr.sched.axis = {z0.id, z1.id, z2.id};
brc3.y.dtype = af::DT_FLOAT;
*brc3.y.axis = {z0.id, z1.id, z2.id};
*brc3.y.repeats = {s0, s1, s2};
*brc3.y.strides = {s1 * s2, s2, One};
Add add0("add0");
add0.x1 = brc3.y;
add0.x2 = sub0.y;
add0.attr.sched.axis = {z0.id, z1.id, z2.id};
add0.y.dtype = af::DT_FLOAT;
*add0.y.axis = {z0.id, z1.id, z2.id};
*add0.y.repeats = {s0, s1, s2};
*add0.y.strides = {s1 * s2, s2, One};
Store store0("store0");
store0.x = add0.y;
store0.attr.sched.axis = {z0.id, z1.id, z2.id};
store0.y.dtype = af::DT_FLOAT;
*store0.y.axis = {z0.id, z1.id, z2.id};
*store0.y.repeats = {s0, s1, s2};
*store0.y.strides = {s1 * s2, s2, One};
Output out("out");
out.x = store0.y;
out.y.dtype = af::DT_FLOAT;
out.ir_attr.SetIndex(0);
Prod prod0("prod0");
prod0.x = add0.y;
prod0.attr.sched.axis = {z0.id, z1.id, z2.id};
prod0.y.dtype = af::DT_FLOAT;
*prod0.y.axis = {z0.id, z1.id, z2.id};
*prod0.y.repeats = {s0, s1, One};
*prod0.y.strides = {s1, One, Zero};
Store store1("store1");
store1.x = prod0.y;
store1.attr.sched.axis = {z0.id, z1.id, z2.id};
store1.y.dtype = af::DT_FLOAT;
*store1.y.axis = {z0.id, z1.id, z2.id};
*store1.y.repeats = {s0, s1, One};
*store1.y.strides = {s1, One, Zero};
Output out1("out1");
out1.x = store1.y;
out1.y.dtype = af::DT_FLOAT;
out1.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, af::SUCCESS);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 4UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
ASSERT_TRUE(!fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.empty());
std::vector<af::AscGraph> asc_graphs;
fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0].GetAllSubGraphs(
asc_graphs);
EXPECT_EQ(asc_graphs.size(), 1UL);
}
TEST_F(VectorFuncSt, vectorized_not_empty) {
af::AscGraph graph("brc_abs");
auto s0 = af::Symbol("s0");
auto s1 = af::Symbol("s1");
auto s2 = af::Symbol("s2");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
af::ascir_op::Data data0("data0", graph);
data0.y.dtype = af::DT_FLOAT16;
data0.ir_attr.SetIndex(0);
data0.attr.api.type = af::ApiType::kAPITypeBuffer;
af::ascir_op::Load load0("load0");
load0.x = data0.y;
load0.attr.api.compute_type = af::ComputeType::kComputeLoad;
load0.attr.sched.axis = {z0.id, z1.id, z2.id};
load0.y.dtype = af::DT_FLOAT16;
*load0.y.axis = {z0.id, z1.id, z2.id};
*load0.y.repeats = {s0, s1, s2};
*load0.y.strides = {s1 * s2, s2, One};
af::ascir_op::Cast cast0("cast0");
cast0.x = load0.y;
cast0.attr.api.compute_type = af::ComputeType::kComputeElewise;
cast0.attr.sched.axis = {z0.id, z1.id, z2.id};
cast0.y.dtype = af::DT_FLOAT;
*cast0.y.axis = {z0.id, z1.id, z2.id};
*cast0.y.repeats = {s0, s1, s2};
*cast0.y.strides = {s1 * s2, s2, One};
af::ascir_op::Neg neg0("neg0");
neg0.x = cast0.y;
neg0.attr.api.compute_type = af::ComputeType::kComputeElewise;
neg0.attr.sched.axis = {z0.id, z1.id, z2.id};
neg0.y.dtype = af::DT_FLOAT;
*neg0.y.axis = {z0.id, z1.id, z2.id};
*neg0.y.repeats = {s0, s1, s2};
*neg0.y.strides = {s1 * s2, s2, One};
af::ascir_op::Sum sum0("sum0");
sum0.x = neg0.y;
sum0.attr.api.compute_type = af::ComputeType::kComputeReduce;
sum0.attr.sched.axis = {z0.id, z1.id, z2.id};
sum0.y.dtype = af::DT_FLOAT;
*sum0.y.axis = {z0.id, z1.id, z2.id};
*sum0.y.repeats = {s0, s1, One};
*sum0.y.strides = {s1, One, Zero};
af::ascir_op::Cast cast1("cast1");
cast1.x = sum0.y;
cast1.attr.api.compute_type = af::ComputeType::kComputeElewise;
cast1.attr.sched.axis = {z0.id, z1.id, z2.id};
cast1.y.dtype = af::DT_FLOAT16;
*cast1.y.axis = {z0.id, z1.id, z2.id};
*cast1.y.repeats = {s0, s1, One};
*cast1.y.strides = {s1, One, Zero};
af::ascir_op::Store store("store");
store.x = cast1.y;
store.attr.api.compute_type = af::ComputeType::kComputeStore;
store.attr.sched.axis = {z0.id, z1.id, z2.id};
store.y.dtype = af::DT_FLOAT16;
*store.y.axis = {z0.id, z1.id, z2.id};
*store.y.repeats = {s0, s1, One};
*store.y.strides = {s1, One, Zero};
Output out("out");
out.x = store.y;
out.attr.api.compute_type = ComputeType::kComputeInvalid;
out.attr.api.type = af::ApiType::kAPITypeBuffer;
out.y.dtype = af::DT_FLOAT;
out.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, af::SUCCESS);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1UL);
std::vector<af::AscGraph> asc_graphs;
fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0].GetAllSubGraphs(
asc_graphs);
EXPECT_EQ(asc_graphs.size(), 1UL);
}
TEST_F(VectorFuncSt, scalar_brc_fusion) {
af::AscGraph graph("brc_abs");
auto s0 = af::Symbol("s0");
auto s1 = af::Symbol("s1");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data data0("data0", graph);
data0.y.dtype = af::DT_FLOAT;
data0.ir_attr.SetIndex(0);
data0.attr.api.type = af::ApiType::kAPITypeBuffer;
af::ascir_op::Load load0("load0");
load0.x = data0.y;
load0.attr.api.compute_type = af::ComputeType::kComputeLoad;
load0.attr.sched.axis = {z0.id, z1.id};
load0.y.dtype = af::DT_FLOAT;
*load0.y.axis = {z0.id, z1.id};
*load0.y.repeats = {One, One};
*load0.y.strides = {Zero, Zero};
af::ascir_op::Abs abs("abs0");
abs.x = load0.y;
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
abs.attr.sched.axis = {z0.id, z1.id};
abs.y.dtype = af::DT_FLOAT;
*abs.y.axis = {z0.id, z1.id};
*abs.y.repeats = {One, One};
*abs.y.strides = {Zero, Zero};
af::ascir_op::Broadcast brc0("brc0");
brc0.x = abs.y;
brc0.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc0.attr.sched.axis = {z0.id, z1.id};
brc0.y.dtype = af::DT_FLOAT;
*brc0.y.axis = {z0.id, z1.id};
*brc0.y.repeats = {s0, s1};
*brc0.y.strides = {s1, One};
af::ascir_op::Scalar scalar("scalar0", graph);
scalar.y.dtype = af::DT_FLOAT;
scalar.ir_attr.SetValue("0");
scalar.attr.api.type = af::ApiType::kAPITypeBuffer;
*scalar.y.repeats = {One, One};
*scalar.y.strides = {Zero, Zero};
af::ascir_op::Broadcast brc1("brc1");
brc1.x = scalar.y;
brc1.attr.api.compute_type = af::ComputeType::kComputeBroadcast;
brc1.attr.sched.axis = {z0.id, z1.id};
brc1.y.dtype = af::DT_FLOAT;
*brc1.y.axis = {z0.id, z1.id};
*brc1.y.repeats = {s0, s1};
*brc1.y.strides = {s1, One};
af::ascir_op::Add add0("add0");
add0.x1 = brc0.y;
add0.x2 = brc1.y;
add0.attr.api.compute_type = af::ComputeType::kComputeElewise;
add0.attr.sched.axis = {z0.id, z1.id};
add0.y.dtype = af::DT_FLOAT;
*add0.y.axis = {z0.id, z1.id};
*add0.y.repeats = {s0, s1};
*add0.y.strides = {s1, One};
af::ascir_op::Store store("store");
store.x = add0.y;
store.attr.api.compute_type = af::ComputeType::kComputeStore;
store.attr.sched.axis = {z0.id, z1.id};
store.y.dtype = af::DT_FLOAT16;
*store.y.axis = {z0.id, z1.id};
*store.y.repeats = {s0, s1};
*store.y.strides = {s1, One};
Output out("out");
out.x = store.y;
out.attr.api.compute_type = ComputeType::kComputeInvalid;
out.attr.api.type = af::ApiType::kAPITypeBuffer;
out.y.dtype = af::DT_FLOAT;
out.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, af::SUCCESS);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
for (const auto &group : fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups) {
for (const auto &impl_graph : group.impl_graphs) {
auto brc_node0 = impl_graph.FindNode("brc0");
EXPECT_EQ(brc_node0, nullptr);
auto brc_node1 = impl_graph.FindNode("brc1");
EXPECT_EQ(brc_node1, nullptr);
std::vector<af::AscGraph> asc_graphs;
impl_graph.GetAllSubGraphs(asc_graphs);
EXPECT_EQ(asc_graphs.size(), 1UL);
}
}
}
TEST_F(VectorFuncSt, BrcMultiReuse) {
af::AscGraph graph("brc_abs");
auto s0 = af::Symbol(4);
auto s1 = af::Symbol(28);
auto s2 = af::Symbol(28);
auto s3 = af::Symbol(4);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
auto z2 = graph.CreateAxis("z2", s2);
auto z3 = graph.CreateAxis("z3", s3);
Data data0("data0", graph);
data0.y.dtype = af::DT_FLOAT16;
data0.ir_attr.SetIndex(0);
Load load0("load0");
load0.x = data0.y;
load0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
load0.y.dtype = af::DT_FLOAT16;
*load0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*load0.y.repeats = {s0, s1, One, One};
*load0.y.strides = {s1, One, Zero, Zero};
Cast cast0("cast0");
cast0.x = load0.y;
cast0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
cast0.y.dtype = af::DT_FLOAT;
*cast0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*cast0.y.repeats = {s0, s1, One, One};
*cast0.y.strides = {s1, One, Zero, Zero};
Abs abs0("abs0");
abs0.x = cast0.y;
abs0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
abs0.y.dtype = af::DT_FLOAT;
*abs0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs0.y.repeats = {s0, s1, One, One};
*abs0.y.strides = {s1, One, Zero, Zero};
Broadcast brc0("brc0");
brc0.x = abs0.y;
brc0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
brc0.y.dtype = af::DT_FLOAT;
*brc0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*brc0.y.repeats = {s0, s1, s2, s3};
*brc0.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Data data1("data1", graph);
data1.y.dtype = af::DT_FLOAT16;
data1.ir_attr.SetIndex(1);
Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
load1.y.dtype = af::DT_FLOAT16;
*load1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*load1.y.repeats = {s0, s1, s2, s3};
*load1.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Cast cast1("cast1");
cast1.x = load1.y;
cast1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
cast1.y.dtype = af::DT_FLOAT;
*cast1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*cast1.y.repeats = {s0, s1, s2, s3};
*cast1.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Relu relu0("relu0");
relu0.x = cast1.y;
relu0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
relu0.y.dtype = af::DT_FLOAT;
*relu0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*relu0.y.repeats = {s0, s1, s2, s3};
*relu0.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Sub sub0("sub0");
sub0.x1 = relu0.y;
sub0.x2 = brc0.y;
sub0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
sub0.y.dtype = af::DT_FLOAT;
*sub0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*sub0.y.repeats = {s0, s1, s2, s3};
*sub0.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Add add0("add0");
add0.x1 = brc0.y;
add0.x2 = relu0.y;
add0.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
add0.y.dtype = af::DT_FLOAT;
*add0.y.axis = {z0.id, z1.id, z2.id, z3.id};
*add0.y.repeats = {s0, s1, s2, s3};
*add0.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Add add1("add1");
add1.x1 = add0.y;
add1.x2 = sub0.y;
add1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
add1.y.dtype = af::DT_FLOAT;
*add1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*add1.y.repeats = {s0, s1, s2, s3};
*add1.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Abs abs1("abs1");
abs1.x = add1.y;
abs1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
abs1.y.dtype = af::DT_FLOAT;
*abs1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*abs1.y.repeats = {s0, s1, s2, s3};
*abs1.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Relu relu1("relu1");
relu1.x = abs1.y;
relu1.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
relu1.y.dtype = af::DT_FLOAT;
*relu1.y.axis = {z0.id, z1.id, z2.id, z3.id};
*relu1.y.repeats = {s0, s1, s2, s3};
*relu1.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Cast cast2("cast2");
cast2.x = relu1.y;
cast2.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
cast2.y.dtype = af::DT_FLOAT16;
*cast2.y.axis = {z0.id, z1.id, z2.id, z3.id};
*cast2.y.repeats = {s0, s1, s2, s3};
*cast2.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
af::ascir_op::Store store("store");
store.x = cast2.y;
store.attr.sched.axis = {z0.id, z1.id, z2.id, z3.id};
store.y.dtype = af::DT_FLOAT16;
*store.y.axis = {z0.id, z1.id, z2.id, z3.id};
*store.y.repeats = {s0, s1, s2, s3};
*store.y.strides = {s1 * s2 * s3, s2 * s3, s3, One};
Output out("out");
out.x = store.y;
out.y.dtype = af::DT_FLOAT;
out.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, af::SUCCESS);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
std::vector<af::AscGraph> asc_graphs;
fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0].GetAllSubGraphs(
asc_graphs);
EXPECT_EQ(asc_graphs.size(), 2UL);
auto graph1 = fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[1];
std::vector<af::AscGraph> asc_graphs1;
graph1.GetAllSubGraphs(asc_graphs1);
ASSERT_EQ(asc_graphs1.size(), 1UL);
std::set<std::string> supported_types = {"Data", "Output", "VectorFunc", "Load", "Store"};
for (const auto &node : graph1.GetAllNodes()) {
EXPECT_TRUE(supported_types.count(node->GetType()) > 0UL);
}
}
TEST_F(VectorFuncSt, CastNotFusion) {
af::AscGraph graph("brc_abs");
auto s0 = af::Symbol(32);
auto s1 = af::Symbol(60);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
Data data0("data0", graph);
data0.ir_attr.SetIndex(0);
Load load0("load0");
load0.x = data0.y;
load0.attr.sched.axis = {z0.id, z1.id};
load0.y.dtype = af::DT_FLOAT;
*load0.y.axis = {z0.id, z1.id};
*load0.y.repeats = {s0, s1};
*load0.y.strides = {s1, One};
Abs abs0("abs0");
abs0.x = load0.y;
abs0.attr.sched.axis = {z0.id, z1.id};
abs0.y.dtype = af::DT_FLOAT;
*abs0.y.axis = {z0.id, z1.id};
*abs0.y.repeats = {s0, s1};
*abs0.y.strides = {s1, One};
Add add0("add0");
add0.x1 = abs0.y;
add0.x2 = load0.y;
add0.attr.sched.axis = {z0.id, z1.id};
add0.y.dtype = af::DT_FLOAT;
*add0.y.axis = {z0.id, z1.id};
*add0.y.repeats = {s0, s1};
*add0.y.strides = {s1, One};
Cast cast0("cast0");
cast0.x = add0.y;
cast0.attr.sched.axis = {z0.id, z1.id};
cast0.y.dtype = af::DT_INT64;
*cast0.y.axis = {z0.id, z1.id};
*cast0.y.repeats = {s0, s1};
*cast0.y.strides = {s1, One};
Store store("store");
store.x = cast0.y;
store.attr.sched.axis = {z0.id, z1.id};
store.y.dtype = af::DT_INT64;
*store.y.axis = {z0.id, z1.id};
*store.y.repeats = {s0, s1};
*store.y.strides = {s1, One};
Output out("out");
out.x = store.y;
out.y.dtype = af::DT_INT64;
out.ir_attr.SetIndex(0);
Cast cast1("cast1");
cast1.x = abs0.y;
cast1.attr.sched.axis = {z0.id, z1.id};
cast1.y.dtype = af::DT_INT64;
*cast1.y.axis = {z0.id, z1.id};
*cast1.y.repeats = {s0, s1};
*cast1.y.strides = {s1, One};
Store store1("store1");
store1.x = cast1.y;
store1.attr.sched.axis = {z0.id, z1.id};
store1.y.dtype = af::DT_INT64;
*store1.y.axis = {z0.id, z1.id};
*store1.y.repeats = {s0, s1};
*store1.y.strides = {s1, One};
Output out1("out1");
out1.x = store1.y;
out1.y.dtype = af::DT_INT64;
out1.ir_attr.SetIndex(1);
Scalar scalar("scalar0", graph);
scalar.y.dtype = af::DT_INT64;
Broadcast brc3("brc3");
brc3.x = scalar.y;
brc3.attr.sched.axis = {z0.id, z1.id};
brc3.y.dtype = af::DT_INT64;
*brc3.y.axis = {z0.id, z1.id};
*brc3.y.repeats = {s0, s1};
*brc3.y.strides = {s1, One};
Data data1("data1", graph);
data1.ir_attr.SetIndex(1);
Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z0.id, z1.id};
load1.y.dtype = af::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id};
*load1.y.repeats = {One, s1};
*load1.y.strides = {Zero, One};
Broadcast brc4("brc4");
brc4.x = load1.y;
brc4.attr.sched.axis = {z0.id, z1.id};
brc4.y.dtype = af::DT_FLOAT;
*brc4.y.axis = {z0.id, z1.id};
*brc4.y.repeats = {s0, s1};
*brc4.y.strides = {s1, One};
Sign sign0("sign0");
sign0.x = brc4.y;
sign0.attr.sched.axis = {z0.id, z1.id};
sign0.y.dtype = af::DT_FLOAT;
*sign0.y.axis = {z0.id, z1.id};
*sign0.y.repeats = {s0, s1};
*sign0.y.strides = {s1, One};
Mul mul0("mul0");
mul0.x1 = sign0.y;
mul0.x2 = sign0.y;
mul0.attr.sched.axis = {z0.id, z1.id};
mul0.y.dtype = af::DT_FLOAT;
*mul0.y.axis = {z0.id, z1.id};
*mul0.y.repeats = {s0, s1};
*mul0.y.strides = {s1, One};
Abs abs1("Abs1");
abs1.x = mul0.y;
abs1.attr.sched.axis = {z0.id, z1.id};
abs1.y.dtype = af::DT_FLOAT;
*abs1.y.axis = {z0.id, z1.id};
*abs1.y.repeats = {s0, s1};
*abs1.y.strides = {s1, One};
Gt gt0("gt0");
gt0.x1 = abs0.y;
gt0.x2 = abs1.y;
gt0.attr.sched.axis = {z0.id, z1.id};
gt0.y.dtype = af::DT_FLOAT;
*gt0.y.axis = {z0.id, z1.id};
*gt0.y.repeats = {s0, s1};
*gt0.y.strides = {s1, One};
Where where0("where0");
where0.x1 = gt0.y;
where0.x2 = cast1.y;
where0.x3 = brc3.y;
where0.attr.sched.axis = {z0.id, z1.id};
where0.y.dtype = af::DT_INT64;
*where0.y.axis = {z0.id, z1.id};
*where0.y.repeats = {s0, s1};
*where0.y.strides = {s1, One};
Store store2("store2");
store2.x = where0.y;
store2.attr.sched.axis = {z0.id, z1.id};
store2.y.dtype = af::DT_INT64;
*store2.y.axis = {z0.id, z1.id};
*store2.y.repeats = {s0, s1};
*store2.y.strides = {s1, One};
Output out2("out2");
out2.x = store2.y;
out2.y.dtype = af::DT_INT64;
out2.ir_attr.SetIndex(2);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, af::SUCCESS);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
std::vector<af::AscGraph> asc_graphs;
fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0].GetAllSubGraphs(
asc_graphs);
EXPECT_EQ(asc_graphs.size(), 2UL);
auto graph1 = fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[1];
std::vector<af::AscGraph> asc_graphs1;
graph1.GetAllSubGraphs(asc_graphs1);
ASSERT_EQ(asc_graphs1.size(), 2UL);
size_t cast_num = 0UL;
for (const auto &node : graph1.GetAllNodes()) {
if (node->GetType() == "Cast") {
++cast_num;
}
}
EXPECT_EQ(cast_num, 1UL);
}
TEST_F(VectorFuncSt, MaximumNotFusion) {
af::AscGraph graph("brc_abs");
auto s1 = af::Symbol(60);
auto z1 = graph.CreateAxis("z1", s1);
Data data0("data0", graph);
data0.ir_attr.SetIndex(0);
Load load0("load0");
load0.x = data0.y;
load0.attr.sched.axis = {z1.id};
load0.y.dtype = af::DT_FLOAT;
*load0.y.axis = {z1.id};
*load0.y.repeats = {s1};
*load0.y.strides = {One};
Abs abs0("abs0");
abs0.x = load0.y;
abs0.attr.sched.axis = {z1.id};
abs0.y.dtype = af::DT_FLOAT;
*abs0.y.axis = {z1.id};
*abs0.y.repeats = {s1};
*abs0.y.strides = {One};
Add add0("add0");
add0.x1 = abs0.y;
add0.x2 = load0.y;
add0.attr.sched.axis = {z1.id};
add0.y.dtype = af::DT_FLOAT;
*add0.y.axis = {z1.id};
*add0.y.repeats = {s1};
*add0.y.strides = {One};
af::ascir_op::Scalar scalar0("scalar0", graph);
Maximum maximum("maximum");
maximum.x1 = add0.y;
maximum.x2 = scalar0.y;
maximum.attr.sched.axis = {z1.id};
maximum.y.dtype = af::DT_FLOAT;
*maximum.y.axis = {z1.id};
*maximum.y.repeats = {s1};
*maximum.y.strides = {One};
Abs abs1("abs1");
abs1.x = maximum.y;
abs1.attr.sched.axis = {z1.id};
abs1.y.dtype = af::DT_FLOAT;
*abs1.y.axis = {z1.id};
*abs1.y.repeats = {s1};
*abs1.y.strides = {One};
Store store("store");
store.x = abs1.y;
store.attr.sched.axis = {z1.id};
store.y.dtype = af::DT_FLOAT;
*store.y.axis = {z1.id};
*store.y.repeats = {s1};
*store.y.strides = {One};
Output out("out");
out.x = store.y;
out.y.dtype = af::DT_FLOAT;
out.ir_attr.SetIndex(0);
::ascir::FusedScheduledResult fused_scheduled_result;
Status res = optimizer.Optimize(graph, fused_scheduled_result);
EXPECT_EQ(res, af::SUCCESS);
ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
std::vector<af::AscGraph> asc_graphs;
fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0].GetAllSubGraphs(
asc_graphs);
EXPECT_EQ(asc_graphs.size(), 1UL);
auto graph1 = fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs[0];
std::vector<af::AscGraph> asc_graphs1;
graph1.GetAllSubGraphs(asc_graphs1);
ASSERT_EQ(asc_graphs1.size(), 1UL);
auto scalar_node = graph.FindNode("scalar0");
ASSERT_NE(scalar_node, nullptr);
auto maximum_node = graph.FindNode("maximum");
ASSERT_NE(maximum_node, nullptr);
auto abs1_node = graph.FindNode("abs1");
ASSERT_NE(abs1_node, nullptr);
auto abs_node = asc_graphs1[0].FindNode("abs0");
ASSERT_NE(abs_node, nullptr);
auto add_node = asc_graphs1[0].FindNode("add0");
ASSERT_NE(add_node, nullptr);
}
}