* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gtest/gtest.h"
#include "autofuse_config/auto_fuse_config.h"
#define private public
#include "codegen_tiling.h"
#include "ascir_ops.h"
#include "ascir_ops_utils.h"
#include "schedule_result.h"
#include "runtime_stub.h"
#include "platform_context.h"
#include "ascgraph_info_complete.h"
#include "optimize/optimize.h"
#include <fstream>
#include <filesystem>
namespace {
std::pair<int, std::string> execute_command(const std::string& command) {
std::array<char, 128> buffer;
std::string output;
std::unique_ptr<FILE, decltype(&pclose)> pipe(popen(command.c_str(), "r"), pclose);
if (!pipe) {
throw std::runtime_error("Failed to open pipe");
}
while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) {
output += buffer.data();
}
return {WEXITSTATUS(pclose(pipe.release())), output};
}
bool CompileCode(const std::string &code){
std::string cmake_dir = CMAKE_BINARY_DIR;
std::string temp_dir = cmake_dir + "/tests/ut/temp_compile_codegen_tiling";
std::filesystem::remove_all(temp_dir);
std::filesystem::create_directories(temp_dir);
std::string source_file = temp_dir + "/temp_codegen_infershape.cpp";
std::ofstream source_stream(source_file);
source_stream << code << R"(
int main() {
return 0;
}
)";
source_stream.close();
std::string ascend_install_path = ASCEND_INSTALL_PATH;
std::string include_path = "-I" + ascend_install_path + "/include/ ";
std::string link_path = "-L" + ascend_install_path + "/lib64";
std::string compile_command = "g++ -std=c++17 " + include_path + " " + link_path + " " + source_file + " -lc_sec";
auto [compile_exit_code, compile_output] = execute_command(compile_command);
std::filesystem::remove_all(temp_dir);
return compile_exit_code == 0;
}
}
namespace {
static void CreateElemwiseGraphWithRelu(af::AscGraph &graph) {
auto n = graph.CreateSizeVar(1);
auto c = graph.CreateSizeVar(64);
auto h = graph.CreateSizeVar(56);
auto w = graph.CreateSizeVar(56);
auto z_n = graph.CreateAxis("z_n", n);
auto z_c = graph.CreateAxis("z_c", c);
auto z_h = graph.CreateAxis("z_h", h);
auto z_w = graph.CreateAxis("z_w", w);
af::ascir_op::Data data0("data0", graph);
data0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data0.y.dtype = ge::DT_FLOAT;
*data0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0.y.strides = {c*h*w, h*w, w, af::ops::One};
*data0.y.repeats = {n, c, h, w};
data0.ir_attr.SetIndex(0);
af::ascir_op::Load load0("load0");
load0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load0.x = data0.y;
*load0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load0.y.dtype = ge::DT_FLOAT;
*load0.y.strides = {c*h*w, h*w, w, af::ops::One};
*load0.y.repeats = {n, c, h, w};
af::ascir_op::Relu relu("relu");
graph.AddNode(relu);
relu.x = load0.y;
relu.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
relu.y.dtype = ge::DT_FLOAT;
*relu.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*relu.y.repeats = {n, c, h, w};
*relu.y.strides = {c*h*w, h*w, w, af::ops::One};
relu.attr.api.compute_type = af::ComputeType::kComputeElewise;
af::ascir_op::Store store_op("store");
store_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
store_op.x = relu.y;
*store_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {c*h*w, h*w, w, af::ops::One};
*store_op.y.repeats = {n, c, h, w};
af::ascir_op::Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto x1Local = graph.FindNode("data0");
x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
}
static void VerifyTilingCodeBasic(const std::map<std::string, std::string> &res) {
auto pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" int64_t FindBestTilingKey");
ASSERT_NE(pos, std::string::npos);
auto static_shape_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" bool AutofuseIsStaticShape() {\n return true;");
ASSERT_NE(static_shape_pos, std::string::npos);
auto tiling_func_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus TilingFunc");
ASSERT_NE(tiling_func_pos, std::string::npos);
auto get_size_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" size_t GetTilingDataSize()");
ASSERT_NE(get_size_pos, std::string::npos);
auto tiling_data_pos = res.at("tiling_def_and_tiling_const").find("AutofuseTilingData");
ASSERT_NE(tiling_data_pos, std::string::npos);
}
static void CreateConv2DOffsetBiasGraph(af::AscGraph &conv2d_offset_bias_graph) {
auto n_ob = conv2d_offset_bias_graph.CreateSizeVar(1);
auto c_ob = conv2d_offset_bias_graph.CreateSizeVar(64);
auto h_ob = conv2d_offset_bias_graph.CreateSizeVar(56);
auto w_ob = conv2d_offset_bias_graph.CreateSizeVar(56);
auto z_n_ob = conv2d_offset_bias_graph.CreateAxis("z_n", n_ob);
auto z_c_ob = conv2d_offset_bias_graph.CreateAxis("z_c", c_ob);
auto z_h_ob = conv2d_offset_bias_graph.CreateAxis("z_h", h_ob);
auto z_w_ob = conv2d_offset_bias_graph.CreateAxis("z_w", w_ob);
af::ascir_op::Data data0_ob("data0", conv2d_offset_bias_graph);
data0_ob.attr.sched.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
data0_ob.y.dtype = ge::DT_FLOAT16;
*data0_ob.y.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
data0_ob.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0_ob.y.strides = {c_ob*h_ob*w_ob, h_ob*w_ob, w_ob, af::ops::One};
*data0_ob.y.repeats = {n_ob, c_ob, h_ob, w_ob};
data0_ob.ir_attr.SetIndex(0);
af::ascir_op::Load load0_ob("load0");
load0_ob.attr.sched.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
load0_ob.x = data0_ob.y;
*load0_ob.y.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
load0_ob.y.dtype = ge::DT_FLOAT16;
*load0_ob.y.strides = {c_ob*h_ob*w_ob, h_ob*w_ob, w_ob, af::ops::One};
*load0_ob.y.repeats = {n_ob, c_ob, h_ob, w_ob};
af::ascir_op::Data data1_ob("data1", conv2d_offset_bias_graph);
data1_ob.y.dtype = ge::DT_FLOAT16;
data1_ob.attr.sched.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
*data1_ob.y.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
data1_ob.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data1_ob.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
*data1_ob.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
data1_ob.ir_attr.SetIndex(1);
af::ascir_op::Load load1_ob("load1");
load1_ob.x = data1_ob.y;
load1_ob.attr.sched.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
load1_ob.y.dtype = ge::DT_FLOAT16;
*load1_ob.y.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
*load1_ob.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
*load1_ob.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
af::ascir_op::Data data2_ob("data2", conv2d_offset_bias_graph);
data2_ob.y.dtype = ge::DT_FLOAT;
data2_ob.attr.sched.axis = {z_c_ob.id};
*data2_ob.y.axis = {z_c_ob.id};
data2_ob.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data2_ob.y.repeats = {c_ob};
*data2_ob.y.strides = {af::ops::One};
data2_ob.ir_attr.SetIndex(2);
af::ascir_op::Load load2_ob("load2");
load2_ob.x = data2_ob.y;
load2_ob.attr.sched.axis = {z_c_ob.id};
load2_ob.y.dtype = ge::DT_FLOAT;
*load2_ob.y.axis = {z_c_ob.id};
*load2_ob.y.strides = {af::ops::One};
*load2_ob.y.repeats = {c_ob};
af::ascir_op::Data data3_ob("data3", conv2d_offset_bias_graph);
data3_ob.y.dtype = ge::DT_FLOAT16;
data3_ob.attr.sched.axis = {z_c_ob.id};
*data3_ob.y.axis = {z_c_ob.id};
data3_ob.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data3_ob.y.repeats = {c_ob};
*data3_ob.y.strides = {af::ops::One};
data3_ob.ir_attr.SetIndex(3);
af::ascir_op::Load load3_ob("load3");
load3_ob.x = data3_ob.y;
load3_ob.attr.sched.axis = {z_c_ob.id};
load3_ob.y.dtype = ge::DT_FLOAT16;
*load3_ob.y.axis = {z_c_ob.id};
*load3_ob.y.strides = {af::ops::One};
*load3_ob.y.repeats = {c_ob};
af::ascir_op::Conv2DOffsetBias conv2d_offset_bias("conv2d_offset_bias");
conv2d_offset_bias.attr.sched.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
conv2d_offset_bias.x = load0_ob.y;
conv2d_offset_bias.filter = load1_ob.y;
conv2d_offset_bias.bias = load2_ob.y;
conv2d_offset_bias.offset_w = load3_ob.y;
conv2d_offset_bias.y.dtype = ge::DT_FLOAT;
*conv2d_offset_bias.y.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
*conv2d_offset_bias.y.repeats = {n_ob, c_ob, h_ob, w_ob};
*conv2d_offset_bias.y.strides = {c_ob*h_ob*w_ob, h_ob*w_ob, w_ob, af::ops::One};
conv2d_offset_bias.attr.api.compute_type = af::ComputeType::kComputeCube;
conv2d_offset_bias.ir_attr.SetStrides({1, 1});
conv2d_offset_bias.ir_attr.SetPads({1, 1, 1, 1});
conv2d_offset_bias.ir_attr.SetDilations({1, 1});
conv2d_offset_bias.ir_attr.SetGroups(1);
conv2d_offset_bias.ir_attr.SetData_format("NCHW");
conv2d_offset_bias.ir_attr.SetOffset_x(0);
conv2d_offset_bias.ir_attr.SetEnable_hf32(false);
af::ascir_op::Store store_ob("store");
store_ob.attr.sched.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
store_ob.x = conv2d_offset_bias.y;
*store_ob.y.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
store_ob.y.dtype = ge::DT_FLOAT;
*store_ob.y.strides = {c_ob*h_ob*w_ob, h_ob*w_ob, w_ob, af::ops::One};
*store_ob.y.repeats = {n_ob, c_ob, h_ob, w_ob};
af::ascir_op::Output output_ob("output");
output_ob.x = store_ob.y;
output_ob.y.dtype = ge::DT_FLOAT;
output_ob.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(conv2d_offset_bias_graph);
}
static void CreateElemwiseGraphWithReluDynamic(af::AscGraph &graph) {
auto n = graph.CreateSizeVar("n");
auto c = graph.CreateSizeVar("c");
auto h = graph.CreateSizeVar("h");
auto w = graph.CreateSizeVar("w");
auto z_n = graph.CreateAxis("z_n", n);
auto z_c = graph.CreateAxis("z_c", c);
auto z_h = graph.CreateAxis("z_h", h);
auto z_w = graph.CreateAxis("z_w", w);
af::ascir_op::Data data0("data0", graph);
data0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data0.y.dtype = ge::DT_FLOAT;
*data0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0.y.strides = {c*h*w, h*w, w, af::ops::One};
*data0.y.repeats = {n, c, h, w};
data0.ir_attr.SetIndex(0);
af::ascir_op::Load load0("load0");
load0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load0.x = data0.y;
*load0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load0.y.dtype = ge::DT_FLOAT;
*load0.y.strides = {c*h*w, h*w, w, af::ops::One};
*load0.y.repeats = {n, c, h, w};
af::ascir_op::Relu relu("relu");
graph.AddNode(relu);
relu.x = load0.y;
relu.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
relu.y.dtype = ge::DT_FLOAT;
*relu.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*relu.y.repeats = {n, c, h, w};
*relu.y.strides = {c*h*w, h*w, w, af::ops::One};
relu.attr.api.compute_type = af::ComputeType::kComputeElewise;
af::ascir_op::Store store_op("store");
store_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
store_op.x = relu.y;
*store_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {c*h*w, h*w, w, af::ops::One};
*store_op.y.repeats = {n, c, h, w};
af::ascir_op::Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto x1Local = graph.FindNode("data0");
x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
}
static void CreateMatmulElemwiseDynamicGraph(af::AscGraph &graph) {
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data data0("data0", graph);
data0.attr.sched.axis = {z0.id, z1.id};
data0.y.dtype = ge::DT_FLOAT;
*data0.y.axis = {z0.id, z1.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0.y.strides = {s1 ,af::ops::One};
*data0.y.repeats = {s0, s1};
data0.ir_attr.SetIndex(0);
af::ascir_op::Load load0("load0");
load0.attr.sched.axis = {z0.id, z1.id};
load0.x = data0.y;
*load0.y.axis = {z0.id, z1.id};
load0.y.dtype = ge::DT_FLOAT;
*load0.y.strides = {s1 ,af::ops::One};
*load0.y.repeats = {s0, s1};
af::ascir_op::Abs abs("abs");
graph.AddNode(abs);
abs.x = load0.y;
abs.attr.sched.axis = {z0.id, z1.id};
abs.y.dtype = ge::DT_FLOAT;
*abs.y.axis = {z0.id, z1.id};
*abs.y.repeats = {s0, s1};
*abs.y.strides = {s1, af::ops::One};
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
af::ascir_op::Scalar scalar0("scalar0", graph);
scalar0.attr.sched.axis = {z0.id, z1.id};
scalar0.ir_attr.SetValue("0");
scalar0.y.dtype = ge::DT_FLOAT;
*scalar0.y.axis = {z0.id, z1.id};
*scalar0.y.repeats = {af::ops::One, af::ops::One};
*scalar0.y.strides = {af::ops::Zero, af::ops::Zero};
af::ascir_op::Broadcast broadcast0("broadcast0");
broadcast0.x = scalar0.y;
broadcast0.attr.sched.axis = {z0.id, z1.id};
*broadcast0.y.axis = {z0.id, z1.id};
broadcast0.y.dtype = ge::DT_FLOAT;
*broadcast0.y.repeats = {af::ops::One, s1};
*broadcast0.y.strides = {af::ops::Zero, af::ops::One};
af::ascir_op::Broadcast broadcast1("broadcast1");
broadcast1.x = broadcast0.y;
broadcast1.attr.sched.axis = {z0.id, z1.id};
*broadcast1.y.axis = {z0.id, z1.id};
broadcast1.y.dtype = ge::DT_FLOAT;
*broadcast1.y.repeats = {s0, s1};
*broadcast1.y.strides = {s1, af::ops::One};
af::ascir_op::Add add_op("add");
add_op.attr.sched.axis = {z0.id, z1.id};
add_op.x1 = abs.y;
add_op.x2 = broadcast1.y;
add_op.y.dtype = ge::DT_FLOAT;
*add_op.y.axis = {z0.id, z1.id};
*add_op.y.repeats = {s0, s1};
*add_op.y.strides = {s1, af::ops::One};
af::ascir_op::Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT;
data1.attr.sched.axis = {z0.id, z1.id};
*data1.y.axis = {z0.id, z1.id};
data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data1.y.repeats = {af::ops::One, af::ops::One};
*data1.y.strides = {af::ops::Zero, af::ops::Zero};
data1.ir_attr.SetIndex(1);
af::ascir_op::Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z0.id, z1.id};
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id};
*load1.y.strides = {af::ops::Zero, af::ops::Zero};
*load1.y.repeats = {af::ops::One, af::ops::One};
af::ascir_op::Broadcast broadcast2("broadcast2");
broadcast2.x = load1.y;
broadcast2.attr.sched.axis = {z0.id, z1.id};
*broadcast2.y.axis = {z0.id, z1.id};
broadcast2.y.dtype = ge::DT_FLOAT;
*broadcast2.y.repeats = {af::ops::One, s1};
*broadcast2.y.strides = {af::ops::Zero, af::ops::One};
af::ascir_op::Broadcast broadcast3("broadcast3");
broadcast3.x = broadcast2.y;
broadcast3.attr.sched.axis = {z0.id, z1.id};
*broadcast3.y.axis = {z0.id, z1.id};
broadcast3.y.dtype = ge::DT_FLOAT;
*broadcast3.y.repeats = {s0, s1};
*broadcast3.y.strides = {s1, af::ops::One};
af::ascir_op::Mul mul("mul");
mul.attr.sched.axis = {z0.id, z1.id};
mul.x1 = add_op.y;
mul.x2 = broadcast3.y;
mul.y.dtype = ge::DT_FLOAT;
*mul.y.axis = {z0.id, z1.id};
*mul.y.repeats = {s0, s1};
*mul.y.strides = {s1, af::ops::One};
af::ascir_op::Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id};
store_op.x = mul.y;
*store_op.y.axis = {z0.id, z1.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {s1 ,af::ops::One};
*store_op.y.repeats = {s0, s1};
af::ascir_op::Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto x1Local = graph.FindNode("data0");
x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
}
static void CreateElemwiseGraphWithMulDynamic(af::AscGraph &graph) {
auto n = graph.CreateSizeVar("n");
auto c = graph.CreateSizeVar("c");
auto h = graph.CreateSizeVar("h");
auto w = graph.CreateSizeVar("w");
auto z_n = graph.CreateAxis("z_n", n);
auto z_c = graph.CreateAxis("z_c", c);
auto z_h = graph.CreateAxis("z_h", h);
auto z_w = graph.CreateAxis("z_w", w);
af::ascir_op::Data data0("data0", graph);
data0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data0.y.dtype = ge::DT_FLOAT;
*data0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0.y.strides = {c*h*w, h*w, w, af::ops::One};
*data0.y.repeats = {n, c, h, w};
data0.ir_attr.SetIndex(0);
af::ascir_op::Load load0("load0");
load0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load0.x = data0.y;
*load0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load0.y.dtype = ge::DT_FLOAT;
*load0.y.strides = {c*h*w, h*w, w, af::ops::One};
*load0.y.repeats = {n, c, h, w};
af::ascir_op::Scalar scalar("scalar", graph);
scalar.ir_attr.SetValue("2.0");
scalar.y.dtype = ge::DT_FLOAT;
*scalar.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*scalar.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
*scalar.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
af::ascir_op::Broadcast broadcast("broadcast");
broadcast.x = scalar.y;
broadcast.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*broadcast.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
broadcast.y.dtype = ge::DT_FLOAT;
*broadcast.y.repeats = {n, c, h, w};
*broadcast.y.strides = {c*h*w, h*w, w, af::ops::One};
af::ascir_op::Mul mul("mul");
mul.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
mul.x1 = load0.y;
mul.x2 = broadcast.y;
mul.y.dtype = ge::DT_FLOAT;
*mul.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*mul.y.repeats = {n, c, h, w};
*mul.y.strides = {c*h*w, h*w, w, af::ops::One};
mul.attr.api.compute_type = af::ComputeType::kComputeElewise;
af::ascir_op::Store store_op("store");
store_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
store_op.x = mul.y;
*store_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {c*h*w, h*w, w, af::ops::One};
*store_op.y.repeats = {n, c, h, w};
af::ascir_op::Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto x1Local = graph.FindNode("data0");
x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
}
static void CreateElemwiseGraphWithAbsAndAddStatic(af::AscGraph &graph) {
auto n = graph.CreateSizeVar(1);
auto c = graph.CreateSizeVar(64);
auto h = graph.CreateSizeVar(56);
auto w = graph.CreateSizeVar(56);
auto z_n = graph.CreateAxis("z_n", n);
auto z_c = graph.CreateAxis("z_c", c);
auto z_h = graph.CreateAxis("z_h", h);
auto z_w = graph.CreateAxis("z_w", w);
af::ascir_op::Data data0("data0", graph);
data0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data0.y.dtype = ge::DT_FLOAT;
*data0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0.y.strides = {c*h*w, h*w, w, af::ops::One};
*data0.y.repeats = {n, c, h, w};
data0.ir_attr.SetIndex(0);
af::ascir_op::Load load0("load0");
load0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load0.x = data0.y;
*load0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load0.y.dtype = ge::DT_FLOAT;
*load0.y.strides = {c*h*w, h*w, w, af::ops::One};
*load0.y.repeats = {n, c, h, w};
af::ascir_op::Abs abs("abs");
graph.AddNode(abs);
abs.x = load0.y;
abs.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
abs.y.dtype = ge::DT_FLOAT;
*abs.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*abs.y.repeats = {n, c, h, w};
*abs.y.strides = {c*h*w, h*w, w, af::ops::One};
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
af::ascir_op::Scalar scalar0("scalar0", graph);
scalar0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
scalar0.ir_attr.SetValue("0.1");
scalar0.y.dtype = ge::DT_FLOAT;
*scalar0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*scalar0.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
*scalar0.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
af::ascir_op::Broadcast broadcast0("broadcast0");
broadcast0.x = scalar0.y;
broadcast0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*broadcast0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
broadcast0.y.dtype = ge::DT_FLOAT;
*broadcast0.y.repeats = {n, c, h, w};
*broadcast0.y.strides = {c*h*w, h*w, w, af::ops::One};
af::ascir_op::Add add_op("add");
add_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
add_op.x1 = abs.y;
add_op.x2 = broadcast0.y;
add_op.y.dtype = ge::DT_FLOAT;
*add_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*add_op.y.repeats = {n, c, h, w};
*add_op.y.strides = {c*h*w, h*w, w, af::ops::One};
af::ascir_op::Store store_op("store");
store_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
store_op.x = add_op.y;
*store_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {c*h*w, h*w, w, af::ops::One};
*store_op.y.repeats = {n, c, h, w};
af::ascir_op::Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto x1Local = graph.FindNode("data0");
x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
}
static void CreateConv2DOffsetGraph(af::AscGraph &conv2d_offset_graph) {
auto n_o = conv2d_offset_graph.CreateSizeVar(1);
auto c_o = conv2d_offset_graph.CreateSizeVar(64);
auto h_o = conv2d_offset_graph.CreateSizeVar(56);
auto w_o = conv2d_offset_graph.CreateSizeVar(56);
auto z_n_o = conv2d_offset_graph.CreateAxis("z_n", n_o);
auto z_c_o = conv2d_offset_graph.CreateAxis("z_c", c_o);
auto z_h_o = conv2d_offset_graph.CreateAxis("z_h", h_o);
auto z_w_o = conv2d_offset_graph.CreateAxis("z_w", w_o);
af::ascir_op::Data data0_o("data0", conv2d_offset_graph);
data0_o.attr.sched.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
data0_o.y.dtype = ge::DT_FLOAT16;
*data0_o.y.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
data0_o.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0_o.y.strides = {c_o*h_o*w_o, h_o*w_o, w_o, af::ops::One};
*data0_o.y.repeats = {n_o, c_o, h_o, w_o};
data0_o.ir_attr.SetIndex(0);
af::ascir_op::Load load0_o("load0");
load0_o.attr.sched.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
load0_o.x = data0_o.y;
*load0_o.y.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
load0_o.y.dtype = ge::DT_FLOAT16;
*load0_o.y.strides = {c_o*h_o*w_o, h_o*w_o, w_o, af::ops::One};
*load0_o.y.repeats = {n_o, c_o, h_o, w_o};
af::ascir_op::Data data1_o("data1", conv2d_offset_graph);
data1_o.y.dtype = ge::DT_FLOAT16;
data1_o.attr.sched.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
*data1_o.y.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
data1_o.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data1_o.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
*data1_o.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
data1_o.ir_attr.SetIndex(1);
af::ascir_op::Load load1_o("load1");
load1_o.x = data1_o.y;
load1_o.attr.sched.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
load1_o.y.dtype = ge::DT_FLOAT16;
*load1_o.y.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
*load1_o.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
*load1_o.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
af::ascir_op::Data data2_o("data2", conv2d_offset_graph);
data2_o.y.dtype = ge::DT_FLOAT16;
data2_o.attr.sched.axis = {z_c_o.id};
*data2_o.y.axis = {z_c_o.id};
data2_o.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data2_o.y.repeats = {c_o};
*data2_o.y.strides = {af::ops::One};
data2_o.ir_attr.SetIndex(2);
af::ascir_op::Load load2_o("load2");
load2_o.x = data2_o.y;
load2_o.attr.sched.axis = {z_c_o.id};
load2_o.y.dtype = ge::DT_FLOAT16;
*load2_o.y.axis = {z_c_o.id};
*load2_o.y.strides = {af::ops::One};
*load2_o.y.repeats = {c_o};
af::ascir_op::Conv2DOffset conv2d_offset("conv2d_offset");
conv2d_offset.attr.sched.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
conv2d_offset.x = load0_o.y;
conv2d_offset.filter = load1_o.y;
conv2d_offset.offset_w = load2_o.y;
conv2d_offset.y.dtype = ge::DT_FLOAT;
*conv2d_offset.y.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
*conv2d_offset.y.repeats = {n_o, c_o, h_o, w_o};
*conv2d_offset.y.strides = {c_o*h_o*w_o, h_o*w_o, w_o, af::ops::One};
conv2d_offset.attr.api.compute_type = af::ComputeType::kComputeCube;
conv2d_offset.ir_attr.SetStrides({1, 1});
conv2d_offset.ir_attr.SetPads({1, 1, 1, 1});
conv2d_offset.ir_attr.SetDilations({1, 1});
conv2d_offset.ir_attr.SetGroups(1);
conv2d_offset.ir_attr.SetData_format("NCHW");
conv2d_offset.ir_attr.SetOffset_x(0);
conv2d_offset.ir_attr.SetEnable_hf32(false);
af::ascir_op::Store store_o("store");
store_o.attr.sched.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
store_o.x = conv2d_offset.y;
*store_o.y.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
store_o.y.dtype = ge::DT_FLOAT;
*store_o.y.strides = {c_o*h_o*w_o, h_o*w_o, w_o, af::ops::One};
*store_o.y.repeats = {n_o, c_o, h_o, w_o};
af::ascir_op::Output output_o("output");
output_o.x = store_o.y;
output_o.y.dtype = ge::DT_FLOAT;
output_o.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(conv2d_offset_graph);
}
static void CreateConv2DGraphWithGroups(af::AscGraph &conv2d_graph) {
auto n_g = conv2d_graph.CreateSizeVar(1);
auto c_g = conv2d_graph.CreateSizeVar(64);
auto h_g = conv2d_graph.CreateSizeVar(56);
auto w_g = conv2d_graph.CreateSizeVar(56);
auto z_n_g = conv2d_graph.CreateAxis("z_n", n_g);
auto z_c_g = conv2d_graph.CreateAxis("z_c", c_g);
auto z_h_g = conv2d_graph.CreateAxis("z_h", h_g);
auto z_w_g = conv2d_graph.CreateAxis("z_w", w_g);
af::ascir_op::Data data0_g("data0", conv2d_graph);
data0_g.attr.sched.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
data0_g.y.dtype = ge::DT_FLOAT16;
*data0_g.y.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
data0_g.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0_g.y.strides = {c_g*h_g*w_g, h_g*w_g, w_g, af::ops::One};
*data0_g.y.repeats = {n_g, c_g, h_g, w_g};
data0_g.ir_attr.SetIndex(0);
af::ascir_op::Load load0_g("load0");
load0_g.attr.sched.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
load0_g.x = data0_g.y;
*load0_g.y.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
load0_g.y.dtype = ge::DT_FLOAT16;
*load0_g.y.strides = {c_g*h_g*w_g, h_g*w_g, w_g, af::ops::One};
*load0_g.y.repeats = {n_g, c_g, h_g, w_g};
af::ascir_op::Data data1_g("data1", conv2d_graph);
data1_g.y.dtype = ge::DT_FLOAT16;
data1_g.attr.sched.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
*data1_g.y.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
data1_g.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data1_g.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
*data1_g.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
data1_g.ir_attr.SetIndex(1);
af::ascir_op::Load load1_g("load1");
load1_g.x = data1_g.y;
load1_g.attr.sched.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
load1_g.y.dtype = ge::DT_FLOAT16;
*load1_g.y.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
*load1_g.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
*load1_g.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
af::ascir_op::Conv2D conv2d_g("conv2d");
conv2d_g.attr.sched.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
conv2d_g.x = load0_g.y;
conv2d_g.filter = load1_g.y;
conv2d_g.y.dtype = ge::DT_FLOAT;
*conv2d_g.y.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
*conv2d_g.y.repeats = {n_g, c_g, h_g, w_g};
*conv2d_g.y.strides = {c_g*h_g*w_g, h_g*w_g, w_g, af::ops::One};
conv2d_g.attr.api.compute_type = af::ComputeType::kComputeCube;
conv2d_g.ir_attr.SetStrides({2, 2});
conv2d_g.ir_attr.SetPads({1, 1, 1, 1});
conv2d_g.ir_attr.SetDilations({1, 1});
conv2d_g.ir_attr.SetGroups(4);
conv2d_g.ir_attr.SetData_format("NCHW");
conv2d_g.ir_attr.SetOffset_x(0);
conv2d_g.ir_attr.SetEnable_hf32(false);
af::ascir_op::Store store_g("store");
store_g.attr.sched.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
store_g.x = conv2d_g.y;
*store_g.y.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
store_g.y.dtype = ge::DT_FLOAT;
*store_g.y.strides = {c_g*h_g*w_g, h_g*w_g, w_g, af::ops::One};
*store_g.y.repeats = {n_g, c_g, h_g, w_g};
af::ascir_op::Output output_g("output");
output_g.x = store_g.y;
output_g.y.dtype = ge::DT_FLOAT;
output_g.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(conv2d_graph);
}
static void CreateConv2DGraphWithDilation(af::AscGraph &conv2d_graph) {
auto n_d = conv2d_graph.CreateSizeVar(1);
auto c_d = conv2d_graph.CreateSizeVar(64);
auto h_d = conv2d_graph.CreateSizeVar(56);
auto w_d = conv2d_graph.CreateSizeVar(56);
auto z_n_d = conv2d_graph.CreateAxis("z_n", n_d);
auto z_c_d = conv2d_graph.CreateAxis("z_c", c_d);
auto z_h_d = conv2d_graph.CreateAxis("z_h", h_d);
auto z_w_d = conv2d_graph.CreateAxis("z_w", w_d);
af::ascir_op::Data data0_d("data0", conv2d_graph);
data0_d.attr.sched.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
data0_d.y.dtype = ge::DT_FLOAT16;
*data0_d.y.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
data0_d.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0_d.y.strides = {c_d*h_d*w_d, h_d*w_d, w_d, af::ops::One};
*data0_d.y.repeats = {n_d, c_d, h_d, w_d};
data0_d.ir_attr.SetIndex(0);
af::ascir_op::Load load0_d("load0");
load0_d.attr.sched.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
load0_d.x = data0_d.y;
*load0_d.y.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
load0_d.y.dtype = ge::DT_FLOAT16;
*load0_d.y.strides = {c_d*h_d*w_d, h_d*w_d, w_d, af::ops::One};
*load0_d.y.repeats = {n_d, c_d, h_d, w_d};
af::ascir_op::Data data1_d("data1", conv2d_graph);
data1_d.y.dtype = ge::DT_FLOAT16;
data1_d.attr.sched.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
*data1_d.y.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
data1_d.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data1_d.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
*data1_d.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
data1_d.ir_attr.SetIndex(1);
af::ascir_op::Load load1_d("load1");
load1_d.x = data1_d.y;
load1_d.attr.sched.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
load1_d.y.dtype = ge::DT_FLOAT16;
*load1_d.y.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
*load1_d.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
*load1_d.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
af::ascir_op::Conv2D conv2d_d("conv2d");
conv2d_d.attr.sched.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
conv2d_d.x = load0_d.y;
conv2d_d.filter = load1_d.y;
conv2d_d.y.dtype = ge::DT_FLOAT;
*conv2d_d.y.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
*conv2d_d.y.repeats = {n_d, c_d, h_d, w_d};
*conv2d_d.y.strides = {c_d*h_d*w_d, h_d*w_d, w_d, af::ops::One};
conv2d_d.attr.api.compute_type = af::ComputeType::kComputeCube;
conv2d_d.ir_attr.SetStrides({1, 1});
conv2d_d.ir_attr.SetPads({2, 2, 2, 2});
conv2d_d.ir_attr.SetDilations({2, 2});
conv2d_d.ir_attr.SetGroups(1);
conv2d_d.ir_attr.SetData_format("NCHW");
conv2d_d.ir_attr.SetOffset_x(0);
conv2d_d.ir_attr.SetEnable_hf32(false);
af::ascir_op::Store store_d("store");
store_d.attr.sched.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
store_d.x = conv2d_d.y;
*store_d.y.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
store_d.y.dtype = ge::DT_FLOAT;
*store_d.y.strides = {c_d*h_d*w_d, h_d*w_d, w_d, af::ops::One};
*store_d.y.repeats = {n_d, c_d, h_d, w_d};
af::ascir_op::Output output_d("output");
output_d.x = store_d.y;
output_d.y.dtype = ge::DT_FLOAT;
output_d.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(conv2d_graph);
}
static void CreateBatchMatmulElemwiseDynamicGraph(af::AscGraph &graph) {
auto batch = graph.CreateSizeVar("batch");
auto m = graph.CreateSizeVar("m");
auto n = graph.CreateSizeVar("n");
auto z_batch = graph.CreateAxis("z_batch", batch);
auto z_m = graph.CreateAxis("z_m", m);
auto z_n = graph.CreateAxis("z_n", n);
af::ascir_op::Data data0("data0", graph);
data0.attr.sched.axis = {z_batch.id, z_m.id, z_n.id};
data0.y.dtype = ge::DT_FLOAT;
*data0.y.axis = {z_batch.id, z_m.id, z_n.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0.y.strides = {m*n, n, af::ops::One};
*data0.y.repeats = {batch, m, n};
data0.ir_attr.SetIndex(0);
af::ascir_op::Load load0("load0");
load0.attr.sched.axis = {z_batch.id, z_m.id, z_n.id};
load0.x = data0.y;
*load0.y.axis = {z_batch.id, z_m.id, z_n.id};
load0.y.dtype = ge::DT_FLOAT;
*load0.y.strides = {m*n, n, af::ops::One};
*load0.y.repeats = {batch, m, n};
af::ascir_op::Relu relu("relu");
graph.AddNode(relu);
relu.x = load0.y;
relu.attr.sched.axis = {z_batch.id, z_m.id, z_n.id};
relu.y.dtype = ge::DT_FLOAT;
*relu.y.axis = {z_batch.id, z_m.id, z_n.id};
*relu.y.repeats = {batch, m, n};
*relu.y.strides = {m*n, n, af::ops::One};
relu.attr.api.compute_type = af::ComputeType::kComputeElewise;
af::ascir_op::Store store_op("store");
store_op.attr.sched.axis = {z_batch.id, z_m.id, z_n.id};
store_op.x = relu.y;
*store_op.y.axis = {z_batch.id, z_m.id, z_n.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {m*n, n, af::ops::One};
*store_op.y.repeats = {batch, m, n};
af::ascir_op::Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto x1Local = graph.FindNode("data0");
x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
}
static void VerifyDynamicShapeTiling(const std::map<std::string, std::string> &res) {
auto pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" int64_t FindBestTilingKey");
ASSERT_NE(pos, std::string::npos);
auto dynamic_shape_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" bool AutofuseIsStaticShape() {\n return false;");
ASSERT_NE(dynamic_shape_pos, std::string::npos);
auto tiling_func_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus TilingFunc(gert::TilingSymbolEvalContext *context)");
ASSERT_NE(tiling_func_pos, std::string::npos);
auto tiling_call_pos = res.at("tiling_def_and_tiling_const").find("AutofuseTilingWithConfig");
ASSERT_NE(tiling_call_pos, std::string::npos);
auto cache_key_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus GetSymbolTilingCacheKey");
ASSERT_NE(cache_key_pos, std::string::npos);
auto tiling_data_pos = res.at("tiling_def_and_tiling_const").find("AutofuseTilingData");
ASSERT_NE(tiling_data_pos, std::string::npos);
}
static void VerifyConv2dElemwiseTiling(const std::map<std::string, std::string> &res) {
auto pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" int64_t FindBestTilingKey");
ASSERT_NE(pos, std::string::npos);
auto static_shape_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" bool AutofuseIsStaticShape() {\n return true;");
ASSERT_NE(static_shape_pos, std::string::npos);
auto tiling_func_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus TilingFunc");
ASSERT_NE(tiling_func_pos, std::string::npos);
auto tiling_parse_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus TilingParse");
ASSERT_NE(tiling_parse_pos, std::string::npos);
auto get_size_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" size_t GetTilingDataSize()");
ASSERT_NE(get_size_pos, std::string::npos);
auto workspace_pos = res.at("tiling_def_and_tiling_const").find("*context->GetWorkspaceSizes(1) = 16 * 1024 * 1024");
ASSERT_NE(workspace_pos, std::string::npos);
auto tiling_data_pos = res.at("tiling_def_and_tiling_const").find("AutofuseTilingData");
ASSERT_NE(tiling_data_pos, std::string::npos);
auto block_dim_pos = res.at("tiling_def_and_tiling_const").find("set_block_dim");
ASSERT_NE(block_dim_pos, std::string::npos);
}
static void VerifyConv2DBiasElemwiseTiling(const std::map<std::string, std::string> &res) {
auto pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" int64_t FindBestTilingKey");
ASSERT_NE(pos, std::string::npos);
auto static_shape_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" bool AutofuseIsStaticShape() {\n return true;");
ASSERT_NE(static_shape_pos, std::string::npos);
auto tiling_func_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus TilingFunc");
ASSERT_NE(tiling_func_pos, std::string::npos);
auto get_size_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" size_t GetTilingDataSize()");
ASSERT_NE(get_size_pos, std::string::npos);
auto tiling_data_pos = res.at("tiling_def_and_tiling_const").find("AutofuseTilingData");
ASSERT_NE(tiling_data_pos, std::string::npos);
auto workspace_pos = res.at("tiling_def_and_tiling_const").find("*context->GetWorkspaceSizes(1) = 16 * 1024 * 1024");
ASSERT_NE(workspace_pos, std::string::npos);
}
static void VerifyConv2DOffsetTiling(const std::map<std::string, std::string> &res) {
auto pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" int64_t FindBestTilingKey");
ASSERT_NE(pos, std::string::npos);
auto static_shape_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" bool AutofuseIsStaticShape() {\n return true;");
ASSERT_NE(static_shape_pos, std::string::npos);
auto tiling_func_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus TilingFunc");
ASSERT_NE(tiling_func_pos, std::string::npos);
auto get_size_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" size_t GetTilingDataSize()");
ASSERT_NE(get_size_pos, std::string::npos);
auto tiling_data_pos = res.at("tiling_def_and_tiling_const").find("AutofuseTilingData");
ASSERT_NE(tiling_data_pos, std::string::npos);
}
}
class TestCodegenTiling : public testing::Test, public codegen::TilingLib {
public:
void SetUp() override {
dlog_setlevel(ASCGEN_MODULE_NAME, DLOG_DEBUG, 0);
}
void TearDown() override {
}
void SetupLoadAttrs(af::AscNode &load, uint64_t z0_id, const af::Expression &z0_size) {
auto &attr = load.outputs[0].attr;
attr.axis = {static_cast<int64_t>(z0_id)};
attr.vectorized_axis = {static_cast<int64_t>(z0_id)};
attr.vectorized_strides = {af::ops::One};
attr.repeats = {z0_size};
attr.strides = {af::ops::One};
attr.mem.position = af::Position::kPositionVecIn;
attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
attr.mem.tensor_id = 1;
attr.que.id = 0;
attr.mem.reuse_id = 0;
attr.que.depth = 2;
attr.que.buf_num = 2;
attr.opt.merge_scope = af::kIdNone;
}
void SetupStoreAttrs(af::AscNode &store, uint64_t z0_id, const af::Expression &z0_size) {
auto &attr = store.outputs[0].attr;
attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
attr.mem.tensor_id = 2;
attr.axis = {static_cast<int64_t>(z0_id)};
attr.vectorized_axis = {static_cast<int64_t>(z0_id)};
attr.vectorized_strides = {af::ops::One};
attr.repeats = {z0_size};
attr.strides = {af::ops::One};
}
ascir::FusedScheduledResult GenBasicFusedScheduleResult(const std::vector<af::Expression> &origin_vars = {}) {
af::AscGraph graph("test_graph");
auto s0 = graph.CreateSizeVar("s0");
auto z0 = graph.CreateAxis("z0", af::ops::Zero);
af::ascir_op::Data x_op("x", graph);
x_op.ir_attr.SetIndex(0);
af::ascir_op::Load load_op("load");
af::ascir_op::Store store_op("store");
af::ascir_op::Output y_op("y");
y_op.ir_attr.SetIndex(0);
graph.AddNode(load_op);
graph.AddNode(store_op);
graph.AddNode(y_op);
load_op.x = x_op.y;
load_op.y.dtype = ge::DT_FLOAT16;
store_op.x = load_op.y;
y_op.x = store_op.y;
auto x = graph.FindNode("x");
auto load = graph.FindNode("load");
auto store = graph.FindNode("store");
auto y = graph.FindNode("y");
x->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
x->outputs[0].attr.mem.tensor_id = 0;
x->attr.api.unit = af::ComputeUnit::kUnitNone;
y->attr.api.unit = af::ComputeUnit::kUnitNone;
SetupLoadAttrs(*load, z0.id, z0.size);
SetupStoreAttrs(*store, z0.id, z0.size);
::ascir::ScheduledResult schedule_result;
schedule_result.schedule_groups.resize(1);
for (auto &schedule_group : schedule_result.schedule_groups) {
schedule_group.impl_graphs.emplace_back(graph);
}
std::vector<ascir::ScheduledResult> schedule_results;
schedule_results.push_back(schedule_result);
schedule_results.push_back(schedule_result);
ascir::FusedScheduledResult fused_schedule_result;
fused_schedule_result.fused_graph_name = af::AscendString(graph.GetName().c_str());
fused_schedule_result.input_nodes.push_back(x);
fused_schedule_result.output_nodes.push_back(y);
fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
fused_schedule_result.origin_vars = origin_vars;
return fused_schedule_result;
}
std::map<std::string, std::string> GenTilingCode(const std::vector<af::Expression> &origin_vars = {},
const std::map<std::string, std::string> &shape_info = {}) {
auto fused_schedule_result = GenBasicFusedScheduleResult(origin_vars);
return this->Generate(fused_schedule_result, shape_info, "", "0");
}
std::map<std::string, std::string> GenTilingCodeForInductor(const std::vector<af::Expression> &origin_vars = {}) {
auto fused_schedule_result = GenBasicFusedScheduleResult(origin_vars);
return this->GenerateForInductor(fused_schedule_result);
}
protected:
TestCodegenTiling() : codegen::TilingLib("test", "test") {}
};
TEST_F(TestCodegenTiling, NoWorkspaceTest) {
ascir::ImplGraph graph0("test_graph0");
graph0.CreateSizeVar("s0");
graph0.CreateSizeVar("s1");
std::vector<ascir::ImplGraph> impl_graphs;
impl_graphs.push_back(graph0);
std::vector<ascir::ScheduledResult> schedule_results;
ascir::ScheduledResult schedule_result;
ascir::ScheduleGroup schedule_group;
schedule_group.impl_graphs = impl_graphs;
schedule_result.schedule_groups.push_back(schedule_group);
schedule_results.push_back(schedule_result);
ascir::FusedScheduledResult fused_schedule_result;
fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
EXPECT_EQ(this->GenGetWorkspaceSizeFunc("AutofuseTilingData", fused_schedule_result), std::string{
"uint32_t GetWorkspaceSize(const AutofuseTilingData &t) {\n"
" using namespace optiling;\n"
" uint32_t ws_size = 0;\n"
" if (t.tiling_key == 0) {\n"
" ws_size += 0;\n"
" }\n"
"\n"
" ws_size = (ws_size + 512 - 1) / 512 * 512;\n"
" return ws_size;\n"
"}\n"});
}
TEST_F(TestCodegenTiling, SingleGroupWorkspaceSymbolTest) {
ascir::ImplGraph graph0("test_graph0");
auto s0 = graph0.CreateSizeVar("s0");
auto s1 = graph0.CreateSizeVar("s1");
auto z0 = graph0.CreateAxis("z0", s0);
auto z1 = graph0.CreateAxis("z1", s1);
af::ascir_op::Workspace workspace("workspace");
graph0.AddNode(workspace);
workspace.y.dtype = ge::DT_FLOAT16;
af::ascir_op::Load load("load");
graph0.AddNode(load);
load.x = workspace.y;
load.attr.sched.axis = {z0.id, z1.id};
*load.y.axis = {z0.id, z1.id};
*load.y.repeats = {s0, s1};
*load.y.strides = {s1, af::ops::One};
auto load_node = graph0.FindNode("load");
auto workspace_node = graph0.FindNode("workspace");
workspace_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
workspace_node->outputs[0].attr.mem.tensor_id = 0;
load_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
load_node->outputs[0].attr.mem.tensor_id = 1;
std::vector<ascir::ImplGraph> impl_graphs;
impl_graphs.push_back(graph0);
std::vector<ascir::ScheduledResult> schedule_results;
ascir::ScheduledResult schedule_result;
ascir::ScheduleGroup schedule_group;
schedule_group.impl_graphs = impl_graphs;
schedule_result.schedule_groups.push_back(schedule_group);
schedule_results.push_back(schedule_result);
ascir::FusedScheduledResult fused_schedule_result;
fused_schedule_result.workspace_nodes.push_back(workspace_node);
fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
EXPECT_EQ(this->GenGetWorkspaceSizeFunc("AutofuseTilingData", fused_schedule_result), std::string{
"uint32_t GetWorkspaceSize(const AutofuseTilingData &t) {\n"
" using namespace optiling;\n"
" uint32_t ws_size = 0;\n"
" if (t.tiling_key == 0) {\n"
" ws_size += Max(0, (2 * Max(Max(1, t.s1), (t.s0 * t.s1))));\n"
" }\n"
"\n"
" ws_size = (ws_size + 512 - 1) / 512 * 512;\n"
" return ws_size;\n"
"}\n"
});
}
TEST_F(TestCodegenTiling, SingleGroupWorkspaceValueTest) {
ascir::ImplGraph graph0("test_graph0");
auto s0 = graph0.CreateSizeVar(150);
auto s1 = graph0.CreateSizeVar(2);
auto z0 = graph0.CreateAxis("z0", s0);
auto z1 = graph0.CreateAxis("z1", s1);
af::ascir_op::Workspace workspace("workspace");
graph0.AddNode(workspace);
workspace.y.dtype = ge::DT_FLOAT16;
af::ascir_op::Load load("load");
graph0.AddNode(load);
load.x = workspace.y;
load.attr.sched.axis = {z0.id, z1.id};
*load.y.axis = {z0.id, z1.id};
*load.y.repeats = {s0, s1};
*load.y.strides = {s1, af::ops::One};
auto load_node = graph0.FindNode("load");
auto workspace_node = graph0.FindNode("workspace");
workspace_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
workspace_node->outputs[0].attr.mem.tensor_id = 0;
load_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
load_node->outputs[0].attr.mem.tensor_id = 1;
std::vector<ascir::ImplGraph> impl_graphs;
impl_graphs.push_back(graph0);
std::vector<ascir::ScheduledResult> schedule_results;
ascir::ScheduledResult schedule_result;
ascir::ScheduleGroup schedule_group;
schedule_group.impl_graphs = impl_graphs;
schedule_result.schedule_groups.push_back(schedule_group);
schedule_results.push_back(schedule_result);
ascir::FusedScheduledResult fused_schedule_result;
fused_schedule_result.workspace_nodes.push_back(workspace_node);
fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
EXPECT_EQ(this->GenGetWorkspaceSizeFunc("AutofuseTilingData", fused_schedule_result), std::string{
"uint32_t GetWorkspaceSize(const AutofuseTilingData &t) {\n"
" using namespace optiling;\n"
" uint32_t ws_size = 0;\n"
" if (t.tiling_key == 0) {\n"
" ws_size += 600;\n"
" }\n"
"\n"
" ws_size = (ws_size + 512 - 1) / 512 * 512;\n"
" return ws_size;\n"
"}\n"});
}
TEST_F(TestCodegenTiling, MultiGroupWorkspaceSymbolTest) {
ascir::ImplGraph graph0("test_graph0");
auto s0 = graph0.CreateSizeVar("s0");
auto s1 = graph0.CreateSizeVar("s1");
auto z0 = graph0.CreateAxis("z0", s0);
auto z1 = graph0.CreateAxis("z1", s1);
af::ascir_op::Workspace workspace("workspace");
graph0.AddNode(workspace);
workspace.y.dtype = ge::DT_FLOAT16;
af::ascir_op::Load load("load");
graph0.AddNode(load);
load.x = workspace.y;
load.attr.sched.axis = {z0.id, z1.id};
*load.y.axis = {z0.id, z1.id};
*load.y.repeats = {s0, s1};
*load.y.strides = {s1, af::ops::One};
auto load_node = graph0.FindNode("load");
auto workspace_node = graph0.FindNode("workspace");
workspace_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
workspace_node->outputs[0].attr.mem.tensor_id = 0;
load_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
load_node->outputs[0].attr.mem.tensor_id = 2;
ascir::ImplGraph graph1("test_graph1");
s0 = graph1.CreateSizeVar("s0");
s1 = graph1.CreateSizeVar("s1");
z0 = graph1.CreateAxis("z0", s0);
z1 = graph1.CreateAxis("z1", s1);
af::ascir_op::Workspace workspace1("workspace1");
graph1.AddNode(workspace1);
workspace1.y.dtype = ge::DT_FLOAT16;
af::ascir_op::Load load1("load1");
graph1.AddNode(load1);
load1.x = workspace1.y;
load1.attr.sched.axis = {z0.id, z1.id};
*load1.y.axis = {z0.id, z1.id};
*load1.y.repeats = {s0, s1};
*load1.y.strides = {s1, af::ops::One};
auto load1_node = graph1.FindNode("load1");
auto workspace1_node = graph1.FindNode("workspace1");
workspace1_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
workspace1_node->outputs[0].attr.mem.tensor_id = 1;
load1_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
load1_node->outputs[0].attr.mem.tensor_id = 3;
std::vector<ascir::ScheduledResult> schedule_results;
ascir::ScheduledResult schedule_result;
ascir::ScheduleGroup sch_groups0;
ascir::ScheduleGroup sch_groups1;
sch_groups0.impl_graphs = {graph0};
sch_groups1.impl_graphs = {graph1};
schedule_result.schedule_groups.push_back(sch_groups0);
schedule_result.schedule_groups.push_back(sch_groups1);
schedule_results.push_back(schedule_result);
ascir::FusedScheduledResult fused_schedule_result;
fused_schedule_result.workspace_nodes.push_back(workspace_node);
fused_schedule_result.workspace_nodes.push_back(workspace1_node);
fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
EXPECT_EQ(this->GenGetWorkspaceSizeFunc("AutofuseTilingData", fused_schedule_result), std::string{
"uint32_t GetWorkspaceSize(const AutofuseTilingData &t) {\n"
" using namespace optiling;\n"
" uint32_t ws_size = 0;\n"
" if (t.graph0_tiling_key == 0) {\n"
" if (t.graph0_result0_g0_tiling_data.tiling_key == 0) {\n"
" ws_size += Max(0, (2 * Max(Max(1, t.graph0_result0_g0_tiling_data.s1), (t.graph0_result0_g0_tiling_data.s0 * t.graph0_result0_g0_tiling_data.s1))));\n"
" }\n"
" if (t.graph0_result0_g1_tiling_data.tiling_key == 0) {\n"
" ws_size += Max(0, (2 * Max(Max(1, t.graph0_result0_g1_tiling_data.s1), (t.graph0_result0_g1_tiling_data.s0 * t.graph0_result0_g1_tiling_data.s1))));\n"
" }\n"
" }\n"
" ws_size = (ws_size + 512 - 1) / 512 * 512;\n"
" return ws_size;\n"
"}\n"
});
}
TEST_F(TestCodegenTiling, MultiGroupWorkspaceValueTest) {
ascir::ImplGraph graph0("test_graph0");
auto s0 = graph0.CreateSizeVar(16);
auto s1 = graph0.CreateSizeVar(32);
auto z0 = graph0.CreateAxis("z0", s0);
auto z1 = graph0.CreateAxis("z1", s1);
af::ascir_op::Workspace workspace("workspace");
graph0.AddNode(workspace);
workspace.y.dtype = ge::DT_FLOAT16;
af::ascir_op::Load load("load");
graph0.AddNode(load);
load.x = workspace.y;
load.attr.sched.axis = {z0.id, z1.id};
*load.y.axis = {z0.id, z1.id};
*load.y.repeats = {s0, s1};
*load.y.strides = {s1, af::ops::One};
auto load_node = graph0.FindNode("load");
auto workspace_node = graph0.FindNode("workspace");
workspace_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
workspace_node->outputs[0].attr.mem.tensor_id = 0;
load_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
load_node->outputs[0].attr.mem.tensor_id = 2;
ascir::ImplGraph graph1("test_graph1");
s0 = graph1.CreateSizeVar(5);
s1 = graph1.CreateSizeVar(100);
z0 = graph1.CreateAxis("z0", s0);
z1 = graph1.CreateAxis("z1", s1);
af::ascir_op::Workspace workspace1("workspace1");
graph1.AddNode(workspace1);
workspace1.y.dtype = ge::DT_FLOAT16;
af::ascir_op::Load load1("load1");
graph1.AddNode(load1);
load1.x = workspace1.y;
load1.attr.sched.axis = {z0.id, z1.id};
*load1.y.axis = {z0.id, z1.id};
*load1.y.repeats = {s0, s1};
*load1.y.strides = {s1, af::ops::One};
auto load1_node = graph1.FindNode("load1");
auto workspace1_node = graph1.FindNode("workspace1");
workspace1_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
workspace1_node->outputs[0].attr.mem.tensor_id = 1;
load1_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
load1_node->outputs[0].attr.mem.tensor_id = 3;
std::vector<ascir::ScheduledResult> schedule_results;
ascir::ScheduledResult schedule_result;
ascir::ScheduleGroup sch_groups0;
ascir::ScheduleGroup sch_groups1;
sch_groups0.impl_graphs = {graph0};
sch_groups1.impl_graphs = {graph1};
schedule_result.schedule_groups.push_back(sch_groups0);
schedule_result.schedule_groups.push_back(sch_groups1);
schedule_results.push_back(schedule_result);
ascir::FusedScheduledResult fused_schedule_result;
fused_schedule_result.workspace_nodes.push_back(workspace_node);
fused_schedule_result.workspace_nodes.push_back(workspace1_node);
fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
EXPECT_EQ(this->GenGetWorkspaceSizeFunc("AutofuseTilingData", fused_schedule_result), std::string{
"uint32_t GetWorkspaceSize(const AutofuseTilingData &t) {\n"
" using namespace optiling;\n"
" uint32_t ws_size = 0;\n"
" if (t.graph0_tiling_key == 0) {\n"
" if (t.graph0_result0_g0_tiling_data.tiling_key == 0) {\n"
" ws_size += 1024;\n"
" }\n"
" if (t.graph0_result0_g1_tiling_data.tiling_key == 0) {\n"
" ws_size += 1000;\n"
" }\n"
" }\n"
" ws_size = (ws_size + 512 - 1) / 512 * 512;\n"
" return ws_size;\n"
"}\n"});
}
TEST_F(TestCodegenTiling, EmptyTensorKernel) {
af::AscGraph graph("test_graph");
auto s0 = graph.CreateSizeVar("s0");
auto z0 = graph.CreateAxis("z0", af::ops::Zero);
af::ascir_op::Data x_op("x", graph);
x_op.ir_attr.SetIndex(0);
af::ascir_op::Load load_op("load");
af::ascir_op::Store store_op("store");
af::ascir_op::Output y_op("y");
y_op.ir_attr.SetIndex(0);
graph.AddNode(load_op);
graph.AddNode(store_op);
graph.AddNode(y_op);
load_op.x = x_op.y;
load_op.y.dtype = ge::DT_FLOAT16;
store_op.x = load_op.y;
y_op.x = store_op.y;
auto x = graph.FindNode("x");
auto load = graph.FindNode("load");
auto store = graph.FindNode("store");
auto y = graph.FindNode("y");
x->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
x->outputs[0].attr.mem.tensor_id = 0;
x->attr.api.unit = af::ComputeUnit::kUnitNone;
y->attr.api.unit = af::ComputeUnit::kUnitNone;
load->outputs[0].attr.axis = {z0.id};
load->outputs[0].attr.vectorized_axis = {z0.id};
load->outputs[0].attr.vectorized_strides = {af::ops::One};
load->outputs[0].attr.repeats = {z0.size};
load->outputs[0].attr.strides = {af::ops::One};
load->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
load->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
load->outputs[0].attr.mem.tensor_id = 1;
load->outputs[0].attr.que.id = 0;
load->outputs[0].attr.mem.reuse_id = 0;
load->outputs[0].attr.que.depth = 2;
load->outputs[0].attr.que.buf_num = 2;
load->outputs[0].attr.opt.merge_scope = af::kIdNone;
store->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
store->outputs[0].attr.mem.tensor_id = 2;
store->outputs[0].attr.axis = {z0.id};
store->outputs[0].attr.vectorized_axis = {z0.id};
store->outputs[0].attr.vectorized_strides = {af::ops::One};
store->outputs[0].attr.repeats = {z0.size};
store->outputs[0].attr.strides = {af::ops::One};
::ascir::ScheduledResult schedule_result;
schedule_result.schedule_groups.resize(1);
for (auto &schedule_group : schedule_result.schedule_groups) {
schedule_group.impl_graphs.emplace_back(graph);
}
std::vector<ascir::ScheduledResult> schedule_results;
schedule_results.push_back(schedule_result);
schedule_results.push_back(schedule_result);
ascir::FusedScheduledResult fused_schedule_result;
fused_schedule_result.fused_graph_name = af::AscendString(graph.GetName().c_str());
fused_schedule_result.input_nodes.push_back(x);
fused_schedule_result.output_nodes.push_back(y);
fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
const std::map<std::string, std::string> shape_info;
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::string tiling_func_declare {"TilingFunc(gert::TilingSymbolEvalContext *context)\n{\n"};
auto pos = res["tiling_def_and_tiling_const"].find(tiling_func_declare) + tiling_func_declare.size();
std::string expect_str {" context->SetBlockDim(1);\n *context->GetWorkspaceSizes(1) = 0;\n return ge::GRAPH_SUCCESS;\n"};
std::string tiling_func_content = res["tiling_def_and_tiling_const"].substr(pos, expect_str.size());
EXPECT_EQ(expect_str, tiling_func_content);
}
TEST_F(TestCodegenTiling, TestGenDfxInputSymbolInfo) {
std::map<std::string, std::string> shape_info;
shape_info["s0"] = R"([&]() -> int64_t {
const auto *tensor = context->GetInputTensor(0);
if (tensor == nullptr) {
return -1;
}
return tensor->GetOriginShape().GetDim(0);
}())";
shape_info["s1"] = R"([&]() -> int64_t {
const auto *tensor = context->GetInputTensor(0);
if (tensor == nullptr) {
return -1;
}
return tensor->GetOriginShape().GetDim(1);
}())";
shape_info["s2"] = R"([&]() -> int64_t {
const auto *tensor = context->GetInputTensor(1);
if (tensor == nullptr) {
return -1;
}
return tensor->GetOriginShape().GetDim(0);
}())";
ascir::FusedScheduledResult fused_schedule_result;
std::vector<af::Expression> origin_vars{af::Symbol("s0"), af::Symbol("s1"), af::Symbol("s2")};
fused_schedule_result.origin_vars = origin_vars;
auto gen_func = this->GenDfxInputSymbolInfo(fused_schedule_result, shape_info);
auto expect_func = R"(extern "C" ge::graphStatus DfxInputSymbolInfo(gert::TilingSymbolEvalContext *context, char *out_symbol_info, size_t size)
{
if (out_symbol_info == nullptr || size == 0) {
return ge::GRAPH_SUCCESS;
}
std::string symbol_info;
auto s0 = [&]() -> int64_t {
const auto *tensor = context->GetInputTensor(0);
if (tensor == nullptr) {
return -1;
}
return tensor->GetOriginShape().GetDim(0);
}();
symbol_info += ("s0: " + std::to_string(s0));
auto s1 = [&]() -> int64_t {
const auto *tensor = context->GetInputTensor(0);
if (tensor == nullptr) {
return -1;
}
return tensor->GetOriginShape().GetDim(1);
}();
symbol_info += (", s1: " + std::to_string(s1));
auto s2 = [&]() -> int64_t {
const auto *tensor = context->GetInputTensor(1);
if (tensor == nullptr) {
return -1;
}
return tensor->GetOriginShape().GetDim(0);
}();
symbol_info += (", s2: " + std::to_string(s2));
if (symbol_info.empty()) {
out_symbol_info[0] = '\0';
return ge::GRAPH_SUCCESS;
}
symbol_info += ".";
if (strncpy_s(out_symbol_info, size, symbol_info.c_str(), std::min(symbol_info.size(), size - 1)) != 0) {
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
)";
EXPECT_EQ(gen_func, expect_func);
}
TEST_F(TestCodegenTiling, TestCompileSuccess) {
std::stringstream ss;
ss << "#include <stdexcept>" << std::endl;
ss << "#include <sstream>" << std::endl;
ss << "#include <cmath>" << std::endl;
ss << "#ifndef __CCE_KT_TEST__" << std::endl;
ss << "#include \"register/op_def_registry.h\"" << std::endl;
ss << "#include \"exe_graph/runtime/infer_shape_context.h\"" << std::endl;
ss << "#include \"exe_graph/runtime/kernel_context.h\"" << std::endl;
ss << "#include \"exe_graph/runtime/continuous_vector.h\"" << std::endl;
ss << "#endif" << std::endl;
ss << "#define Max(a, b) ((double)(a) > (double)(b) ? (a) : (b))" << std::endl;
ss << "#define Min(a, b) ((double)(a) < (double)(b) ? (a) : (b))" << std::endl;
ss << "#define Log(a) (log((double)(a)))" << std::endl;
ss << "#define Pow(a, b) pow(a, b)" << std::endl;
ss << "#define Rational(a, b) ((double)(a) / (double)(b))" << std::endl;
std::string tiling_context = R"(
namespace gert {
class TilingSymbolEvalContext : public TilingContext {
public:
const gert::Tensor *GetGraphInputTensor(size_t data_index) const {
auto *tensor = GetInputPointer<gert::Tensor>(data_index + 1);
if (tensor == nullptr) {
return nullptr;
}
return tensor;
}
};
})";
ss << tiling_context << std::endl;
std::map<std::string, std::string> shape_info;
shape_info["s0"] = R"([&]() -> int64_t {
const auto *tensor = context->GetInputTensor(0);
if (tensor == nullptr) {
return -1;
}
return tensor->GetOriginShape().GetDim(0);
}())";
shape_info["s1"] = R"([&]() -> int64_t {
const auto *tensor = context->GetInputTensor(0);
if (tensor == nullptr) {
return -1;
}
return tensor->GetOriginShape().GetDim(1);
}())";
shape_info["s2"] = R"([&]() -> int64_t {
const auto *tensor = context->GetInputTensor(1);
if (tensor == nullptr) {
return -1;
}
return tensor->GetOriginShape().GetDim(0);
}())";
ascir::FusedScheduledResult fused_schedule_result;
std::vector<af::Expression> origin_vars{af::Symbol("s0"), af::Symbol("s1"), af::Symbol("s2")};
fused_schedule_result.origin_vars = origin_vars;
auto dfx_func = this->GenDfxInputSymbolInfo(fused_schedule_result, shape_info);
ss << dfx_func << std::endl;
ASSERT_TRUE(CompileCode(ss.str()));
}
* Codegen FindBestTilingKey测试
* 1、单graph,单result单group
* 2、多graph,仅在inductor场景下有,本轮暂不支持
* 3、单graph,多result组合场景
* result1:单group,单graph
* result2:单group,多graph
* result3:多group场景组合
* group1:单graph
* group2:多graph
* 4、enable_group_parallel场景, 不支持生成
*/
TEST_F(TestCodegenTiling, TestGenFindBestTilingKeyFuncFor1Group) {
af::AscGraph graph1("graph1");
af::ascir_op::Workspace workspace("workspace");
graph1.AddNode(workspace);
af::AscGraph graph2("graph2");
af::AscGraph graph3("graph3");
ascir::ScheduleGroup schedule_group;
schedule_group.impl_graphs.push_back(graph1);
schedule_group.impl_graphs.push_back(graph2);
schedule_group.impl_graphs.push_back(graph3);
ascir::ScheduledResult schedule_result;
schedule_result.schedule_groups.push_back(schedule_group);
ascir::FusedScheduledResult fused_schedule_result;
std::vector<ascir::ScheduledResult> graph0_results = {schedule_result};
fused_schedule_result.node_idx_to_scheduled_results.emplace_back(std::move(graph0_results));
const std::map<std::string, std::string> shape_info;
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::string expect = R"(extern "C" int64_t FindBestTilingKey(AutofuseTilingData &t)
{
if (t.tiling_key == 0) {
return 0;
} else if (t.tiling_key == 1) {
return 1;
} else if (t.tiling_key == 2) {
return 2;
}
return -1;
}
)";
auto pos = res["tiling_def_and_tiling_const"].find("extern \"C\" int64_t FindBestTilingKey(AutofuseTilingData &t)");
auto func = res["tiling_def_and_tiling_const"].substr(pos, expect.size());
ASSERT_EQ(func, expect);
}
TEST_F(TestCodegenTiling, TestGenFindBestTilingKeyFuncForMultiResult) {
af::AscGraph graph1("graph1");
af::AscGraph graph2("graph2");
ascir::ScheduleGroup schedule_group1;
schedule_group1.impl_graphs.push_back(graph1);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(graph1);
schedule_group2.impl_graphs.push_back(graph2);
ascir::ScheduledResult schedule_result1;
schedule_result1.schedule_groups.push_back(schedule_group1);
ascir::ScheduledResult schedule_result2;
schedule_result2.schedule_groups.push_back(schedule_group2);
ascir::ScheduledResult schedule_result3;
schedule_result3.schedule_groups.push_back(schedule_group1);
schedule_result3.schedule_groups.push_back(schedule_group2);
ascir::FusedScheduledResult fused_schedule_result;
std::vector<ascir::ScheduledResult> graph0_results = {schedule_result1, schedule_result2, schedule_result3};
fused_schedule_result.node_idx_to_scheduled_results.emplace_back(std::move(graph0_results));
const std::map<std::string, std::string> shape_info;
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::string expect = R"(extern "C" int64_t FindBestTilingKey(AutofuseTilingData &t)
{
if (t.graph0_tiling_key == 0) {
if (t.graph0_result0_g0_tiling_data.tiling_key == 0) {
return 0;
}
} else if (t.graph0_tiling_key == 1) {
if (t.graph0_result1_g0_tiling_data.tiling_key == 0) {
return 1;
} else if (t.graph0_result1_g0_tiling_data.tiling_key == 1) {
return 2;
}
} else if (t.graph0_tiling_key == 2) {
if (t.graph0_result2_g0_tiling_data.tiling_key == 0 && t.graph0_result2_g1_tiling_data.tiling_key == 0) {
return 3;
} else if (t.graph0_result2_g0_tiling_data.tiling_key == 0 && t.graph0_result2_g1_tiling_data.tiling_key == 1) {
return 4;
}
}
return -1;
}
)";
auto pos = res["tiling_def_and_tiling_const"].find("extern \"C\" int64_t FindBestTilingKey(AutofuseTilingData &t)");
auto func = res["tiling_def_and_tiling_const"].substr(pos, expect.size());
ASSERT_EQ(func, expect);
}
TEST_F(TestCodegenTiling, TestGenFindBestTilingKeyFuncForEnableParallel) {
af::AscGraph graph1("graph1");
af::AscGraph graph2("graph2");
ascir::ScheduleGroup schedule_group1;
schedule_group1.impl_graphs.push_back(graph1);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(graph1);
schedule_group2.impl_graphs.push_back(graph2);
ascir::ScheduledResult schedule_result1;
schedule_result1.schedule_groups.push_back(schedule_group1);
ascir::ScheduledResult schedule_result2;
schedule_result2.schedule_groups.push_back(schedule_group2);
ascir::ScheduledResult schedule_result3;
schedule_result3.enable_group_parallel = true;
schedule_result3.schedule_groups.push_back(schedule_group1);
schedule_result3.schedule_groups.push_back(schedule_group2);
ascir::FusedScheduledResult fused_schedule_result;
std::vector<ascir::ScheduledResult> graph0_results = {schedule_result1, schedule_result2, schedule_result3};
fused_schedule_result.node_idx_to_scheduled_results.emplace_back(std::move(graph0_results));
const std::map<std::string, std::string> shape_info;
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
auto pos = res["tiling_def_and_tiling_const"].find("extern \"C\" int64_t FindBestTilingKey(AutofuseTilingData &t)");
ASSERT_EQ(pos, std::string::npos);
}
TEST_F(TestCodegenTiling, TestGenExternTilingFunc) {
ge::PlatformContext::GetInstance().Reset();
auto stub_v2 = std::make_shared<ge::RuntimeStubV2Common>();
ge::RuntimeStub::SetInstance(stub_v2);
af::AscGraph graph1("graph1");
af::AscGraph graph2("graph2");
ascir::ScheduleGroup schedule_group1;
schedule_group1.impl_graphs.push_back(graph1);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(graph1);
schedule_group2.impl_graphs.push_back(graph2);
ascir::ScheduledResult schedule_result1;
schedule_result1.schedule_groups.push_back(schedule_group1);
ascir::ScheduledResult schedule_result2;
schedule_result2.schedule_groups.push_back(schedule_group2);
ascir::ScheduledResult schedule_result3;
schedule_result3.enable_group_parallel = true;
schedule_result3.schedule_groups.push_back(schedule_group1);
schedule_result3.schedule_groups.push_back(schedule_group2);
ascir::FusedScheduledResult fused_schedule_result;
std::vector<ascir::ScheduledResult> graph0_results = {schedule_result1, schedule_result2, schedule_result3};
fused_schedule_result.node_idx_to_scheduled_results.emplace_back(std::move(graph0_results));
const std::map<std::string, std::string> shape_info;
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
auto pos = res["tiling_def_and_tiling_const"].find("extern \"C\" int64_t FindBestTilingKey(AutofuseTilingData &t)");
ASSERT_EQ(pos, std::string::npos);
ge::RuntimeStub::Reset();
ge::PlatformContext::GetInstance().Reset();
}
TEST_F(TestCodegenTiling, TestPGOSearchTensorMallocDef) {
af::AscGraph graph("test_graph");
auto s0 = graph.CreateSizeVar("s0");
auto z0 = graph.CreateAxis("z0", af::ops::One);
af::ascir_op::Data x_op("x", graph);
x_op.ir_attr.SetIndex(0);
af::ascir_op::Load load_op("load");
af::ascir_op::Store store_op("store");
af::ascir_op::Output y_op("y");
y_op.ir_attr.SetIndex(0);
graph.AddNode(load_op);
graph.AddNode(store_op);
graph.AddNode(y_op);
load_op.x = x_op.y;
load_op.y.dtype = ge::DT_FLOAT16;
store_op.x = load_op.y;
y_op.x = store_op.y;
auto x = graph.FindNode("x");
auto load = graph.FindNode("load");
auto store = graph.FindNode("store");
auto y = graph.FindNode("y");
x->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
x->outputs[0].attr.mem.tensor_id = 0;
x->attr.api.unit = af::ComputeUnit::kUnitNone;
y->attr.api.unit = af::ComputeUnit::kUnitNone;
load->outputs[0].attr.axis = {z0.id};
load->outputs[0].attr.vectorized_axis = {z0.id};
load->outputs[0].attr.vectorized_strides = {af::ops::One};
load->outputs[0].attr.repeats = {z0.size};
load->outputs[0].attr.strides = {af::ops::One};
load->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
load->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
load->outputs[0].attr.mem.tensor_id = 1;
load->outputs[0].attr.que.id = 0;
load->outputs[0].attr.mem.reuse_id = 0;
load->outputs[0].attr.que.depth = 2;
load->outputs[0].attr.que.buf_num = 2;
load->outputs[0].attr.opt.merge_scope = af::kIdNone;
store->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
store->outputs[0].attr.mem.tensor_id = 2;
store->outputs[0].attr.axis = {z0.id};
store->outputs[0].attr.vectorized_axis = {z0.id};
store->outputs[0].attr.vectorized_strides = {af::ops::One};
store->outputs[0].attr.repeats = {z0.size};
store->outputs[0].attr.strides = {af::ops::One};
y->inputs[0].attr.repeats = {z0.size};
y->inputs[0].attr.strides = {af::ops::One};
y->inputs[0].attr.dtype = ge::DT_FLOAT16;
::ascir::ScheduledResult schedule_result;
schedule_result.schedule_groups.resize(1);
for (auto &schedule_group : schedule_result.schedule_groups) {
schedule_group.impl_graphs.emplace_back(graph);
}
std::vector<ascir::ScheduledResult> schedule_results;
schedule_results.push_back(schedule_result);
schedule_results.push_back(schedule_result);
ascir::FusedScheduledResult fused_schedule_result;
fused_schedule_result.fused_graph_name = af::AscendString(graph.GetName().c_str());
fused_schedule_result.input_nodes.push_back(x);
fused_schedule_result.output_nodes.push_back(y);
fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
std::string mallocdef = this->PGOSearchTensorMallocDef(fused_schedule_result);
const std::string expect = R"( size_t input0_size = 2;
ret = aclrtMalloc(&input0, input0_size, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_SUCCESS) {
DLOGE("aclrtMalloc input0 failed. ERROR: %d", ret);
return FAILED;
}
size_t output0_size = 2;
ret = aclrtMalloc(&output0, output0_size, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_SUCCESS) {
DLOGE("aclrtMalloc output0 failed. ERROR: %d", ret);
return FAILED;
}
)";
ASSERT_EQ(mallocdef, expect);
}
TEST_F(TestCodegenTiling, TestCalculateTensorMemorySizeStrWithNoRepeatsOrStrides) {
af::AscGraph graph("test_graph");
auto s0 = graph.CreateSizeVar("s0");
auto z0 = graph.CreateAxis("z0", af::ops::One);
af::ascir_op::Data x_op("x", graph);
x_op.ir_attr.SetIndex(0);
auto x = graph.FindNode("x");
x->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
x->outputs[0].attr.mem.tensor_id = 0;
x->attr.api.unit = af::ComputeUnit::kUnitNone;
x->outputs[0].attr.repeats = {};
x->outputs[0].attr.strides = {};
std::string memory_size = this->CalculateTensorMemorySizeStr(x->outputs[0]);
const std::string expect = "0";
ASSERT_EQ(memory_size, expect);
}
TEST_F(TestCodegenTiling, TestCalculateTensorMemorySizeStrWithZeroFirstStride) {
af::AscGraph graph("test_graph");
auto s0 = graph.CreateSizeVar("s0");
auto s1 = graph.CreateSizeVar("s1");
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data x_op("x", graph);
x_op.ir_attr.SetIndex(0);
auto x = graph.FindNode("x");
x->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
x->outputs[0].attr.mem.tensor_id = 0;
x->attr.api.unit = af::ComputeUnit::kUnitNone;
x->outputs[0].attr.dtype = ge::DT_FLOAT16;
x->outputs[0].attr.axis = {z0.id, z1.id};
x->outputs[0].attr.repeats = {s0, s1};
x->outputs[0].attr.strides = {af::ops::Zero, af::ops::One};
std::string memory_size = this->CalculateTensorMemorySizeStr(x->outputs[0]);
const std::string expect = std::string(af::sym::Mul(s1, af::Expression::Parse("2")).Simplify().Str().get());
ASSERT_EQ(memory_size, expect);
}
TEST_F(TestCodegenTiling, TestCalculateTensorMemorySizeStrWithOnlyZeroStride) {
af::AscGraph graph("test_graph");
auto s0 = graph.CreateSizeVar("s0");
auto z0 = graph.CreateAxis("z0", s0);
af::ascir_op::Data x_op("x", graph);
x_op.ir_attr.SetIndex(0);
auto x = graph.FindNode("x");
x->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
x->outputs[0].attr.mem.tensor_id = 0;
x->attr.api.unit = af::ComputeUnit::kUnitNone;
x->outputs[0].attr.dtype = ge::DT_FLOAT16;
x->outputs[0].attr.axis = {z0.id};
x->outputs[0].attr.repeats = {s0};
x->outputs[0].attr.strides = {af::ops::Zero};
std::string memory_size = this->CalculateTensorMemorySizeStr(x->outputs[0]);
const std::string expect = std::string(af::sym::Mul(af::ops::One, af::Expression::Parse("2")).Simplify().Str().get());
ASSERT_EQ(memory_size, expect);
}
void CreateMatmulGraph(af::AscGraph &graph, bool is_dynamic = false) {
af::Expression s0;
af::Expression s1;
if (is_dynamic) {
s0 = graph.CreateSizeVar("s0");
s1 = graph.CreateSizeVar("s1");
} else {
s0 = graph.CreateSizeVar(31);
s1 = graph.CreateSizeVar(1);
}
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data data0("data0", graph);
data0.attr.sched.axis = {z0.id, z1.id};
data0.y.dtype = ge::DT_FLOAT16;
*data0.y.axis = {z0.id, z1.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0.y.strides = {s1 ,af::ops::One};
*data0.y.repeats = {s0, s1};
data0.ir_attr.SetIndex(0);
af::ascir_op::Load load0("load0");
load0.attr.sched.axis = {z0.id, z1.id};
load0.x = data0.y;
*load0.y.axis = {z0.id, z1.id};
load0.y.dtype = ge::DT_FLOAT16;
*load0.y.strides = {s1 ,af::ops::One};
*load0.y.repeats = {s0, s1};
af::ascir_op::Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT16;
data1.attr.sched.axis = {z0.id, z1.id};
*data1.y.axis = {z0.id, z1.id};
data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data1.y.repeats = {af::ops::One, af::ops::One};
*data1.y.strides = {af::ops::Zero, af::ops::Zero};
data1.ir_attr.SetIndex(1);
af::ascir_op::Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z0.id, z1.id};
load1.y.dtype = ge::DT_FLOAT16;
*load1.y.axis = {z0.id, z1.id};
*load1.y.strides = {af::ops::Zero, af::ops::Zero};
*load1.y.repeats = {af::ops::One, af::ops::One};
af::ascir_op::BatchMatMul matmul("matmul");
matmul.attr.sched.axis = {z0.id, z1.id};
matmul.x1 = load0.y;
matmul.x2 = load1.y;
matmul.y.dtype = ge::DT_FLOAT;
*matmul.y.axis = {z0.id, z1.id};
*matmul.y.repeats = {s0, s1};
*matmul.y.strides = {s1, af::ops::One};
matmul.attr.api.compute_type = af::ComputeType::kComputeCube;
matmul.ir_attr.SetAdj_x1(1);
matmul.ir_attr.SetAdj_x2(0);
matmul.ir_attr.SetHas_relu(1);
matmul.ir_attr.SetEnable_hf32(1);
matmul.ir_attr.SetOffset_x(6);
af::ascir_op::Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id};
store_op.x = matmul.y;
*store_op.y.axis = {z0.id, z1.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {s1 ,af::ops::One};
*store_op.y.repeats = {s0, s1};
store_op.ir_attr.SetOffset(af::ops::One);
af::ascir_op::Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
}
TEST_F(TestCodegenTiling, TestMatmulElemwiseFuse) {
af::AscGraph graph("matmul_elemwise_pro");
auto s0 = graph.CreateSizeVar(64);
auto s1 = graph.CreateSizeVar(64);
auto z0 = graph.CreateAxis("z0", s0);
auto z1 = graph.CreateAxis("z1", s1);
af::ascir_op::Data data0("data0", graph);
data0.attr.sched.axis = {z0.id, z1.id};
data0.y.dtype = ge::DT_FLOAT;
*data0.y.axis = {z0.id, z1.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0.y.strides = {s1 ,af::ops::One};
*data0.y.repeats = {s0, s1};
data0.ir_attr.SetIndex(0);
af::ascir_op::Load load0("load0");
load0.attr.sched.axis = {z0.id, z1.id};
load0.x = data0.y;
*load0.y.axis = {z0.id, z1.id};
load0.y.dtype = ge::DT_FLOAT;
*load0.y.strides = {s1 ,af::ops::One};
*load0.y.repeats = {s0, s1};
af::ascir_op::Abs abs("abs");
graph.AddNode(abs);
abs.x = load0.y;
abs.attr.sched.axis = {z0.id, z1.id};
abs.y.dtype = ge::DT_FLOAT;
*abs.y.axis = {z0.id, z1.id};
*abs.y.repeats = {s0, s1};
*abs.y.strides = {s1, af::ops::One};
abs.attr.api.compute_type = af::ComputeType::kComputeElewise;
af::ascir_op::Scalar scalar0("scalar0", graph);
scalar0.attr.sched.axis = {z0.id, z1.id};
scalar0.ir_attr.SetValue("0");
scalar0.y.dtype = ge::DT_FLOAT;
*scalar0.y.axis = {z0.id, z1.id};
*scalar0.y.repeats = {af::ops::One, af::ops::One};
*scalar0.y.strides = {af::ops::Zero, af::ops::Zero};
af::ascir_op::Broadcast broadcast0("broadcast0");
broadcast0.x = scalar0.y;
broadcast0.attr.sched.axis = {z0.id, z1.id};
*broadcast0.y.axis = {z0.id, z1.id};
broadcast0.y.dtype = ge::DT_FLOAT;
*broadcast0.y.repeats = {af::ops::One, s1};
*broadcast0.y.strides = {af::ops::Zero, af::ops::One};
af::ascir_op::Broadcast broadcast1("broadcast1");
broadcast1.x = broadcast0.y;
broadcast1.attr.sched.axis = {z0.id, z1.id};
*broadcast1.y.axis = {z0.id, z1.id};
broadcast1.y.dtype = ge::DT_FLOAT;
*broadcast1.y.repeats = {s0, s1};
*broadcast1.y.strides = {s1, af::ops::One};
af::ascir_op::Add add_op("add");
add_op.attr.sched.axis = {z0.id, z1.id};
add_op.x1 = abs.y;
add_op.x2 = broadcast1.y;
add_op.y.dtype = ge::DT_FLOAT;
*add_op.y.axis = {z0.id, z1.id};
*add_op.y.repeats = {s0, s1};
*add_op.y.strides = {s1, af::ops::One};
af::ascir_op::Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT;
data1.attr.sched.axis = {z0.id, z1.id};
*data1.y.axis = {z0.id, z1.id};
data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data1.y.repeats = {af::ops::One, af::ops::One};
*data1.y.strides = {af::ops::Zero, af::ops::Zero};
data1.ir_attr.SetIndex(1);
af::ascir_op::Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z0.id, z1.id};
load1.y.dtype = ge::DT_FLOAT;
*load1.y.axis = {z0.id, z1.id};
*load1.y.strides = {af::ops::Zero, af::ops::Zero};
*load1.y.repeats = {af::ops::One, af::ops::One};
af::ascir_op::Broadcast broadcast2("broadcast2");
broadcast2.x = load1.y;
broadcast2.attr.sched.axis = {z0.id, z1.id};
*broadcast2.y.axis = {z0.id, z1.id};
broadcast2.y.dtype = ge::DT_FLOAT;
*broadcast2.y.repeats = {af::ops::One, s1};
*broadcast2.y.strides = {af::ops::Zero, af::ops::One};
af::ascir_op::Broadcast broadcast3("broadcast3");
broadcast3.x = broadcast2.y;
broadcast3.attr.sched.axis = {z0.id, z1.id};
*broadcast3.y.axis = {z0.id, z1.id};
broadcast3.y.dtype = ge::DT_FLOAT;
*broadcast3.y.repeats = {s0, s1};
*broadcast3.y.strides = {s1, af::ops::One};
af::ascir_op::Mul mul("mul");
mul.attr.sched.axis = {z0.id, z1.id};
mul.x1 = add_op.y;
mul.x2 = broadcast3.y;
mul.y.dtype = ge::DT_FLOAT;
*mul.y.axis = {z0.id, z1.id};
*mul.y.repeats = {s0, s1};
*mul.y.strides = {s1, af::ops::One};
af::ascir_op::Store store_op("store");
store_op.attr.sched.axis = {z0.id, z1.id};
store_op.x = mul.y;
*store_op.y.axis = {z0.id, z1.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {s1 ,af::ops::One};
*store_op.y.repeats = {s0, s1};
af::ascir_op::Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
auto x1Local = graph.FindNode("data0");
x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
af::AscGraph mm_graph("mutmul");
CreateMatmulGraph(mm_graph);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
ascir::FusedScheduledResult fused_schedule_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(mm_graph);
fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;;
const std::map<std::string, std::string> shape_info;
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::fstream tiling_func("Mutmul_fuse_tiling_func.cpp", std::ios::out);
tiling_func << res["tiling_def_and_tiling_const"];
auto pos = res["tiling_def_and_tiling_const"].find("extern \"C\" int64_t FindBestTilingKey");
ASSERT_NE(pos, std::string::npos);
}
TEST_F(TestCodegenTiling, TestConv2DOffsetFuse) {
af::AscGraph graph("conv2d_offset_elemwise_pro");
CreateElemwiseGraphWithRelu(graph);
af::AscGraph conv2d_offset_graph("conv2d_offset");
CreateConv2DOffsetGraph(conv2d_offset_graph);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
ascir::FusedScheduledResult fused_schedule_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(conv2d_offset_graph);
fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
const std::map<std::string, std::string> shape_info;
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::fstream tiling_func("Conv2d_offset_fuse_tiling_func.cpp", std::ios::out);
tiling_func << res["tiling_def_and_tiling_const"];
VerifyConv2DOffsetTiling(res);
}
TEST_F(TestCodegenTiling, TestConv2DOffsetBiasFuse) {
af::AscGraph graph("conv2d_offset_bias_elemwise_pro");
CreateElemwiseGraphWithRelu(graph);
af::AscGraph conv2d_offset_bias_graph("conv2d_offset_bias");
CreateConv2DOffsetBiasGraph(conv2d_offset_bias_graph);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
ascir::FusedScheduledResult fused_schedule_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(conv2d_offset_bias_graph);
fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
const std::map<std::string, std::string> shape_info;
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::fstream tiling_func("Conv2d_offset_bias_fuse_tiling_func.cpp", std::ios::out);
tiling_func << res["tiling_def_and_tiling_const"];
VerifyTilingCodeBasic(res);
}
namespace {
static ascir::FusedScheduledResult GenMultiGroupFusedScheduleResult() {
af::AscGraph graph1("graph1");
af::AscGraph graph2("graph2");
ascir::ScheduleGroup schedule_group1;
schedule_group1.impl_graphs.push_back(graph1);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(graph1);
schedule_group2.impl_graphs.push_back(graph2);
ascir::ScheduledResult schedule_result;
schedule_result.schedule_groups.push_back(schedule_group1);
schedule_result.schedule_groups.push_back(schedule_group2);
ascir::FusedScheduledResult fused_schedule_result;
std::vector<ascir::ScheduledResult> graph0_results = {schedule_result};
fused_schedule_result.node_idx_to_scheduled_results.emplace_back(std::move(graph0_results));
return fused_schedule_result;
}
static ascir::ImplGraph GenGraphWithSizeVar(const std::string &graph_name, const std::string &var_name) {
ascir::ImplGraph graph(graph_name.c_str());
auto size = graph.CreateSizeVar(var_name.c_str());
(void)graph.CreateAxis("z0", size);
return graph;
}
static ascir::FusedScheduledResult GenMultiGroupFusedScheduleResultWithSizeVar(const std::string &var_name) {
ascir::ScheduleGroup schedule_group1;
schedule_group1.impl_graphs.push_back(GenGraphWithSizeVar("graph1", var_name));
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(GenGraphWithSizeVar("graph2", var_name));
ascir::ScheduledResult schedule_result;
schedule_result.schedule_groups.push_back(schedule_group1);
schedule_result.schedule_groups.push_back(schedule_group2);
ascir::FusedScheduledResult fused_schedule_result;
fused_schedule_result.origin_vars.push_back(af::Symbol(var_name.c_str()));
fused_schedule_result.node_idx_to_scheduled_results.push_back({schedule_result});
return fused_schedule_result;
}
}
TEST_F(TestCodegenTiling, GenerateForInductorGetTilingDataReprShouldContainStableFields) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("GetTilingDataRepr returns a valid C++ designated initializer string"),
std::string::npos);
EXPECT_NE(tiling_impl.find("emit_field(\"block_dim\", tiling_data->get_block_dim()"), std::string::npos);
EXPECT_NE(tiling_impl.find("emit_field(\"corenum\", tiling_data->get_corenum()"), std::string::npos);
EXPECT_NE(tiling_impl.find("emit_field(\"ub_size\", tiling_data->get_ub_size()"), std::string::npos);
EXPECT_NE(tiling_impl.find("emit_field(\"hbm_size\", tiling_data->get_hbm_size()"), std::string::npos);
EXPECT_TRUE(tiling_impl.find("emit_field(\"tiling_key\"") != std::string::npos ||
tiling_impl.find("emit_field(\"graph0_tiling_key\"") != std::string::npos);
}
TEST_F(TestCodegenTiling, GenerateForInductorGetTilingDataReprShouldKeepWorkspaceBeforeSymbols) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s1"), af::Symbol("s0")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
size_t tiling_key_pos = tiling_impl.find("emit_field(\"tiling_key\"");
if (tiling_key_pos == std::string::npos) {
tiling_key_pos = tiling_impl.find("emit_field(\"graph0_tiling_key\"");
}
const auto s0_pos = tiling_impl.find("emit_field(\"s0\"");
const auto s1_pos = tiling_impl.find("emit_field(\"s1\"");
ASSERT_NE(tiling_key_pos, std::string::npos);
if (s0_pos != std::string::npos) {
EXPECT_LT(tiling_key_pos, s0_pos);
}
if (s1_pos != std::string::npos) {
EXPECT_LT(tiling_key_pos, s1_pos);
}
}
TEST_F(TestCodegenTiling, GenerateForInductorGetTilingDataReprShouldUseGraphLevelTilingKeysForMultiGroup) {
auto fused_schedule_result = GenMultiGroupFusedScheduleResult();
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("emit_field(\"graph0_tiling_key\""),
std::string::npos);
EXPECT_EQ(tiling_impl.find("emit_field(\"tiling_key\""), std::string::npos);
}
TEST_F(TestCodegenTiling, GenerateForInductorGetTilingDataReprShouldKeepZeroValuedStableFields) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("emit_field(\"block_dim\", tiling_data->get_block_dim()"), std::string::npos);
EXPECT_NE(tiling_impl.find("emit_field(\"corenum\", tiling_data->get_corenum()"), std::string::npos);
EXPECT_NE(tiling_impl.find("emit_field(\"ub_size\", tiling_data->get_ub_size()"), std::string::npos);
EXPECT_NE(tiling_impl.find("emit_field(\"hbm_size\", tiling_data->get_hbm_size()"), std::string::npos);
EXPECT_EQ(tiling_impl.find("if (tiling_data->get_block_dim() != 0)"), std::string::npos);
EXPECT_EQ(tiling_impl.find("if (tiling_data->get_corenum() != 0)"), std::string::npos);
EXPECT_EQ(tiling_impl.find("if (tiling_data->get_ub_size() != 0)"), std::string::npos);
EXPECT_EQ(tiling_impl.find("if (tiling_data->get_hbm_size() != 0)"), std::string::npos);
}
TEST_F(TestCodegenTiling, GenerateForInductorShouldContainTopnMainOutputAbi) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("extern \"C\" int64_t GenerateTopnSolutions("), std::string::npos);
EXPECT_NE(tiling_impl.find("GetTilingDataRepr("), std::string::npos);
}
TEST_F(TestCodegenTiling, GenerateForInductorShouldUseGetTilingDataReprAsTilingDataValidationAid) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("GetTilingDataRepr("), std::string::npos);
}
TEST_F(TestCodegenTiling, GenerateForInductorTopnAbiShouldNotEmitOutputConfigsMetadataLogic) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_EQ(tiling_impl.find("std::map<std::string, std::string> solution_config;"), std::string::npos);
EXPECT_EQ(tiling_impl.find("configs.push_back(solution_config);"), std::string::npos);
EXPECT_EQ(tiling_impl.find("solution_config[\"canonical_repr\"]"), std::string::npos);
EXPECT_EQ(tiling_impl.find("solution_config[\"topn_status\"]"), std::string::npos);
}
TEST_F(TestCodegenTiling, MultiGroupInductorShouldContainTopnMainOutputAbi) {
auto fused_schedule_result = GenMultiGroupFusedScheduleResult();
auto res = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(res.find(codegen::kTilingDefAndConstIdentify) != res.end());
const auto &tiling_impl = res.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("extern \"C\" int64_t GenerateTopnSolutions("), std::string::npos);
EXPECT_NE(tiling_impl.find("const std::vector<std::map<std::string, std::string>> &input_configs"), std::string::npos);
EXPECT_NE(tiling_impl.find("std::vector<AutofuseTilingData> &tiling_datas"), std::string::npos);
EXPECT_NE(tiling_impl.find("std::vector<int64_t> &workspaces"), std::string::npos);
EXPECT_NE(tiling_impl.find("std::vector<int64_t> &block_dims"), std::string::npos);
}
TEST_F(TestCodegenTiling, MultiGroupTopnShouldSetShapeDimOnGroupTilingData) {
auto fused_schedule_result = GenMultiGroupFusedScheduleResultWithSizeVar("ks0");
auto res = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(res.find(codegen::kTilingDefAndConstIdentify) != res.end());
const auto &tiling_impl = res.at(codegen::kTilingDefAndConstIdentify);
EXPECT_EQ(tiling_impl.find("search_tiling.set_ks0("), std::string::npos);
EXPECT_NE(tiling_impl.find("search_tiling.graph0_result0_g0_tiling_data.set_ks0(ks0);"), std::string::npos);
EXPECT_NE(tiling_impl.find("search_tiling.graph0_result0_g1_tiling_data.set_ks0(ks0);"), std::string::npos);
}
TEST_F(TestCodegenTiling, MultiGroupInductorShouldContainReprAbi) {
auto fused_schedule_result = GenMultiGroupFusedScheduleResult();
auto res = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(res.find(codegen::kTilingDefAndConstIdentify) != res.end());
const auto &tiling_impl = res.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("std::string GetTilingDataRepr(const AutofuseTilingData *tiling_data)"), std::string::npos);
}
TEST_F(TestCodegenTiling, TestMatmulElemwiseDynamicShapeFuse) {
af::AscGraph graph("matmul_elemwise_pro");
CreateMatmulElemwiseDynamicGraph(graph);
af::AscGraph mm_graph("mutmul");
CreateMatmulGraph(mm_graph, true);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
ascir::FusedScheduledResult fused_schedule_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(mm_graph);
fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
std::map<std::string, std::string> shape_info;
shape_info["s0"] = "64";
shape_info["s1"] = "64";
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::fstream tiling_func("Mutmul_fuse_tiling_func.cpp", std::ios::out);
tiling_func << res["tiling_def_and_tiling_const"];
auto pos = res["tiling_def_and_tiling_const"].find("TilingResult result = wrapper.DoMatMulTiling(");
ASSERT_NE(pos, std::string::npos);
}
void CreateConv2dGraph(af::AscGraph &graph, bool is_dynamic = false) {
af::Expression n, c, h, w;
if (is_dynamic) {
n = graph.CreateSizeVar("n");
c = graph.CreateSizeVar("c");
h = graph.CreateSizeVar("h");
w = graph.CreateSizeVar("w");
} else {
n = graph.CreateSizeVar(1);
c = graph.CreateSizeVar(64);
h = graph.CreateSizeVar(56);
w = graph.CreateSizeVar(56);
}
auto z_n = graph.CreateAxis("z_n", n);
auto z_c = graph.CreateAxis("z_c", c);
auto z_h = graph.CreateAxis("z_h", h);
auto z_w = graph.CreateAxis("z_w", w);
af::ascir_op::Data data0("data0", graph);
data0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data0.y.dtype = ge::DT_FLOAT16;
*data0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0.y.strides = {c*h*w, h*w, w, af::ops::One};
*data0.y.repeats = {n, c, h, w};
data0.ir_attr.SetIndex(0);
af::ascir_op::Load load0("load0");
load0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load0.x = data0.y;
*load0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load0.y.dtype = ge::DT_FLOAT16;
*load0.y.strides = {c*h*w, h*w, w, af::ops::One};
*load0.y.repeats = {n, c, h, w};
af::ascir_op::Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT16;
data1.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*data1.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data1.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
*data1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
data1.ir_attr.SetIndex(1);
af::ascir_op::Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load1.y.dtype = ge::DT_FLOAT16;
*load1.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*load1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
*load1.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
af::ascir_op::Conv2D conv2d("conv2d");
conv2d.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
conv2d.x = load0.y;
conv2d.filter = load1.y;
conv2d.y.dtype = ge::DT_FLOAT;
*conv2d.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*conv2d.y.repeats = {n, c, h, w};
*conv2d.y.strides = {c*h*w, h*w, w, af::ops::One};
conv2d.attr.api.compute_type = af::ComputeType::kComputeCube;
conv2d.ir_attr.SetStrides({1, 1});
conv2d.ir_attr.SetPads({1, 1, 1, 1});
conv2d.ir_attr.SetDilations({1, 1});
conv2d.ir_attr.SetGroups(1);
conv2d.ir_attr.SetData_format("NCHW");
conv2d.ir_attr.SetOffset_x(0);
conv2d.ir_attr.SetEnable_hf32(false);
af::ascir_op::Store store_op("store");
store_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
store_op.x = conv2d.y;
*store_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {c*h*w, h*w, w, af::ops::One};
*store_op.y.repeats = {n, c, h, w};
af::ascir_op::Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
}
TEST_F(TestCodegenTiling, TestConv2dElemwiseFuse) {
af::AscGraph graph("conv2d_elemwise_pro");
CreateElemwiseGraphWithAbsAndAddStatic(graph);
af::AscGraph conv2d_graph("conv2d");
CreateConv2dGraph(conv2d_graph, false);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
ascir::FusedScheduledResult fused_schedule_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(conv2d_graph);
fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
const std::map<std::string, std::string> shape_info;
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::fstream tiling_func("Conv2d_fuse_tiling_func.cpp", std::ios::out);
tiling_func << res["tiling_def_and_tiling_const"];
VerifyConv2dElemwiseTiling(res);
}
TEST_F(TestCodegenTiling, TestConv2dElemwiseDynamicShapeFuse) {
af::AscGraph graph("conv2d_elemwise_dynamic_pro");
CreateElemwiseGraphWithReluDynamic(graph);
af::AscGraph conv2d_graph("conv2d_dynamic");
CreateConv2dGraph(conv2d_graph, true);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
ascir::FusedScheduledResult fused_schedule_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(conv2d_graph);
fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
std::map<std::string, std::string> shape_info;
shape_info["n"] = "1";
shape_info["c"] = "64";
shape_info["h"] = "56";
shape_info["w"] = "56";
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::fstream tiling_func("Conv2d_dynamic_fuse_tiling_func.cpp", std::ios::out);
tiling_func << res["tiling_def_and_tiling_const"];
VerifyDynamicShapeTiling(res);
auto n_val_pos = res["tiling_def_and_tiling_const"].find("auto n = 1;");
ASSERT_NE(n_val_pos, std::string::npos);
auto c_val_pos = res["tiling_def_and_tiling_const"].find("auto c = 64;");
ASSERT_NE(c_val_pos, std::string::npos);
auto h_val_pos = res["tiling_def_and_tiling_const"].find("auto h = 56;");
ASSERT_NE(h_val_pos, std::string::npos);
auto w_val_pos = res["tiling_def_and_tiling_const"].find("auto w = 56;");
ASSERT_NE(w_val_pos, std::string::npos);
auto dfx_pos = res["tiling_def_and_tiling_const"].find("extern \"C\" ge::graphStatus DfxInputSymbolInfo");
ASSERT_NE(dfx_pos, std::string::npos);
}
void CreateConv2dBiasGraph(af::AscGraph &graph, bool is_dynamic = false) {
af::Expression n, c, h, w;
if (is_dynamic) {
n = graph.CreateSizeVar("n");
c = graph.CreateSizeVar("c");
h = graph.CreateSizeVar("h");
w = graph.CreateSizeVar("w");
} else {
n = graph.CreateSizeVar(1);
c = graph.CreateSizeVar(64);
h = graph.CreateSizeVar(56);
w = graph.CreateSizeVar(56);
}
auto z_n = graph.CreateAxis("z_n", n);
auto z_c = graph.CreateAxis("z_c", c);
auto z_h = graph.CreateAxis("z_h", h);
auto z_w = graph.CreateAxis("z_w", w);
af::ascir_op::Data data0("data0", graph);
data0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data0.y.dtype = ge::DT_FLOAT16;
*data0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0.y.strides = {c*h*w, h*w, w, af::ops::One};
*data0.y.repeats = {n, c, h, w};
data0.ir_attr.SetIndex(0);
af::ascir_op::Load load0("load0");
load0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load0.x = data0.y;
*load0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load0.y.dtype = ge::DT_FLOAT16;
*load0.y.strides = {c*h*w, h*w, w, af::ops::One};
*load0.y.repeats = {n, c, h, w};
af::ascir_op::Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT16;
data1.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*data1.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data1.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
*data1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
data1.ir_attr.SetIndex(1);
af::ascir_op::Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
load1.y.dtype = ge::DT_FLOAT16;
*load1.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*load1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
*load1.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
af::ascir_op::Data data2("data2", graph);
data2.y.dtype = ge::DT_FLOAT;
data2.attr.sched.axis = {z_c.id};
*data2.y.axis = {z_c.id};
data2.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data2.y.repeats = {c};
*data2.y.strides = {af::ops::One};
data2.ir_attr.SetIndex(2);
af::ascir_op::Load load2("load2");
load2.x = data2.y;
load2.attr.sched.axis = {z_c.id};
load2.y.dtype = ge::DT_FLOAT;
*load2.y.axis = {z_c.id};
*load2.y.strides = {af::ops::One};
*load2.y.repeats = {c};
af::ascir_op::Conv2DBias conv2d_bias("conv2d_bias");
conv2d_bias.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
conv2d_bias.x = load0.y;
conv2d_bias.filter = load1.y;
conv2d_bias.bias = load2.y;
conv2d_bias.y.dtype = ge::DT_FLOAT;
*conv2d_bias.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
*conv2d_bias.y.repeats = {n, c, h, w};
*conv2d_bias.y.strides = {c*h*w, h*w, w, af::ops::One};
conv2d_bias.attr.api.compute_type = af::ComputeType::kComputeCube;
conv2d_bias.ir_attr.SetStrides({1, 1});
conv2d_bias.ir_attr.SetPads({1, 1, 1, 1});
conv2d_bias.ir_attr.SetDilations({1, 1});
conv2d_bias.ir_attr.SetGroups(1);
conv2d_bias.ir_attr.SetData_format("NCHW");
conv2d_bias.ir_attr.SetOffset_x(0);
conv2d_bias.ir_attr.SetEnable_hf32(false);
af::ascir_op::Store store_op("store");
store_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
store_op.x = conv2d_bias.y;
*store_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {c*h*w, h*w, w, af::ops::One};
*store_op.y.repeats = {n, c, h, w};
af::ascir_op::Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
}
TEST_F(TestCodegenTiling, TestConv2dBiasElemwiseFuse) {
af::AscGraph graph("conv2d_bias_elemwise_pro");
CreateElemwiseGraphWithRelu(graph);
af::AscGraph conv2d_bias_graph("conv2d_bias");
CreateConv2dBiasGraph(conv2d_bias_graph, false);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
ascir::FusedScheduledResult fused_schedule_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(conv2d_bias_graph);
fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
const std::map<std::string, std::string> shape_info;
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::fstream tiling_func("Conv2d_bias_fuse_tiling_func.cpp", std::ios::out);
tiling_func << res["tiling_def_and_tiling_const"];
VerifyConv2DBiasElemwiseTiling(res);
}
TEST_F(TestCodegenTiling, TestConv2dBiasElemwiseDynamicShapeFuse) {
af::AscGraph graph("conv2d_bias_elemwise_dynamic_pro");
CreateElemwiseGraphWithMulDynamic(graph);
af::AscGraph conv2d_bias_graph("conv2d_bias_dynamic");
CreateConv2dBiasGraph(conv2d_bias_graph, true);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
ascir::FusedScheduledResult fused_schedule_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(conv2d_bias_graph);
fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
std::map<std::string, std::string> shape_info;
shape_info["n"] = "1";
shape_info["c"] = "64";
shape_info["h"] = "56";
shape_info["w"] = "56";
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::fstream tiling_func("Conv2d_bias_dynamic_fuse_tiling_func.cpp", std::ios::out);
tiling_func << res["tiling_def_and_tiling_const"];
VerifyDynamicShapeTiling(res);
auto n_val_pos = res["tiling_def_and_tiling_const"].find("auto n = 1;");
ASSERT_NE(n_val_pos, std::string::npos);
auto c_val_pos = res["tiling_def_and_tiling_const"].find("auto c = 64;");
ASSERT_NE(c_val_pos, std::string::npos);
auto h_val_pos = res["tiling_def_and_tiling_const"].find("auto h = 56;");
ASSERT_NE(h_val_pos, std::string::npos);
auto w_val_pos = res["tiling_def_and_tiling_const"].find("auto w = 56;");
ASSERT_NE(w_val_pos, std::string::npos);
}
void CreateBatchMatmulDynamicGraph(af::AscGraph &graph) {
auto batch = graph.CreateSizeVar("batch");
auto m = graph.CreateSizeVar("m");
auto n = graph.CreateSizeVar("n");
auto k = graph.CreateSizeVar("k");
auto z_batch = graph.CreateAxis("z_batch", batch);
auto z_m = graph.CreateAxis("z_m", m);
auto z_n = graph.CreateAxis("z_n", n);
auto z_k = graph.CreateAxis("z_k", k);
af::ascir_op::Data data0("data0", graph);
data0.attr.sched.axis = {z_batch.id, z_m.id, z_k.id};
data0.y.dtype = ge::DT_FLOAT16;
*data0.y.axis = {z_batch.id, z_m.id, z_k.id};
data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data0.y.strides = {m*k, k, af::ops::One};
*data0.y.repeats = {batch, m, k};
data0.ir_attr.SetIndex(0);
af::ascir_op::Load load0("load0");
load0.attr.sched.axis = {z_batch.id, z_m.id, z_k.id};
load0.x = data0.y;
*load0.y.axis = {z_batch.id, z_m.id, z_k.id};
load0.y.dtype = ge::DT_FLOAT16;
*load0.y.strides = {m*k, k, af::ops::One};
*load0.y.repeats = {batch, m, k};
af::ascir_op::Data data1("data1", graph);
data1.y.dtype = ge::DT_FLOAT16;
data1.attr.sched.axis = {z_batch.id, z_k.id, z_n.id};
*data1.y.axis = {z_batch.id, z_k.id, z_n.id};
data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
*data1.y.repeats = {batch, k, n};
*data1.y.strides = {k*n, n, af::ops::One};
data1.ir_attr.SetIndex(1);
af::ascir_op::Load load1("load1");
load1.x = data1.y;
load1.attr.sched.axis = {z_batch.id, z_k.id, z_n.id};
load1.y.dtype = ge::DT_FLOAT16;
*load1.y.axis = {z_batch.id, z_k.id, z_n.id};
*load1.y.strides = {k*n, n, af::ops::One};
*load1.y.repeats = {batch, k, n};
af::ascir_op::BatchMatMul batch_matmul("batch_matmul");
batch_matmul.attr.sched.axis = {z_batch.id, z_m.id, z_n.id};
batch_matmul.x1 = load0.y;
batch_matmul.x2 = load1.y;
batch_matmul.y.dtype = ge::DT_FLOAT;
*batch_matmul.y.axis = {z_batch.id, z_m.id, z_n.id};
*batch_matmul.y.repeats = {batch, m, n};
*batch_matmul.y.strides = {m*n, n, af::ops::One};
batch_matmul.attr.api.compute_type = af::ComputeType::kComputeCube;
batch_matmul.ir_attr.SetAdj_x1(0);
batch_matmul.ir_attr.SetAdj_x2(0);
batch_matmul.ir_attr.SetHas_relu(1);
batch_matmul.ir_attr.SetEnable_hf32(true);
batch_matmul.ir_attr.SetOffset_x(6);
af::ascir_op::Store store_op("store");
store_op.attr.sched.axis = {z_batch.id, z_m.id, z_n.id};
store_op.x = batch_matmul.y;
*store_op.y.axis = {z_batch.id, z_m.id, z_n.id};
store_op.y.dtype = ge::DT_FLOAT;
*store_op.y.strides = {m*n, n, af::ops::One};
*store_op.y.repeats = {batch, m, n};
af::ascir_op::Output output_op("output");
output_op.x = store_op.y;
output_op.y.dtype = ge::DT_FLOAT;
output_op.ir_attr.SetIndex(0);
optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
}
TEST_F(TestCodegenTiling, TestBatchMatmulDynamicShapeFuse) {
af::AscGraph graph("batch_matmul_dynamic_pro");
CreateBatchMatmulElemwiseDynamicGraph(graph);
af::AscGraph mm_graph("batch_matmul_dynamic");
CreateBatchMatmulDynamicGraph(mm_graph);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
ascir::FusedScheduledResult fused_schedule_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(mm_graph);
fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
std::map<std::string, std::string> shape_info;
shape_info["batch"] = "16";
shape_info["m"] = "64";
shape_info["n"] = "64";
shape_info["k"] = "64";
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::fstream tiling_func("Batch_matmul_dynamic_fuse_tiling_func.cpp", std::ios::out);
tiling_func << res["tiling_def_and_tiling_const"];
auto pos = res["tiling_def_and_tiling_const"].find("TilingResult result = wrapper.DoMatMulTiling(");
ASSERT_NE(pos, std::string::npos);
auto dynamic_pos = res["tiling_def_and_tiling_const"].find("AutofuseIsStaticShape() {\n return false;");
ASSERT_NE(dynamic_pos, std::string::npos);
}
TEST_F(TestCodegenTiling, TestConv2dWithGroups) {
af::AscGraph graph("conv2d_groups_pro");
CreateElemwiseGraphWithRelu(graph);
af::AscGraph conv2d_graph("conv2d_groups");
CreateConv2DGraphWithGroups(conv2d_graph);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
ascir::FusedScheduledResult fused_schedule_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(conv2d_graph);
fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
const std::map<std::string, std::string> shape_info;
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::fstream tiling_func("Conv2d_groups_tiling_func.cpp", std::ios::out);
tiling_func << res["tiling_def_and_tiling_const"];
VerifyTilingCodeBasic(res);
}
TEST_F(TestCodegenTiling, TestConv2dWithDilation) {
af::AscGraph graph("conv2d_dilation_pro");
CreateElemwiseGraphWithRelu(graph);
af::AscGraph conv2d_graph("conv2d_dilation");
CreateConv2DGraphWithDilation(conv2d_graph);
optimize::Optimizer optimizer(optimize::OptimizerOptions{});
ascir::FusedScheduledResult fused_schedule_result;
EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);
ascir::ScheduleGroup schedule_group2;
schedule_group2.impl_graphs.push_back(conv2d_graph);
fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
const std::map<std::string, std::string> shape_info;
auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
std::fstream tiling_func("Conv2d_dilation_tiling_func.cpp", std::ios::out);
tiling_func << res["tiling_def_and_tiling_const"];
VerifyTilingCodeBasic(res);
}
TEST_F(TestCodegenTiling, ProtocolHeaderShouldContainMinimalRequestResponse) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("struct GetTilingRequest"), std::string::npos);
EXPECT_NE(tiling_impl.find("struct CandidateSolution"), std::string::npos);
EXPECT_NE(tiling_impl.find("struct GetTilingResponse"), std::string::npos);
}
TEST_F(TestCodegenTiling, ProtocolRequestShouldContainMinimalFields) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("const std::vector<std::map<std::string, std::string>> *input_configs = nullptr;"),
std::string::npos);
EXPECT_NE(tiling_impl.find("ResLimit *res_limit = nullptr;"), std::string::npos);
EXPECT_NE(tiling_impl.find("int64_t topn = 1;"), std::string::npos);
}
TEST_F(TestCodegenTiling, CandidateSolutionShouldContainOnlyMinimalFields) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("double modeled_perf = 0.0;"), std::string::npos);
EXPECT_NE(tiling_impl.find("bool is_default = false;"), std::string::npos);
EXPECT_NE(tiling_impl.find("std::string canonical_repr;"), std::string::npos);
EXPECT_EQ(tiling_impl.find("candidate.workspace ="), std::string::npos);
EXPECT_EQ(tiling_impl.find("candidate.block_dim ="), std::string::npos);
}
TEST_F(TestCodegenTiling, ProtocolShouldNotContainBannedFields) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_EQ(tiling_impl.find("schedule_result_key"), std::string::npos);
EXPECT_EQ(tiling_impl.find("group_case_ids"), std::string::npos);
}
TEST_F(TestCodegenTiling, GetTilingShouldEnterMainSearchNotForKey) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("int64_t GetTopnCandidateSolutions(const GetTilingRequest &request, GetTilingResponse &response)"),
std::string::npos);
EXPECT_NE(tiling_impl.find("optiling::PGOSearchTilingKey("), std::string::npos);
const std::string kSolverFunc = "solver_func";
if (tiling_files.find(kSolverFunc) != tiling_files.end()) {
std::string all_tiling_code = tiling_impl + tiling_files.at(kSolverFunc);
EXPECT_NE(all_tiling_code.find("SearchAllTilingbyCaseId("), std::string::npos);
EXPECT_NE(all_tiling_code.find("ExecutePGOSolver("), std::string::npos);
}
EXPECT_EQ(tiling_impl.find("for (int64_t key = 0; key < GetTilingKeyCount(); ++key)"), std::string::npos);
EXPECT_EQ(tiling_impl.find("if (AutofuseTiling(&default_tiling_data"), std::string::npos);
}
TEST_F(TestCodegenTiling, NoEarlyStopByTopnOrDefault) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_EQ(tiling_impl.find("if (request.topn == 1) { return"), std::string::npos);
EXPECT_EQ(tiling_impl.find("response.candidate_solutions.resize(topn)"), std::string::npos);
EXPECT_EQ(tiling_impl.find("partial_sort"), std::string::npos);
EXPECT_EQ(tiling_impl.find("current_candidate_num >= request.topn"), std::string::npos);
EXPECT_EQ(tiling_impl.find("if (found_default) break"), std::string::npos);
}
TEST_F(TestCodegenTiling, MultiGroupDoesNotCarryWorkspaceMap) {
auto fused_schedule_result = GenMultiGroupFusedScheduleResult();
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_EQ(tiling_impl.find("std::unordered_map<int64_t, uint64_t> workspace_map"), std::string::npos);
}
TEST_F(TestCodegenTiling, BridgeMapsModeledPerfFromFinalComparablePerf) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("double final_modeled_perf ="), std::string::npos);
EXPECT_NE(tiling_impl.find("solution.modeled_perf = final_modeled_perf;"), std::string::npos);
}
TEST_F(TestCodegenTiling, BridgePreservesSingleGroupComparablePerfSemantics) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("CandidateSolution solution;"), std::string::npos);
EXPECT_NE(tiling_impl.find("solution.modeled_perf = final_modeled_perf;"), std::string::npos);
EXPECT_NE(tiling_impl.find("std::isfinite(final_modeled_perf)"), std::string::npos);
}
TEST_F(TestCodegenTiling, BridgeDefaultFromOriginalPath) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("is_default = is_default_config_request && !default_repr.empty()"),
std::string::npos);
EXPECT_NE(tiling_impl.find("default_repr = GetTilingDataRepr(&default_tiling)"), std::string::npos);
EXPECT_EQ(tiling_impl.find("is_default && default_repr.empty()"), std::string::npos);
}
TEST_F(TestCodegenTiling, BridgeDoesNotWriteWorkspaceOrBlockDim) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_EQ(tiling_impl.find("solution.workspace ="), std::string::npos);
EXPECT_EQ(tiling_impl.find("solution.block_dim ="), std::string::npos);
}
TEST_F(TestCodegenTiling, WrapperUsesSelectorForTopn) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("SelectTopnCandidateSolutions(response.candidate_solutions, topn)"), std::string::npos);
}
TEST_F(TestCodegenTiling, WrapperBackfillsWorkspaceAndBlockDim) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("GetWorkspaceSize(sol.tiling_data)"), std::string::npos);
EXPECT_NE(tiling_impl.find("sol.tiling_data.get_block_dim()"), std::string::npos);
}
TEST_F(TestCodegenTiling, TopnWrapperMapsEmptyConfigsToInternalNoConfigPath) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("if (input_configs.empty()) {"), std::string::npos);
EXPECT_NE(tiling_impl.find("request.input_configs = nullptr;"), std::string::npos);
EXPECT_NE(tiling_impl.find("request.input_configs = &input_configs;"), std::string::npos);
EXPECT_EQ(tiling_impl.find("normalized_configs"), std::string::npos);
}
TEST_F(TestCodegenTiling, TopnWrapperConstructsRequestAndInvokesGetTiling) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("GetTilingRequest request;"), std::string::npos);
EXPECT_NE(tiling_impl.find("GetTilingResponse response;"), std::string::npos);
EXPECT_NE(tiling_impl.find("GetTopnCandidateSolutions(request, response)"), std::string::npos);
}
TEST_F(TestCodegenTiling, GetTilingDefaultConfigDetectionIncludesInternalPath) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("const bool internal_no_config_path = (request.input_configs == nullptr);"),
std::string::npos);
EXPECT_NE(tiling_impl.find("const bool is_default_config_request = internal_no_config_path || "),
std::string::npos);
EXPECT_NE(tiling_impl.find("(request.input_configs != nullptr && request.input_configs->size() == 1 && "),
std::string::npos);
}
TEST_F(TestCodegenTiling, GetTilingIteratesConfigsInOrder) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("for (const auto &cfg : configs)"), std::string::npos);
EXPECT_NE(tiling_impl.find("PGOSearchTilingKey("), std::string::npos);
}
TEST_F(TestCodegenTiling, GetTilingInternalPathOnlyForNullptr) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("request.input_configs == nullptr"), std::string::npos);
EXPECT_NE(tiling_impl.find("internal_no_config_path"), std::string::npos);
EXPECT_NE(tiling_impl.find("configs.push_back(SearchConfig())"), std::string::npos);
}
TEST_F(TestCodegenTiling, ParseSearchConfigsParsesExplicitConfigsOnly) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_EQ(tiling_impl.find("if (raws.empty()) {"), std::string::npos);
EXPECT_EQ(tiling_impl.find("return {SearchConfig()};"), std::string::npos);
}
TEST_F(TestCodegenTiling, DefaultFlagComesOnlyFromExplicitDefaultConfigRequest) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("solution.is_default = is_default_config_request"), std::string::npos);
EXPECT_NE(tiling_impl.find("!default_repr.empty()"), std::string::npos);
EXPECT_NE(tiling_impl.find("solution.canonical_repr == default_repr"), std::string::npos);
EXPECT_EQ(tiling_impl.find("solution.is_default = internal_no_config_path"), std::string::npos);
}
TEST_F(TestCodegenTiling, ExplicitDefaultConfigRequestMustFailWithoutDefaultCandidate) {
auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("bool found_default_candidate = false;"), std::string::npos);
EXPECT_NE(tiling_impl.find("if (solution.is_default) { found_default_candidate = true; }"), std::string::npos);
EXPECT_NE(tiling_impl.find("if (is_default_config_request && !found_default_candidate) { return -1; }"), std::string::npos);
}
TEST_F(TestCodegenTiling, MultiGroupUsesGraphLevelTilingKeysAndPerfAggregation) {
auto fused_schedule_result = GenMultiGroupFusedScheduleResult();
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_NE(tiling_impl.find("PGOSearchTilingKey"), std::string::npos);
EXPECT_NE(tiling_impl.find("for (const auto &cfg : configs)"), std::string::npos);
EXPECT_NE(tiling_impl.find("graph0_tiling_key"), std::string::npos);
EXPECT_NE(tiling_impl.find("UpdateCurPerfAndBlockByGroup"), std::string::npos);
}
TEST_F(TestCodegenTiling, MultiGroupMustNotUseBasicPGOSearchTilingKeyOverload) {
auto fused_schedule_result = GenMultiGroupFusedScheduleResult();
auto tiling_files = this->GenerateForInductor(fused_schedule_result);
ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
EXPECT_EQ(tiling_impl.find("PGOSearchTilingKey(raw_candidates, AutofuseTilingData &"), std::string::npos);
EXPECT_EQ(tiling_impl.find("PGOSearchTilingKey(raw_candidates, search_tiling, -1, &search_tiling"),
std::string::npos);
}