/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "schedule_group_partitioner.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "ascendc_ir.h"
#include "ascir_ops.h"
#include "schedule_utils.h"
#include "util/mem_utils.h"

using namespace af;

namespace optimize {
Status ScheduleGroupGraphPartitioner::PartitionByConnectivity(const ::ascir::ImplGraph &optimize_graph,
                                                              std::vector<::ascir::ImplGraph> &sub_optimize_graphs,
                                                              std::vector<AscNodePtr> node_order) {
  std::map<std::string, ::ascir::NodeView> name_to_output_nodes;
  size_t num_nodes = 0;
  for (const auto &node : optimize_graph.GetAllNodes()) {
    GE_ASSERT_NOTNULL(node);
    if (node->GetOutDataNodes().empty()) {
      name_to_output_nodes.emplace(node->GetName(), node);
      GELOGI("Output node: %s[%s]", node->GetName().c_str(), node->GetType().c_str());
    }
    ++num_nodes;
  }
  if (node_order.empty()) {
    for (const auto &name_and_node : name_to_output_nodes) {
      node_order.emplace_back(name_and_node.second);
    }
  }
  int32_t index = 0;
  std::set<af::NodePtr> visited;
  for (const auto &node : node_order) {
    const auto &root_node = node;
    if (visited.find(root_node) == visited.cend()) {
      ::ascir::ImplGraph sub_graph((optimize_graph.GetName() + "_" + std::to_string(index)).c_str());
      GE_ASSERT_TRUE(sub_graph.CopyAttrFrom(optimize_graph));
      GE_CHK_STATUS_RET(AddConnectedNodes(root_node, sub_graph, visited),
                        "Failed to add connected nodes, root_node = %s", root_node->GetNamePtr());
      GE_ASSERT_GRAPH_SUCCESS(ScheduleUtils::TopologicalSorting(sub_graph));
      sub_optimize_graphs.emplace_back(sub_graph);
      ++index;
    }
  }
  if (sub_optimize_graphs.size() > 1U) {  // 切图后会产生冗余轴信息
    for (auto &sub_optimize_graph : sub_optimize_graphs) {
      GE_CHK_STATUS_RET(ScheduleUtils::RemoveUnusedAxes(sub_optimize_graph), "Failed to remove unused axes");
    }
  }
  if (visited.size() != num_nodes) {
    for (const auto &node : optimize_graph.GetAllNodes()) {
      if (visited.find(node) == visited.cend()) {
        GELOGE(ge::FAILED, "node: %s[%s] not visited", node->GetName().c_str(), node->GetType().c_str());
      }
    }
    GELOGE(ge::FAILED, "number of visited nodes = %zu, num_nodes = %zu", visited.size(), num_nodes);
    return ge::FAILED;
  }

  GELOGI("Partition success, subgraph number = %zu", sub_optimize_graphs.size());
  return ge::SUCCESS;
}

Status ScheduleGroupGraphPartitioner::NeedRefreshAxisSize(const ::ascir::ImplGraph &optimize_graph,
                                                          bool &need_refresh) {
  const auto concat_node = ScheduleUtils::FindFirstNodeOfType<af::ascir_op::Concat>(optimize_graph);
  GE_CHK_BOOL_RET_SPECIAL_STATUS(concat_node == nullptr, ge::SUCCESS, "no Concat node was found");
  need_refresh = true;
  return ge::SUCCESS;
}

Status ScheduleGroupGraphPartitioner::CollectConnectedNodes(const af::AscNodePtr &root_node,
                                                            std::set<af::NodePtr> &visited,
                                                            std::vector<af::AscNodePtr> &asc_nodes) {
  std::list<af::NodePtr> next_nodes{root_node};
  visited.emplace(root_node);
  while (!next_nodes.empty()) {
    auto node = next_nodes.front();
    next_nodes.pop_front();
    auto asc_node = std::dynamic_pointer_cast<af::AscNode>(node);
    GE_CHECK_NOTNULL(asc_node);
    asc_nodes.emplace_back(std::move(asc_node));
    for (auto &in_node : node->GetInNodes()) {
      if (visited.find(in_node) == visited.cend()) {
        next_nodes.emplace_back(in_node);
        visited.emplace(in_node);
      }
    }
    for (auto &out_node : node->GetOutNodes()) {
      if (visited.find(out_node) == visited.cend()) {
        next_nodes.emplace_back(out_node);
        visited.emplace(out_node);
      }
    }
  }
  return ge::SUCCESS;
}

Status ScheduleGroupGraphPartitioner::AddConnectedNodes(const af::AscNodePtr &root_node, ::ascir::ImplGraph &sub_graph,
                                                        std::set<af::NodePtr> &all_visited) {
  GELOGI("AddConnectedNodes in, root_node = %s[%s]", root_node->GetName().c_str(), root_node->GetType().c_str());
  std::unordered_map<std::string, af::NodePtr> all_new_nodes;
  std::set<af::NodePtr> visited;
  std::vector<af::AscNodePtr> asc_nodes;
  GE_ASSERT_SUCCESS(CollectConnectedNodes(root_node, visited, asc_nodes));
  std::sort(asc_nodes.begin(), asc_nodes.end(), CompareByNodeId);
  for (const auto &asc_node : asc_nodes) {
    const auto &op_desc = af::GraphUtils::CopyOpDesc(asc_node->GetOpDesc(), nullptr);
    GE_CHECK_NOTNULL(op_desc);
    op_desc->SetName(asc_node->GetName());
    af::Operator op = af::OpDescUtils::CreateOperatorFromOpDesc(op_desc);
    auto dst_new_node = sub_graph.AddNode(op);
    all_new_nodes[dst_new_node->GetName()] = dst_new_node;
    GE_ASSERT_TRUE(AscGraph::CopyAscNodeTensorAttr(asc_node, dst_new_node),
                   "DoCopyAscNodeTensorAttr failed, node = %s[%s]", asc_node->GetNamePtr(), asc_node->GetTypePtr());
  }

  for (const auto &src_node : visited) {
    GE_CHK_STATUS_RET(af::GraphUtils::RelinkGraphEdges(src_node, "", all_new_nodes), "RelinkGraphEdges failed");
  }
  all_visited.insert(visited.cbegin(), visited.cend());
  return ge::SUCCESS;
}

bool ScheduleGroupGraphPartitioner::CompareByNodeId(const AscNodePtr &lhs, const AscNodePtr &rhs) {
  return lhs->GetOpDesc()->GetId() < rhs->GetOpDesc()->GetId();
}

Status ScheduleGroupGraphPartitioner::RecordAxisSizes(const std::vector<af::Expression> &repeats,
                                                      const std::vector<int64_t> &axis_ids,
                                                      std::map<af::AxisId, af::Expression> &axis_id_to_size) {
  for (size_t j = 0UL; j < repeats.size(); ++j) {
    if (af::SymbolicUtils::StaticCheckEq(repeats[j], ops::One) != af::TriBool::kTrue) {
      const auto axis_id = axis_ids[j];
      if (!axis_id_to_size.emplace(axis_id, repeats[j]).second) {
        GE_LOGW_IF(af::SymbolicUtils::StaticCheckEq(repeats[j], axis_id_to_size[axis_id]) != TriBool::kTrue,
                   "inconsistent axis size, id = %ld, size0 = %s, size1 = %s", axis_id,
                   SymbolicUtils::ToString(repeats[j]).c_str(),
                   SymbolicUtils::ToString(axis_id_to_size[axis_id]).c_str());
      }
    }
  }
  return SUCCESS;
}

Status ScheduleGroupGraphPartitioner::RefreshAxisSize(const ::ascir::ImplGraph &sub_graph) {
  GELOGI("RefreshAxisSize start, graph_name = %s", sub_graph.GetName().c_str());
  const auto graph_attr = af::AscGraphUtils::GetComputeGraph(sub_graph)->GetOrCreateAttrsGroup<af::AscGraphAttr>();
  GE_ASSERT_NOTNULL(graph_attr);
  std::map<af::AxisId, af::Expression> axis_id_to_size;
  for (const auto &node : sub_graph.GetAllNodes()) {
    if (ScheduleUtils::IsBuffer(node)) {
      continue;
    }
    for (uint32_t i = 0U; i < node->GetAllOutDataAnchorsSize(); ++i) {
      const auto &repeats = node->outputs[i].attr.repeats;
      const auto &axis_ids = node->outputs[i].attr.axis;
      GE_ASSERT_TRUE(repeats.size() == axis_ids.size(), "node: %s:%u repeats.size() = %zu, but axis_ids.size() = %zu",
                     node->GetNamePtr(), i, repeats.size(), axis_ids.size());
      GE_ASSERT_SUCCESS(RecordAxisSizes(repeats, axis_ids, axis_id_to_size));
    }
  }
  for (const auto &axis : graph_attr->axis) {
    GE_ASSERT_NOTNULL(axis);
    const auto it = axis_id_to_size.find(axis->id);
    if ((it != axis_id_to_size.cend()) && (SymbolicUtils::StaticCheckEq(axis->size, it->second) != TriBool::kTrue)) {
      GELOGI("update axis, id = %ld, name = %s, from %s to %s", axis->id, axis->name.c_str(),
             SymbolicUtils::ToString(axis->size).c_str(), SymbolicUtils::ToString(it->second).c_str());
      axis->size = it->second;
    }
  }
  GELOGI("RefreshAxisSize end, graph_name = %s", sub_graph.GetName().c_str());
  return SUCCESS;
}
}  // namespace optimize