/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "gtest/gtest.h"

#include "autofuse_config/auto_fuse_config.h"
#define private public
#include "codegen_tiling.h"
#include "ascir_ops.h"
#include "ascir_ops_utils.h"
#include "schedule_result.h"
#include "runtime_stub.h"
#include "platform_context.h"
#include "ascgraph_info_complete.h"
#include "optimize/optimize.h"
#include <fstream>
#include <filesystem>

namespace {
std::pair<int, std::string> execute_command(const std::string& command) {
  std::array<char, 128> buffer;
  std::string output;
  std::unique_ptr<FILE, decltype(&pclose)> pipe(popen(command.c_str(), "r"), pclose);

  if (!pipe) {
    throw std::runtime_error("Failed to open pipe");
  }

  while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) {
    output += buffer.data();
  }

  return {WEXITSTATUS(pclose(pipe.release())), output};
}

bool CompileCode(const std::string &code){
  std::string cmake_dir = CMAKE_BINARY_DIR;
  // 临时目录
  std::string temp_dir = cmake_dir + "/tests/ut/temp_compile_codegen_tiling";
  std::filesystem::remove_all(temp_dir);
  std::filesystem::create_directories(temp_dir);
  // 生成的 C++ 文件路径
  std::string source_file = temp_dir + "/temp_codegen_infershape.cpp";
  // 生成 C++ 代码
  std::ofstream source_stream(source_file);
  source_stream << code << R"(
        int main() {
            return 0;
        }
    )";
  source_stream.close();
  // 头文件路径
  std::string ascend_install_path = ASCEND_INSTALL_PATH;
  std::string include_path = "-I" + ascend_install_path + "/include/ ";
  std::string link_path = "-L" + ascend_install_path + "/lib64";
  // 编译代码
  std::string compile_command = "g++ -std=c++17 " + include_path + " " + link_path + " " + source_file + " -lc_sec";
  auto [compile_exit_code, compile_output] = execute_command(compile_command);
  // 清理临时目录
  std::filesystem::remove_all(temp_dir);
  return compile_exit_code == 0;
}
}

namespace {

static void CreateElemwiseGraphWithRelu(af::AscGraph &graph) {
  auto n = graph.CreateSizeVar(1);
  auto c = graph.CreateSizeVar(64);
  auto h = graph.CreateSizeVar(56);
  auto w = graph.CreateSizeVar(56);
  
  auto z_n = graph.CreateAxis("z_n", n);
  auto z_c = graph.CreateAxis("z_c", c);
  auto z_h = graph.CreateAxis("z_h", h);
  auto z_w = graph.CreateAxis("z_w", w);

  af::ascir_op::Data data0("data0", graph);
  data0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data0.y.dtype = ge::DT_FLOAT;
  *data0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0.y.strides = {c*h*w, h*w, w, af::ops::One};
  *data0.y.repeats = {n, c, h, w};
  data0.ir_attr.SetIndex(0);

  af::ascir_op::Load load0("load0");
  load0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load0.x = data0.y;
  *load0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load0.y.dtype = ge::DT_FLOAT;
  *load0.y.strides = {c*h*w, h*w, w, af::ops::One};
  *load0.y.repeats = {n, c, h, w};

  af::ascir_op::Relu relu("relu");
  graph.AddNode(relu);
  relu.x = load0.y;
  relu.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  relu.y.dtype = ge::DT_FLOAT;
  *relu.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *relu.y.repeats = {n, c, h, w};
  *relu.y.strides = {c*h*w, h*w, w, af::ops::One};
  relu.attr.api.compute_type = af::ComputeType::kComputeElewise;

  af::ascir_op::Store store_op("store");
  store_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  store_op.x = relu.y;
  *store_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  store_op.y.dtype = ge::DT_FLOAT;
  *store_op.y.strides = {c*h*w, h*w, w, af::ops::One};
  *store_op.y.repeats = {n, c, h, w};

  af::ascir_op::Output output_op("output");
  output_op.x = store_op.y;
  output_op.y.dtype = ge::DT_FLOAT;
  output_op.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(graph);

  auto x1Local = graph.FindNode("data0");
  x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
  x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
  x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
}

static void VerifyTilingCodeBasic(const std::map<std::string, std::string> &res) {
  auto pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" int64_t FindBestTilingKey");
  ASSERT_NE(pos, std::string::npos);
  auto static_shape_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" bool AutofuseIsStaticShape() {\n  return true;");
  ASSERT_NE(static_shape_pos, std::string::npos);
  auto tiling_func_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus TilingFunc");
  ASSERT_NE(tiling_func_pos, std::string::npos);
  auto get_size_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" size_t GetTilingDataSize()");
  ASSERT_NE(get_size_pos, std::string::npos);
  auto tiling_data_pos = res.at("tiling_def_and_tiling_const").find("AutofuseTilingData");
  ASSERT_NE(tiling_data_pos, std::string::npos);
}

static void CreateConv2DOffsetBiasGraph(af::AscGraph &conv2d_offset_bias_graph) {
  auto n_ob = conv2d_offset_bias_graph.CreateSizeVar(1);
  auto c_ob = conv2d_offset_bias_graph.CreateSizeVar(64);
  auto h_ob = conv2d_offset_bias_graph.CreateSizeVar(56);
  auto w_ob = conv2d_offset_bias_graph.CreateSizeVar(56);
  
  auto z_n_ob = conv2d_offset_bias_graph.CreateAxis("z_n", n_ob);
  auto z_c_ob = conv2d_offset_bias_graph.CreateAxis("z_c", c_ob);
  auto z_h_ob = conv2d_offset_bias_graph.CreateAxis("z_h", h_ob);
  auto z_w_ob = conv2d_offset_bias_graph.CreateAxis("z_w", w_ob);

  af::ascir_op::Data data0_ob("data0", conv2d_offset_bias_graph);
  data0_ob.attr.sched.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
  data0_ob.y.dtype = ge::DT_FLOAT16;
  *data0_ob.y.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
  data0_ob.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0_ob.y.strides = {c_ob*h_ob*w_ob, h_ob*w_ob, w_ob, af::ops::One};
  *data0_ob.y.repeats = {n_ob, c_ob, h_ob, w_ob};
  data0_ob.ir_attr.SetIndex(0);

  af::ascir_op::Load load0_ob("load0");
  load0_ob.attr.sched.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
  load0_ob.x = data0_ob.y;
  *load0_ob.y.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
  load0_ob.y.dtype = ge::DT_FLOAT16;
  *load0_ob.y.strides = {c_ob*h_ob*w_ob, h_ob*w_ob, w_ob, af::ops::One};
  *load0_ob.y.repeats = {n_ob, c_ob, h_ob, w_ob};

  af::ascir_op::Data data1_ob("data1", conv2d_offset_bias_graph);
  data1_ob.y.dtype = ge::DT_FLOAT16;
  data1_ob.attr.sched.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
  *data1_ob.y.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
  data1_ob.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data1_ob.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
  *data1_ob.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  data1_ob.ir_attr.SetIndex(1);

  af::ascir_op::Load load1_ob("load1");
  load1_ob.x = data1_ob.y;
  load1_ob.attr.sched.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
  load1_ob.y.dtype = ge::DT_FLOAT16;
  *load1_ob.y.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
  *load1_ob.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  *load1_ob.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};

  af::ascir_op::Data data2_ob("data2", conv2d_offset_bias_graph);
  data2_ob.y.dtype = ge::DT_FLOAT;
  data2_ob.attr.sched.axis = {z_c_ob.id};
  *data2_ob.y.axis = {z_c_ob.id};
  data2_ob.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data2_ob.y.repeats = {c_ob};
  *data2_ob.y.strides = {af::ops::One};
  data2_ob.ir_attr.SetIndex(2);

  af::ascir_op::Load load2_ob("load2");
  load2_ob.x = data2_ob.y;
  load2_ob.attr.sched.axis = {z_c_ob.id};
  load2_ob.y.dtype = ge::DT_FLOAT;
  *load2_ob.y.axis = {z_c_ob.id};
  *load2_ob.y.strides = {af::ops::One};
  *load2_ob.y.repeats = {c_ob};

  af::ascir_op::Data data3_ob("data3", conv2d_offset_bias_graph);
  data3_ob.y.dtype = ge::DT_FLOAT16;
  data3_ob.attr.sched.axis = {z_c_ob.id};
  *data3_ob.y.axis = {z_c_ob.id};
  data3_ob.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data3_ob.y.repeats = {c_ob};
  *data3_ob.y.strides = {af::ops::One};
  data3_ob.ir_attr.SetIndex(3);

  af::ascir_op::Load load3_ob("load3");
  load3_ob.x = data3_ob.y;
  load3_ob.attr.sched.axis = {z_c_ob.id};
  load3_ob.y.dtype = ge::DT_FLOAT16;
  *load3_ob.y.axis = {z_c_ob.id};
  *load3_ob.y.strides = {af::ops::One};
  *load3_ob.y.repeats = {c_ob};

  af::ascir_op::Conv2DOffsetBias conv2d_offset_bias("conv2d_offset_bias");
  conv2d_offset_bias.attr.sched.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
  conv2d_offset_bias.x = load0_ob.y;
  conv2d_offset_bias.filter = load1_ob.y;
  conv2d_offset_bias.bias = load2_ob.y;
  conv2d_offset_bias.offset_w = load3_ob.y;
  conv2d_offset_bias.y.dtype = ge::DT_FLOAT;
  *conv2d_offset_bias.y.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
  *conv2d_offset_bias.y.repeats = {n_ob, c_ob, h_ob, w_ob};
  *conv2d_offset_bias.y.strides = {c_ob*h_ob*w_ob, h_ob*w_ob, w_ob, af::ops::One};
  conv2d_offset_bias.attr.api.compute_type = af::ComputeType::kComputeCube;
  conv2d_offset_bias.ir_attr.SetStrides({1, 1});
  conv2d_offset_bias.ir_attr.SetPads({1, 1, 1, 1});
  conv2d_offset_bias.ir_attr.SetDilations({1, 1});
  conv2d_offset_bias.ir_attr.SetGroups(1);
  conv2d_offset_bias.ir_attr.SetData_format("NCHW");
  conv2d_offset_bias.ir_attr.SetOffset_x(0);
  conv2d_offset_bias.ir_attr.SetEnable_hf32(false);

  af::ascir_op::Store store_ob("store");
  store_ob.attr.sched.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
  store_ob.x = conv2d_offset_bias.y;
  *store_ob.y.axis = {z_n_ob.id, z_c_ob.id, z_h_ob.id, z_w_ob.id};
  store_ob.y.dtype = ge::DT_FLOAT;
  *store_ob.y.strides = {c_ob*h_ob*w_ob, h_ob*w_ob, w_ob, af::ops::One};
  *store_ob.y.repeats = {n_ob, c_ob, h_ob, w_ob};

  af::ascir_op::Output output_ob("output");
  output_ob.x = store_ob.y;
  output_ob.y.dtype = ge::DT_FLOAT;
  output_ob.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(conv2d_offset_bias_graph);
}

static void CreateElemwiseGraphWithReluDynamic(af::AscGraph &graph) {
  auto n = graph.CreateSizeVar("n");
  auto c = graph.CreateSizeVar("c");
  auto h = graph.CreateSizeVar("h");
  auto w = graph.CreateSizeVar("w");
  
  auto z_n = graph.CreateAxis("z_n", n);
  auto z_c = graph.CreateAxis("z_c", c);
  auto z_h = graph.CreateAxis("z_h", h);
  auto z_w = graph.CreateAxis("z_w", w);

  af::ascir_op::Data data0("data0", graph);
  data0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data0.y.dtype = ge::DT_FLOAT;
  *data0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0.y.strides = {c*h*w, h*w, w, af::ops::One};
  *data0.y.repeats = {n, c, h, w};
  data0.ir_attr.SetIndex(0);

  af::ascir_op::Load load0("load0");
  load0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load0.x = data0.y;
  *load0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load0.y.dtype = ge::DT_FLOAT;
  *load0.y.strides = {c*h*w, h*w, w, af::ops::One};
  *load0.y.repeats = {n, c, h, w};

  af::ascir_op::Relu relu("relu");
  graph.AddNode(relu);
  relu.x = load0.y;
  relu.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  relu.y.dtype = ge::DT_FLOAT;
  *relu.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *relu.y.repeats = {n, c, h, w};
  *relu.y.strides = {c*h*w, h*w, w, af::ops::One};
  relu.attr.api.compute_type = af::ComputeType::kComputeElewise;

  af::ascir_op::Store store_op("store");
  store_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  store_op.x = relu.y;
  *store_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  store_op.y.dtype = ge::DT_FLOAT;
  *store_op.y.strides = {c*h*w, h*w, w, af::ops::One};
  *store_op.y.repeats = {n, c, h, w};

  af::ascir_op::Output output_op("output");
  output_op.x = store_op.y;
  output_op.y.dtype = ge::DT_FLOAT;
  output_op.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(graph);

  auto x1Local = graph.FindNode("data0");
  x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
  x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
  x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
}

static void CreateMatmulElemwiseDynamicGraph(af::AscGraph &graph) {
  auto s0 = graph.CreateSizeVar("s0");
  auto s1 = graph.CreateSizeVar("s1");
  auto z0 = graph.CreateAxis("z0", s0);
  auto z1 = graph.CreateAxis("z1", s1);

  af::ascir_op::Data data0("data0", graph);
  data0.attr.sched.axis = {z0.id, z1.id};
  data0.y.dtype = ge::DT_FLOAT;
  *data0.y.axis = {z0.id, z1.id};
  data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0.y.strides = {s1 ,af::ops::One};
  *data0.y.repeats = {s0, s1};
  data0.ir_attr.SetIndex(0);

  af::ascir_op::Load load0("load0");
  load0.attr.sched.axis = {z0.id, z1.id};
  load0.x = data0.y;
  *load0.y.axis = {z0.id, z1.id};
  load0.y.dtype = ge::DT_FLOAT;
  *load0.y.strides = {s1 ,af::ops::One};
  *load0.y.repeats = {s0, s1};

  af::ascir_op::Abs abs("abs");
  graph.AddNode(abs);
  abs.x = load0.y;
  abs.attr.sched.axis = {z0.id, z1.id};
  abs.y.dtype = ge::DT_FLOAT;
  *abs.y.axis = {z0.id, z1.id};
  *abs.y.repeats = {s0, s1};
  *abs.y.strides = {s1, af::ops::One};
  abs.attr.api.compute_type = af::ComputeType::kComputeElewise;

  af::ascir_op::Scalar scalar0("scalar0", graph);
  scalar0.attr.sched.axis = {z0.id, z1.id};
  scalar0.ir_attr.SetValue("0");
  scalar0.y.dtype = ge::DT_FLOAT;
  *scalar0.y.axis = {z0.id, z1.id};
  *scalar0.y.repeats = {af::ops::One, af::ops::One};
  *scalar0.y.strides = {af::ops::Zero, af::ops::Zero};

  af::ascir_op::Broadcast broadcast0("broadcast0");
  broadcast0.x = scalar0.y;
  broadcast0.attr.sched.axis = {z0.id, z1.id};
  *broadcast0.y.axis = {z0.id, z1.id};
  broadcast0.y.dtype = ge::DT_FLOAT;
  *broadcast0.y.repeats = {af::ops::One, s1};
  *broadcast0.y.strides = {af::ops::Zero, af::ops::One};

  af::ascir_op::Broadcast broadcast1("broadcast1");
  broadcast1.x = broadcast0.y;
  broadcast1.attr.sched.axis = {z0.id, z1.id};
  *broadcast1.y.axis = {z0.id, z1.id};
  broadcast1.y.dtype = ge::DT_FLOAT;
  *broadcast1.y.repeats = {s0, s1};
  *broadcast1.y.strides = {s1, af::ops::One};

  af::ascir_op::Add add_op("add");
  add_op.attr.sched.axis = {z0.id, z1.id};
  add_op.x1 = abs.y;
  add_op.x2 = broadcast1.y;
  add_op.y.dtype = ge::DT_FLOAT;
  *add_op.y.axis = {z0.id, z1.id};
  *add_op.y.repeats = {s0, s1};
  *add_op.y.strides = {s1, af::ops::One};

  af::ascir_op::Data data1("data1", graph);
  data1.y.dtype = ge::DT_FLOAT;
  data1.attr.sched.axis = {z0.id, z1.id};
  *data1.y.axis = {z0.id, z1.id};
  data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data1.y.repeats = {af::ops::One, af::ops::One};
  *data1.y.strides = {af::ops::Zero, af::ops::Zero};
  data1.ir_attr.SetIndex(1);

  af::ascir_op::Load load1("load1");
  load1.x = data1.y;
  load1.attr.sched.axis = {z0.id, z1.id};
  load1.y.dtype = ge::DT_FLOAT;
  *load1.y.axis = {z0.id, z1.id};
  *load1.y.strides = {af::ops::Zero, af::ops::Zero};
  *load1.y.repeats = {af::ops::One, af::ops::One};

  af::ascir_op::Broadcast broadcast2("broadcast2");
  broadcast2.x = load1.y;
  broadcast2.attr.sched.axis = {z0.id, z1.id};
  *broadcast2.y.axis = {z0.id, z1.id};
  broadcast2.y.dtype = ge::DT_FLOAT;
  *broadcast2.y.repeats = {af::ops::One, s1};
  *broadcast2.y.strides = {af::ops::Zero, af::ops::One};

  af::ascir_op::Broadcast broadcast3("broadcast3");
  broadcast3.x = broadcast2.y;
  broadcast3.attr.sched.axis = {z0.id, z1.id};
  *broadcast3.y.axis = {z0.id, z1.id};
  broadcast3.y.dtype = ge::DT_FLOAT;
  *broadcast3.y.repeats = {s0, s1};
  *broadcast3.y.strides = {s1, af::ops::One};

  af::ascir_op::Mul mul("mul");
  mul.attr.sched.axis = {z0.id, z1.id};
  mul.x1 = add_op.y;
  mul.x2 = broadcast3.y;
  mul.y.dtype = ge::DT_FLOAT;
  *mul.y.axis = {z0.id, z1.id};
  *mul.y.repeats = {s0, s1};
  *mul.y.strides = {s1, af::ops::One};

  af::ascir_op::Store store_op("store");
  store_op.attr.sched.axis = {z0.id, z1.id};
  store_op.x = mul.y;
  *store_op.y.axis = {z0.id, z1.id};
  store_op.y.dtype = ge::DT_FLOAT;
  *store_op.y.strides = {s1 ,af::ops::One};
  *store_op.y.repeats = {s0, s1};

  af::ascir_op::Output output_op("output");
  output_op.x = store_op.y;
  output_op.y.dtype = ge::DT_FLOAT;
  output_op.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(graph);

  auto x1Local = graph.FindNode("data0");
  x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
  x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
  x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
}

static void CreateElemwiseGraphWithMulDynamic(af::AscGraph &graph) {
  auto n = graph.CreateSizeVar("n");
  auto c = graph.CreateSizeVar("c");
  auto h = graph.CreateSizeVar("h");
  auto w = graph.CreateSizeVar("w");
  
  auto z_n = graph.CreateAxis("z_n", n);
  auto z_c = graph.CreateAxis("z_c", c);
  auto z_h = graph.CreateAxis("z_h", h);
  auto z_w = graph.CreateAxis("z_w", w);

  af::ascir_op::Data data0("data0", graph);
  data0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data0.y.dtype = ge::DT_FLOAT;
  *data0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0.y.strides = {c*h*w, h*w, w, af::ops::One};
  *data0.y.repeats = {n, c, h, w};
  data0.ir_attr.SetIndex(0);

  af::ascir_op::Load load0("load0");
  load0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load0.x = data0.y;
  *load0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load0.y.dtype = ge::DT_FLOAT;
  *load0.y.strides = {c*h*w, h*w, w, af::ops::One};
  *load0.y.repeats = {n, c, h, w};

  af::ascir_op::Scalar scalar("scalar", graph);
  scalar.ir_attr.SetValue("2.0");
  scalar.y.dtype = ge::DT_FLOAT;
  *scalar.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *scalar.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
  *scalar.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  
  af::ascir_op::Broadcast broadcast("broadcast");
  broadcast.x = scalar.y;
  broadcast.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *broadcast.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  broadcast.y.dtype = ge::DT_FLOAT;
  *broadcast.y.repeats = {n, c, h, w};
  *broadcast.y.strides = {c*h*w, h*w, w, af::ops::One};

  af::ascir_op::Mul mul("mul");
  mul.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  mul.x1 = load0.y;
  mul.x2 = broadcast.y;
  mul.y.dtype = ge::DT_FLOAT;
  *mul.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *mul.y.repeats = {n, c, h, w};
  *mul.y.strides = {c*h*w, h*w, w, af::ops::One};
  mul.attr.api.compute_type = af::ComputeType::kComputeElewise;

  af::ascir_op::Store store_op("store");
  store_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  store_op.x = mul.y;
  *store_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  store_op.y.dtype = ge::DT_FLOAT;
  *store_op.y.strides = {c*h*w, h*w, w, af::ops::One};
  *store_op.y.repeats = {n, c, h, w};

  af::ascir_op::Output output_op("output");
  output_op.x = store_op.y;
  output_op.y.dtype = ge::DT_FLOAT;
  output_op.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(graph);

  auto x1Local = graph.FindNode("data0");
  x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
  x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
  x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
}

static void CreateElemwiseGraphWithAbsAndAddStatic(af::AscGraph &graph) {
  auto n = graph.CreateSizeVar(1);
  auto c = graph.CreateSizeVar(64);
  auto h = graph.CreateSizeVar(56);
  auto w = graph.CreateSizeVar(56);
  
  auto z_n = graph.CreateAxis("z_n", n);
  auto z_c = graph.CreateAxis("z_c", c);
  auto z_h = graph.CreateAxis("z_h", h);
  auto z_w = graph.CreateAxis("z_w", w);

  af::ascir_op::Data data0("data0", graph);
  data0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data0.y.dtype = ge::DT_FLOAT;
  *data0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0.y.strides = {c*h*w, h*w, w, af::ops::One};
  *data0.y.repeats = {n, c, h, w};
  data0.ir_attr.SetIndex(0);

  af::ascir_op::Load load0("load0");
  load0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load0.x = data0.y;
  *load0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load0.y.dtype = ge::DT_FLOAT;
  *load0.y.strides = {c*h*w, h*w, w, af::ops::One};
  *load0.y.repeats = {n, c, h, w};

  af::ascir_op::Abs abs("abs");
  graph.AddNode(abs);
  abs.x = load0.y;
  abs.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  abs.y.dtype = ge::DT_FLOAT;
  *abs.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *abs.y.repeats = {n, c, h, w};
  *abs.y.strides = {c*h*w, h*w, w, af::ops::One};
  abs.attr.api.compute_type = af::ComputeType::kComputeElewise;

  af::ascir_op::Scalar scalar0("scalar0", graph);
  scalar0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  scalar0.ir_attr.SetValue("0.1");
  scalar0.y.dtype = ge::DT_FLOAT;
  *scalar0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *scalar0.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
  *scalar0.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};

  af::ascir_op::Broadcast broadcast0("broadcast0");
  broadcast0.x = scalar0.y;
  broadcast0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *broadcast0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  broadcast0.y.dtype = ge::DT_FLOAT;
  *broadcast0.y.repeats = {n, c, h, w};
  *broadcast0.y.strides = {c*h*w, h*w, w, af::ops::One};

  af::ascir_op::Add add_op("add");
  add_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  add_op.x1 = abs.y;
  add_op.x2 = broadcast0.y;
  add_op.y.dtype = ge::DT_FLOAT;
  *add_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *add_op.y.repeats = {n, c, h, w};
  *add_op.y.strides = {c*h*w, h*w, w, af::ops::One};

  af::ascir_op::Store store_op("store");
  store_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  store_op.x = add_op.y;
  *store_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  store_op.y.dtype = ge::DT_FLOAT;
  *store_op.y.strides = {c*h*w, h*w, w, af::ops::One};
  *store_op.y.repeats = {n, c, h, w};

  af::ascir_op::Output output_op("output");
  output_op.x = store_op.y;
  output_op.y.dtype = ge::DT_FLOAT;
  output_op.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(graph);

  auto x1Local = graph.FindNode("data0");
  x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
  x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
  x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
}

static void CreateConv2DOffsetGraph(af::AscGraph &conv2d_offset_graph) {
  auto n_o = conv2d_offset_graph.CreateSizeVar(1);
  auto c_o = conv2d_offset_graph.CreateSizeVar(64);
  auto h_o = conv2d_offset_graph.CreateSizeVar(56);
  auto w_o = conv2d_offset_graph.CreateSizeVar(56);
  
  auto z_n_o = conv2d_offset_graph.CreateAxis("z_n", n_o);
  auto z_c_o = conv2d_offset_graph.CreateAxis("z_c", c_o);
  auto z_h_o = conv2d_offset_graph.CreateAxis("z_h", h_o);
  auto z_w_o = conv2d_offset_graph.CreateAxis("z_w", w_o);

  af::ascir_op::Data data0_o("data0", conv2d_offset_graph);
  data0_o.attr.sched.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
  data0_o.y.dtype = ge::DT_FLOAT16;
  *data0_o.y.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
  data0_o.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0_o.y.strides = {c_o*h_o*w_o, h_o*w_o, w_o, af::ops::One};
  *data0_o.y.repeats = {n_o, c_o, h_o, w_o};
  data0_o.ir_attr.SetIndex(0);

  af::ascir_op::Load load0_o("load0");
  load0_o.attr.sched.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
  load0_o.x = data0_o.y;
  *load0_o.y.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
  load0_o.y.dtype = ge::DT_FLOAT16;
  *load0_o.y.strides = {c_o*h_o*w_o, h_o*w_o, w_o, af::ops::One};
  *load0_o.y.repeats = {n_o, c_o, h_o, w_o};

  af::ascir_op::Data data1_o("data1", conv2d_offset_graph);
  data1_o.y.dtype = ge::DT_FLOAT16;
  data1_o.attr.sched.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
  *data1_o.y.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
  data1_o.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data1_o.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
  *data1_o.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  data1_o.ir_attr.SetIndex(1);

  af::ascir_op::Load load1_o("load1");
  load1_o.x = data1_o.y;
  load1_o.attr.sched.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
  load1_o.y.dtype = ge::DT_FLOAT16;
  *load1_o.y.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
  *load1_o.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  *load1_o.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};

  af::ascir_op::Data data2_o("data2", conv2d_offset_graph);
  data2_o.y.dtype = ge::DT_FLOAT16;
  data2_o.attr.sched.axis = {z_c_o.id};
  *data2_o.y.axis = {z_c_o.id};
  data2_o.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data2_o.y.repeats = {c_o};
  *data2_o.y.strides = {af::ops::One};
  data2_o.ir_attr.SetIndex(2);

  af::ascir_op::Load load2_o("load2");
  load2_o.x = data2_o.y;
  load2_o.attr.sched.axis = {z_c_o.id};
  load2_o.y.dtype = ge::DT_FLOAT16;
  *load2_o.y.axis = {z_c_o.id};
  *load2_o.y.strides = {af::ops::One};
  *load2_o.y.repeats = {c_o};

  af::ascir_op::Conv2DOffset conv2d_offset("conv2d_offset");
  conv2d_offset.attr.sched.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
  conv2d_offset.x = load0_o.y;
  conv2d_offset.filter = load1_o.y;
  conv2d_offset.offset_w = load2_o.y;
  conv2d_offset.y.dtype = ge::DT_FLOAT;
  *conv2d_offset.y.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
  *conv2d_offset.y.repeats = {n_o, c_o, h_o, w_o};
  *conv2d_offset.y.strides = {c_o*h_o*w_o, h_o*w_o, w_o, af::ops::One};
  conv2d_offset.attr.api.compute_type = af::ComputeType::kComputeCube;
  conv2d_offset.ir_attr.SetStrides({1, 1});
  conv2d_offset.ir_attr.SetPads({1, 1, 1, 1});
  conv2d_offset.ir_attr.SetDilations({1, 1});
  conv2d_offset.ir_attr.SetGroups(1);
  conv2d_offset.ir_attr.SetData_format("NCHW");
  conv2d_offset.ir_attr.SetOffset_x(0);
  conv2d_offset.ir_attr.SetEnable_hf32(false);

  af::ascir_op::Store store_o("store");
  store_o.attr.sched.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
  store_o.x = conv2d_offset.y;
  *store_o.y.axis = {z_n_o.id, z_c_o.id, z_h_o.id, z_w_o.id};
  store_o.y.dtype = ge::DT_FLOAT;
  *store_o.y.strides = {c_o*h_o*w_o, h_o*w_o, w_o, af::ops::One};
  *store_o.y.repeats = {n_o, c_o, h_o, w_o};

  af::ascir_op::Output output_o("output");
  output_o.x = store_o.y;
  output_o.y.dtype = ge::DT_FLOAT;
  output_o.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(conv2d_offset_graph);
}

static void CreateConv2DGraphWithGroups(af::AscGraph &conv2d_graph) {
  auto n_g = conv2d_graph.CreateSizeVar(1);
  auto c_g = conv2d_graph.CreateSizeVar(64);
  auto h_g = conv2d_graph.CreateSizeVar(56);
  auto w_g = conv2d_graph.CreateSizeVar(56);
  
  auto z_n_g = conv2d_graph.CreateAxis("z_n", n_g);
  auto z_c_g = conv2d_graph.CreateAxis("z_c", c_g);
  auto z_h_g = conv2d_graph.CreateAxis("z_h", h_g);
  auto z_w_g = conv2d_graph.CreateAxis("z_w", w_g);

  af::ascir_op::Data data0_g("data0", conv2d_graph);
  data0_g.attr.sched.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
  data0_g.y.dtype = ge::DT_FLOAT16;
  *data0_g.y.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
  data0_g.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0_g.y.strides = {c_g*h_g*w_g, h_g*w_g, w_g, af::ops::One};
  *data0_g.y.repeats = {n_g, c_g, h_g, w_g};
  data0_g.ir_attr.SetIndex(0);

  af::ascir_op::Load load0_g("load0");
  load0_g.attr.sched.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
  load0_g.x = data0_g.y;
  *load0_g.y.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
  load0_g.y.dtype = ge::DT_FLOAT16;
  *load0_g.y.strides = {c_g*h_g*w_g, h_g*w_g, w_g, af::ops::One};
  *load0_g.y.repeats = {n_g, c_g, h_g, w_g};

  af::ascir_op::Data data1_g("data1", conv2d_graph);
  data1_g.y.dtype = ge::DT_FLOAT16;
  data1_g.attr.sched.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
  *data1_g.y.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
  data1_g.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data1_g.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
  *data1_g.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  data1_g.ir_attr.SetIndex(1);

  af::ascir_op::Load load1_g("load1");
  load1_g.x = data1_g.y;
  load1_g.attr.sched.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
  load1_g.y.dtype = ge::DT_FLOAT16;
  *load1_g.y.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
  *load1_g.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  *load1_g.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};

  af::ascir_op::Conv2D conv2d_g("conv2d");
  conv2d_g.attr.sched.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
  conv2d_g.x = load0_g.y;
  conv2d_g.filter = load1_g.y;
  conv2d_g.y.dtype = ge::DT_FLOAT;
  *conv2d_g.y.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
  *conv2d_g.y.repeats = {n_g, c_g, h_g, w_g};
  *conv2d_g.y.strides = {c_g*h_g*w_g, h_g*w_g, w_g, af::ops::One};
  conv2d_g.attr.api.compute_type = af::ComputeType::kComputeCube;
  conv2d_g.ir_attr.SetStrides({2, 2});
  conv2d_g.ir_attr.SetPads({1, 1, 1, 1});
  conv2d_g.ir_attr.SetDilations({1, 1});
  conv2d_g.ir_attr.SetGroups(4);
  conv2d_g.ir_attr.SetData_format("NCHW");
  conv2d_g.ir_attr.SetOffset_x(0);
  conv2d_g.ir_attr.SetEnable_hf32(false);

  af::ascir_op::Store store_g("store");
  store_g.attr.sched.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
  store_g.x = conv2d_g.y;
  *store_g.y.axis = {z_n_g.id, z_c_g.id, z_h_g.id, z_w_g.id};
  store_g.y.dtype = ge::DT_FLOAT;
  *store_g.y.strides = {c_g*h_g*w_g, h_g*w_g, w_g, af::ops::One};
  *store_g.y.repeats = {n_g, c_g, h_g, w_g};

  af::ascir_op::Output output_g("output");
  output_g.x = store_g.y;
  output_g.y.dtype = ge::DT_FLOAT;
  output_g.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(conv2d_graph);
}

static void CreateConv2DGraphWithDilation(af::AscGraph &conv2d_graph) {
  auto n_d = conv2d_graph.CreateSizeVar(1);
  auto c_d = conv2d_graph.CreateSizeVar(64);
  auto h_d = conv2d_graph.CreateSizeVar(56);
  auto w_d = conv2d_graph.CreateSizeVar(56);
  
  auto z_n_d = conv2d_graph.CreateAxis("z_n", n_d);
  auto z_c_d = conv2d_graph.CreateAxis("z_c", c_d);
  auto z_h_d = conv2d_graph.CreateAxis("z_h", h_d);
  auto z_w_d = conv2d_graph.CreateAxis("z_w", w_d);

  af::ascir_op::Data data0_d("data0", conv2d_graph);
  data0_d.attr.sched.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
  data0_d.y.dtype = ge::DT_FLOAT16;
  *data0_d.y.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
  data0_d.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0_d.y.strides = {c_d*h_d*w_d, h_d*w_d, w_d, af::ops::One};
  *data0_d.y.repeats = {n_d, c_d, h_d, w_d};
  data0_d.ir_attr.SetIndex(0);

  af::ascir_op::Load load0_d("load0");
  load0_d.attr.sched.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
  load0_d.x = data0_d.y;
  *load0_d.y.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
  load0_d.y.dtype = ge::DT_FLOAT16;
  *load0_d.y.strides = {c_d*h_d*w_d, h_d*w_d, w_d, af::ops::One};
  *load0_d.y.repeats = {n_d, c_d, h_d, w_d};

  af::ascir_op::Data data1_d("data1", conv2d_graph);
  data1_d.y.dtype = ge::DT_FLOAT16;
  data1_d.attr.sched.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
  *data1_d.y.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
  data1_d.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data1_d.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
  *data1_d.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  data1_d.ir_attr.SetIndex(1);

  af::ascir_op::Load load1_d("load1");
  load1_d.x = data1_d.y;
  load1_d.attr.sched.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
  load1_d.y.dtype = ge::DT_FLOAT16;
  *load1_d.y.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
  *load1_d.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  *load1_d.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};

  af::ascir_op::Conv2D conv2d_d("conv2d");
  conv2d_d.attr.sched.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
  conv2d_d.x = load0_d.y;
  conv2d_d.filter = load1_d.y;
  conv2d_d.y.dtype = ge::DT_FLOAT;
  *conv2d_d.y.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
  *conv2d_d.y.repeats = {n_d, c_d, h_d, w_d};
  *conv2d_d.y.strides = {c_d*h_d*w_d, h_d*w_d, w_d, af::ops::One};
  conv2d_d.attr.api.compute_type = af::ComputeType::kComputeCube;
  conv2d_d.ir_attr.SetStrides({1, 1});
  conv2d_d.ir_attr.SetPads({2, 2, 2, 2});
  conv2d_d.ir_attr.SetDilations({2, 2});
  conv2d_d.ir_attr.SetGroups(1);
  conv2d_d.ir_attr.SetData_format("NCHW");
  conv2d_d.ir_attr.SetOffset_x(0);
  conv2d_d.ir_attr.SetEnable_hf32(false);

  af::ascir_op::Store store_d("store");
  store_d.attr.sched.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
  store_d.x = conv2d_d.y;
  *store_d.y.axis = {z_n_d.id, z_c_d.id, z_h_d.id, z_w_d.id};
  store_d.y.dtype = ge::DT_FLOAT;
  *store_d.y.strides = {c_d*h_d*w_d, h_d*w_d, w_d, af::ops::One};
  *store_d.y.repeats = {n_d, c_d, h_d, w_d};

  af::ascir_op::Output output_d("output");
  output_d.x = store_d.y;
  output_d.y.dtype = ge::DT_FLOAT;
  output_d.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(conv2d_graph);
}

static void CreateBatchMatmulElemwiseDynamicGraph(af::AscGraph &graph) {
  auto batch = graph.CreateSizeVar("batch");
  auto m = graph.CreateSizeVar("m");
  auto n = graph.CreateSizeVar("n");
  
  auto z_batch = graph.CreateAxis("z_batch", batch);
  auto z_m = graph.CreateAxis("z_m", m);
  auto z_n = graph.CreateAxis("z_n", n);

  af::ascir_op::Data data0("data0", graph);
  data0.attr.sched.axis = {z_batch.id, z_m.id, z_n.id};
  data0.y.dtype = ge::DT_FLOAT;
  *data0.y.axis = {z_batch.id, z_m.id, z_n.id};
  data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0.y.strides = {m*n, n, af::ops::One};
  *data0.y.repeats = {batch, m, n};
  data0.ir_attr.SetIndex(0);

  af::ascir_op::Load load0("load0");
  load0.attr.sched.axis = {z_batch.id, z_m.id, z_n.id};
  load0.x = data0.y;
  *load0.y.axis = {z_batch.id, z_m.id, z_n.id};
  load0.y.dtype = ge::DT_FLOAT;
  *load0.y.strides = {m*n, n, af::ops::One};
  *load0.y.repeats = {batch, m, n};

  af::ascir_op::Relu relu("relu");
  graph.AddNode(relu);
  relu.x = load0.y;
  relu.attr.sched.axis = {z_batch.id, z_m.id, z_n.id};
  relu.y.dtype = ge::DT_FLOAT;
  *relu.y.axis = {z_batch.id, z_m.id, z_n.id};
  *relu.y.repeats = {batch, m, n};
  *relu.y.strides = {m*n, n, af::ops::One};
  relu.attr.api.compute_type = af::ComputeType::kComputeElewise;

  af::ascir_op::Store store_op("store");
  store_op.attr.sched.axis = {z_batch.id, z_m.id, z_n.id};
  store_op.x = relu.y;
  *store_op.y.axis = {z_batch.id, z_m.id, z_n.id};
  store_op.y.dtype = ge::DT_FLOAT;
  *store_op.y.strides = {m*n, n, af::ops::One};
  *store_op.y.repeats = {batch, m, n};

  af::ascir_op::Output output_op("output");
  output_op.x = store_op.y;
  output_op.y.dtype = ge::DT_FLOAT;
  output_op.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(graph);

  auto x1Local = graph.FindNode("data0");
  x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
  x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
  x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
}

static void VerifyDynamicShapeTiling(const std::map<std::string, std::string> &res) {
  auto pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" int64_t FindBestTilingKey");
  ASSERT_NE(pos, std::string::npos);
  auto dynamic_shape_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" bool AutofuseIsStaticShape() {\n  return false;");
  ASSERT_NE(dynamic_shape_pos, std::string::npos);
  auto tiling_func_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus TilingFunc(gert::TilingSymbolEvalContext *context)");
  ASSERT_NE(tiling_func_pos, std::string::npos);
  auto tiling_call_pos = res.at("tiling_def_and_tiling_const").find("AutofuseTilingWithConfig");
  ASSERT_NE(tiling_call_pos, std::string::npos);
  auto cache_key_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus GetSymbolTilingCacheKey");
  ASSERT_NE(cache_key_pos, std::string::npos);
  auto tiling_data_pos = res.at("tiling_def_and_tiling_const").find("AutofuseTilingData");
  ASSERT_NE(tiling_data_pos, std::string::npos);
}

static void VerifyConv2dElemwiseTiling(const std::map<std::string, std::string> &res) {
  auto pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" int64_t FindBestTilingKey");
  ASSERT_NE(pos, std::string::npos);
  auto static_shape_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" bool AutofuseIsStaticShape() {\n  return true;");
  ASSERT_NE(static_shape_pos, std::string::npos);
  auto tiling_func_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus TilingFunc");
  ASSERT_NE(tiling_func_pos, std::string::npos);
  auto tiling_parse_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus TilingParse");
  ASSERT_NE(tiling_parse_pos, std::string::npos);
  auto get_size_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" size_t GetTilingDataSize()");
  ASSERT_NE(get_size_pos, std::string::npos);
  auto workspace_pos = res.at("tiling_def_and_tiling_const").find("*context->GetWorkspaceSizes(1) = 16 * 1024 * 1024");
  ASSERT_NE(workspace_pos, std::string::npos);
  auto tiling_data_pos = res.at("tiling_def_and_tiling_const").find("AutofuseTilingData");
  ASSERT_NE(tiling_data_pos, std::string::npos);
  auto block_dim_pos = res.at("tiling_def_and_tiling_const").find("set_block_dim");
  ASSERT_NE(block_dim_pos, std::string::npos);
}

static void VerifyConv2DBiasElemwiseTiling(const std::map<std::string, std::string> &res) {
  auto pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" int64_t FindBestTilingKey");
  ASSERT_NE(pos, std::string::npos);
  auto static_shape_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" bool AutofuseIsStaticShape() {\n  return true;");
  ASSERT_NE(static_shape_pos, std::string::npos);
  auto tiling_func_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus TilingFunc");
  ASSERT_NE(tiling_func_pos, std::string::npos);
  auto get_size_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" size_t GetTilingDataSize()");
  ASSERT_NE(get_size_pos, std::string::npos);
  auto tiling_data_pos = res.at("tiling_def_and_tiling_const").find("AutofuseTilingData");
  ASSERT_NE(tiling_data_pos, std::string::npos);
  auto workspace_pos = res.at("tiling_def_and_tiling_const").find("*context->GetWorkspaceSizes(1) = 16 * 1024 * 1024");
  ASSERT_NE(workspace_pos, std::string::npos);
}

static void VerifyConv2DOffsetTiling(const std::map<std::string, std::string> &res) {
  auto pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" int64_t FindBestTilingKey");
  ASSERT_NE(pos, std::string::npos);
  auto static_shape_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" bool AutofuseIsStaticShape() {\n  return true;");
  ASSERT_NE(static_shape_pos, std::string::npos);
  auto tiling_func_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" ge::graphStatus TilingFunc");
  ASSERT_NE(tiling_func_pos, std::string::npos);
  auto get_size_pos = res.at("tiling_def_and_tiling_const").find("extern \"C\" size_t GetTilingDataSize()");
  ASSERT_NE(get_size_pos, std::string::npos);
  auto tiling_data_pos = res.at("tiling_def_and_tiling_const").find("AutofuseTilingData");
  ASSERT_NE(tiling_data_pos, std::string::npos);
}

} // namespace

class TestCodegenTiling : public testing::Test, public codegen::TilingLib {
  public:
  void SetUp() override {
    dlog_setlevel(ASCGEN_MODULE_NAME, DLOG_DEBUG, 0);
  }
  void TearDown() override {
  }
  void SetupLoadAttrs(af::AscNode &load, uint64_t z0_id, const af::Expression &z0_size) {
    auto &attr = load.outputs[0].attr;
    attr.axis = {static_cast<int64_t>(z0_id)};
    attr.vectorized_axis = {static_cast<int64_t>(z0_id)};
    attr.vectorized_strides = {af::ops::One};
    attr.repeats = {z0_size};
    attr.strides = {af::ops::One};
    attr.mem.position = af::Position::kPositionVecIn;
    attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
    attr.mem.tensor_id = 1;
    attr.que.id = 0;
    attr.mem.reuse_id = 0;
    attr.que.depth = 2;
    attr.que.buf_num = 2;
    attr.opt.merge_scope = af::kIdNone;
  }

  void SetupStoreAttrs(af::AscNode &store, uint64_t z0_id, const af::Expression &z0_size) {
    auto &attr = store.outputs[0].attr;
    attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
    attr.mem.tensor_id = 2;
    attr.axis = {static_cast<int64_t>(z0_id)};
    attr.vectorized_axis = {static_cast<int64_t>(z0_id)};
    attr.vectorized_strides = {af::ops::One};
    attr.repeats = {z0_size};
    attr.strides = {af::ops::One};
  }

  ascir::FusedScheduledResult GenBasicFusedScheduleResult(const std::vector<af::Expression> &origin_vars = {}) {
    af::AscGraph graph("test_graph");
    auto s0 = graph.CreateSizeVar("s0");
    auto z0 = graph.CreateAxis("z0", af::ops::Zero);

    af::ascir_op::Data x_op("x", graph);
    x_op.ir_attr.SetIndex(0);
    af::ascir_op::Load load_op("load");
    af::ascir_op::Store store_op("store");
    af::ascir_op::Output y_op("y");
    y_op.ir_attr.SetIndex(0);
    graph.AddNode(load_op);
    graph.AddNode(store_op);
    graph.AddNode(y_op);

    load_op.x = x_op.y;
    load_op.y.dtype = ge::DT_FLOAT16;
    store_op.x = load_op.y;
    y_op.x = store_op.y;

    auto x = graph.FindNode("x");
    auto load = graph.FindNode("load");
    auto store = graph.FindNode("store");
    auto y = graph.FindNode("y");

    x->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
    x->outputs[0].attr.mem.tensor_id = 0;
    x->attr.api.unit = af::ComputeUnit::kUnitNone;
    y->attr.api.unit = af::ComputeUnit::kUnitNone;

    SetupLoadAttrs(*load, z0.id, z0.size);
    SetupStoreAttrs(*store, z0.id, z0.size);

    ::ascir::ScheduledResult schedule_result;
    schedule_result.schedule_groups.resize(1);
    for (auto &schedule_group : schedule_result.schedule_groups) {
      schedule_group.impl_graphs.emplace_back(graph);
    }
    std::vector<ascir::ScheduledResult> schedule_results;
    schedule_results.push_back(schedule_result);
    schedule_results.push_back(schedule_result);

    ascir::FusedScheduledResult fused_schedule_result;
    fused_schedule_result.fused_graph_name = af::AscendString(graph.GetName().c_str());
    fused_schedule_result.input_nodes.push_back(x);
    fused_schedule_result.output_nodes.push_back(y);
    fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
    fused_schedule_result.origin_vars = origin_vars;
    return fused_schedule_result;
  }

  std::map<std::string, std::string> GenTilingCode(const std::vector<af::Expression> &origin_vars = {},
                                                   const std::map<std::string, std::string> &shape_info = {}) {
    auto fused_schedule_result = GenBasicFusedScheduleResult(origin_vars);
    return this->Generate(fused_schedule_result, shape_info, "", "0");
  }

  std::map<std::string, std::string> GenTilingCodeForInductor(const std::vector<af::Expression> &origin_vars = {}) {
    auto fused_schedule_result = GenBasicFusedScheduleResult(origin_vars);
    return this->GenerateForInductor(fused_schedule_result);
  }

  protected:
  TestCodegenTiling() : codegen::TilingLib("test", "test") {}
};

TEST_F(TestCodegenTiling, NoWorkspaceTest) {
  ascir::ImplGraph graph0("test_graph0");
  graph0.CreateSizeVar("s0");
  graph0.CreateSizeVar("s1");

  std::vector<ascir::ImplGraph> impl_graphs;
  impl_graphs.push_back(graph0);

  std::vector<ascir::ScheduledResult> schedule_results;
  ascir::ScheduledResult schedule_result;
  ascir::ScheduleGroup schedule_group;
  schedule_group.impl_graphs = impl_graphs;
  schedule_result.schedule_groups.push_back(schedule_group);
  schedule_results.push_back(schedule_result);
  ascir::FusedScheduledResult fused_schedule_result;
  fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);

  EXPECT_EQ(this->GenGetWorkspaceSizeFunc("AutofuseTilingData", fused_schedule_result), std::string{
    "uint32_t GetWorkspaceSize(const AutofuseTilingData &t) {\n"
    "  using namespace optiling;\n"
    "  uint32_t ws_size = 0;\n"
    "    if (t.tiling_key == 0) {\n"
    "      ws_size += 0;\n"
    "    }\n"
    "\n"
    "  ws_size = (ws_size + 512 - 1) / 512 * 512;\n"
    "  return ws_size;\n"
    "}\n"});
}

TEST_F(TestCodegenTiling, SingleGroupWorkspaceSymbolTest) {
  ascir::ImplGraph graph0("test_graph0");
  auto s0 = graph0.CreateSizeVar("s0");
  auto s1 = graph0.CreateSizeVar("s1");

  auto z0 = graph0.CreateAxis("z0", s0);
  auto z1 = graph0.CreateAxis("z1", s1);

  af::ascir_op::Workspace workspace("workspace");
  graph0.AddNode(workspace);
  workspace.y.dtype = ge::DT_FLOAT16;

  af::ascir_op::Load load("load");
  graph0.AddNode(load);
  load.x = workspace.y;
  load.attr.sched.axis = {z0.id, z1.id};
  *load.y.axis = {z0.id, z1.id};
  *load.y.repeats = {s0, s1};
  *load.y.strides = {s1, af::ops::One};

  auto load_node = graph0.FindNode("load");
  auto workspace_node = graph0.FindNode("workspace");

  workspace_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
  workspace_node->outputs[0].attr.mem.tensor_id = 0;

  load_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
  load_node->outputs[0].attr.mem.tensor_id = 1;

  std::vector<ascir::ImplGraph> impl_graphs;
  impl_graphs.push_back(graph0);

  std::vector<ascir::ScheduledResult> schedule_results;
  ascir::ScheduledResult schedule_result;
  ascir::ScheduleGroup schedule_group;
  schedule_group.impl_graphs = impl_graphs;
  schedule_result.schedule_groups.push_back(schedule_group);
  schedule_results.push_back(schedule_result);

  ascir::FusedScheduledResult fused_schedule_result;
  fused_schedule_result.workspace_nodes.push_back(workspace_node);
  fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
  EXPECT_EQ(this->GenGetWorkspaceSizeFunc("AutofuseTilingData", fused_schedule_result), std::string{
    "uint32_t GetWorkspaceSize(const AutofuseTilingData &t) {\n"
    "  using namespace optiling;\n"
    "  uint32_t ws_size = 0;\n"
    "    if (t.tiling_key == 0) {\n"
    "      ws_size += Max(0, (2 * Max(Max(1, t.s1), (t.s0 * t.s1))));\n"
    "    }\n"
    "\n"
    "  ws_size = (ws_size + 512 - 1) / 512 * 512;\n"
    "  return ws_size;\n"
    "}\n"
  });
}

TEST_F(TestCodegenTiling, SingleGroupWorkspaceValueTest) {
  ascir::ImplGraph graph0("test_graph0");
  auto s0 = graph0.CreateSizeVar(150);
  auto s1 = graph0.CreateSizeVar(2);

  auto z0 = graph0.CreateAxis("z0", s0);
  auto z1 = graph0.CreateAxis("z1", s1);

  af::ascir_op::Workspace workspace("workspace");
  graph0.AddNode(workspace);
  workspace.y.dtype = ge::DT_FLOAT16;

  af::ascir_op::Load load("load");
  graph0.AddNode(load);
  load.x = workspace.y;
  load.attr.sched.axis = {z0.id, z1.id};
  *load.y.axis = {z0.id, z1.id};
  *load.y.repeats = {s0, s1};
  *load.y.strides = {s1, af::ops::One};

  auto load_node = graph0.FindNode("load");
  auto workspace_node = graph0.FindNode("workspace");

  workspace_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
  workspace_node->outputs[0].attr.mem.tensor_id = 0;

  load_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
  load_node->outputs[0].attr.mem.tensor_id = 1;

  std::vector<ascir::ImplGraph> impl_graphs;
  impl_graphs.push_back(graph0);

  std::vector<ascir::ScheduledResult> schedule_results;
  ascir::ScheduledResult schedule_result;
  ascir::ScheduleGroup schedule_group;
  schedule_group.impl_graphs = impl_graphs;
  schedule_result.schedule_groups.push_back(schedule_group);
  schedule_results.push_back(schedule_result);

  ascir::FusedScheduledResult fused_schedule_result;
  fused_schedule_result.workspace_nodes.push_back(workspace_node);
  fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
  EXPECT_EQ(this->GenGetWorkspaceSizeFunc("AutofuseTilingData", fused_schedule_result), std::string{
    "uint32_t GetWorkspaceSize(const AutofuseTilingData &t) {\n"
    "  using namespace optiling;\n"
    "  uint32_t ws_size = 0;\n"
    "    if (t.tiling_key == 0) {\n"
    "      ws_size += 600;\n"
    "    }\n"
    "\n"
    "  ws_size = (ws_size + 512 - 1) / 512 * 512;\n"
    "  return ws_size;\n"
    "}\n"});
}

TEST_F(TestCodegenTiling, MultiGroupWorkspaceSymbolTest) {
  ascir::ImplGraph graph0("test_graph0");
  auto s0 = graph0.CreateSizeVar("s0");
  auto s1 = graph0.CreateSizeVar("s1");

  auto z0 = graph0.CreateAxis("z0", s0);
  auto z1 = graph0.CreateAxis("z1", s1);

  af::ascir_op::Workspace workspace("workspace");
  graph0.AddNode(workspace);
  workspace.y.dtype = ge::DT_FLOAT16;

  af::ascir_op::Load load("load");
  graph0.AddNode(load);
  load.x = workspace.y;
  load.attr.sched.axis = {z0.id, z1.id};
  *load.y.axis = {z0.id, z1.id};
  *load.y.repeats = {s0, s1};
  *load.y.strides = {s1, af::ops::One};

  auto load_node = graph0.FindNode("load");
  auto workspace_node = graph0.FindNode("workspace");

  workspace_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
  workspace_node->outputs[0].attr.mem.tensor_id = 0;

  load_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
  load_node->outputs[0].attr.mem.tensor_id = 2;

  ascir::ImplGraph graph1("test_graph1");
  s0 = graph1.CreateSizeVar("s0");
  s1 = graph1.CreateSizeVar("s1");

  z0 = graph1.CreateAxis("z0", s0);
  z1 = graph1.CreateAxis("z1", s1);

  af::ascir_op::Workspace workspace1("workspace1");
  graph1.AddNode(workspace1);
  workspace1.y.dtype = ge::DT_FLOAT16;

  af::ascir_op::Load load1("load1");
  graph1.AddNode(load1);
  load1.x = workspace1.y;
  load1.attr.sched.axis = {z0.id, z1.id};
  *load1.y.axis = {z0.id, z1.id};
  *load1.y.repeats = {s0, s1};
  *load1.y.strides = {s1, af::ops::One};

  auto load1_node = graph1.FindNode("load1");
  auto workspace1_node = graph1.FindNode("workspace1");

  workspace1_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
  workspace1_node->outputs[0].attr.mem.tensor_id = 1;

  load1_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
  load1_node->outputs[0].attr.mem.tensor_id = 3;

  std::vector<ascir::ScheduledResult> schedule_results;
  ascir::ScheduledResult schedule_result;
  ascir::ScheduleGroup sch_groups0;
  ascir::ScheduleGroup sch_groups1;
  sch_groups0.impl_graphs = {graph0};
  sch_groups1.impl_graphs = {graph1};
  schedule_result.schedule_groups.push_back(sch_groups0);
  schedule_result.schedule_groups.push_back(sch_groups1);
  schedule_results.push_back(schedule_result);

  ascir::FusedScheduledResult fused_schedule_result;
  fused_schedule_result.workspace_nodes.push_back(workspace_node);
  fused_schedule_result.workspace_nodes.push_back(workspace1_node);
  fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
  EXPECT_EQ(this->GenGetWorkspaceSizeFunc("AutofuseTilingData", fused_schedule_result), std::string{
    "uint32_t GetWorkspaceSize(const AutofuseTilingData &t) {\n"
    "  using namespace optiling;\n"
    "  uint32_t ws_size = 0;\n"
    "  if (t.graph0_tiling_key == 0) {\n"
    "    if (t.graph0_result0_g0_tiling_data.tiling_key == 0) {\n"
    "      ws_size += Max(0, (2 * Max(Max(1, t.graph0_result0_g0_tiling_data.s1), (t.graph0_result0_g0_tiling_data.s0 * t.graph0_result0_g0_tiling_data.s1))));\n"
    "    }\n"
    "    if (t.graph0_result0_g1_tiling_data.tiling_key == 0) {\n"
    "      ws_size += Max(0, (2 * Max(Max(1, t.graph0_result0_g1_tiling_data.s1), (t.graph0_result0_g1_tiling_data.s0 * t.graph0_result0_g1_tiling_data.s1))));\n"
    "    }\n"
    "  }\n"
    "  ws_size = (ws_size + 512 - 1) / 512 * 512;\n"
    "  return ws_size;\n"
    "}\n"
  });
}

TEST_F(TestCodegenTiling, MultiGroupWorkspaceValueTest) {
  ascir::ImplGraph graph0("test_graph0");
  auto s0 = graph0.CreateSizeVar(16);
  auto s1 = graph0.CreateSizeVar(32);

  auto z0 = graph0.CreateAxis("z0", s0);
  auto z1 = graph0.CreateAxis("z1", s1);

  af::ascir_op::Workspace workspace("workspace");
  graph0.AddNode(workspace);
  workspace.y.dtype = ge::DT_FLOAT16;

  af::ascir_op::Load load("load");
  graph0.AddNode(load);
  load.x = workspace.y;
  load.attr.sched.axis = {z0.id, z1.id};
  *load.y.axis = {z0.id, z1.id};
  *load.y.repeats = {s0, s1};
  *load.y.strides = {s1, af::ops::One};

  auto load_node = graph0.FindNode("load");
  auto workspace_node = graph0.FindNode("workspace");

  workspace_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
  workspace_node->outputs[0].attr.mem.tensor_id = 0;

  load_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
  load_node->outputs[0].attr.mem.tensor_id = 2;

  ascir::ImplGraph graph1("test_graph1");
  s0 = graph1.CreateSizeVar(5);
  s1 = graph1.CreateSizeVar(100);

  z0 = graph1.CreateAxis("z0", s0);
  z1 = graph1.CreateAxis("z1", s1);

  af::ascir_op::Workspace workspace1("workspace1");
  graph1.AddNode(workspace1);
  workspace1.y.dtype = ge::DT_FLOAT16;

  af::ascir_op::Load load1("load1");
  graph1.AddNode(load1);
  load1.x = workspace1.y;
  load1.attr.sched.axis = {z0.id, z1.id};
  *load1.y.axis = {z0.id, z1.id};
  *load1.y.repeats = {s0, s1};
  *load1.y.strides = {s1, af::ops::One};

  auto load1_node = graph1.FindNode("load1");
  auto workspace1_node = graph1.FindNode("workspace1");

  workspace1_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
  workspace1_node->outputs[0].attr.mem.tensor_id = 1;

  load1_node->outputs[0].attr.dtype = ge::DT_FLOAT16;
  load1_node->outputs[0].attr.mem.tensor_id = 3;

  std::vector<ascir::ScheduledResult> schedule_results;
  ascir::ScheduledResult schedule_result;
  ascir::ScheduleGroup sch_groups0;
  ascir::ScheduleGroup sch_groups1;
  sch_groups0.impl_graphs = {graph0};
  sch_groups1.impl_graphs = {graph1};
  schedule_result.schedule_groups.push_back(sch_groups0);
  schedule_result.schedule_groups.push_back(sch_groups1);
  schedule_results.push_back(schedule_result);

  ascir::FusedScheduledResult fused_schedule_result;
  fused_schedule_result.workspace_nodes.push_back(workspace_node);
  fused_schedule_result.workspace_nodes.push_back(workspace1_node);
  fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
  EXPECT_EQ(this->GenGetWorkspaceSizeFunc("AutofuseTilingData", fused_schedule_result), std::string{
    "uint32_t GetWorkspaceSize(const AutofuseTilingData &t) {\n"
    "  using namespace optiling;\n"
    "  uint32_t ws_size = 0;\n"
    "  if (t.graph0_tiling_key == 0) {\n"
    "    if (t.graph0_result0_g0_tiling_data.tiling_key == 0) {\n"
    "      ws_size += 1024;\n"
    "    }\n"
    "    if (t.graph0_result0_g1_tiling_data.tiling_key == 0) {\n"
    "      ws_size += 1000;\n"
    "    }\n"
    "  }\n"
    "  ws_size = (ws_size + 512 - 1) / 512 * 512;\n"
    "  return ws_size;\n"
    "}\n"});
}

TEST_F(TestCodegenTiling, EmptyTensorKernel) {
  af::AscGraph graph("test_graph");
  auto s0 = graph.CreateSizeVar("s0");
  auto z0 = graph.CreateAxis("z0", af::ops::Zero);

  af::ascir_op::Data x_op("x", graph);
  x_op.ir_attr.SetIndex(0);
  af::ascir_op::Load load_op("load");
  af::ascir_op::Store store_op("store");
  af::ascir_op::Output y_op("y");
  y_op.ir_attr.SetIndex(0);
  graph.AddNode(load_op);
  graph.AddNode(store_op);
  graph.AddNode(y_op);

  load_op.x = x_op.y;
  load_op.y.dtype = ge::DT_FLOAT16;
  store_op.x = load_op.y;
  y_op.x = store_op.y;

  auto x = graph.FindNode("x");
  auto load = graph.FindNode("load");
  auto store = graph.FindNode("store");
  auto y = graph.FindNode("y");

  x->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
  x->outputs[0].attr.mem.tensor_id = 0;
  x->attr.api.unit = af::ComputeUnit::kUnitNone;
  y->attr.api.unit = af::ComputeUnit::kUnitNone;

  load->outputs[0].attr.axis = {z0.id};
  load->outputs[0].attr.vectorized_axis = {z0.id};
  load->outputs[0].attr.vectorized_strides = {af::ops::One};
  load->outputs[0].attr.repeats = {z0.size};
  load->outputs[0].attr.strides = {af::ops::One};
  load->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
  load->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
  load->outputs[0].attr.mem.tensor_id = 1;
  load->outputs[0].attr.que.id = 0;
  load->outputs[0].attr.mem.reuse_id = 0;
  load->outputs[0].attr.que.depth = 2;
  load->outputs[0].attr.que.buf_num = 2;
  load->outputs[0].attr.opt.merge_scope = af::kIdNone;

  store->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
  store->outputs[0].attr.mem.tensor_id = 2;
  store->outputs[0].attr.axis = {z0.id};
  store->outputs[0].attr.vectorized_axis = {z0.id};
  store->outputs[0].attr.vectorized_strides = {af::ops::One};
  store->outputs[0].attr.repeats = {z0.size};
  store->outputs[0].attr.strides = {af::ops::One};

  ::ascir::ScheduledResult schedule_result;
  schedule_result.schedule_groups.resize(1);
  for (auto &schedule_group : schedule_result.schedule_groups) {
    schedule_group.impl_graphs.emplace_back(graph);
  }
  std::vector<ascir::ScheduledResult> schedule_results;
  schedule_results.push_back(schedule_result);
  schedule_results.push_back(schedule_result);

  ascir::FusedScheduledResult fused_schedule_result;
  fused_schedule_result.fused_graph_name = af::AscendString(graph.GetName().c_str());
  fused_schedule_result.input_nodes.push_back(x);
  fused_schedule_result.output_nodes.push_back(y);
  fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);
  const std::map<std::string, std::string> shape_info;
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");
  std::string tiling_func_declare {"TilingFunc(gert::TilingSymbolEvalContext *context)\n{\n"};
  auto pos = res["tiling_def_and_tiling_const"].find(tiling_func_declare) + tiling_func_declare.size();
  std::string expect_str {"  context->SetBlockDim(1);\n  *context->GetWorkspaceSizes(1) = 0;\n  return ge::GRAPH_SUCCESS;\n"};
  std::string tiling_func_content = res["tiling_def_and_tiling_const"].substr(pos, expect_str.size());
  EXPECT_EQ(expect_str, tiling_func_content);
}

TEST_F(TestCodegenTiling, TestGenDfxInputSymbolInfo) {
  std::map<std::string, std::string> shape_info;
  shape_info["s0"] = R"([&]() -> int64_t {
    const auto *tensor = context->GetInputTensor(0);
    if (tensor == nullptr) {
      return -1;
    }
    return tensor->GetOriginShape().GetDim(0);
  }())";
  shape_info["s1"] = R"([&]() -> int64_t {
    const auto *tensor = context->GetInputTensor(0);
    if (tensor == nullptr) {
      return -1;
    }
    return tensor->GetOriginShape().GetDim(1);
  }())";
  shape_info["s2"] = R"([&]() -> int64_t {
    const auto *tensor = context->GetInputTensor(1);
    if (tensor == nullptr) {
      return -1;
    }
    return tensor->GetOriginShape().GetDim(0);
  }())";

  ascir::FusedScheduledResult fused_schedule_result;
  std::vector<af::Expression> origin_vars{af::Symbol("s0"), af::Symbol("s1"), af::Symbol("s2")};
  fused_schedule_result.origin_vars = origin_vars;
  auto gen_func = this->GenDfxInputSymbolInfo(fused_schedule_result, shape_info);
  auto expect_func = R"(extern "C" ge::graphStatus DfxInputSymbolInfo(gert::TilingSymbolEvalContext *context, char *out_symbol_info, size_t size)
{
  if (out_symbol_info == nullptr || size == 0) {
    return ge::GRAPH_SUCCESS;
  }
  std::string symbol_info;
  auto s0 = [&]() -> int64_t {
    const auto *tensor = context->GetInputTensor(0);
    if (tensor == nullptr) {
      return -1;
    }
    return tensor->GetOriginShape().GetDim(0);
  }();
  symbol_info += ("s0: " + std::to_string(s0));

  auto s1 = [&]() -> int64_t {
    const auto *tensor = context->GetInputTensor(0);
    if (tensor == nullptr) {
      return -1;
    }
    return tensor->GetOriginShape().GetDim(1);
  }();
  symbol_info += (", s1: " + std::to_string(s1));

  auto s2 = [&]() -> int64_t {
    const auto *tensor = context->GetInputTensor(1);
    if (tensor == nullptr) {
      return -1;
    }
    return tensor->GetOriginShape().GetDim(0);
  }();
  symbol_info += (", s2: " + std::to_string(s2));


  if (symbol_info.empty()) {
    out_symbol_info[0] = '\0';
    return ge::GRAPH_SUCCESS;
  }
  symbol_info += ".";
  if (strncpy_s(out_symbol_info, size, symbol_info.c_str(), std::min(symbol_info.size(), size - 1)) != 0) {
    return ge::GRAPH_FAILED;
  }
  return ge::GRAPH_SUCCESS;
}
)";
  EXPECT_EQ(gen_func, expect_func);
}

TEST_F(TestCodegenTiling, TestCompileSuccess) {
  std::stringstream ss;

  ss << "#include <stdexcept>" << std::endl;
  ss << "#include <sstream>" << std::endl;
  ss << "#include <cmath>" << std::endl;
  ss << "#ifndef __CCE_KT_TEST__" << std::endl;
  ss << "#include \"register/op_def_registry.h\"" << std::endl;
  ss << "#include \"exe_graph/runtime/infer_shape_context.h\"" << std::endl;
  ss << "#include \"exe_graph/runtime/kernel_context.h\"" << std::endl;
  ss << "#include \"exe_graph/runtime/continuous_vector.h\"" << std::endl;
  ss << "#endif" << std::endl;

  ss << "#define Max(a, b) ((double)(a) > (double)(b) ? (a) : (b))" << std::endl;
  ss << "#define Min(a, b) ((double)(a) < (double)(b) ? (a) : (b))" << std::endl;
  ss << "#define Log(a) (log((double)(a)))" << std::endl;
  ss << "#define Pow(a, b) pow(a, b)" << std::endl;
  ss << "#define Rational(a, b) ((double)(a) / (double)(b))" << std::endl;

  std::string tiling_context = R"(
namespace gert {
  class TilingSymbolEvalContext : public TilingContext {
    public:
      const gert::Tensor *GetGraphInputTensor(size_t data_index) const {
        auto *tensor = GetInputPointer<gert::Tensor>(data_index + 1);
        if (tensor == nullptr) {
          return nullptr;
        }
        return tensor;
      }
  };
})";

  ss << tiling_context << std::endl;

  std::map<std::string, std::string> shape_info;
  shape_info["s0"] = R"([&]() -> int64_t {
    const auto *tensor = context->GetInputTensor(0);
    if (tensor == nullptr) {
      return -1;
    }
    return tensor->GetOriginShape().GetDim(0);
  }())";
  shape_info["s1"] = R"([&]() -> int64_t {
    const auto *tensor = context->GetInputTensor(0);
    if (tensor == nullptr) {
      return -1;
    }
    return tensor->GetOriginShape().GetDim(1);
  }())";
  shape_info["s2"] = R"([&]() -> int64_t {
    const auto *tensor = context->GetInputTensor(1);
    if (tensor == nullptr) {
      return -1;
    }
    return tensor->GetOriginShape().GetDim(0);
  }())";
  ascir::FusedScheduledResult fused_schedule_result;
  std::vector<af::Expression> origin_vars{af::Symbol("s0"), af::Symbol("s1"), af::Symbol("s2")};
  fused_schedule_result.origin_vars = origin_vars;
  auto dfx_func = this->GenDfxInputSymbolInfo(fused_schedule_result, shape_info);
  ss << dfx_func << std::endl;
  ASSERT_TRUE(CompileCode(ss.str()));
}

/*
 * Codegen FindBestTilingKey测试
 * 1、单graph,单result单group
 * 2、多graph,仅在inductor场景下有,本轮暂不支持
 * 3、单graph,多result组合场景
 *    result1:单group,单graph
 *    result2:单group,多graph
 *    result3:多group场景组合
 *             group1:单graph
 *             group2:多graph
 * 4、enable_group_parallel场景, 不支持生成
 */

TEST_F(TestCodegenTiling, TestGenFindBestTilingKeyFuncFor1Group) {
  af::AscGraph graph1("graph1");
  af::ascir_op::Workspace workspace("workspace");
  graph1.AddNode(workspace);

  af::AscGraph graph2("graph2");
  af::AscGraph graph3("graph3");

  ascir::ScheduleGroup schedule_group;
  schedule_group.impl_graphs.push_back(graph1);
  schedule_group.impl_graphs.push_back(graph2);
  schedule_group.impl_graphs.push_back(graph3);

  ascir::ScheduledResult schedule_result;
  schedule_result.schedule_groups.push_back(schedule_group);

  ascir::FusedScheduledResult fused_schedule_result;
  std::vector<ascir::ScheduledResult> graph0_results = {schedule_result};
  fused_schedule_result.node_idx_to_scheduled_results.emplace_back(std::move(graph0_results));

  const std::map<std::string, std::string> shape_info;
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::string expect = R"(extern "C" int64_t FindBestTilingKey(AutofuseTilingData &t)
{
  if (t.tiling_key == 0) {
    return 0;
  } else if (t.tiling_key == 1) {
    return 1;
  } else if (t.tiling_key == 2) {
    return 2;
  }
  return -1;
}
)";
  auto pos = res["tiling_def_and_tiling_const"].find("extern \"C\" int64_t FindBestTilingKey(AutofuseTilingData &t)");
  auto func = res["tiling_def_and_tiling_const"].substr(pos, expect.size());
  ASSERT_EQ(func, expect);
}

TEST_F(TestCodegenTiling, TestGenFindBestTilingKeyFuncForMultiResult) {
  af::AscGraph graph1("graph1");
  af::AscGraph graph2("graph2");

  ascir::ScheduleGroup schedule_group1;
  schedule_group1.impl_graphs.push_back(graph1);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(graph1);
  schedule_group2.impl_graphs.push_back(graph2);


  // 单group,单graph
  ascir::ScheduledResult schedule_result1;
  schedule_result1.schedule_groups.push_back(schedule_group1);

  // 单group,多graph
  ascir::ScheduledResult schedule_result2;
  schedule_result2.schedule_groups.push_back(schedule_group2);

  // 多group
  // group1单graph
  // group2多graph
  ascir::ScheduledResult schedule_result3;
  schedule_result3.schedule_groups.push_back(schedule_group1);
  schedule_result3.schedule_groups.push_back(schedule_group2);

  ascir::FusedScheduledResult fused_schedule_result;
  std::vector<ascir::ScheduledResult> graph0_results = {schedule_result1, schedule_result2, schedule_result3};
  fused_schedule_result.node_idx_to_scheduled_results.emplace_back(std::move(graph0_results));

  const std::map<std::string, std::string> shape_info;
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::string expect = R"(extern "C" int64_t FindBestTilingKey(AutofuseTilingData &t)
{
  if (t.graph0_tiling_key == 0) {
    if (t.graph0_result0_g0_tiling_data.tiling_key == 0) {
      return 0;
    }
  }  else if (t.graph0_tiling_key == 1) {
    if (t.graph0_result1_g0_tiling_data.tiling_key == 0) {
      return 1;
    } else if (t.graph0_result1_g0_tiling_data.tiling_key == 1) {
      return 2;
    }
  }  else if (t.graph0_tiling_key == 2) {
    if (t.graph0_result2_g0_tiling_data.tiling_key == 0 && t.graph0_result2_g1_tiling_data.tiling_key == 0) {
      return 3;
    } else if (t.graph0_result2_g0_tiling_data.tiling_key == 0 && t.graph0_result2_g1_tiling_data.tiling_key == 1) {
      return 4;
    }
  }
  return -1;
}
)";
  auto pos = res["tiling_def_and_tiling_const"].find("extern \"C\" int64_t FindBestTilingKey(AutofuseTilingData &t)");
  auto func = res["tiling_def_and_tiling_const"].substr(pos, expect.size());
  ASSERT_EQ(func, expect);
}

TEST_F(TestCodegenTiling, TestGenFindBestTilingKeyFuncForEnableParallel) {
  af::AscGraph graph1("graph1");
  af::AscGraph graph2("graph2");

  ascir::ScheduleGroup schedule_group1;
  schedule_group1.impl_graphs.push_back(graph1);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(graph1);
  schedule_group2.impl_graphs.push_back(graph2);


  // 单group,单graph
  ascir::ScheduledResult schedule_result1;
  schedule_result1.schedule_groups.push_back(schedule_group1);

  // 单group,多graph
  ascir::ScheduledResult schedule_result2;
  schedule_result2.schedule_groups.push_back(schedule_group2);

  // 多group
  // group1单graph
  // group2多graph
  ascir::ScheduledResult schedule_result3;
  schedule_result3.enable_group_parallel = true;
  schedule_result3.schedule_groups.push_back(schedule_group1);
  schedule_result3.schedule_groups.push_back(schedule_group2);

  ascir::FusedScheduledResult fused_schedule_result;
  std::vector<ascir::ScheduledResult> graph0_results = {schedule_result1, schedule_result2, schedule_result3};
  fused_schedule_result.node_idx_to_scheduled_results.emplace_back(std::move(graph0_results));

  const std::map<std::string, std::string> shape_info;
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  auto pos = res["tiling_def_and_tiling_const"].find("extern \"C\" int64_t FindBestTilingKey(AutofuseTilingData &t)");
  ASSERT_EQ(pos, std::string::npos);
}

TEST_F(TestCodegenTiling, TestGenExternTilingFunc) {
  ge::PlatformContext::GetInstance().Reset();
  auto stub_v2 = std::make_shared<ge::RuntimeStubV2Common>();
  ge::RuntimeStub::SetInstance(stub_v2);
  af::AscGraph graph1("graph1");
  af::AscGraph graph2("graph2");

  ascir::ScheduleGroup schedule_group1;
  schedule_group1.impl_graphs.push_back(graph1);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(graph1);
  schedule_group2.impl_graphs.push_back(graph2);


  // 单group,单graph
  ascir::ScheduledResult schedule_result1;
  schedule_result1.schedule_groups.push_back(schedule_group1);

  // 单group,多graph
  ascir::ScheduledResult schedule_result2;
  schedule_result2.schedule_groups.push_back(schedule_group2);

  // 多group
  // group1单graph
  // group2多graph
  ascir::ScheduledResult schedule_result3;
  schedule_result3.enable_group_parallel = true;
  schedule_result3.schedule_groups.push_back(schedule_group1);
  schedule_result3.schedule_groups.push_back(schedule_group2);

  ascir::FusedScheduledResult fused_schedule_result;
  std::vector<ascir::ScheduledResult> graph0_results = {schedule_result1, schedule_result2, schedule_result3};
  fused_schedule_result.node_idx_to_scheduled_results.emplace_back(std::move(graph0_results));

  const std::map<std::string, std::string> shape_info;
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  auto pos = res["tiling_def_and_tiling_const"].find("extern \"C\" int64_t FindBestTilingKey(AutofuseTilingData &t)");
  ASSERT_EQ(pos, std::string::npos);
  ge::RuntimeStub::Reset();
  ge::PlatformContext::GetInstance().Reset();
}

TEST_F(TestCodegenTiling, TestPGOSearchTensorMallocDef) {
  af::AscGraph graph("test_graph");
  auto s0 = graph.CreateSizeVar("s0");
  auto z0 = graph.CreateAxis("z0", af::ops::One);

  af::ascir_op::Data x_op("x", graph);
  x_op.ir_attr.SetIndex(0);
  af::ascir_op::Load load_op("load");
  af::ascir_op::Store store_op("store");
  af::ascir_op::Output y_op("y");
  y_op.ir_attr.SetIndex(0);
  graph.AddNode(load_op);
  graph.AddNode(store_op);
  graph.AddNode(y_op);

  load_op.x = x_op.y;
  load_op.y.dtype = ge::DT_FLOAT16;
  store_op.x = load_op.y;
  y_op.x = store_op.y;

  auto x = graph.FindNode("x");
  auto load = graph.FindNode("load");
  auto store = graph.FindNode("store");
  auto y = graph.FindNode("y");

  x->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
  x->outputs[0].attr.mem.tensor_id = 0;
  x->attr.api.unit = af::ComputeUnit::kUnitNone;
  y->attr.api.unit = af::ComputeUnit::kUnitNone;

  load->outputs[0].attr.axis = {z0.id};
  load->outputs[0].attr.vectorized_axis = {z0.id};
  load->outputs[0].attr.vectorized_strides = {af::ops::One};
  load->outputs[0].attr.repeats = {z0.size};
  load->outputs[0].attr.strides = {af::ops::One};
  load->outputs[0].attr.mem.position = af::Position::kPositionVecIn;
  load->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
  load->outputs[0].attr.mem.tensor_id = 1;
  load->outputs[0].attr.que.id = 0;
  load->outputs[0].attr.mem.reuse_id = 0;
  load->outputs[0].attr.que.depth = 2;
  load->outputs[0].attr.que.buf_num = 2;
  load->outputs[0].attr.opt.merge_scope = af::kIdNone;

  store->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
  store->outputs[0].attr.mem.tensor_id = 2;
  store->outputs[0].attr.axis = {z0.id};
  store->outputs[0].attr.vectorized_axis = {z0.id};
  store->outputs[0].attr.vectorized_strides = {af::ops::One};
  store->outputs[0].attr.repeats = {z0.size};
  store->outputs[0].attr.strides = {af::ops::One};

  y->inputs[0].attr.repeats = {z0.size};
  y->inputs[0].attr.strides = {af::ops::One};
  y->inputs[0].attr.dtype = ge::DT_FLOAT16;

  ::ascir::ScheduledResult schedule_result;
  schedule_result.schedule_groups.resize(1);
  for (auto &schedule_group : schedule_result.schedule_groups) {
    schedule_group.impl_graphs.emplace_back(graph);
  }
  std::vector<ascir::ScheduledResult> schedule_results;
  schedule_results.push_back(schedule_result);
  schedule_results.push_back(schedule_result);

  ascir::FusedScheduledResult fused_schedule_result;
  fused_schedule_result.fused_graph_name = af::AscendString(graph.GetName().c_str());
  fused_schedule_result.input_nodes.push_back(x);
  fused_schedule_result.output_nodes.push_back(y);
  fused_schedule_result.node_idx_to_scheduled_results.push_back(schedule_results);

  std::string mallocdef = this->PGOSearchTensorMallocDef(fused_schedule_result);
  const std::string expect = R"(  size_t input0_size = 2;
  ret = aclrtMalloc(&input0, input0_size, ACL_MEM_MALLOC_HUGE_FIRST);
  if (ret != ACL_SUCCESS) {
    DLOGE("aclrtMalloc input0 failed. ERROR: %d", ret);
    return FAILED;
  }
  size_t output0_size = 2;
  ret = aclrtMalloc(&output0, output0_size, ACL_MEM_MALLOC_HUGE_FIRST);
  if (ret != ACL_SUCCESS) {
    DLOGE("aclrtMalloc output0 failed. ERROR: %d", ret);
    return FAILED;
  }
)";
  ASSERT_EQ(mallocdef, expect);
}

TEST_F(TestCodegenTiling, TestCalculateTensorMemorySizeStrWithNoRepeatsOrStrides) {
  af::AscGraph graph("test_graph");
  auto s0 = graph.CreateSizeVar("s0");
  auto z0 = graph.CreateAxis("z0", af::ops::One);

  af::ascir_op::Data x_op("x", graph);
  x_op.ir_attr.SetIndex(0);

  auto x = graph.FindNode("x");

  x->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
  x->outputs[0].attr.mem.tensor_id = 0;
  x->attr.api.unit = af::ComputeUnit::kUnitNone;
  x->outputs[0].attr.repeats = {};
  x->outputs[0].attr.strides = {};

  std::string memory_size = this->CalculateTensorMemorySizeStr(x->outputs[0]);
  const std::string expect = "0";
  ASSERT_EQ(memory_size, expect);
}

TEST_F(TestCodegenTiling, TestCalculateTensorMemorySizeStrWithZeroFirstStride) {
  af::AscGraph graph("test_graph");
  auto s0 = graph.CreateSizeVar("s0");
  auto s1 = graph.CreateSizeVar("s1");
  auto z0 = graph.CreateAxis("z0", s0);
  auto z1 = graph.CreateAxis("z1", s1);

  af::ascir_op::Data x_op("x", graph);
  x_op.ir_attr.SetIndex(0);

  auto x = graph.FindNode("x");

  x->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
  x->outputs[0].attr.mem.tensor_id = 0;
  x->attr.api.unit = af::ComputeUnit::kUnitNone;
  x->outputs[0].attr.dtype = ge::DT_FLOAT16;
  x->outputs[0].attr.axis = {z0.id, z1.id};
  x->outputs[0].attr.repeats = {s0, s1};
  x->outputs[0].attr.strides = {af::ops::Zero, af::ops::One};

  std::string memory_size = this->CalculateTensorMemorySizeStr(x->outputs[0]);
  const std::string expect = std::string(af::sym::Mul(s1, af::Expression::Parse("2")).Simplify().Str().get());
  ASSERT_EQ(memory_size, expect);
}

TEST_F(TestCodegenTiling, TestCalculateTensorMemorySizeStrWithOnlyZeroStride) {
  af::AscGraph graph("test_graph");
  auto s0 = graph.CreateSizeVar("s0");
  auto z0 = graph.CreateAxis("z0", s0);

  af::ascir_op::Data x_op("x", graph);
  x_op.ir_attr.SetIndex(0);

  auto x = graph.FindNode("x");

  x->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeGlobal;
  x->outputs[0].attr.mem.tensor_id = 0;
  x->attr.api.unit = af::ComputeUnit::kUnitNone;
  x->outputs[0].attr.dtype = ge::DT_FLOAT16;
  x->outputs[0].attr.axis = {z0.id};
  x->outputs[0].attr.repeats = {s0};
  x->outputs[0].attr.strides = {af::ops::Zero};

  std::string memory_size = this->CalculateTensorMemorySizeStr(x->outputs[0]);
  const std::string expect = std::string(af::sym::Mul(af::ops::One, af::Expression::Parse("2")).Simplify().Str().get());
  ASSERT_EQ(memory_size, expect);
}

void CreateMatmulGraph(af::AscGraph &graph, bool is_dynamic = false) {
  af::Expression s0;
  af::Expression s1;
  if (is_dynamic) {
    s0 = graph.CreateSizeVar("s0");
    s1 = graph.CreateSizeVar("s1");
  } else {
    s0 = graph.CreateSizeVar(31);
    s1 = graph.CreateSizeVar(1);
  }
  auto z0 = graph.CreateAxis("z0", s0);
  auto z1 = graph.CreateAxis("z1", s1);

  af::ascir_op::Data data0("data0", graph);
  data0.attr.sched.axis = {z0.id, z1.id};
  data0.y.dtype = ge::DT_FLOAT16;
  *data0.y.axis = {z0.id, z1.id};
  data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0.y.strides = {s1 ,af::ops::One};
  *data0.y.repeats = {s0, s1};
  data0.ir_attr.SetIndex(0);

  af::ascir_op::Load load0("load0");
  load0.attr.sched.axis = {z0.id, z1.id};
  load0.x = data0.y;
  *load0.y.axis = {z0.id, z1.id};
  load0.y.dtype = ge::DT_FLOAT16;
  *load0.y.strides = {s1 ,af::ops::One};
  *load0.y.repeats = {s0, s1};

  af::ascir_op::Data data1("data1", graph);
  data1.y.dtype = ge::DT_FLOAT16;
  data1.attr.sched.axis = {z0.id, z1.id};
  *data1.y.axis = {z0.id, z1.id};
  data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data1.y.repeats = {af::ops::One, af::ops::One};
  *data1.y.strides = {af::ops::Zero, af::ops::Zero};
  data1.ir_attr.SetIndex(1);

  af::ascir_op::Load load1("load1");
  load1.x = data1.y;
  load1.attr.sched.axis = {z0.id, z1.id};
  load1.y.dtype = ge::DT_FLOAT16;
  *load1.y.axis = {z0.id, z1.id};
  *load1.y.strides = {af::ops::Zero, af::ops::Zero};
  *load1.y.repeats = {af::ops::One, af::ops::One};

  af::ascir_op::BatchMatMul matmul("matmul");
  matmul.attr.sched.axis = {z0.id, z1.id};
  matmul.x1 = load0.y;
  matmul.x2 = load1.y;
  matmul.y.dtype = ge::DT_FLOAT;
  *matmul.y.axis = {z0.id, z1.id};
  *matmul.y.repeats = {s0, s1};
  *matmul.y.strides = {s1, af::ops::One};
  matmul.attr.api.compute_type = af::ComputeType::kComputeCube;
  matmul.ir_attr.SetAdj_x1(1);
  matmul.ir_attr.SetAdj_x2(0);
  matmul.ir_attr.SetHas_relu(1);
  matmul.ir_attr.SetEnable_hf32(1);
  matmul.ir_attr.SetOffset_x(6);

  af::ascir_op::Store store_op("store");
  store_op.attr.sched.axis = {z0.id, z1.id};
  store_op.x = matmul.y;
  *store_op.y.axis = {z0.id, z1.id};
  store_op.y.dtype = ge::DT_FLOAT;
  *store_op.y.strides = {s1 ,af::ops::One};
  *store_op.y.repeats = {s0, s1};
  store_op.ir_attr.SetOffset(af::ops::One);

  af::ascir_op::Output output_op("output");
  output_op.x = store_op.y;
  output_op.y.dtype = ge::DT_FLOAT;
  output_op.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
}

TEST_F(TestCodegenTiling, TestMatmulElemwiseFuse) {
  af::AscGraph graph("matmul_elemwise_pro");

  auto s0 = graph.CreateSizeVar(64);
  auto s1 = graph.CreateSizeVar(64);
  auto z0 = graph.CreateAxis("z0", s0);
  auto z1 = graph.CreateAxis("z1", s1);

  af::ascir_op::Data data0("data0", graph);
  data0.attr.sched.axis = {z0.id, z1.id};
  data0.y.dtype = ge::DT_FLOAT;
  *data0.y.axis = {z0.id, z1.id};
  data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0.y.strides = {s1 ,af::ops::One};
  *data0.y.repeats = {s0, s1};
  data0.ir_attr.SetIndex(0);

  af::ascir_op::Load load0("load0");
  load0.attr.sched.axis = {z0.id, z1.id};
  load0.x = data0.y;
  *load0.y.axis = {z0.id, z1.id};
  load0.y.dtype = ge::DT_FLOAT;
  *load0.y.strides = {s1 ,af::ops::One};
  *load0.y.repeats = {s0, s1};

  af::ascir_op::Abs abs("abs");
  graph.AddNode(abs);
  abs.x = load0.y;
  abs.attr.sched.axis = {z0.id, z1.id};
  abs.y.dtype = ge::DT_FLOAT;
  *abs.y.axis = {z0.id, z1.id};
  *abs.y.repeats = {s0, s1};
  *abs.y.strides = {s1, af::ops::One};
  abs.attr.api.compute_type = af::ComputeType::kComputeElewise;

  af::ascir_op::Scalar scalar0("scalar0", graph);
  scalar0.attr.sched.axis = {z0.id, z1.id};
  scalar0.ir_attr.SetValue("0");
  scalar0.y.dtype = ge::DT_FLOAT;
  *scalar0.y.axis = {z0.id, z1.id};
  *scalar0.y.repeats = {af::ops::One, af::ops::One};
  *scalar0.y.strides = {af::ops::Zero, af::ops::Zero};

  af::ascir_op::Broadcast broadcast0("broadcast0");
  broadcast0.x = scalar0.y;
  broadcast0.attr.sched.axis = {z0.id, z1.id};
  *broadcast0.y.axis = {z0.id, z1.id};
  broadcast0.y.dtype = ge::DT_FLOAT;
  *broadcast0.y.repeats = {af::ops::One, s1};
  *broadcast0.y.strides = {af::ops::Zero, af::ops::One};

  af::ascir_op::Broadcast broadcast1("broadcast1");
  broadcast1.x = broadcast0.y;
  broadcast1.attr.sched.axis = {z0.id, z1.id};
  *broadcast1.y.axis = {z0.id, z1.id};
  broadcast1.y.dtype = ge::DT_FLOAT;
  *broadcast1.y.repeats = {s0, s1};
  *broadcast1.y.strides = {s1, af::ops::One};

  af::ascir_op::Add add_op("add");
  add_op.attr.sched.axis = {z0.id, z1.id};
  add_op.x1 = abs.y;
  add_op.x2 = broadcast1.y;
  add_op.y.dtype = ge::DT_FLOAT;
  *add_op.y.axis = {z0.id, z1.id};
  *add_op.y.repeats = {s0, s1};
  *add_op.y.strides = {s1, af::ops::One};

  af::ascir_op::Data data1("data1", graph);
  data1.y.dtype = ge::DT_FLOAT;
  data1.attr.sched.axis = {z0.id, z1.id};
  *data1.y.axis = {z0.id, z1.id};
  data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data1.y.repeats = {af::ops::One, af::ops::One};
  *data1.y.strides = {af::ops::Zero, af::ops::Zero};
  data1.ir_attr.SetIndex(1);

  af::ascir_op::Load load1("load1");
  load1.x = data1.y;
  load1.attr.sched.axis = {z0.id, z1.id};
  load1.y.dtype = ge::DT_FLOAT;
  *load1.y.axis = {z0.id, z1.id};
  *load1.y.strides = {af::ops::Zero, af::ops::Zero};
  *load1.y.repeats = {af::ops::One, af::ops::One};

  af::ascir_op::Broadcast broadcast2("broadcast2");
  broadcast2.x = load1.y;
  broadcast2.attr.sched.axis = {z0.id, z1.id};
  *broadcast2.y.axis = {z0.id, z1.id};
  broadcast2.y.dtype = ge::DT_FLOAT;
  *broadcast2.y.repeats = {af::ops::One, s1};
  *broadcast2.y.strides = {af::ops::Zero, af::ops::One};

  af::ascir_op::Broadcast broadcast3("broadcast3");
  broadcast3.x = broadcast2.y;
  broadcast3.attr.sched.axis = {z0.id, z1.id};
  *broadcast3.y.axis = {z0.id, z1.id};
  broadcast3.y.dtype = ge::DT_FLOAT;
  *broadcast3.y.repeats = {s0, s1};
  *broadcast3.y.strides = {s1, af::ops::One};

  af::ascir_op::Mul mul("mul");
  mul.attr.sched.axis = {z0.id, z1.id};
  mul.x1 = add_op.y;
  mul.x2 = broadcast3.y;
  mul.y.dtype = ge::DT_FLOAT;
  *mul.y.axis = {z0.id, z1.id};
  *mul.y.repeats = {s0, s1};
  *mul.y.strides = {s1, af::ops::One};

  af::ascir_op::Store store_op("store");
  store_op.attr.sched.axis = {z0.id, z1.id};
  store_op.x = mul.y;
  *store_op.y.axis = {z0.id, z1.id};
  store_op.y.dtype = ge::DT_FLOAT;
  *store_op.y.strides = {s1 ,af::ops::One};
  *store_op.y.repeats = {s0, s1};

  af::ascir_op::Output output_op("output");
  output_op.x = store_op.y;
  output_op.y.dtype = ge::DT_FLOAT;
  output_op.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(graph);

  auto x1Local = graph.FindNode("data0");
  x1Local->outputs[0].attr.mem.alloc_type = af::AllocType::kAllocTypeQueue;
  x1Local->outputs[0].attr.mem.hardware = af::MemHardware::kMemHardwareUB;
  x1Local->outputs[0].attr.mem.position = af::Position::kPositionVecIn;

  af::AscGraph mm_graph("mutmul");
  CreateMatmulGraph(mm_graph);
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ascir::FusedScheduledResult fused_schedule_result;
  EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(mm_graph);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;;
  const std::map<std::string, std::string> shape_info;
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::fstream tiling_func("Mutmul_fuse_tiling_func.cpp", std::ios::out);
tiling_func << res["tiling_def_and_tiling_const"];

  auto pos = res["tiling_def_and_tiling_const"].find("extern \"C\" int64_t FindBestTilingKey");
  ASSERT_NE(pos, std::string::npos);
}

// ==================== Conv2DOffset 融合测试 ====================

TEST_F(TestCodegenTiling, TestConv2DOffsetFuse) {
  af::AscGraph graph("conv2d_offset_elemwise_pro");
  CreateElemwiseGraphWithRelu(graph);
  
  af::AscGraph conv2d_offset_graph("conv2d_offset");
  CreateConv2DOffsetGraph(conv2d_offset_graph);

  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ascir::FusedScheduledResult fused_schedule_result;
  EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(conv2d_offset_graph);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
  
  const std::map<std::string, std::string> shape_info;
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::fstream tiling_func("Conv2d_offset_fuse_tiling_func.cpp", std::ios::out);
  tiling_func << res["tiling_def_and_tiling_const"];

  VerifyConv2DOffsetTiling(res);
}

// ==================== Conv2DOffsetBias 融合测试 ====================

TEST_F(TestCodegenTiling, TestConv2DOffsetBiasFuse) {
  af::AscGraph graph("conv2d_offset_bias_elemwise_pro");
  CreateElemwiseGraphWithRelu(graph);
  
  af::AscGraph conv2d_offset_bias_graph("conv2d_offset_bias");
  CreateConv2DOffsetBiasGraph(conv2d_offset_bias_graph);

  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ascir::FusedScheduledResult fused_schedule_result;
  EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(conv2d_offset_bias_graph);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
  
  const std::map<std::string, std::string> shape_info;
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::fstream tiling_func("Conv2d_offset_bias_fuse_tiling_func.cpp", std::ios::out);
  tiling_func << res["tiling_def_and_tiling_const"];

  VerifyTilingCodeBasic(res);
}

namespace {
static ascir::FusedScheduledResult GenMultiGroupFusedScheduleResult() {
  af::AscGraph graph1("graph1");
  af::AscGraph graph2("graph2");

  ascir::ScheduleGroup schedule_group1;
  schedule_group1.impl_graphs.push_back(graph1);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(graph1);
  schedule_group2.impl_graphs.push_back(graph2);

  ascir::ScheduledResult schedule_result;
  schedule_result.schedule_groups.push_back(schedule_group1);
  schedule_result.schedule_groups.push_back(schedule_group2);

  ascir::FusedScheduledResult fused_schedule_result;
  std::vector<ascir::ScheduledResult> graph0_results = {schedule_result};
  fused_schedule_result.node_idx_to_scheduled_results.emplace_back(std::move(graph0_results));
  return fused_schedule_result;
}

static ascir::ImplGraph GenGraphWithSizeVar(const std::string &graph_name, const std::string &var_name) {
  ascir::ImplGraph graph(graph_name.c_str());
  auto size = graph.CreateSizeVar(var_name.c_str());
  (void)graph.CreateAxis("z0", size);
  return graph;
}

static ascir::FusedScheduledResult GenMultiGroupFusedScheduleResultWithSizeVar(const std::string &var_name) {
  ascir::ScheduleGroup schedule_group1;
  schedule_group1.impl_graphs.push_back(GenGraphWithSizeVar("graph1", var_name));

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(GenGraphWithSizeVar("graph2", var_name));

  ascir::ScheduledResult schedule_result;
  schedule_result.schedule_groups.push_back(schedule_group1);
  schedule_result.schedule_groups.push_back(schedule_group2);

  ascir::FusedScheduledResult fused_schedule_result;
  fused_schedule_result.origin_vars.push_back(af::Symbol(var_name.c_str()));
  fused_schedule_result.node_idx_to_scheduled_results.push_back({schedule_result});
  return fused_schedule_result;
}
}  // namespace

TEST_F(TestCodegenTiling, GenerateForInductorGetTilingDataReprShouldContainStableFields) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("GetTilingDataRepr returns a valid C++ designated initializer string"),
            std::string::npos);
  EXPECT_NE(tiling_impl.find("emit_field(\"block_dim\", tiling_data->get_block_dim()"), std::string::npos);
  EXPECT_NE(tiling_impl.find("emit_field(\"corenum\", tiling_data->get_corenum()"), std::string::npos);
  EXPECT_NE(tiling_impl.find("emit_field(\"ub_size\", tiling_data->get_ub_size()"), std::string::npos);
  EXPECT_NE(tiling_impl.find("emit_field(\"hbm_size\", tiling_data->get_hbm_size()"), std::string::npos);
  EXPECT_TRUE(tiling_impl.find("emit_field(\"tiling_key\"") != std::string::npos ||
              tiling_impl.find("emit_field(\"graph0_tiling_key\"") != std::string::npos);
}

TEST_F(TestCodegenTiling, GenerateForInductorGetTilingDataReprShouldKeepWorkspaceBeforeSymbols) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s1"), af::Symbol("s0")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  size_t tiling_key_pos = tiling_impl.find("emit_field(\"tiling_key\"");
  if (tiling_key_pos == std::string::npos) {
    tiling_key_pos = tiling_impl.find("emit_field(\"graph0_tiling_key\"");
  }
  const auto s0_pos = tiling_impl.find("emit_field(\"s0\"");
  const auto s1_pos = tiling_impl.find("emit_field(\"s1\"");
  ASSERT_NE(tiling_key_pos, std::string::npos);
  if (s0_pos != std::string::npos) {
    EXPECT_LT(tiling_key_pos, s0_pos);
  }
  if (s1_pos != std::string::npos) {
    EXPECT_LT(tiling_key_pos, s1_pos);
  }
}

TEST_F(TestCodegenTiling, GenerateForInductorGetTilingDataReprShouldUseGraphLevelTilingKeysForMultiGroup) {
  auto fused_schedule_result = GenMultiGroupFusedScheduleResult();
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("emit_field(\"graph0_tiling_key\""),
            std::string::npos);
  EXPECT_EQ(tiling_impl.find("emit_field(\"tiling_key\""), std::string::npos);
}

TEST_F(TestCodegenTiling, GenerateForInductorGetTilingDataReprShouldKeepZeroValuedStableFields) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("emit_field(\"block_dim\", tiling_data->get_block_dim()"), std::string::npos);
  EXPECT_NE(tiling_impl.find("emit_field(\"corenum\", tiling_data->get_corenum()"), std::string::npos);
  EXPECT_NE(tiling_impl.find("emit_field(\"ub_size\", tiling_data->get_ub_size()"), std::string::npos);
  EXPECT_NE(tiling_impl.find("emit_field(\"hbm_size\", tiling_data->get_hbm_size()"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("if (tiling_data->get_block_dim() != 0)"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("if (tiling_data->get_corenum() != 0)"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("if (tiling_data->get_ub_size() != 0)"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("if (tiling_data->get_hbm_size() != 0)"), std::string::npos);
}

TEST_F(TestCodegenTiling, GenerateForInductorShouldContainTopnMainOutputAbi) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("extern \"C\" int64_t GenerateTopnSolutions("), std::string::npos);
  EXPECT_NE(tiling_impl.find("GetTilingDataRepr("), std::string::npos);
}

TEST_F(TestCodegenTiling, GenerateForInductorShouldUseGetTilingDataReprAsTilingDataValidationAid) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("GetTilingDataRepr("), std::string::npos);
}

TEST_F(TestCodegenTiling, GenerateForInductorTopnAbiShouldNotEmitOutputConfigsMetadataLogic) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_EQ(tiling_impl.find("std::map<std::string, std::string> solution_config;"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("configs.push_back(solution_config);"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("solution_config[\"canonical_repr\"]"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("solution_config[\"topn_status\"]"), std::string::npos);
}

TEST_F(TestCodegenTiling, MultiGroupInductorShouldContainTopnMainOutputAbi) {
  auto fused_schedule_result = GenMultiGroupFusedScheduleResult();
  auto res = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(res.find(codegen::kTilingDefAndConstIdentify) != res.end());
  const auto &tiling_impl = res.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("extern \"C\" int64_t GenerateTopnSolutions("), std::string::npos);
  EXPECT_NE(tiling_impl.find("const std::vector<std::map<std::string, std::string>> &input_configs"), std::string::npos);
  EXPECT_NE(tiling_impl.find("std::vector<AutofuseTilingData> &tiling_datas"), std::string::npos);
  EXPECT_NE(tiling_impl.find("std::vector<int64_t> &workspaces"), std::string::npos);
  EXPECT_NE(tiling_impl.find("std::vector<int64_t> &block_dims"), std::string::npos);
}

TEST_F(TestCodegenTiling, MultiGroupTopnShouldSetShapeDimOnGroupTilingData) {
  auto fused_schedule_result = GenMultiGroupFusedScheduleResultWithSizeVar("ks0");
  auto res = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(res.find(codegen::kTilingDefAndConstIdentify) != res.end());
  const auto &tiling_impl = res.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_EQ(tiling_impl.find("search_tiling.set_ks0("), std::string::npos);
  EXPECT_NE(tiling_impl.find("search_tiling.graph0_result0_g0_tiling_data.set_ks0(ks0);"), std::string::npos);
  EXPECT_NE(tiling_impl.find("search_tiling.graph0_result0_g1_tiling_data.set_ks0(ks0);"), std::string::npos);
}

TEST_F(TestCodegenTiling, MultiGroupInductorShouldContainReprAbi) {
  auto fused_schedule_result = GenMultiGroupFusedScheduleResult();
  auto res = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(res.find(codegen::kTilingDefAndConstIdentify) != res.end());
  const auto &tiling_impl = res.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("std::string GetTilingDataRepr(const AutofuseTilingData *tiling_data)"), std::string::npos);
}

TEST_F(TestCodegenTiling, TestMatmulElemwiseDynamicShapeFuse) {
  af::AscGraph graph("matmul_elemwise_pro");
  CreateMatmulElemwiseDynamicGraph(graph);
  
  af::AscGraph mm_graph("mutmul");
  CreateMatmulGraph(mm_graph, true);
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ascir::FusedScheduledResult fused_schedule_result;
  EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(mm_graph);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
  
  std::map<std::string, std::string> shape_info;
  shape_info["s0"] = "64";
  shape_info["s1"] = "64";
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::fstream tiling_func("Mutmul_fuse_tiling_func.cpp", std::ios::out);
  tiling_func << res["tiling_def_and_tiling_const"];

  auto pos = res["tiling_def_and_tiling_const"].find("TilingResult result = wrapper.DoMatMulTiling(");
  ASSERT_NE(pos, std::string::npos);
}

// ==================== Conv2D 相关辅助函数 ====================

void CreateConv2dGraph(af::AscGraph &graph, bool is_dynamic = false) {
  af::Expression n, c, h, w;
  if (is_dynamic) {
    n = graph.CreateSizeVar("n");
    c = graph.CreateSizeVar("c");
    h = graph.CreateSizeVar("h");
    w = graph.CreateSizeVar("w");
  } else {
    n = graph.CreateSizeVar(1);
    c = graph.CreateSizeVar(64);
    h = graph.CreateSizeVar(56);
    w = graph.CreateSizeVar(56);
  }
  
  auto z_n = graph.CreateAxis("z_n", n);
  auto z_c = graph.CreateAxis("z_c", c);
  auto z_h = graph.CreateAxis("z_h", h);
  auto z_w = graph.CreateAxis("z_w", w);

  af::ascir_op::Data data0("data0", graph);
  data0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data0.y.dtype = ge::DT_FLOAT16;
  *data0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0.y.strides = {c*h*w, h*w, w, af::ops::One};
  *data0.y.repeats = {n, c, h, w};
  data0.ir_attr.SetIndex(0);

  af::ascir_op::Load load0("load0");
  load0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load0.x = data0.y;
  *load0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load0.y.dtype = ge::DT_FLOAT16;
  *load0.y.strides = {c*h*w, h*w, w, af::ops::One};
  *load0.y.repeats = {n, c, h, w};

  af::ascir_op::Data data1("data1", graph);
  data1.y.dtype = ge::DT_FLOAT16;
  data1.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *data1.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data1.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
  *data1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  data1.ir_attr.SetIndex(1);

  af::ascir_op::Load load1("load1");
  load1.x = data1.y;
  load1.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load1.y.dtype = ge::DT_FLOAT16;
  *load1.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *load1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  *load1.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};

  af::ascir_op::Conv2D conv2d("conv2d");
  conv2d.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  conv2d.x = load0.y;
  conv2d.filter = load1.y;
  conv2d.y.dtype = ge::DT_FLOAT;
  *conv2d.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *conv2d.y.repeats = {n, c, h, w};
  *conv2d.y.strides = {c*h*w, h*w, w, af::ops::One};
  conv2d.attr.api.compute_type = af::ComputeType::kComputeCube;
  conv2d.ir_attr.SetStrides({1, 1});
  conv2d.ir_attr.SetPads({1, 1, 1, 1});
  conv2d.ir_attr.SetDilations({1, 1});
  conv2d.ir_attr.SetGroups(1);
  conv2d.ir_attr.SetData_format("NCHW");
  conv2d.ir_attr.SetOffset_x(0);
  conv2d.ir_attr.SetEnable_hf32(false);

  af::ascir_op::Store store_op("store");
  store_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  store_op.x = conv2d.y;
  *store_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  store_op.y.dtype = ge::DT_FLOAT;
  *store_op.y.strides = {c*h*w, h*w, w, af::ops::One};
  *store_op.y.repeats = {n, c, h, w};

  af::ascir_op::Output output_op("output");
  output_op.x = store_op.y;
  output_op.y.dtype = ge::DT_FLOAT;
  output_op.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
}

// ==================== Conv2D + Elemwise 融合测试用例 ====================

TEST_F(TestCodegenTiling, TestConv2dElemwiseFuse) {
  af::AscGraph graph("conv2d_elemwise_pro");
  CreateElemwiseGraphWithAbsAndAddStatic(graph);

  af::AscGraph conv2d_graph("conv2d");
  CreateConv2dGraph(conv2d_graph, false);
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ascir::FusedScheduledResult fused_schedule_result;
  EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(conv2d_graph);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
  
  const std::map<std::string, std::string> shape_info;
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::fstream tiling_func("Conv2d_fuse_tiling_func.cpp", std::ios::out);
  tiling_func << res["tiling_def_and_tiling_const"];

  VerifyConv2dElemwiseTiling(res);
}

// ==================== Conv2D + Elemwise 动态 Shape 融合测试用例 ====================

TEST_F(TestCodegenTiling, TestConv2dElemwiseDynamicShapeFuse) {
  af::AscGraph graph("conv2d_elemwise_dynamic_pro");
  CreateElemwiseGraphWithReluDynamic(graph);

  af::AscGraph conv2d_graph("conv2d_dynamic");
  CreateConv2dGraph(conv2d_graph, true);
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ascir::FusedScheduledResult fused_schedule_result;
  EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(conv2d_graph);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
  
  std::map<std::string, std::string> shape_info;
  shape_info["n"] = "1";
  shape_info["c"] = "64";
  shape_info["h"] = "56";
  shape_info["w"] = "56";
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::fstream tiling_func("Conv2d_dynamic_fuse_tiling_func.cpp", std::ios::out);
  tiling_func << res["tiling_def_and_tiling_const"];

  VerifyDynamicShapeTiling(res);
  
  auto n_val_pos = res["tiling_def_and_tiling_const"].find("auto n = 1;");
  ASSERT_NE(n_val_pos, std::string::npos);
  
  auto c_val_pos = res["tiling_def_and_tiling_const"].find("auto c = 64;");
  ASSERT_NE(c_val_pos, std::string::npos);
  
  auto h_val_pos = res["tiling_def_and_tiling_const"].find("auto h = 56;");
  ASSERT_NE(h_val_pos, std::string::npos);
  
  auto w_val_pos = res["tiling_def_and_tiling_const"].find("auto w = 56;");
  ASSERT_NE(w_val_pos, std::string::npos);
  
  auto dfx_pos = res["tiling_def_and_tiling_const"].find("extern \"C\" ge::graphStatus DfxInputSymbolInfo");
  ASSERT_NE(dfx_pos, std::string::npos);
}

// ==================== Conv2DBias 辅助函数 ====================

void CreateConv2dBiasGraph(af::AscGraph &graph, bool is_dynamic = false) {
  af::Expression n, c, h, w;
  if (is_dynamic) {
    n = graph.CreateSizeVar("n");
    c = graph.CreateSizeVar("c");
    h = graph.CreateSizeVar("h");
    w = graph.CreateSizeVar("w");
  } else {
    n = graph.CreateSizeVar(1);
    c = graph.CreateSizeVar(64);
    h = graph.CreateSizeVar(56);
    w = graph.CreateSizeVar(56);
  }
  
  auto z_n = graph.CreateAxis("z_n", n);
  auto z_c = graph.CreateAxis("z_c", c);
  auto z_h = graph.CreateAxis("z_h", h);
  auto z_w = graph.CreateAxis("z_w", w);

  af::ascir_op::Data data0("data0", graph);
  data0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data0.y.dtype = ge::DT_FLOAT16;
  *data0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0.y.strides = {c*h*w, h*w, w, af::ops::One};
  *data0.y.repeats = {n, c, h, w};
  data0.ir_attr.SetIndex(0);

  af::ascir_op::Load load0("load0");
  load0.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load0.x = data0.y;
  *load0.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load0.y.dtype = ge::DT_FLOAT16;
  *load0.y.strides = {c*h*w, h*w, w, af::ops::One};
  *load0.y.repeats = {n, c, h, w};

  af::ascir_op::Data data1("data1", graph);
  data1.y.dtype = ge::DT_FLOAT16;
  data1.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *data1.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data1.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};
  *data1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  data1.ir_attr.SetIndex(1);

  af::ascir_op::Load load1("load1");
  load1.x = data1.y;
  load1.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  load1.y.dtype = ge::DT_FLOAT16;
  *load1.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *load1.y.strides = {af::ops::Zero, af::ops::Zero, af::ops::Zero, af::ops::Zero};
  *load1.y.repeats = {af::ops::One, af::ops::One, af::ops::One, af::ops::One};

  af::ascir_op::Data data2("data2", graph);
  data2.y.dtype = ge::DT_FLOAT;
  data2.attr.sched.axis = {z_c.id};
  *data2.y.axis = {z_c.id};
  data2.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data2.y.repeats = {c};
  *data2.y.strides = {af::ops::One};
  data2.ir_attr.SetIndex(2);

  af::ascir_op::Load load2("load2");
  load2.x = data2.y;
  load2.attr.sched.axis = {z_c.id};
  load2.y.dtype = ge::DT_FLOAT;
  *load2.y.axis = {z_c.id};
  *load2.y.strides = {af::ops::One};
  *load2.y.repeats = {c};

  af::ascir_op::Conv2DBias conv2d_bias("conv2d_bias");
  conv2d_bias.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  conv2d_bias.x = load0.y;
  conv2d_bias.filter = load1.y;
  conv2d_bias.bias = load2.y;
  conv2d_bias.y.dtype = ge::DT_FLOAT;
  *conv2d_bias.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  *conv2d_bias.y.repeats = {n, c, h, w};
  *conv2d_bias.y.strides = {c*h*w, h*w, w, af::ops::One};
  conv2d_bias.attr.api.compute_type = af::ComputeType::kComputeCube;
  conv2d_bias.ir_attr.SetStrides({1, 1});
  conv2d_bias.ir_attr.SetPads({1, 1, 1, 1});
  conv2d_bias.ir_attr.SetDilations({1, 1});
  conv2d_bias.ir_attr.SetGroups(1);
  conv2d_bias.ir_attr.SetData_format("NCHW");
  conv2d_bias.ir_attr.SetOffset_x(0);
  conv2d_bias.ir_attr.SetEnable_hf32(false);

  af::ascir_op::Store store_op("store");
  store_op.attr.sched.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  store_op.x = conv2d_bias.y;
  *store_op.y.axis = {z_n.id, z_c.id, z_h.id, z_w.id};
  store_op.y.dtype = ge::DT_FLOAT;
  *store_op.y.strides = {c*h*w, h*w, w, af::ops::One};
  *store_op.y.repeats = {n, c, h, w};

  af::ascir_op::Output output_op("output");
  output_op.x = store_op.y;
  output_op.y.dtype = ge::DT_FLOAT;
  output_op.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
}

// ==================== Conv2DBias + Elemwise 静态 Shape 融合测试 ====================

TEST_F(TestCodegenTiling, TestConv2dBiasElemwiseFuse) {
  af::AscGraph graph("conv2d_bias_elemwise_pro");
  CreateElemwiseGraphWithRelu(graph);

  af::AscGraph conv2d_bias_graph("conv2d_bias");
  CreateConv2dBiasGraph(conv2d_bias_graph, false);
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ascir::FusedScheduledResult fused_schedule_result;
  EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(conv2d_bias_graph);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
  
  const std::map<std::string, std::string> shape_info;
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::fstream tiling_func("Conv2d_bias_fuse_tiling_func.cpp", std::ios::out);
  tiling_func << res["tiling_def_and_tiling_const"];

  VerifyConv2DBiasElemwiseTiling(res);
}

// ==================== Conv2DBias + Elemwise 动态 Shape 融合测试 ====================

TEST_F(TestCodegenTiling, TestConv2dBiasElemwiseDynamicShapeFuse) {
  af::AscGraph graph("conv2d_bias_elemwise_dynamic_pro");
  CreateElemwiseGraphWithMulDynamic(graph);

  af::AscGraph conv2d_bias_graph("conv2d_bias_dynamic");
  CreateConv2dBiasGraph(conv2d_bias_graph, true);
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ascir::FusedScheduledResult fused_schedule_result;
  EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(conv2d_bias_graph);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
  
  std::map<std::string, std::string> shape_info;
  shape_info["n"] = "1";
  shape_info["c"] = "64";
  shape_info["h"] = "56";
  shape_info["w"] = "56";
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::fstream tiling_func("Conv2d_bias_dynamic_fuse_tiling_func.cpp", std::ios::out);
  tiling_func << res["tiling_def_and_tiling_const"];

  VerifyDynamicShapeTiling(res);
  
  auto n_val_pos = res["tiling_def_and_tiling_const"].find("auto n = 1;");
  ASSERT_NE(n_val_pos, std::string::npos);
  
  auto c_val_pos = res["tiling_def_and_tiling_const"].find("auto c = 64;");
  ASSERT_NE(c_val_pos, std::string::npos);
  
  auto h_val_pos = res["tiling_def_and_tiling_const"].find("auto h = 56;");
  ASSERT_NE(h_val_pos, std::string::npos);
  
  auto w_val_pos = res["tiling_def_and_tiling_const"].find("auto w = 56;");
  ASSERT_NE(w_val_pos, std::string::npos);
}

// ==================== BatchMatmul Dynamic Shape 融合测试 ====================

void CreateBatchMatmulDynamicGraph(af::AscGraph &graph) {
  auto batch = graph.CreateSizeVar("batch");
  auto m = graph.CreateSizeVar("m");
  auto n = graph.CreateSizeVar("n");
  auto k = graph.CreateSizeVar("k");
  
  auto z_batch = graph.CreateAxis("z_batch", batch);
  auto z_m = graph.CreateAxis("z_m", m);
  auto z_n = graph.CreateAxis("z_n", n);
  auto z_k = graph.CreateAxis("z_k", k);

  af::ascir_op::Data data0("data0", graph);
  data0.attr.sched.axis = {z_batch.id, z_m.id, z_k.id};
  data0.y.dtype = ge::DT_FLOAT16;
  *data0.y.axis = {z_batch.id, z_m.id, z_k.id};
  data0.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data0.y.strides = {m*k, k, af::ops::One};
  *data0.y.repeats = {batch, m, k};
  data0.ir_attr.SetIndex(0);

  af::ascir_op::Load load0("load0");
  load0.attr.sched.axis = {z_batch.id, z_m.id, z_k.id};
  load0.x = data0.y;
  *load0.y.axis = {z_batch.id, z_m.id, z_k.id};
  load0.y.dtype = ge::DT_FLOAT16;
  *load0.y.strides = {m*k, k, af::ops::One};
  *load0.y.repeats = {batch, m, k};

  af::ascir_op::Data data1("data1", graph);
  data1.y.dtype = ge::DT_FLOAT16;
  data1.attr.sched.axis = {z_batch.id, z_k.id, z_n.id};
  *data1.y.axis = {z_batch.id, z_k.id, z_n.id};
  data1.attr.api.compute_type = af::ComputeType::kComputeInvalid;
  *data1.y.repeats = {batch, k, n};
  *data1.y.strides = {k*n, n, af::ops::One};
  data1.ir_attr.SetIndex(1);

  af::ascir_op::Load load1("load1");
  load1.x = data1.y;
  load1.attr.sched.axis = {z_batch.id, z_k.id, z_n.id};
  load1.y.dtype = ge::DT_FLOAT16;
  *load1.y.axis = {z_batch.id, z_k.id, z_n.id};
  *load1.y.strides = {k*n, n, af::ops::One};
  *load1.y.repeats = {batch, k, n};

  af::ascir_op::BatchMatMul batch_matmul("batch_matmul");
  batch_matmul.attr.sched.axis = {z_batch.id, z_m.id, z_n.id};
  batch_matmul.x1 = load0.y;
  batch_matmul.x2 = load1.y;
  batch_matmul.y.dtype = ge::DT_FLOAT;
  *batch_matmul.y.axis = {z_batch.id, z_m.id, z_n.id};
  *batch_matmul.y.repeats = {batch, m, n};
  *batch_matmul.y.strides = {m*n, n, af::ops::One};
  batch_matmul.attr.api.compute_type = af::ComputeType::kComputeCube;
  batch_matmul.ir_attr.SetAdj_x1(0);
  batch_matmul.ir_attr.SetAdj_x2(0);
  batch_matmul.ir_attr.SetHas_relu(1);
  batch_matmul.ir_attr.SetEnable_hf32(true);
  batch_matmul.ir_attr.SetOffset_x(6);

  af::ascir_op::Store store_op("store");
  store_op.attr.sched.axis = {z_batch.id, z_m.id, z_n.id};
  store_op.x = batch_matmul.y;
  *store_op.y.axis = {z_batch.id, z_m.id, z_n.id};
  store_op.y.dtype = ge::DT_FLOAT;
  *store_op.y.strides = {m*n, n, af::ops::One};
  *store_op.y.repeats = {batch, m, n};

  af::ascir_op::Output output_op("output");
  output_op.x = store_op.y;
  output_op.y.dtype = ge::DT_FLOAT;
  output_op.ir_attr.SetIndex(0);
  optimize::AscGraphInfoComplete::CompleteApiInfo(graph);
}

TEST_F(TestCodegenTiling, TestBatchMatmulDynamicShapeFuse) {
  af::AscGraph graph("batch_matmul_dynamic_pro");
  CreateBatchMatmulElemwiseDynamicGraph(graph);

  af::AscGraph mm_graph("batch_matmul_dynamic");
  CreateBatchMatmulDynamicGraph(mm_graph);
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ascir::FusedScheduledResult fused_schedule_result;
  EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(mm_graph);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
  
  std::map<std::string, std::string> shape_info;
  shape_info["batch"] = "16";
  shape_info["m"] = "64";
  shape_info["n"] = "64";
  shape_info["k"] = "64";
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::fstream tiling_func("Batch_matmul_dynamic_fuse_tiling_func.cpp", std::ios::out);
  tiling_func << res["tiling_def_and_tiling_const"];

  auto pos = res["tiling_def_and_tiling_const"].find("TilingResult result = wrapper.DoMatMulTiling(");
  ASSERT_NE(pos, std::string::npos);
  
  auto dynamic_pos = res["tiling_def_and_tiling_const"].find("AutofuseIsStaticShape() {\n  return false;");
  ASSERT_NE(dynamic_pos, std::string::npos);
}

// ==================== Conv2D with Groups 测试 ====================

TEST_F(TestCodegenTiling, TestConv2dWithGroups) {
  af::AscGraph graph("conv2d_groups_pro");
  CreateElemwiseGraphWithRelu(graph);

  af::AscGraph conv2d_graph("conv2d_groups");
  CreateConv2DGraphWithGroups(conv2d_graph);
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ascir::FusedScheduledResult fused_schedule_result;
  EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(conv2d_graph);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
  
  const std::map<std::string, std::string> shape_info;
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::fstream tiling_func("Conv2d_groups_tiling_func.cpp", std::ios::out);
  tiling_func << res["tiling_def_and_tiling_const"];

  VerifyTilingCodeBasic(res);
}

// ==================== Conv2D with Dilation 测试 ====================

TEST_F(TestCodegenTiling, TestConv2dWithDilation) {
  af::AscGraph graph("conv2d_dilation_pro");
  CreateElemwiseGraphWithRelu(graph);

  af::AscGraph conv2d_graph("conv2d_dilation");
  CreateConv2DGraphWithDilation(conv2d_graph);
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ascir::FusedScheduledResult fused_schedule_result;
  EXPECT_EQ(optimizer.Optimize(graph, fused_schedule_result), 0);

  ascir::ScheduleGroup schedule_group2;
  schedule_group2.impl_graphs.push_back(conv2d_graph);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].schedule_groups.push_back(schedule_group2);
  fused_schedule_result.node_idx_to_scheduled_results[0][0].cube_type = ascir::CubeTemplateType::kUBFuse;
  
  const std::map<std::string, std::string> shape_info;
  auto res = this->Generate(fused_schedule_result, shape_info, "", "0");

  std::fstream tiling_func("Conv2d_dilation_tiling_func.cpp", std::ios::out);
  tiling_func << res["tiling_def_and_tiling_const"];

  VerifyTilingCodeBasic(res);
}
// --- Task 1: spec 协议最小字段集红灯测试 ---

TEST_F(TestCodegenTiling, ProtocolHeaderShouldContainMinimalRequestResponse) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("struct GetTilingRequest"), std::string::npos);
  EXPECT_NE(tiling_impl.find("struct CandidateSolution"), std::string::npos);
  EXPECT_NE(tiling_impl.find("struct GetTilingResponse"), std::string::npos);
}

TEST_F(TestCodegenTiling, ProtocolRequestShouldContainMinimalFields) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("const std::vector<std::map<std::string, std::string>> *input_configs = nullptr;"),
            std::string::npos);
  EXPECT_NE(tiling_impl.find("ResLimit *res_limit = nullptr;"), std::string::npos);
  EXPECT_NE(tiling_impl.find("int64_t topn = 1;"), std::string::npos);
}

TEST_F(TestCodegenTiling, CandidateSolutionShouldContainOnlyMinimalFields) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("double modeled_perf = 0.0;"), std::string::npos);
  EXPECT_NE(tiling_impl.find("bool is_default = false;"), std::string::npos);
  EXPECT_NE(tiling_impl.find("std::string canonical_repr;"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("candidate.workspace ="), std::string::npos);
  EXPECT_EQ(tiling_impl.find("candidate.block_dim ="), std::string::npos);
}

TEST_F(TestCodegenTiling, ProtocolShouldNotContainBannedFields) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_EQ(tiling_impl.find("schedule_result_key"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("group_case_ids"), std::string::npos);
}

TEST_F(TestCodegenTiling, GetTilingShouldEnterMainSearchNotForKey) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("int64_t GetTopnCandidateSolutions(const GetTilingRequest &request, GetTilingResponse &response)"),
            std::string::npos);
  // GetTiling must enter ATT main search skeleton via PGOSearchTilingKey (which internally calls
  // SearchAllTilingbyCaseId / ExecutePGOSolver). In codegen UT (no ATT callback), verify PGOSearchTilingKey
  // is used. In integrated environments (solver_func key present), also verify the internal functions.
  EXPECT_NE(tiling_impl.find("optiling::PGOSearchTilingKey("), std::string::npos);
  const std::string kSolverFunc = "solver_func";
  if (tiling_files.find(kSolverFunc) != tiling_files.end()) {
    std::string all_tiling_code = tiling_impl + tiling_files.at(kSolverFunc);
    EXPECT_NE(all_tiling_code.find("SearchAllTilingbyCaseId("), std::string::npos);
    EXPECT_NE(all_tiling_code.find("ExecutePGOSolver("), std::string::npos);
  }
  EXPECT_EQ(tiling_impl.find("for (int64_t key = 0; key < GetTilingKeyCount(); ++key)"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("if (AutofuseTiling(&default_tiling_data"), std::string::npos);
}

TEST_F(TestCodegenTiling, NoEarlyStopByTopnOrDefault) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_EQ(tiling_impl.find("if (request.topn == 1) { return"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("response.candidate_solutions.resize(topn)"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("partial_sort"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("current_candidate_num >= request.topn"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("if (found_default) break"), std::string::npos);
}

// Task 5: multi-group must carry workspace_map / block_dim_vec

TEST_F(TestCodegenTiling, MultiGroupDoesNotCarryWorkspaceMap) {
  auto fused_schedule_result = GenMultiGroupFusedScheduleResult();
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  // Multi-group GetTiling does not declare workspace_map (single-group only)
  EXPECT_EQ(tiling_impl.find("std::unordered_map<int64_t, uint64_t> workspace_map"), std::string::npos);
}

// Task 6: Bridge layer — raw-candidate to CandidateSolution mapping

TEST_F(TestCodegenTiling, BridgeMapsModeledPerfFromFinalComparablePerf) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("double final_modeled_perf ="), std::string::npos);
  EXPECT_NE(tiling_impl.find("solution.modeled_perf = final_modeled_perf;"), std::string::npos);
}

TEST_F(TestCodegenTiling, BridgePreservesSingleGroupComparablePerfSemantics) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("CandidateSolution solution;"), std::string::npos);
  EXPECT_NE(tiling_impl.find("solution.modeled_perf = final_modeled_perf;"), std::string::npos);
  EXPECT_NE(tiling_impl.find("std::isfinite(final_modeled_perf)"), std::string::npos);
}

TEST_F(TestCodegenTiling, BridgeDefaultFromOriginalPath) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  // Default must come from default config path, not arbitrary first candidate
  EXPECT_NE(tiling_impl.find("is_default = is_default_config_request && !default_repr.empty()"),
            std::string::npos);
  // default_repr is set from GetTiling(default_tiling, -1) on default-config path only
  EXPECT_NE(tiling_impl.find("default_repr = GetTilingDataRepr(&default_tiling)"), std::string::npos);
  // No unreachable branch: is_default already requires !default_repr.empty()
  EXPECT_EQ(tiling_impl.find("is_default && default_repr.empty()"), std::string::npos);
}

TEST_F(TestCodegenTiling, BridgeDoesNotWriteWorkspaceOrBlockDim) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  // Bridge layer must not write workspace into protocol object
  EXPECT_EQ(tiling_impl.find("solution.workspace ="), std::string::npos);
  EXPECT_EQ(tiling_impl.find("solution.block_dim ="), std::string::npos);
}

// Task 7: Selector + ABI backfill

TEST_F(TestCodegenTiling, WrapperUsesSelectorForTopn) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  // Wrapper must call SelectTopnCandidateSolutions
  EXPECT_NE(tiling_impl.find("SelectTopnCandidateSolutions(response.candidate_solutions, topn)"), std::string::npos);
}

TEST_F(TestCodegenTiling, WrapperBackfillsWorkspaceAndBlockDim) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  // Workspace must be dynamically computed from tiling_data
  EXPECT_NE(tiling_impl.find("GetWorkspaceSize(sol.tiling_data)"), std::string::npos);
  // Block dim must be dynamically extracted from tiling_data
  EXPECT_NE(tiling_impl.find("sol.tiling_data.get_block_dim()"), std::string::npos);
}

// Task 2: Config truth table — request construction & config semantics

TEST_F(TestCodegenTiling, TopnWrapperMapsEmptyConfigsToInternalNoConfigPath) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("if (input_configs.empty()) {"), std::string::npos);
  EXPECT_NE(tiling_impl.find("request.input_configs = nullptr;"), std::string::npos);
  EXPECT_NE(tiling_impl.find("request.input_configs = &input_configs;"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("normalized_configs"), std::string::npos);
}

TEST_F(TestCodegenTiling, TopnWrapperConstructsRequestAndInvokesGetTiling) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  // Wrapper should construct request/response objects and invoke GetTiling with them.
  EXPECT_NE(tiling_impl.find("GetTilingRequest request;"), std::string::npos);
  EXPECT_NE(tiling_impl.find("GetTilingResponse response;"), std::string::npos);
  EXPECT_NE(tiling_impl.find("GetTopnCandidateSolutions(request, response)"), std::string::npos);
}

TEST_F(TestCodegenTiling, GetTilingDefaultConfigDetectionIncludesInternalPath) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("const bool internal_no_config_path = (request.input_configs == nullptr);"),
            std::string::npos);
  EXPECT_NE(tiling_impl.find("const bool is_default_config_request = internal_no_config_path || "),
            std::string::npos);
  EXPECT_NE(tiling_impl.find("(request.input_configs != nullptr && request.input_configs->size() == 1 && "),
            std::string::npos);
}

TEST_F(TestCodegenTiling, GetTilingIteratesConfigsInOrder) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  // Multi-config must iterate in input order
  EXPECT_NE(tiling_impl.find("for (const auto &cfg : configs)"), std::string::npos);
  // Each config feeds PGOSearchTilingKey within the loop
  EXPECT_NE(tiling_impl.find("PGOSearchTilingKey("), std::string::npos);
}

TEST_F(TestCodegenTiling, GetTilingInternalPathOnlyForNullptr) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  // nullptr comparison exists in GetTiling for internal path detection
  EXPECT_NE(tiling_impl.find("request.input_configs == nullptr"), std::string::npos);
  // Internal path pushes default SearchConfig directly (not via ParseSearchConfigs)
  EXPECT_NE(tiling_impl.find("internal_no_config_path"), std::string::npos);
  // The pattern: "if (internal_no_config_path) { configs.push_back(SearchConfig()); }"
  // ensures nullptr path bypasses ParseSearchConfigs
  EXPECT_NE(tiling_impl.find("configs.push_back(SearchConfig())"), std::string::npos);
}

TEST_F(TestCodegenTiling, ParseSearchConfigsParsesExplicitConfigsOnly) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_EQ(tiling_impl.find("if (raws.empty()) {"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("return {SearchConfig()};"), std::string::npos);
}

TEST_F(TestCodegenTiling, DefaultFlagComesOnlyFromExplicitDefaultConfigRequest) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("solution.is_default = is_default_config_request"), std::string::npos);
  EXPECT_NE(tiling_impl.find("!default_repr.empty()"), std::string::npos);
  EXPECT_NE(tiling_impl.find("solution.canonical_repr == default_repr"), std::string::npos);
  EXPECT_EQ(tiling_impl.find("solution.is_default = internal_no_config_path"), std::string::npos);
}

TEST_F(TestCodegenTiling, ExplicitDefaultConfigRequestMustFailWithoutDefaultCandidate) {
  auto fused_schedule_result = this->GenBasicFusedScheduleResult({af::Symbol("s0"), af::Symbol("s1")});
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  EXPECT_NE(tiling_impl.find("bool found_default_candidate = false;"), std::string::npos);
  EXPECT_NE(tiling_impl.find("if (solution.is_default) { found_default_candidate = true; }"), std::string::npos);
  EXPECT_NE(tiling_impl.find("if (is_default_config_request && !found_default_candidate) { return -1; }"), std::string::npos);
}

// Task 1 Step 2: multi-group semantic annotations — must match TF PGO real implementation

// Verify multi-group generates graph-level tiling keys and perf aggregation via UpdateCurPerfAndBlockByGroup
TEST_F(TestCodegenTiling, MultiGroupUsesGraphLevelTilingKeysAndPerfAggregation) {
  auto fused_schedule_result = GenMultiGroupFusedScheduleResult();
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  // Multi-group must generate PGOSearchTilingKey with best_perf pointer (not ref)
  EXPECT_NE(tiling_impl.find("PGOSearchTilingKey"), std::string::npos);
  // Must have config iteration for SearchConfig
  EXPECT_NE(tiling_impl.find("for (const auto &cfg : configs)"), std::string::npos);
  // Multi-group repr must use graph-level tiling keys
  EXPECT_NE(tiling_impl.find("graph0_tiling_key"), std::string::npos);
  // Multi-group perf aggregation must use UpdateCurPerfAndBlockByGroup
  EXPECT_NE(tiling_impl.find("UpdateCurPerfAndBlockByGroup"), std::string::npos);
}

TEST_F(TestCodegenTiling, MultiGroupMustNotUseBasicPGOSearchTilingKeyOverload) {
  auto fused_schedule_result = GenMultiGroupFusedScheduleResult();
  auto tiling_files = this->GenerateForInductor(fused_schedule_result);
  ASSERT_TRUE(tiling_files.find(codegen::kTilingDefAndConstIdentify) != tiling_files.end());
  const auto &tiling_impl = tiling_files.at(codegen::kTilingDefAndConstIdentify);
  // Must not generate basic PGOSearchTilingKey that takes AutofuseTilingData & directly
  EXPECT_EQ(tiling_impl.find("PGOSearchTilingKey(raw_candidates, AutofuseTilingData &"), std::string::npos);
  // Must not generate basic PGOSearchTilingKey with search_tiling output at -1 case id
  EXPECT_EQ(tiling_impl.find("PGOSearchTilingKey(raw_candidates, search_tiling, -1, &search_tiling"),
            std::string::npos);
}