/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "gtest/gtest.h"

#include "ascendc_ir.h"
#include "ascendc_ir_def.h"
#include "ascir_ops.h"
#define private public
#include "optimize.h"
#include "autoschedule/autoschedule.h"
#undef private
#include "ascir_ops_utils.h"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "graph/compute_graph.h"
#include "graph/node.h"
#include "graph/utils/graph_utils.h"
#include "attr_utils.h"
#include "graph/debug/ge_op_types.h"
#include "autoschedule/axis_group.h"
#include "schedule_utils.h"
#include "fused_graph/fused_graph_unfolder.h"
#include "platform_context.h"
#include "platform/v1/platformv1.h"

#include "asc_graph_builder.h"

using namespace std;
using namespace af;
using namespace af::ops;
using namespace af::ascir_op;
using optimize::autoschedule::AxisGroup;
using af::testing::Sym;
using af::testing::AscGraphBuilder;

AscGraph Construct_Reduce_RARA(const std::string &name) {
  return AscGraphBuilder(name)
    .Loops({Sym("s0"), Sym("s1"), Sym("s2"), Sym("s3")})
    .Data("arg4_1", 0)
    .Load("b0_load", "arg4_1")
    .Abs("abs", "b0_load")
    .Max("b0_max", "abs", {0, 2})
    .Store("b3_store", "b0_max")
    .Output("buf3", "b3_store", 0)
    .Build();
}

AscGraph Construct_Reduce_ARAR(const std::string &name) {
  return AscGraphBuilder(name)
    .Loops({Sym("s0"), Sym("s1"), Sym("s2"), Sym("s3")})
    .Data("arg4_1", 0)
    .Load("b0_load", "arg4_1")
    .Abs("abs", "b0_load")
    .Max("b0_max", "abs", {1, 3})
    .Store("b3_store", "b0_max")
    .Output("buf3", "b3_store", 0)
    .Build();
}

void Construct_Reduce_RR(af::AscGraph &graph) {
  auto s0 = graph.CreateSizeVar(128);
  auto s1 = graph.CreateSizeVar(64);
  auto z0 = graph.CreateAxis("z0", s0);
  auto z1 = graph.CreateAxis("z1", s1);

  Data data("data", graph);
  data.y.dtype = af::DT_FLOAT;
  data.ir_attr.SetIndex(0);

  Load load("load");
  load.attr.sched.axis = {z0.id, z1.id};
  load.x = data.y;
  *load.y.axis = {z0.id, z1.id};
  load.y.dtype = af::DT_FLOAT;
  *load.y.strides = {s1, af::ops::One};
  *load.y.repeats = {s1, s0};

  Sum sum("sum");
  sum.attr.sched.axis = {z0.id, z1.id};
  sum.x = load.y;
  *sum.y.axis = {z0.id, z1.id};
  sum.y.dtype = af::DT_FLOAT;
  *sum.y.strides = {af::ops::One, af::ops::One};
  *sum.y.repeats = {af::ops::Zero, af::ops::Zero};

  Store store_op1("store1");
  store_op1.attr.sched.axis = {z0.id, z1.id};
  store_op1.x = sum.y;
  *store_op1.y.axis = {z0.id, z1.id};
  store_op1.y.dtype = af::DT_FLOAT;
  *store_op1.y.axis = {z0.id, z1.id};
  *store_op1.y.strides = {af::ops::One, af::ops::One};
  *store_op1.y.repeats = {af::ops::Zero, af::ops::Zero};

  Output output_op("output");
  output_op.x = store_op1.y;
  output_op.y.dtype = af::DT_FLOAT;
  output_op.ir_attr.SetIndex(0);
}

AscGraph Construct_Mul_Consumer_Struct(const std::string &name) {
  auto s0 = Sym("s0");
  auto s1 = Sym("s1");
  auto s2 = Sym("s2");
  auto s3 = Sym("s3");
  return AscGraphBuilder(name)
    .Loops({s0 * s1 * s2, s3})
    .Data("arg4_1", 0, {s0 * s1 * s2, s3}, {s3, af::ops::One}, af::DT_FLOAT16)
    .Load("b0_load", "arg4_1", {s0 * s1 * s2, s3}, {s3, af::ops::One})
    .Exp("b1_exp", "b0_load")
    .Abs("b0_abs", "b1_exp")
    .Max("b0_max", "b0_abs", {1})
    .Broadcast("b1_broadcast", "b0_max", {s0 * s1 * s2, s3})
    .Store("b0_store", "b1_broadcast")
    .Output("buf0", "b0_store", 1, af::DT_FLOAT)
    .template Op<af::ascir_op::Relu>("b0_relu", {"b1_exp"})
    .Store("b1_store", "b0_relu")
    .Output("buf1", "b1_store", 2, af::DT_FLOAT)
    .Build();
}

AscGraph ConstructNormStruct(const std::string &name) {
  return AscGraphBuilder(name)
    .Loops({Sym(128), Sym(64)})
    .Data("data", 0)
    .Load("load", "data")
    .Exp("exp", "load")
    .Sum("sum", "exp", {0, 1})
    .Broadcast("broadcast", "sum", {Sym(128), Sym(64)})
    .Sub("sub", "broadcast", "exp")
    .Store("store1", "sub")
    .Output("output", "store1", 0)
    .Build();
}

AscGraph ConstructNormStruct3Elewise(const std::string &name) {
  return AscGraphBuilder(name)
    .Loops({Sym(128), Sym(64)})
    .Data("data", 0)
    .Load("load", "data")
    .Sum("sum", "load", {0, 1})
    .Abs("abs", "sum")
    .Exp("exp", "abs")
    .Relu("b0_relu", "exp")
    .Store("store1", "b0_relu")
    .Output("output", "store1", 0)
    .Build();
}

AscGraph ConstructNormStruct1Elewise(const std::string &name) {
  return AscGraphBuilder(name)
    .Loops({Sym(128), Sym(64)})
    .Data("data", 0)
    .Load("load", "data")
    .Sum("sum", "load", {0, 1})
    .Abs("abs", "sum")
    .Store("store1", "abs")
    .Output("output", "store1", 0)
    .Build();
}

AscGraph ConstructNormStruct4Elewise(const std::string &name) {
  return AscGraphBuilder(name)
    .Loops({Sym(128), Sym(64)})
    .Data("data", 0)
    .Load("load", "data")
    .Sum("sum", "load", {0, 1})
    .Abs("abs", "sum")
    .Op<af::ascir_op::Tanh>("tanh", {"abs"})
    .Exp("exp", "tanh")
    .Relu("b0_relu", "exp")
    .Store("store1", "b0_relu")
    .Output("output", "store1", 0)
    .Build();
}

AscGraph ConstructNormStruct4Elewise4ReduceMultipleCitations(const std::string &name) {
  return AscGraphBuilder(name)
    .Loops({Sym(128), Sym(64)})
    .Data("data", 0)
    .Load("load", "data")
    .Sum("sum", "load", {0, 1})
    .Abs("abs", "sum")
    .Op<af::ascir_op::Tanh>("tanh", {"sum"})
    .Add("add", "abs", "tanh")
    .Relu("b0_relu", "add")
    .Store("store1", "b0_relu")
    .Output("output", "store1", 0)
    .Build();
}

AscGraph ConstructNormStruct4Elewise3ReduceMultipleCitations(const std::string &name) {
  return AscGraphBuilder(name)
    .Loops({Sym(128), Sym(64)})
    .Data("data", 0)
    .Load("load", "data")
    .Sum("sum", "load", {0, 1})
    .Abs("abs", "sum")
    .Op<af::ascir_op::Tanh>("tanh", {"sum"})
    .Add("add", "abs", "tanh")
    .Store("store1", "add")
    .Output("output", "store1", 0)
    .Build();
}

AscGraph ConstructNormStruct4Elewise3(const std::string &name) {
  return AscGraphBuilder(name)
    .Loops({Sym(128), Sym(64)})
    .Data("data0", 0)
    .Load("load0", "data0")
    .Data("data1", 1)
    .Load("load1", "data1")
    .Mul("mul", "load0", "load1")
    .Sum("sum", "mul", {0, 1})
    .Relu("b0_relu", "sum")
    .Op<af::ascir_op::Tanh>("tanh", {"b0_relu"})
    .Add("add", "tanh", "b0_relu")
    .Abs("abs", "add")
    .Store("store1", "abs")
    .Output("output", "store1", 0)
    .Build();
}

void Construct_Reduce_Cast_RR(af::AscGraph &graph) {
  auto s0 = graph.CreateSizeVar(128);
  auto s1 = graph.CreateSizeVar(64);
  auto z0 = graph.CreateAxis("z0", s0);
  auto z1 = graph.CreateAxis("z1", s1);

  Data data("data", graph);
  data.y.dtype = af::DT_FLOAT;
  data.ir_attr.SetIndex(0);

  Load load("load");
  load.attr.sched.axis = {z0.id, z1.id};
  load.x = data.y;
  *load.y.axis = {z0.id, z1.id};
  load.y.dtype = af::DT_FLOAT;
  *load.y.strides = {s1, af::ops::One};
  *load.y.repeats = {s1, s0};

  Sum sum("sum");
  sum.attr.sched.axis = {z0.id, z1.id};
  sum.x = load.y;
  *sum.y.axis = {z0.id, z1.id};
  sum.y.dtype = af::DT_FLOAT;
  *sum.y.strides = {af::ops::One, af::ops::One};
  *sum.y.repeats = {af::ops::Zero, af::ops::Zero};

  Cast cast("cast");
  cast.attr.sched.axis = {z0.id, z1.id};
  cast.x = sum.y;
  *cast.y.axis = {z0.id, z1.id};
  cast.y.dtype = af::DT_FLOAT16;
  *cast.y.axis = {z0.id, z1.id};
  *cast.y.strides = {af::ops::One, af::ops::One};
  *cast.y.repeats = {af::ops::Zero, af::ops::Zero};

  Store store_op1("store1");
  store_op1.attr.sched.axis = {z0.id, z1.id};
  store_op1.x = cast.y;
  *store_op1.y.axis = {z0.id, z1.id};
  store_op1.y.dtype = af::DT_FLOAT16;
  *store_op1.y.axis = {z0.id, z1.id};
  *store_op1.y.strides = {af::ops::One, af::ops::One};
  *store_op1.y.repeats = {af::ops::Zero, af::ops::Zero};

  Output output_op("output");
  output_op.x = store_op1.y;
  output_op.y.dtype = af::DT_FLOAT16;
  output_op.ir_attr.SetIndex(0);
}

AscGraph ConstructNormStruct4MulReduce(const std::string &name) {
  auto s0 = Sym(256);
  auto s1 = Sym(39);
  auto s2 = Sym(80);
  return AscGraphBuilder(name)
    .Loops({s0, s1, s2})
    .Data("data0", 0, {s0, s1, af::ops::One}, {s1, af::ops::One, af::ops::Zero}, af::DT_FLOAT)
    .Load("load0", "data0", {s0, s1, af::ops::One}, {s1, af::ops::One, af::ops::Zero})
    .Broadcast("brc", "load0", {s0, s1, s2})
    .Data("data1", 1, {s0, s1, s2}, {s1 * s2, s2, af::ops::One}, af::DT_FLOAT)
    .Load("load1", "data1")
    .Mul("mul", "brc", "load1")
    .Store("store1", "mul")
    .Output("output", "store1", 0, af::DT_FLOAT)
    .Sum("sum", "mul", {1})
    .Store("store2", "sum")
    .Output("output2", "store2", 1, af::DT_FLOAT)
    .Mul("mul1", "mul", "mul")
    .Sum("sum1", "mul1", {1})
    .Store("store3", "sum1")
    .Output("output3", "store3", 2, af::DT_FLOAT)
    .Build();
}

AscGraph ConstructNormStruct3ElemwiseReducePostMulInput(const std::string &name) {
  return AscGraphBuilder(name)
    .Loops({Sym(128), Sym(64)})
    .Data("data0", 0)
    .Load("load0", "data0")
    .Data("data1", 1)
    .Load("load1", "data1")
    .Mul("mul", "load0", "load1")
    .Sum("sum", "mul", {0, 1})
    .Relu("b0_relu", "sum")
    .Op<af::ascir_op::Tanh>("tanh", {"b0_relu"})
    .Add("add", "tanh", "b0_relu")
    .Store("store1", "add")
    .Output("output", "store1", 0)
    .Build();
}

AscGraph ConstructNormStruct3ElemwiseReducePostMulInputV2(const std::string &name) {
  return AscGraphBuilder(name)
    .Loops({Sym(128), Sym(64)})
    .Data("data0", 0)
    .Load("load0", "data0")
    .Data("data1", 1)
    .Load("load1", "data1")
    .Mul("mul", "load0", "load1")
    .Sum("sum", "mul", {0, 1})
    .Relu("b0_relu", "sum")
    .Op<af::ascir_op::Tanh>("tanh", {"b0_relu"})
    .Add("add", "tanh", "mul")
    .Store("store1", "add")
    .Output("output", "store1", 0)
    .Build();
}

AscGraph ConstructNormStruct4Elewise4ReduceMultipleCitationsMulOut(const std::string &name) {
  return AscGraphBuilder(name)
    .Loops({Sym(128), Sym(64)})
    .Data("data", 0)
    .Load("load", "data")
    .Sum("sum", "load", {0, 1})
    .Abs("abs", "sum")
    .Op<af::ascir_op::Tanh>("tanh", {"sum"})
    .Add("add", "sum", "tanh")
    .Relu("b0_relu", "sum")
    .Store("store1", "b0_relu")
    .Output("output", "store1", 0)
    .Store("store2", "add")
    .Output("output1", "store2", 1)
    .Store("store3", "abs")
    .Output("output2", "store3", 2)
    .Build();
}

namespace optimize {
class OptimizerReduceSt : public ::testing::Test {
 protected:
  void SetUp() override {
    dlog_setlevel(ASCGEN_MODULE_NAME, DLOG_ERROR, 0);
  }
  void TearDown() override {
    dlog_setlevel(ASCGEN_MODULE_NAME, DLOG_ERROR, 0);
  }

  optimize::Optimizer optimizer;

  OptimizerReduceSt() : optimizer(optimize::OptimizerOptions{}) {}

  static std::string ExpressToStr(std::vector<af::Expression> exprs) {
    std::stringstream ss;
    for (auto &size_expr : exprs) {
      ss << std::string(size_expr.Str().get()) << ", ";
    }
    return ss.str();
  }

  static std::string RepeatsToStr(const af::AscGraph &graph, const char *node_name) {
    auto node = graph.FindNode(node_name);
    if (node == nullptr) {
      return "";
    }
    return ExpressToStr(node->outputs[0].attr.repeats);
  }

  static std::string StridesToStr(const af::AscGraph &graph, const char *node_name) {
    auto node = graph.FindNode(node_name);
    if (node == nullptr) {
      return "";
    }
    return ExpressToStr(node->outputs[0].attr.strides);
  }

  static std::string AxisToStr(af::AscGraph &graph, const char *node_name) {
    auto node = graph.FindNode(node_name);
    if (node == nullptr) {
      return "";
    }
    std::stringstream ss;
    for (auto axis_id : node->outputs[0].attr.axis) {
      ss << graph.FindAxis(axis_id)->name << ", ";
    }
    return ss.str();
  }
};

TEST_F(OptimizerReduceSt, TestReduce_RARA) {
  auto graph = Construct_Reduce_RARA("REDUCE_RARA");
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ::ascir::FusedScheduledResult fused_scheduled_result;
  Status res = optimizer.Optimize(graph, fused_scheduled_result);
  EXPECT_EQ(res, af::SUCCESS);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 5UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 4UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups.size(), 2UL);
}

TEST_F(OptimizerReduceSt, TestReduce_MUL_CONSUMER) {
  auto graph = Construct_Mul_Consumer_Struct("REDUCE_MUL_CONSUMER");
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ::ascir::FusedScheduledResult fused_scheduled_result;
  Status res = optimizer.Optimize(graph, fused_scheduled_result);
  EXPECT_EQ(res, af::SUCCESS);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 2UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups.size(), 3UL);
}

TEST_F(OptimizerReduceSt, TestReduce_ARAR) {
  auto graph = Construct_Reduce_ARAR("REDUCE_ARAR");
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ::ascir::FusedScheduledResult fused_scheduled_result;
  Status res = optimizer.Optimize(graph, fused_scheduled_result);
  EXPECT_EQ(res, af::SUCCESS);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 5UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][0].schedule_groups[0].impl_graphs.size(), 4UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0][1].schedule_groups.size(), 2UL);
}

TEST_F(OptimizerReduceSt, TestReduce_RR) {
  af::AscGraph graph("REDUCE_RR");
  Construct_Reduce_RR(graph);
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ::ascir::FusedScheduledResult fused_scheduled_result;
  Status res = optimizer.Optimize(graph, fused_scheduled_result);
  EXPECT_EQ(res, af::SUCCESS);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3UL);
}

TEST_F(OptimizerReduceSt, TestReduce_Cast_RR) {
  af::AscGraph graph("REDUCE_Cast_RR");
  Construct_Reduce_Cast_RR(graph);
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  ::ascir::FusedScheduledResult fused_scheduled_result;
  Status res = optimizer.Optimize(graph, fused_scheduled_result);
  EXPECT_EQ(res, af::SUCCESS);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0].size(), 3UL);
}

TEST_F(OptimizerReduceSt, TestReduce_PatitionNorm) {
  auto graph = ConstructNormStruct("reduce_patition_norm");
  ::ascir::FusedScheduledResult fused_scheduled_result;
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  optimizer.Optimize(graph, fused_scheduled_result);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups.size(), 2UL);
}

TEST_F(OptimizerReduceSt, TestReduce_Three_Elewise_Store) {
  auto graph = ConstructNormStruct3Elewise("reduce_three_elewise_store");
  ::ascir::FusedScheduledResult fused_scheduled_result;
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  optimizer.Optimize(graph, fused_scheduled_result);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL].size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][1UL].schedule_groups.size(), 2UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][2UL].schedule_groups.size(), 1UL);
}

TEST_F(OptimizerReduceSt, TestReduce_One_Elewise_Store) {
  auto graph = ConstructNormStruct1Elewise("reduce_one_elewise_store");
  ::ascir::FusedScheduledResult fused_scheduled_result;
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  optimizer.Optimize(graph, fused_scheduled_result);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][1UL].schedule_groups.size(), 2UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][2UL].schedule_groups.size(), 1UL);
}

TEST_F(OptimizerReduceSt, TestReduce_Four_Elewise_Store) {
  auto graph = ConstructNormStruct4Elewise("reduce_four_elewise_store");
  ::ascir::FusedScheduledResult fused_scheduled_result;
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  optimizer.Optimize(graph, fused_scheduled_result);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL].size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups.size(), 2UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][1UL].schedule_groups.size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][2UL].schedule_groups.size(), 1UL);
}

TEST_F(OptimizerReduceSt, TestReduce_Four_Elewise_Store_V2) {
  auto graph = ConstructNormStruct4Elewise4ReduceMultipleCitations("reduce_four_elewise_store");
  ::ascir::FusedScheduledResult fused_scheduled_result;
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  optimizer.Optimize(graph, fused_scheduled_result);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL].size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups.size(), 2UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][1UL].schedule_groups.size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][2UL].schedule_groups.size(), 1UL);
}

TEST_F(OptimizerReduceSt, TestReduce_Three_Elewise_Store_Multi_Citation) {
  auto graph = ConstructNormStruct4Elewise3ReduceMultipleCitations("reduce_three_elewise_store");
  ::ascir::FusedScheduledResult fused_scheduled_result;
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  optimizer.Optimize(graph, fused_scheduled_result);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL].size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups.size(), 2UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][1UL].schedule_groups.size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][2UL].schedule_groups.size(), 1UL);
}

TEST_F(OptimizerReduceSt, TestReduce_Four_Elewise_Store_V3) {
  auto graph = ConstructNormStruct4Elewise3("reduce_four_elewise_store");
  ::ascir::FusedScheduledResult fused_scheduled_result;
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  optimizer.Optimize(graph, fused_scheduled_result);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL].size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups.size(), 2UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][1UL].schedule_groups.size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][2UL].schedule_groups.size(), 1UL);
}

TEST_F(OptimizerReduceSt, TestReduce_Elewise_Store_MulReduce) {
  auto graph = ConstructNormStruct4MulReduce("reduce_four_elewise_store");
  ::ascir::FusedScheduledResult fused_scheduled_result;
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  EXPECT_EQ(optimizer.Optimize(graph, fused_scheduled_result), SUCCESS);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL].size(), 6UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups.size(), 2UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][1UL].schedule_groups.size(), 4UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][2UL].schedule_groups.size(), 4UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][3UL].schedule_groups.size(), 4UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][4UL].schedule_groups.size(), 4UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][5UL].schedule_groups.size(), 1UL);
}

TEST_F(OptimizerReduceSt, TestReduce_Three_Elewise_Reduce_Post_Node_Multi_Input_V1) {
  auto graph = ConstructNormStruct3ElemwiseReducePostMulInput("reduce_three_elewise_store");
  ::ascir::FusedScheduledResult fused_scheduled_result;
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  optimizer.Optimize(graph, fused_scheduled_result);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL].size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][1UL].schedule_groups.size(), 2UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][2UL].schedule_groups.size(), 1UL);
}

TEST_F(OptimizerReduceSt, TestReduce_Three_Elewise_Reduce_Post_Node_Multi_Input_V2) {
  auto graph = ConstructNormStruct3ElemwiseReducePostMulInputV2("reduce_three_elewise_store");
  ::ascir::FusedScheduledResult fused_scheduled_result;
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  optimizer.Optimize(graph, fused_scheduled_result);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL].size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups.size(), 2UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][1UL].schedule_groups.size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][2UL].schedule_groups.size(), 1UL);
}

TEST_F(OptimizerReduceSt, TestReduce_Three_Elewise_Store_Multi_Citation_Multi_Out) {
  auto graph = ConstructNormStruct4Elewise4ReduceMultipleCitationsMulOut("reduce_three_elewise_store");
  ::ascir::FusedScheduledResult fused_scheduled_result;
  optimize::Optimizer optimizer(optimize::OptimizerOptions{});
  optimizer.Optimize(graph, fused_scheduled_result);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results.size(), 1UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL].size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups.size(), 2UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][1UL].schedule_groups.size(), 3UL);
  ASSERT_EQ(fused_scheduled_result.node_idx_to_scheduled_results[0UL][2UL].schedule_groups.size(), 1UL);
}
}  // namespace optimize