/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <cstdint>

#include "ascendc_ir.h"
#include "graph/compute_graph.h"
#include "graph/node.h"
#include "graph/utils/graph_utils.h"
#include "graph/operator_factory.h"
#include "graph/utils/op_desc_utils.h"
#include "ascir_ops.h"
#include "attribute_group/attr_group_symbolic_desc.h"
#include "graph_dump_utils.h"
#include "fused_graph/fused_graph_modifier.h"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "graph_utils.h"
#include "ascir_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/debug/ge_op_types.h"
#include "ascgen_log.h"
#include "schedule_utils.h"
#include "fusion/autofuse_attrs.h"
#include "graph/symbolizer/symbolic.h"
#include "graph/expression/const_values.h"
#include "platform_context.h"
#include "platform/v1/platformv1.h"

namespace optimize {
using namespace af;
class FusedGraphModifierTest : public testing::Test {
 protected:
  void SetUp() override {
    dlog_setlevel(ASCGEN_MODULE_NAME, DLOG_ERROR, 0);
  }
  void TearDown() override {
    dlog_setlevel(ASCGEN_MODULE_NAME, DLOG_ERROR, 0);
  }
};

static void CreateAscBackendGraph(std::shared_ptr<af::AscGraph> &graph, const std::string &prefix, int64_t axis_num = 2) {
  auto ONE = af::Symbol(1);
  std::vector<int64_t> axis_ids;
  std::vector<af::Expression> repeats;
  for (int64_t i = 0; i < axis_num; ++i) {
    const af::Expression exp = graph->CreateSizeVar("s" + std::to_string(i));
    auto axis = graph->CreateAxis("z" + std::to_string(i), exp);
    axis_ids.push_back(i);
    repeats.push_back(exp);
  }

  std::vector<af::Expression> strides(repeats.size(), af::sym::kSymbolOne);
  if (axis_num > 1) {
    for (int64_t i = axis_num - 2; i >= 0; --i) {
      strides[i] = repeats[i + 1] * strides[i + 1];
    }
  }

  af::ascir_op::Data data(std::string(prefix + "_data").c_str(), *graph);
  data.attr.sched.axis = axis_ids;
  *data.y.axis = axis_ids;
  *data.y.repeats = repeats;
  *data.y.strides = strides;
  data.ir_attr.SetIndex(0);
  data.y.dtype = ge::DT_INT8;

  af::ascir_op::Load load(std::string(prefix + "_load").c_str());
  load.x = data.y;
  load.attr.sched.axis = axis_ids;
  *load.y.axis = axis_ids;
  *load.y.repeats = repeats;
  *load.y.strides = strides;

  af::ascir_op::Abs abs(std::string(prefix + "_abs").c_str());
  abs.x = load.y;
  abs.attr.sched.axis = axis_ids;
  *abs.y.axis = axis_ids;
  *abs.y.repeats = repeats;
  *abs.y.strides = strides;

  af::ascir_op::Store store(std::string(prefix + "_store").c_str());
  store.x = abs.y;
  store.attr.sched.axis = axis_ids;
  *store.y.axis = axis_ids;
  *store.y.repeats = repeats;
  *store.y.strides = strides;

  af::ascir_op::Output y(std::string(prefix + "_out").c_str());
  y.x = store.y;
  y.ir_attr.SetIndex(0);
  y.y.dtype = ge::DT_FLOAT16;
}

static NodePtr CreateAscbcToAscGraph(const std::string &name, ComputeGraphPtr &compute_graph, int64_t in_num = 1,
                                     int64_t out_num = 1) {
  OpDescBuilder op_desc_builder(name, "AscBackend");
  op_desc_builder.AddDynamicInput("x", in_num);
  op_desc_builder.AddDynamicOutput("y", out_num);
  const auto &op_desc = op_desc_builder.Build();
  auto node = compute_graph->AddNode(op_desc);
  node->SetOwnerComputeGraph(compute_graph);
  return node;
}

TEST_F(FusedGraphModifierTest, test_workspace_reuse) {
  std::shared_ptr<af::AscGraph> g0 = std::make_shared<af::AscGraph>("g0");
  CreateAscBackendGraph(g0, "g0", 2);
  std::shared_ptr<af::AscGraph> g1 = std::make_shared<af::AscGraph>("g1");
  CreateAscBackendGraph(g1, "g1", 1);
  std::shared_ptr<af::AscGraph> g2 = std::make_shared<af::AscGraph>("g2");
  CreateAscBackendGraph(g2, "g2", 2);
  std::shared_ptr<af::AscGraph> g3 = std::make_shared<af::AscGraph>("g3");
  CreateAscBackendGraph(g3, "g3", 1);

  af::AscGraph fused_asc_graph("fused_graph");

  af::ascir_op::Data data0("data0", fused_asc_graph);
  auto ir_attr = data0.attr.ir_attr->DownCastTo<af::AscDataIrAttrDef>();
  ir_attr->SetIndex(0);

  auto fused_graph = af::AscGraphUtils::GetComputeGraph(fused_asc_graph);
  auto data_node = fused_asc_graph.FindNode("data0");

  auto ascbc0 = CreateAscbcToAscGraph("ascbc0", fused_graph);
  auto ascbc1 = CreateAscbcToAscGraph("ascbc1", fused_graph);
  auto ascbc2 = CreateAscbcToAscGraph("ascbc2", fused_graph);
  auto ascbc3 = CreateAscbcToAscGraph("ascbc3", fused_graph);
  af::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), ascbc0->GetInDataAnchor(0));
  af::GraphUtils::AddEdge(ascbc0->GetOutDataAnchor(0), ascbc1->GetInDataAnchor(0));
  af::GraphUtils::AddEdge(ascbc1->GetOutDataAnchor(0), ascbc2->GetInDataAnchor(0));
  af::GraphUtils::AddEdge(ascbc2->GetOutDataAnchor(0), ascbc3->GetInDataAnchor(0));

  af::ascir_op::Output output("output");
  auto out_ir_attr = output.attr.ir_attr->DownCastTo<af::AscDataIrAttrDef>();
  out_ir_attr->SetIndex(0);
  auto out_desc = OpDescUtils::GetOpDescFromOperator(output);
  auto output_node = fused_graph->AddNode(out_desc);
  af::GraphUtils::AddEdge(ascbc3->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));

  FusedGraphModifier modifier;
  std::map<af::Node *, af::AscGraph> asc_backend_to_ascgraph;
  asc_backend_to_ascgraph.emplace(ascbc0.get(), *g0);
  asc_backend_to_ascgraph.emplace(ascbc1.get(), *g1);
  asc_backend_to_ascgraph.emplace(ascbc2.get(), *g2);
  asc_backend_to_ascgraph.emplace(ascbc3.get(), *g3);
  EXPECT_EQ(modifier.SubgraphConnectionsToWorkspace(fused_graph, asc_backend_to_ascgraph), ge::SUCCESS);

  auto ws0_g0 = g0->FindNode("fused_workspace0");
  EXPECT_NE(ws0_g0, nullptr);
  auto ws0_g1 = g1->FindNode("fused_workspace0");
  EXPECT_NE(ws0_g1, nullptr);
  auto ws1_g1 = g1->FindNode("fused_workspace1");
  EXPECT_NE(ws1_g1, nullptr);
  auto ws1_g2 = g2->FindNode("fused_workspace1");
  EXPECT_NE(ws1_g2, nullptr);
  auto ws0_g2 = g2->FindNode("fused_workspace0");  // reuse
  EXPECT_NE(ws0_g2, nullptr);
  auto ws0_g3 = g3->FindNode("fused_workspace0");
  EXPECT_NE(ws0_g3, nullptr);
}
}  // namespace optimize