/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include <queue>
#include "graph/utils/graph_utils.h"
#include "graph/symbolizer/symbolic_utils.h"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "ascir_ops.h"
#include "autoschedule/axis_group.h"
#include "ascir_utils.h"
#include "node_utils.h"
#include "reduce_schedule_case_generator.h"
#include "register/op_def_factory.h"
#include "base/err_msg.h"
#include "graph/symbolizer/symbolic.h"

namespace optimize {
namespace {
size_t TWO = 2;
size_t kMaxFullLoadAxisSize = 3UL;
size_t NODE_COUNT_AFTER_REDUCE = 4UL;
std::string GetNewNodeName(const af::AscNodePtr &src_node, const af::AscNodePtr &dst_node,
                           const std::string &type, int32_t idx) {
  return src_node->GetName() + "_to_" + dst_node->GetName() + "_" + type + "_" + to_string(idx);
}

Status DoCopyAscNodeTensorAttr(const af::AscNodePtr &src_node, af::AscNodePtr &dst_node) {
  auto op_desc = dst_node->GetOpDesc();
  auto dst_asc_node_attr = op_desc->GetOrCreateAttrsGroup<af::AscNodeAttr>();
  auto src_asc_node_attr = src_node->GetOpDesc()->GetOrCreateAttrsGroup<af::AscNodeAttr>();
  if (src_asc_node_attr != nullptr && dst_asc_node_attr != nullptr) {
    dst_asc_node_attr->sched = src_asc_node_attr->sched;
    if (src_asc_node_attr->ir_attr) {
      dst_asc_node_attr->ir_attr = src_asc_node_attr->ir_attr->Clone();
    }
  }
  for (size_t i = 0U; i < src_node->outputs().size(); i++) {
    GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(i));
    auto tensor_attr_group = op_desc->MutableOutputDesc(i)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
    GE_ASSERT_NOTNULL(tensor_attr_group);
    *tensor_attr_group = src_node->outputs[i].attr;
  }
  return ge::SUCCESS;
}

Status DoCopyWorkspaceTensorAttr(const af::AscNodePtr &src_node, af::AscNodePtr &workspace_node) {
  GE_ASSERT_NOTNULL(src_node);
  GE_ASSERT_NOTNULL(workspace_node);
  GE_ASSERT_TRUE(!src_node->outputs().empty());
  GE_ASSERT_TRUE(!workspace_node->outputs().empty());
  workspace_node->outputs[0].attr.dtype = src_node->outputs[0].attr.dtype;
  return ge::SUCCESS;
}

const std::unordered_map<std::string, std::function<ReduceType(const char*)>> reducers = {
  {"Max",  [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::Max>{}, n}; }},
  {"ArgMax",  [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::ArgMaxMultiRPhase2>{}, n}; }},  // ArgMax 二阶段使用 ArgMaxMultiRPhase2
  {"ArgMaxMultiRPhase1",  [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::ArgMaxMultiRPhase2>{}, n}; }},  // ArgMaxMultiRPhase1 二阶段使用 ArgMaxMultiRPhase2
  {"ArgMaxMultiRPhase2",  [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::ArgMaxMultiRPhase2>{}, n}; }},  // ArgMaxMultiRPhase2 二阶段使用 ArgMaxMultiRPhase2
  {"Sum",  [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::Sum>{}, n}; }},
  {"Mean", [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::Sum>{}, n}; }},  // Mean 二阶段使用 Sum
  {"Min",  [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::Min>{}, n}; }},
  {"Prod", [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::Prod>{}, n}; }},
  {"Any",  [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::Any>{}, n}; }},
  {"All",  [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::All>{}, n}; }}
};

bool IsNotPartitionReduce(const af::AscNodePtr &reduce_node, size_t threshold) {
  std::queue<af::NodePtr> node_queue;
  size_t node_count = 0UL;
  std::unordered_set<const af::Node*> visited;
  visited.insert(reduce_node.get());
  for (const auto &reduce_out_node : reduce_node->GetOutNodes()) {
    node_queue.emplace(reduce_out_node);
    visited.insert(reduce_out_node.get());
  }
  while (!node_queue.empty()) {
    const auto current_node = node_queue.front();
    node_queue.pop();

    if (current_node->GetInDataNodesSize() > 1UL) {
      for (const auto &in_current_node : current_node->GetInDataNodes()) {
        if (visited.find(in_current_node.get()) == visited.end()) {
          GELOGW("Node [%s] has multiple inputs, with input node [%s] not being a post-reduction node.",
                 current_node->GetNamePtr(), in_current_node->GetNamePtr());
          return false;
        }
      }
    }

    node_count += 1UL;
    if (node_count > threshold) {
      GELOGW("The total count of nodes after the reduce node[%s](including the store node) is above the threshold[%zu].",
             reduce_node->GetNamePtr(), threshold);
      return false;
    }
    const auto &asc_current_node = std::dynamic_pointer_cast<af::AscNode>(current_node);
    GE_ASSERT_NOTNULL(asc_current_node);
    if (af::ops::IsOps<af::ascir_op::Store>(asc_current_node) ||
        asc_current_node->attr.api.type == af::ApiType::kAPITypeBuffer) {
      continue;
    }
    
    if (!ScheduleUtils::IsElewise(asc_current_node)) {
      GELOGW("The node[%s] after the reduce node[%s] is not elewise type.",
             asc_current_node->GetNamePtr(), reduce_node->GetNamePtr());
      return false;
    }

    if (node_count > threshold - 1UL) {
      GELOGW("The count of nodes after the reduce node[%s] is above the threshold[%zu].",
             reduce_node->GetNamePtr(), threshold - 1UL);
      return false;
    }

    for (const auto &next_node : current_node->GetOutAllNodes()) {
      if (visited.find(next_node.get()) == visited.end()) {
        visited.insert(next_node.get());
        node_queue.emplace(next_node);
      }
    }
  }
  return true;
}
}

Status ReducePartitionCaseGenerator::GeneratorGeneralTask(ascir::HintGraph &optimize_graph,
                                                          std::vector<ScheduleTask> &tasks) {
  std::vector<ascir::ImplGraph> optimize_graphs;
  std::vector<std::string> score_funcs;
  GE_CHK_STATUS_RET(GenerateGeneralCase(optimize_graph, optimize_graphs, score_funcs), "GenerateScheduleCases failed");
  score_funcs.resize(optimize_graphs.size());
  for (size_t i = 0U; i < optimize_graphs.size(); ++i) {
    const auto &graph = optimize_graphs[i];
    ScheduleTask task{graph, {}, score_funcs[i], {}, ReduceTemplateType::kCommon};
    GE_CHK_STATUS_RET(ScheduleGroupGraphPartitioner::PartitionByConnectivity(graph, task.grouped_graphs, node_order_),
                      "Failed to partition graph");
    tasks.emplace_back(std::move(task));
  }
  return ge::GRAPH_SUCCESS;
}

Status ReducePartitionCaseGenerator::GeneratorAllLoadTask(ascir::HintGraph &optimize_graph,
                                                          std::vector<ScheduleTask> &tasks) {
  if (!CanReduceFuse(optimize_graph)) {
    return ge::GRAPH_SUCCESS;
  }
  std::vector<ascir::ImplGraph> optimize_graphs;
  std::vector<std::string> score_funcs;
  GE_CHK_STATUS_RET(GenerateAllLoadCase(optimize_graph, optimize_graphs, score_funcs), "GenerateScheduleCases failed");
  score_funcs.resize(optimize_graphs.size());
  for (size_t i = 0U; i < optimize_graphs.size(); ++i) {
    const auto &graph = optimize_graphs[i];
    ScheduleTask task{graph, {}, score_funcs[i], {}, ReduceTemplateType::kAllLoad};
    GE_CHK_STATUS_RET(ScheduleGroupGraphPartitioner::PartitionByConnectivity(graph, task.grouped_graphs, node_order_),
                      "Failed to partition graph");
    tasks.emplace_back(std::move(task));
  }
  return ge::GRAPH_SUCCESS;
}

Status ReducePartitionCaseGenerator::GeneratorRCoreTask(ascir::HintGraph &optimize_graph,
                                                        std::vector<ScheduleTask> &tasks) const {
  std::vector<ScheduleTask> new_tasks;
  for (const auto &task : tasks) {
    if (task.reduce_type != ReduceTemplateType::kCommon) {
      continue;
    }
    std::vector<::ascir::ImplGraph> new_task_grouped_graphs;
    std::map<size_t, std::vector<size_t>> map;
    size_t phase_2_graph_size = 0;
    for (size_t i = 0; i < task.grouped_graphs.size(); i++) {
      GE_ASSERT_TRUE(IsGroupGraphLegal(task.grouped_graphs[i]));
      if (!HasReduce(task.grouped_graphs[i])) {
        ::ascir::ImplGraph graph((task.grouped_graphs[i].GetName() + "_r_multicore").c_str());
        graph.CopyFrom(task.grouped_graphs[i]);
        new_task_grouped_graphs.emplace_back(std::move(graph));
        continue;
      }
      ascir::ImplGraph phase_graph((task.grouped_graphs[i].GetName() + "_r_multicore_phase_graph").c_str());
      phase_graph.CopyFrom(task.grouped_graphs[i]);
      af::AscNodePtr reduce_node;
      for (auto node : phase_graph.GetAllNodes()) {
        if (ScheduleUtils::IsReduce(node)) {
          reduce_node = node;
          break;
        }
      }
      GE_CHECK_NOTNULL(reduce_node);
      GE_ASSERT_TRUE(IsNotPartitionReduce(reduce_node, NODE_COUNT_AFTER_REDUCE));
      ascir::ImplGraph phase_1_graph((task.grouped_graphs[i].GetName() + "_r_multicore_phase_1_graph").c_str());
      ascir::ImplGraph phase_2_graph((task.grouped_graphs[i].GetName() + "_r_multicore_phase_2_graph").c_str());
      GE_CHK_STATUS_RET(RMulticorePhase2Graph(phase_2_graph, phase_1_graph, phase_graph, reduce_node).Construct());
      ascir::utils::DumpGraph(phase_2_graph, "phase2graph_construct");
      new_task_grouped_graphs.emplace_back(std::move(phase_1_graph));
      new_task_grouped_graphs.emplace_back(std::move(phase_2_graph));
      map[i + phase_2_graph_size] = {i + phase_2_graph_size + 1};
      phase_2_graph_size++;
    }
    if (phase_2_graph_size == 0) {
      continue;
    }
    ScheduleTask new_task{optimize_graph, new_task_grouped_graphs, task.score_func, map, ReduceTemplateType::kRCore};
    new_tasks.push_back(new_task);
  }
  tasks.insert(tasks.end(), new_tasks.begin(), new_tasks.end());
  return ge::GRAPH_SUCCESS;
}

Status ReducePartitionCaseGenerator::GeneratorTask(ascir::HintGraph &optimize_graph, std::vector<ScheduleTask> &tasks,
  	                                                    const OptimizerOptions &options) {
  	   bool is_norm_like_reduce = autoschedule::IsNormLikeReduceGraph(optimize_graph);
  	   (void)options;
  	   if (is_norm_like_reduce) {
  	     GELOGI("Graph %s satisfies norm-like reduce conditions, only generate AllLoad tasks", optimize_graph.GetName().c_str());
  	     GE_CHK_STATUS_RET(GeneratorAllLoadTask(optimize_graph, tasks));
  	   } else {
  	     GELOGI("Graph %s does not satisfy norm-like reduce conditions, use general strategy", optimize_graph.GetName().c_str());
  	     GE_CHK_STATUS_RET(GeneratorGeneralTask(optimize_graph, tasks));
  	     GE_CHK_STATUS_RET(GeneratorRCoreTask(optimize_graph, tasks));
  	     GE_CHK_STATUS_RET(GeneratorAllLoadTask(optimize_graph, tasks));
  	   }
  	   return ge::GRAPH_SUCCESS;
  	 }

Status ReducePartitionCaseGenerator::Generate([[maybe_unused]] ascir::HintGraph &graph,
                                              [[maybe_unused]] std::vector<ascir::ImplGraph> &graphs,
                                              [[maybe_unused]] std::vector<std::string> &score_functions) {
  return ge::GRAPH_SUCCESS;
}

Status ReducePartitionCaseGenerator::GenerateGeneralCase(ascir::HintGraph &graph,
                                                         std::vector<ascir::ImplGraph> &graphs,
                                                         std::vector<std::string> &score_functions) {
  if (!HasReduce(graph)) {
    return ge::GRAPH_SUCCESS;
  }
  ascir::ImplGraph optimize_graph(graph.GetName().c_str());
  optimize_graph.CopyFrom(graph);

  // 以多输出节点为起点遍历,找环路的终点:如有环路则返回终点的列表
  std::vector<std::pair<af::AscNodePtr, af::AscNodePtr>> loop_start_end;
  for (auto node : optimize_graph.GetAllNodes()) {
    if (node->GetOutDataNodes().empty()) {
      node_order_.emplace_back(node);
    }
    if (node->GetOutNodes().size() <= 1UL) {
      continue;
    }
    std::vector<af::AscNodePtr> loop_ends;
    FindNormLoop(node, loop_ends);
    for (const auto &end : loop_ends) {
      loop_start_end.emplace_back(node, end);
    }
  }
  std::sort(loop_start_end.begin(), loop_start_end.end(), [](
    const std::pair<af::AscNodePtr, af::AscNodePtr> &lhs, const std::pair<af::AscNodePtr, af::AscNodePtr> &rhs) {
    return lhs.second->GetOpDescBarePtr()->GetId() < rhs.second->GetOpDescBarePtr()->GetId();
  });

  // reduce 后融合切分
  GE_CHK_STATUS_RET(ReducePartitionPostFusion(optimize_graph));

  // 按照前面获取的环路起点、终点,进行norm的切分
  GE_CHK_STATUS_RET(PartitionNorm(optimize_graph, loop_start_end));

  // reduce 多引用结构需要切分的,切分开
  GE_CHK_STATUS_RET(ReducePartitionMultipleCitations(optimize_graph));

  if (partition_) {
    std::sort(node_order_.begin(), node_order_.end(), [](const af::AscNodePtr &lhs, af::AscNodePtr &rhs) {
      return lhs->GetOpDescBarePtr()->GetId() < rhs->GetOpDescBarePtr()->GetId();
    });

    ascir::utils::DumpGraph(graph, "before_partition");
    ascir::utils::DumpGraph(optimize_graph, "after_partition");
    graphs.emplace_back(optimize_graph);
    score_functions.resize(graphs.size());
  } else {
    node_order_.clear();
    graphs.emplace_back(graph);
  }
  return ge::GRAPH_SUCCESS;
}

Status ReducePartitionCaseGenerator::GenerateAllLoadCase(ascir::HintGraph &graph,
                                                         std::vector<ascir::ImplGraph> &graphs,
                                                         const std::vector<std::string> &score_functions) {
  (void)score_functions;
  if (!HasReduce(graph)) {
    return ge::GRAPH_SUCCESS;
  }
  node_order_.clear();
  graphs.emplace_back(graph);
  return ge::GRAPH_SUCCESS;
}

Status ReducePartitionCaseGenerator::ReducePartitionMultipleCitations(ascir::ImplGraph &impl_graph) {
  if (IsGroupGraphLegal(impl_graph)) {
    return ge::GRAPH_SUCCESS;
  }
  std::vector<af::AscNodePtr> multi_output_nodes;
  for (auto node : impl_graph.GetAllNodes()) {
    if (node->GetOutNodes().size() > 1UL) {
      multi_output_nodes.emplace_back(node);
    }
  }
  std::sort(multi_output_nodes.begin(), multi_output_nodes.end(), [](const af::AscNodePtr &lhs, af::AscNodePtr &rhs) {
      return lhs->GetOpDescBarePtr()->GetId() > rhs->GetOpDescBarePtr()->GetId();
  });
  for (auto node : multi_output_nodes) {
    std::set<af::AscNodePtr> reduce_nodes;
    for (const auto &output_node : node->GetOutNodes()) {
      af::AscNodePtr out_asc_node = std::dynamic_pointer_cast<af::AscNode>(output_node);
      if (af::AscNodePtr reduce_node = nullptr; FindOutputReduce(out_asc_node, reduce_node)) {
        if (!reduce_nodes.empty() && reduce_nodes.find(reduce_node) == reduce_nodes.end()) {
          PartitionByNode(node, out_asc_node, impl_graph);
        }
        reduce_nodes.emplace(reduce_node);
      }
    }
  }
  return ge::GRAPH_SUCCESS;
}

bool ReducePartitionCaseGenerator::FindOutputReduce(const af::AscNodePtr &node, af::AscNodePtr &reduce_node) {
  if (ScheduleUtils::IsReduce(node)) {
    reduce_node = node;
    return true;
  }
  bool output_has_reduce = false;
  if (node->GetOutNodes().empty()) {
    return output_has_reduce;
  }
  for (const auto &output_node : node->GetOutNodes()) {
    auto output_asc_node = std::dynamic_pointer_cast<af::AscNode>(output_node);
    output_has_reduce = output_has_reduce || FindOutputReduce(output_asc_node, reduce_node);
  }
  return output_has_reduce;
}

Status ReducePartitionCaseGenerator::PartitionReduce(af::AscNodePtr &src_node, ascir::ImplGraph &impl_graph) {
  partition_ = true;
  node_order_.emplace_back(src_node);
  af::ascir_op::Workspace workspace_pre((src_node->GetName() + "_Workspace").c_str());
  af::ascir_op::Workspace workspace_post((src_node->GetName() + "_Workspace").c_str());
  af::ascir_op::Load load((src_node->GetName() + "_Load").c_str());
  af::ascir_op::Store store((src_node->GetName() + "_Store").c_str());
  auto workspace_pre_node = impl_graph.AddNode(workspace_pre);
  auto workspace_post_node = impl_graph.AddNode(workspace_post);
  auto load_node = impl_graph.AddNode(load);
  auto store_node = impl_graph.AddNode(store);
  GE_CHK_STATUS_RET(DoCopyAscNodeTensorAttr(src_node, load_node));
  GE_CHK_STATUS_RET(DoCopyAscNodeTensorAttr(src_node, store_node));
  GE_CHK_STATUS_RET(DoCopyWorkspaceTensorAttr(store_node, workspace_pre_node));
  GE_CHK_STATUS_RET(DoCopyWorkspaceTensorAttr(load_node, workspace_post_node));
  for (const auto &out_anchor : src_node->GetAllOutDataAnchors()) {
    GE_CHK_BOOL_EXEC(out_anchor != nullptr,
                     REPORT_INNER_ERR_MSG("E18888", "out data anchor is null, node:%s.", src_node->GetName().c_str());
                     return ge::GRAPH_FAILED, "[Check][Param] Out data anchor is null, node:%s",
                            src_node->GetName().c_str());
    for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
      GE_CHECK_NOTNULL(peer_in_anchor);
      auto dst_node = peer_in_anchor->GetOwnerNode();
      GE_CHECK_NOTNULL(dst_node, "peer node is null, src node: %s", src_node->GetNamePtr());
      // remove src->dst
      GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(src_node->GetOutAnchor(out_anchor->GetIdx()),
                                                   dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
      GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(load_node->GetOutAnchor(out_anchor->GetIdx()),
                                                dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
    }
  }
  // add src->store->workspace_pre_node
  GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(src_node->GetOutAnchor(0UL), store_node->GetInAnchor(0UL)));
  GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(store_node->GetOutAnchor(0UL), workspace_pre_node->GetInAnchor(0UL)));
  // add workspace_post_node->load->dst
  GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(workspace_post_node->GetOutAnchor(0UL), load_node->GetInAnchor(0UL)));
  return ge::GRAPH_SUCCESS;
}

Status ReducePartitionCaseGenerator::ReducePartitionPostFusion(ascir::ImplGraph &impl_graph) {
  for (auto node : impl_graph.GetAllNodes()) {
    if (ScheduleUtils::IsReduce(node)) {
      if (IsNotPartitionReduce(node, NODE_COUNT_AFTER_REDUCE)) {
        continue;
      }
      GE_CHK_STATUS_RET(PartitionReduce(node, impl_graph));
    }
  }
  return ge::GRAPH_SUCCESS;
}

Status ReducePartitionCaseGenerator::PartitionByNode(af::AscNodePtr &src_node, af::AscNodePtr &dst_node,
                                                     ascir::ImplGraph &impl_graph) {
  partition_ = true;
  node_order_.emplace_back(src_node);
  if (ScheduleUtils::IsLoad(src_node)) {
    return PartitionLoad(src_node, dst_node, impl_graph);
  }
  if (ScheduleUtils::IsScalarLikeNode(src_node)) {
    return PartitionScalar(src_node, dst_node, impl_graph);
  };

  for (const auto &out_anchor : src_node->GetAllOutDataAnchors()) {
    GE_CHK_BOOL_EXEC(out_anchor != nullptr,
                     REPORT_INNER_ERR_MSG("E18888", "out data anchor is null, node:%s.", src_node->GetName().c_str());
                     return ge::GRAPH_FAILED, "[Check][Param] Out data anchor is null, node:%s",
                            src_node->GetName().c_str());
    af::ascir_op::Workspace workspace_pre(GetNewNodeName(src_node, dst_node, "Workspace", out_anchor->GetIdx()).c_str());
    af::ascir_op::Workspace workspace_post(GetNewNodeName(src_node, dst_node, "Workspace", out_anchor->GetIdx()).c_str());
    af::ascir_op::Load load(GetNewNodeName(src_node, dst_node, "Load", out_anchor->GetIdx()).c_str());
    af::ascir_op::Store store(GetNewNodeName(src_node, dst_node, "Store", out_anchor->GetIdx()).c_str());
    auto workspace_pre_node = impl_graph.AddNode(workspace_pre);
    auto workspace_post_node = impl_graph.AddNode(workspace_post);
    auto load_node = impl_graph.AddNode(load);
    auto store_node = impl_graph.AddNode(store);
    GE_CHK_STATUS_RET(DoCopyAscNodeTensorAttr(src_node, load_node));
    GE_CHK_STATUS_RET(DoCopyAscNodeTensorAttr(src_node, store_node));
    GE_CHK_STATUS_RET(DoCopyWorkspaceTensorAttr(store_node, workspace_pre_node));
    GE_CHK_STATUS_RET(DoCopyWorkspaceTensorAttr(load_node, workspace_post_node));
    for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
      GE_CHECK_NOTNULL(peer_in_anchor);
      GE_CHK_BOOL_EXEC(peer_in_anchor->GetOwnerNodeBarePtr() != nullptr,
                       REPORT_INNER_ERR_MSG("E18888", "Peer in node:%s is null", src_node->GetName().c_str());
                       return ge::GRAPH_FAILED, "Peer in node:%s is null", src_node->GetName().c_str());
      if (peer_in_anchor->GetOwnerNodeBarePtr()->GetName() == dst_node->GetName()) {
        // remove src->dst
        GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(src_node->GetOutAnchor(out_anchor->GetIdx()),
                                                     dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
        GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(load_node->GetOutAnchor(0UL),
                                                  dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
      }
    }
    // add src->store->workspace_pre_node
    GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(src_node->GetOutAnchor(out_anchor->GetIdx()),
                                              store_node->GetInAnchor(0UL)));
    GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(store_node->GetOutAnchor(0UL),
                                              workspace_pre_node->GetInAnchor(0UL)));
    // add workspace_post_node->load->dst
    GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(workspace_post_node->GetOutAnchor(0UL),
                                              load_node->GetInAnchor(0UL)));
  }
  return ge::GRAPH_SUCCESS;
}

Status ReducePartitionCaseGenerator::PartitionLoad(af::AscNodePtr &src_node, af::AscNodePtr &dst_node,
                                                   ascir::ImplGraph &impl_graph) {
  auto load_input_node = src_node->GetInNodes().at(0UL);
  auto load_input_asc_node = std::dynamic_pointer_cast<af::AscNode>(load_input_node);
  GE_ASSERT_TRUE(af::ops::IsOps<af::ascir_op::Data>(load_input_asc_node) || af::ops::IsOps<af::ascir_op::Workspace>(load_input_asc_node));
  af::ascir_op::Load load(("copy_from_" + src_node->GetName()).c_str());

  af::AscNodePtr new_load_input_node;
  if (af::ops::IsOps<af::ascir_op::Data>(load_input_asc_node)) {
    af::ascir_op::Data data(("copy_from_" + load_input_asc_node->GetName()).c_str());
    new_load_input_node = impl_graph.AddNode(data);
  } else {
    af::ascir_op::Workspace workspace(("copy_from_" + load_input_asc_node->GetName()).c_str());
    new_load_input_node = impl_graph.AddNode(workspace);
  }
  auto load_node = impl_graph.AddNode(load);
  DoCopyAscNodeTensorAttr(load_input_asc_node, new_load_input_node);
  DoCopyAscNodeTensorAttr(src_node, load_node);
  for (const auto &out_anchor : src_node->GetAllOutDataAnchors()) {
    GE_CHK_BOOL_EXEC(out_anchor != nullptr,
                     REPORT_INNER_ERR_MSG("E18888", "out data anchor is null, node:%s.", src_node->GetName().c_str());
                     return ge::GRAPH_FAILED, "[Check][Param] Out data anchor is null, node:%s", src_node->GetName().c_str());
    for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
      GE_CHECK_NOTNULL(peer_in_anchor);
      GE_CHK_BOOL_EXEC(peer_in_anchor->GetOwnerNodeBarePtr() != nullptr,
                       REPORT_INNER_ERR_MSG("E18888", "Peer in node:%s is null", src_node->GetName().c_str());
                       return ge::GRAPH_FAILED, "Peer in node:%s is null", src_node->GetName().c_str());
      if (peer_in_anchor->GetOwnerNodeBarePtr()->GetName() == dst_node->GetName()) {
        // remove load->dst
        GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(src_node->GetOutAnchor(out_anchor->GetIdx()),
                                                     dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
        // add new_load_input->new_load->dst
        GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(new_load_input_node->GetOutAnchor(0UL),
                                                  load_node->GetInAnchor(0UL)));
        GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(load_node->GetOutAnchor(0UL),
                                                  dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
        return ge::GRAPH_SUCCESS;
      }
    }
  }
  return ge::GRAPH_SUCCESS;
}

Status ReducePartitionCaseGenerator::PartitionScalar(af::AscNodePtr &src_node, af::AscNodePtr &dst_node,
                                                     ascir::ImplGraph &impl_graph) {
  af::AscNodePtr scalar_node;
  if (af::ops::IsOps<af::ascir_op::ScalarData>(src_node)) {
    af::ascir_op::ScalarData scalar_data(("copy_from_" + src_node->GetName()).c_str());
    scalar_node = impl_graph.AddNode(scalar_data);
  } else {
    af::ascir_op::Scalar scalar(("copy_from_" + src_node->GetName()).c_str());
    scalar_node = impl_graph.AddNode(scalar);
  }
  DoCopyAscNodeTensorAttr(src_node, scalar_node);
  for (const auto &out_anchor : src_node->GetAllOutDataAnchors()) {
    GE_CHK_BOOL_EXEC(out_anchor != nullptr,
                     REPORT_INNER_ERR_MSG("E18888", "out data anchor is null, node:%s.", src_node->GetName().c_str());
                     return ge::GRAPH_FAILED, "[Check][Param] Out data anchor is null, node:%s", src_node->GetName().c_str());
    for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
      GE_CHECK_NOTNULL(peer_in_anchor);
      GE_CHK_BOOL_EXEC(peer_in_anchor->GetOwnerNodeBarePtr() != nullptr,
                       REPORT_INNER_ERR_MSG("E18888", "Peer in node:%s is null", src_node->GetName().c_str());
                       return ge::GRAPH_FAILED, "Peer in node:%s is null", src_node->GetName().c_str());
      if (peer_in_anchor->GetOwnerNodeBarePtr()->GetName() == dst_node->GetName()) {
        // remove src->dst
        GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(src_node->GetOutAnchor(out_anchor->GetIdx()),
                                                     dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
        // add new_scalar->dst
        GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(scalar_node->GetOutAnchor(0UL),
                                                  dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
        return ge::GRAPH_SUCCESS;
      }
    }
  }
  return ge::GRAPH_SUCCESS;
}

bool ReducePartitionCaseGenerator::IsInputNodePartitioned(const std::shared_ptr<af::Node>& start,
  const std::shared_ptr<af::Node>& node) {
  // 向前找 start 节点,如找到则未经过切分,到根节点未找到则已切分或不需要切分
  if (node == start) {
    return false;
  }
  bool partitioned = true;
  if (node->GetInNodes().empty()) {
    return partitioned;
  }
  for (const auto &in_node : node->GetInNodes()) {
    partitioned = partitioned && IsInputNodePartitioned(start, in_node);
  }
  return partitioned;
}

Status ReducePartitionCaseGenerator::FindNormLoop(const af::AscNodePtr &start, std::vector<af::AscNodePtr> &ends) {
  std::set<af::NodePtr> visited{start};
  std::list<af::NodePtr> next_nodes{start};
  while (!next_nodes.empty()) {
    const auto node = next_nodes.front();
    next_nodes.pop_front();
    for (auto &out_node : node->GetOutDataNodes()) {
      if (visited.find(out_node) == visited.cend()) {
        next_nodes.emplace_back(out_node);
        visited.emplace(out_node);
      } else {
        auto asc_node = std::dynamic_pointer_cast<af::AscNode>(out_node);
        ends.emplace_back(asc_node);
      }
    }
  }
  ends.erase(
    std::remove_if(ends.begin(), ends.end(),
      [start, this](const af::AscNodePtr &end) { return !IsNorm(start, end); }), ends.end()
    );
  return ge::GRAPH_SUCCESS;
}

void ReducePartitionCaseGenerator::FindAllPath(const af::AscNodePtr& start, const af::AscNodePtr& end,
                                               std::vector<af::AscNodePtr> &path,
                                               std::vector<std::vector<af::AscNodePtr>> &all_paths) {
  // 需保证图为有向无环图
  path.emplace_back(start);
  if (start == end) {
    all_paths.emplace_back(path);
  }
  for (const auto &out_node : start->GetOutNodes()) {
    const auto asc_node = std::dynamic_pointer_cast<af::AscNode>(out_node);
    FindAllPath(asc_node, end, path, all_paths);
  }
  path.pop_back();
}

bool ReducePartitionCaseGenerator::IsNorm(const af::AscNodePtr &start, const af::AscNodePtr &end) {
  std::vector<af::AscNodePtr> path;
  std::vector<std::vector<af::AscNodePtr>> all_paths;
  FindAllPath(start, end, path, all_paths);
  bool is_norm = false;
  std::set<af::AscNodePtr> end_in_nodes;
  for (const auto &path_temp : all_paths) {
    for (const auto &node : path_temp) {
      if (ScheduleUtils::IsReduce(node)) {
        is_norm =  true;
        break;
      }
    }
    end_in_nodes.insert(path_temp.at(path_temp.size() - TWO));
  }
  return is_norm && end_in_nodes.size() > 1;
}

Status ReducePartitionCaseGenerator::PartitionNorm(ascir::ImplGraph &impl_graph, std::vector<std::pair<af::AscNodePtr,
                                                   af::AscNodePtr>> &loop_start_end) {
  for (auto loop : loop_start_end) {
    for (auto &in_node : loop.second->GetInNodes()) {
      if(IsInputNodePartitioned(loop.first, in_node)) {
        continue;
      }
      af::AscNodePtr src_node = std::dynamic_pointer_cast<af::AscNode>(in_node);
      GE_CHK_STATUS_RET(PartitionByNode(src_node, loop.second, impl_graph));
    }
  }
  return ge::GRAPH_SUCCESS;
}

bool ReducePartitionCaseGenerator::HasReduce(const ascir::ImplGraph &impl_graph) {
  for (const auto &node : impl_graph.GetAllNodes()) {
    if (ScheduleUtils::IsReduce(node)) {
      return true;
    }
  }
  return false;
}

// 全载模板只支持reduce AR ARA
bool ReducePartitionCaseGenerator::CanReduceFuse(const ascir::ImplGraph &impl_graph) {
  std::vector<ascir::SizeExpr> temp_strides;
  for (const auto &node : impl_graph.GetAllNodes()) {
    if (!ScheduleUtils::IsReduce(node)) {
      continue;
    }
    std::vector<ascir::SizeExpr> input_repeats = node->inputs[0].attr.repeats;
    std::vector<ascir::SizeExpr> output_repeats = node->outputs[0].attr.repeats;
    GE_ASSERT_TRUE(input_repeats.size() == output_repeats.size());
    if (output_repeats.empty() || (output_repeats.size() > kMaxFullLoadAxisSize)) {
      return false;
    }

    if (af::SymbolicUtils::StaticCheckEq(input_repeats[0], output_repeats[0]) != af::TriBool::kTrue) {
      return false;
    }
  }

  return true;
}

bool ReducePartitionCaseGenerator::IsGroupGraphLegal(const ascir::ImplGraph &impl_graph) {
  int reduce_count = 0;
  for (const auto &node : impl_graph.GetAllNodes()) {
    if (ScheduleUtils::IsReduce(node)) {
      reduce_count += 1;
    }
  }
  return reduce_count <= 1;
}

Status RMulticorePhase2Graph::Construct() {
  GE_ASSERT_TRUE(reducers.find(reduce_node->GetType()) != reducers.end());
  ReduceType phase2graph_reduce = reducers.find(reduce_node->GetType())->
                                  second((phase2graph.GetName() + "_" + reduce_node->GetName() + "_reduce").c_str());
  std::visit([](auto&& reduce_op) {
    reduce_op.attr.sched.axis = {0, 1};
  }, phase2graph_reduce);

  // 对于ArgMax,需要在调用CompletePhaseGraph之前保存输入dtype(value类型)
  // 因为CompletePhaseGraph会调用PartitionByReduce,后者会删除reduce_node
  af::AscTensorDataType argmax_input_dtype;
  if (reduce_node->GetType() == "ArgMax") {
    GE_ASSERT_TRUE(!reduce_node->inputs().empty(), "ArgMax node should have at least 1 input");
    argmax_input_dtype = reduce_node->inputs[0].attr.dtype;
  }

  GE_CHK_STATUS_RET(CompletePhaseGraph(phase2graph_reduce));
  GE_CHK_STATUS_RET(CreateVarAxis());

  auto workspace_node = phase2graph.FindNode((phase2graph.GetName() + "_workspace").c_str());
  GE_ASSERT_NOTNULL(workspace_node);
  workspace_node->attr.sched.axis = {0, 1};
  // 对于ArgMax,workspace存储的是value(输入类型),不是index(输出类型)
  af::AscTensorDataType workspace_dtype;
  if (reduce_node->GetType() == "ArgMax") {
    workspace_dtype = argmax_input_dtype;
  } else {
    workspace_dtype = reduce_node->outputs[0].attr.dtype;
  }
  GE_CHK_STATUS_RET(CompleteNodeAttr(workspace_node, true, workspace_dtype));
  auto load_node = phase2graph.FindNode((phase2graph.GetName() + "_load").c_str());
  GE_ASSERT_NOTNULL(load_node);
  load_node->attr.sched.axis = {0, 1};
  // load从workspace读取,dtype与workspace相同
  GE_CHK_STATUS_RET(CompleteNodeAttr(load_node, true, workspace_dtype));

  // ArgMax特殊处理:需要设置load_index_node和workspace_index_node的属性
  GE_CHK_STATUS_RET(SetupArgMaxIndexNodes(reduce_node, phase2graph));

  auto reduce_node_parse2graph = phase2graph.FindNode((phase2graph.GetName() + "_" + reduce_node->GetName() + "_reduce").c_str());
  GE_ASSERT_NOTNULL(reduce_node_parse2graph);
  std::set<af::NodePtr> visited{reduce_node_parse2graph};
  std::list<af::NodePtr> next_nodes{reduce_node_parse2graph};
  while (!next_nodes.empty()) {
    const auto node = next_nodes.front();
    next_nodes.pop_front();
    auto asc_node = std::dynamic_pointer_cast<af::AscNode>(node);
    asc_node->attr.sched.axis = {0, 1};
    GE_CHK_STATUS_RET(CompleteNodeAttr(asc_node, false, asc_node->outputs[0].attr.dtype));
    for (auto &out_node : node->GetOutDataNodes()) {
      if (visited.find(out_node) == visited.cend()) {
        next_nodes.emplace_back(out_node);
        visited.emplace(out_node);
      }
    }
  }
  return ge::GRAPH_SUCCESS;
}

Status RMulticorePhase2Graph::CreateVarAxis() {
  // 创建符号:A轴符号:s1,R轴符号:s2
  auto  compute_graph = af::AscGraphUtils::GetComputeGraph(phase2graph);
  GE_ASSERT_NOTNULL(compute_graph);
  auto attr = compute_graph->GetOrCreateAttrsGroup<af::AscGraphAttr>();
  GE_ASSERT_NOTNULL(attr);
  attr->axis.clear();
  attr->size_vars.clear();
  Rm_org_size = phase2graph.CreateSizeVar("Rm_org_size");
  A_org_size = phase2graph.CreateSizeVar("A_org_size");
  // 创建轴:a,r
  phase2graph.CreateAxis("Rm", Rm_org_size);
  phase2graph.CreateAxis("A", A_org_size);
  return ge::GRAPH_SUCCESS;
}

Status RMulticorePhase2Graph::CompleteNodeAttr(af::AscNodePtr &node, bool before_reduce,
                                               const af::AscTensorDataType& data_type) {
  node->outputs[0].attr.dtype = data_type;
  node->outputs[0].attr.axis = {0, 1};
  if (before_reduce) {
    node->outputs[0].attr.strides = {A_org_size, af::ops::One};
    node->outputs[0].attr.repeats = {Rm_org_size, A_org_size};
  } else {
    node->outputs[0].attr.strides = {af::ops::Zero, af::ops::One};
    node->outputs[0].attr.repeats = {af::ops::One, A_org_size};
  }
  return ge::GRAPH_SUCCESS;
}

Status RMulticorePhase2Graph::CompleteNodeAttrBeforeReduce(af::AscNodePtr &node) {
  node->outputs[0].attr.axis = {0, 1};
  node->outputs[0].attr.strides = {A_org_size, af::ops::One};
  node->outputs[0].attr.repeats = {Rm_org_size, A_org_size};
  return ge::GRAPH_SUCCESS;
}

Status RMulticorePhase2Graph::SetupArgMaxIndexNodes(const af::AscNodePtr &reduce_node,
                                                     ascir::ImplGraph &phase2graph) {
  // ArgMax特殊处理:需要设置load_index_node和workspace_index_node的属性
  // 注意:虽然这些设置可能被后续的CopyFrom部分覆盖,但轴信息(strides, repeats等)需要保留
  if (reduce_node->GetType() == "ArgMax") {
    auto workspace_index_node = phase2graph.FindNode((phase2graph.GetName() + "_workspace_index").c_str());
    if (workspace_index_node != nullptr) {
      workspace_index_node->attr.sched.axis = {0, 1};
      GE_CHK_STATUS_RET(CompleteNodeAttrBeforeReduce(workspace_index_node));
    }

    auto index_load_node = phase2graph.FindNode((phase2graph.GetName() + "_load_index").c_str());
    if (index_load_node != nullptr) {
      index_load_node->attr.sched.axis = {0, 1};
      GE_CHK_STATUS_RET(CompleteNodeAttrBeforeReduce(index_load_node));
    }
  }
  return ge::GRAPH_SUCCESS;
}

Status RMulticorePhase2Graph::CompletePhaseGraph(ReduceType &phase2graph_reduce) {
  std::vector<af::AscNodePtr> node_order;
  GE_ASSERT_GRAPH_SUCCESS(PartitionByReduce(phase_graph, phase2graph_reduce, node_order));
  GE_ASSERT_GRAPH_SUCCESS(SetNodeOrder(node_order));
  std::vector<::ascir::ImplGraph> sub_optimize_graphs;
  GE_ASSERT_GRAPH_SUCCESS(ScheduleGroupGraphPartitioner::PartitionByConnectivity(phase_graph, sub_optimize_graphs, node_order));
  GE_ASSERT_EQ(sub_optimize_graphs.size(), 2UL);
  phase1graph.CopyFrom(sub_optimize_graphs[0]);
  phase2graph.CopyFrom(sub_optimize_graphs[1]);
  return ge::GRAPH_SUCCESS;
}

// 辅助函数:设置ArgMaxMultiRPhase1的输出属性
// ArgMaxMultiRPhase1有2个输出:output[0]=value(T类型), output[1]=index(DT_INT64)
// value的shape从原始ArgMax输出复制,dtype从输入复制
// index的shape从原始ArgMax输出复制,dtype固定为DT_INT64
static Status SetupArgMaxPhase1OutputAttrs(const af::AscNodePtr &phase1_node,
                                    const af::AscTensorAttr &input_attr,
                                    const af::AscTensorAttr &output_attr) {
  // 设置输出0(value):shape从argmax_node输出复制,dtype从输入复制
  {
    auto op_desc = phase1_node->GetOpDesc();
    GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(0));
    auto tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
    GE_ASSERT_NOTNULL(tensor_attr_group);
    // 复制shape等属性,但dtype使用输入的类型
    *tensor_attr_group = output_attr;
    tensor_attr_group->dtype = input_attr.dtype;
  }

  // 设置输出1(index):shape和格式等从输出0复制,dtype固定为DT_INT64
  {
    auto op_desc = phase1_node->GetOpDesc();
    GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(1));
    auto tensor_attr_group = op_desc->MutableOutputDesc(1)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
    GE_ASSERT_NOTNULL(tensor_attr_group);
    // 使用之前保存的output_attr,而不是访问phase1_node->outputs[0]
    *tensor_attr_group = output_attr;
    tensor_attr_group->dtype = ge::DT_INT64;
  }
  return ge::GRAPH_SUCCESS;
}

// 辅助函数:复制节点级别的属性(sched、ir_attr等)
static void CopyNodeLevelAttrs(const af::AscNodePtr &dst_node, const af::AscNodePtr &src_node) {
  auto dst_op_desc = dst_node->GetOpDesc();
  auto dst_asc_node_attr = dst_op_desc->GetOrCreateAttrsGroup<af::AscNodeAttr>();
  auto src_op_desc = src_node->GetOpDesc();
  auto src_asc_node_attr = src_op_desc->GetOrCreateAttrsGroup<af::AscNodeAttr>();
  if (src_asc_node_attr != nullptr && dst_asc_node_attr != nullptr) {
    dst_asc_node_attr->sched = src_asc_node_attr->sched;
    if (src_asc_node_attr->ir_attr) {
      dst_asc_node_attr->ir_attr = src_asc_node_attr->ir_attr->Clone();
    }
  }
}

// R轴分核时,在阶段1将ArgMax替换为ArgMaxMultiRPhase1
Status ReplaceArgMaxInPhase1(ascir::ImplGraph &phase_graph,
                              const af::AscNodePtr &argmax_node,
                              af::AscNodePtr &store_node,
                              af::AscNodePtr &workspace_pre_node,
                              af::AscNodePtr &workspace_pre_index_node_out) {
  // 创建ArgMaxMultiRPhase1节点替换原始ArgMax
  af::ascir_op::ArgMaxMultiRPhase1 argmax_phase1((argmax_node->GetName() + "_phase1").c_str());
  auto new_argmax_phase1_node = phase_graph.AddNode(argmax_phase1);

  // 复制节点级别的属性(sched、ir_attr等)
  CopyNodeLevelAttrs(new_argmax_phase1_node, argmax_node);

  // 保存输入和输出属性,因为后续重定向边会导致inputs被清空
  const auto &saved_input_attr  = argmax_node->inputs[0].attr;
  const auto &saved_output_attr = argmax_node->outputs[0].attr;

  // 为ArgMaxMultiRPhase1的两个输出设置正确的tensor属性
  (void)SetupArgMaxPhase1OutputAttrs(new_argmax_phase1_node, saved_input_attr, saved_output_attr);

  // 将ArgMax节点的所有输入边重定向到ArgMaxMultiRPhase1节点
  for (const auto &in_anchor : argmax_node->GetAllInDataAnchors()) {
    GE_ASSERT_NOTNULL(in_anchor);
    auto src_anchor = in_anchor->GetPeerOutAnchor();
    if (src_anchor == nullptr) {
      continue;
    }
    GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(src_anchor, in_anchor));
    GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(src_anchor, new_argmax_phase1_node->GetInAnchor(in_anchor->GetIdx())));
  }

  // 为ArgMaxMultiRPhase1的第二个输出(index)创建额外的store节点
  // 注意:workspace_pre_index_node由调用方创建并传入,通过名称"_workspace_index"与Phase2的workspace_post_index关联
  af::ascir_op::Store store_index((argmax_node->GetName() + "_Store_index").c_str());
  auto new_store_index_node = phase_graph.AddNode(store_index);

  // 复制tensor属性
  // 注意:argmax_phase1_node有两个输出,需要正确设置属性

  // store_node接收argmax_phase1_node的输出0(value),类型应该是T(和输入相同)
  // 不能从argmax_node复制(argmax_node的输出是index,类型DT_INT64)
  // 注意:此时输入边已被重定向,argmax_node->inputs()可能为空,需要使用之前保存的属性
  {
    // 复制节点级别属性
    CopyNodeLevelAttrs(store_node, argmax_node);
    // 复制输出tensor属性:shape从output_attr复制,dtype从input_attr复制
    auto op_desc = store_node->GetOpDesc();
    auto tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
    GE_ASSERT_NOTNULL(tensor_attr_group);
    *tensor_attr_group = saved_output_attr;
    tensor_attr_group->dtype = saved_input_attr.dtype;
  }

  // store_index_node接收argmax_phase1_node的输出1(index),类型是DT_INT64
  // 使用之前保存的属性设置,因为此时输入边已被重定向
  {
    // 复制节点级别属性
    CopyNodeLevelAttrs(new_store_index_node, argmax_node);
    // 复制输出tensor属性:使用保存的output_attr
    auto op_desc = new_store_index_node->GetOpDesc();
    auto tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
    GE_ASSERT_NOTNULL(tensor_attr_group);
    *tensor_attr_group = saved_output_attr;  // index的shape和argmax输出相同
  }

  // workspace_pre_node连接到store_node,从store_node复制属性
  GE_ASSERT_GRAPH_SUCCESS(DoCopyWorkspaceTensorAttr(store_node, workspace_pre_node));
  GE_ASSERT_GRAPH_SUCCESS(DoCopyWorkspaceTensorAttr(new_store_index_node, workspace_pre_index_node_out));

  // 连接ArgMaxMultiRPhase1的第一个输出(value)到store节点
  GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(new_argmax_phase1_node->GetOutAnchor(0UL),
                                            store_node->GetInAnchor(0UL)));
  GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(store_node->GetOutAnchor(0UL),
                                            workspace_pre_node->GetInAnchor(0UL)));

  // 连接ArgMaxMultiRPhase1的第二个输出(index)到store_index和workspace
  GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(new_argmax_phase1_node->GetOutAnchor(1UL),
                                            new_store_index_node->GetInAnchor(0UL)));
  GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(new_store_index_node->GetOutAnchor(0UL),
                                            workspace_pre_index_node_out->GetInAnchor(0UL)));

  // 删除原始的ArgMax节点,避免"not visited"错误
  // 输入边已被重定向到argmax_phase1_node
  // 输出边将在PartitionByReduce中重定向到new_reduce_node
  auto compute_graph = af::AscGraphUtils::GetComputeGraph(phase_graph);
  GE_ASSERT_GRAPH_SUCCESS(af::GraphUtils::RemoveNodeWithoutRelink(compute_graph, argmax_node));

  return ge::GRAPH_SUCCESS;
}

// 已保存的节点属性(封装属性指针和属性值)
struct ReduceSavedNodeAttrInfo {
  const af::AscNodeAttr *attr_ptr;  // 属性指针
  af::AscNodeAttr attr_value;        // 属性值
};

// ArgMax的index路径节点
struct ArgMaxIndexNodes {
  af::AscNodePtr index_workspace_post_node;  // index workspace节点
  af::AscNodePtr index_load_node;            // index load节点
};

// Phase1节点集合结构体
struct ReducePhase1Nodes {
  af::AscNodePtr store_node;          // store节点
  af::AscNodePtr workspace_pre_node;  // workspace pre节点
};

// Phase2节点集合结构体
struct ReducePhase2Nodes {
  af::AscNodePtr workspace_post_node;  // workspace post节点
  af::AscNodePtr load_node;            // load节点
  af::AscNodePtr new_reduce_node;      // Phase2的reduce节点
};

// 图上下文结构体
struct ArgMaxPartitionGraphContext {
  ascir::ImplGraph &impl_graph;              // 图
  std::vector<af::AscNodePtr> &node_order;   // 节点顺序
  std::string graph_name;                    // 图名称(使用值而非引用,避免悬空引用)
};

// 设置ArgMax Phase2的所有属性
static Status SetupArgMaxPhase2Attrs(const ReduceSavedNodeAttrInfo &saved_attr,
                                     const ReducePhase2Nodes &phase2_nodes,
                                     const af::AscTensorAttr &reduce_input_attr,
                                     const af::AscTensorAttr &reduce_output_attr,
                                     const ArgMaxIndexNodes &index_nodes) {
  // 复制节点级别的属性到new_reduce_node
  if (saved_attr.attr_ptr != nullptr) {
    auto op_desc = phase2_nodes.new_reduce_node->GetOpDesc();
    auto dst_asc_node_attr = op_desc->GetOrCreateAttrsGroup<af::AscNodeAttr>();
    if (dst_asc_node_attr != nullptr) {
      dst_asc_node_attr->sched = saved_attr.attr_value.sched;
      if (saved_attr.attr_value.ir_attr) {
        dst_asc_node_attr->ir_attr = saved_attr.attr_value.ir_attr->Clone();
      }
    }
  }

  // 设置输出(index):shape和reduce输出相同,dtype为DT_INT64
  {
    auto op_desc = phase2_nodes.new_reduce_node->GetOpDesc();
    GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(0));
    auto tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
    GE_ASSERT_NOTNULL(tensor_attr_group);
    *tensor_attr_group = reduce_output_attr;
    tensor_attr_group->dtype = ge::DT_INT64;
  }

  // ArgMax特殊处理:先设置workspace和load的输出dtype
  {
    // 设置workspace_post和load的输出dtype为value类型
    auto op_desc = phase2_nodes.workspace_post_node->GetOpDesc();
    GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(0));
    auto tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
    GE_ASSERT_NOTNULL(tensor_attr_group);
    tensor_attr_group->dtype = reduce_input_attr.dtype;

    op_desc = phase2_nodes.load_node->GetOpDesc();
    GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(0));
    tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
    GE_ASSERT_NOTNULL(tensor_attr_group);
    tensor_attr_group->dtype = reduce_input_attr.dtype;
  }
  {
    // 设置workspace_post_index和load_index的输出dtype为DT_INT64
    auto op_desc = index_nodes.index_workspace_post_node->GetOpDesc();
    GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(0));
    auto tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
    GE_ASSERT_NOTNULL(tensor_attr_group);
    tensor_attr_group->dtype = ge::DT_INT64;

    op_desc = index_nodes.index_load_node->GetOpDesc();
    GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(0));
    tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
    GE_ASSERT_NOTNULL(tensor_attr_group);
    tensor_attr_group->dtype = ge::DT_INT64;
  }

  return ge::GRAPH_SUCCESS;
}

// 连接ArgMax Phase2的所有边
// 参数:1.输出锚点 2.输出边 3.Phase2节点集合 4.index workspace节点 5.index load节点
static Status ConnectArgMaxPhase2Edges(af::OutDataAnchorPtr argmax_out_anchor,
                                     const std::vector<af::InDataAnchorPtr> &argmax_out_edges,
                                     const ReducePhase2Nodes &phase2_nodes,
                                     const af::AscNodePtr &index_workspace_post_node,
                                     const af::AscNodePtr &index_load_node) {
  // 重定向原始ArgMax的输出边到new_reduce_node(ArgMaxMultiRPhase2)
  for (const auto &peer_in_anchor : argmax_out_edges) {
    // 移除原始边(虽然reduce_node已被删除,但移除边的操作应该还能执行)
    GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(argmax_out_anchor, peer_in_anchor));
    // 添加新边:从new_reduce_node的输出0到下游节点
    // ArgMaxMultiRPhase2只有1个输出(最终的index),所以输出索引是0
    GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(phase2_nodes.new_reduce_node->GetOutAnchor(0UL),
                                              peer_in_anchor));
  }

  // 连接Phase2的所有边
  // Phase2 value路径:workspace_post -> load -> new_reduce[输入0]
  GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(phase2_nodes.workspace_post_node->GetOutAnchor(0UL),
                                            phase2_nodes.load_node->GetInAnchor(0UL)));
  GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(phase2_nodes.load_node->GetOutAnchor(0UL),
                                            phase2_nodes.new_reduce_node->GetInAnchor(0UL)));

  // Phase2 index路径:workspace_post_index -> index_load -> new_reduce[输入1]
  GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(index_workspace_post_node->GetOutAnchor(0UL),
                                            index_load_node->GetInAnchor(0UL)));
  GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(index_load_node->GetOutAnchor(0UL),
                                            phase2_nodes.new_reduce_node->GetInAnchor(1UL)));

  return ge::GRAPH_SUCCESS;
}

// 处理ArgMax在R轴分核时的完整逻辑
// 包括:创建index路径节点、替换为Phase1算子、设置属性、连接边等
static Status HandleArgMaxPartition(const af::AscNodePtr &reduce_node,
                                     ReducePhase1Nodes &phase1_nodes,
                                     ReducePhase2Nodes &phase2_nodes,
                                     const ArgMaxPartitionGraphContext &ctx) {
  // ArgMax特殊处理:需要额外的workspace用于index
  af::ascir_op::Workspace index_workspace_pre((ctx.graph_name + "_workspace_index").c_str());
  af::ascir_op::Workspace index_workspace_post((ctx.graph_name + "_workspace_index").c_str());

  // 为index路径创建额外的workspace和load节点
  af::AscNodePtr index_workspace_pre_node  = ctx.impl_graph.AddNode(index_workspace_pre);
  af::AscNodePtr index_workspace_post_node = ctx.impl_graph.AddNode(index_workspace_post);
  ctx.node_order.emplace_back(index_workspace_pre_node);

  af::ascir_op::Load index_load((ctx.graph_name + "_load_index").c_str());
  af::AscNodePtr index_load_node = ctx.impl_graph.AddNode(index_load);
  // 参考load_node的处理方式,不复制属性(避免从Phase1节点复制错误的sched.axis)

  const auto &saved_reduce_input_attr = reduce_node->inputs[0].attr;
  const auto &saved_reduce_output_attr = reduce_node->outputs[0].attr;

  // 保存ArgMax的输出边,用于后续重定向到new_reduce_node
  af::OutDataAnchorPtr argmax_out_anchor = reduce_node->GetOutDataAnchor(0UL);
  std::vector<af::InDataAnchorPtr> argmax_out_edges;
  if (argmax_out_anchor != nullptr) {
    for (const auto &peer_in_anchor : argmax_out_anchor->GetPeerInDataAnchors()) {
      GE_ASSERT_NOTNULL(peer_in_anchor);
      argmax_out_edges.push_back(peer_in_anchor);
    }
  }

  // 复制节点级别的属性(sched、ir_attr等)
  // 在调用ReplaceArgMaxInPhase1之前保存,因为调用后reduce_node将被删除
  auto src_asc_node_attr = reduce_node->GetOpDesc()->GetOrCreateAttrsGroup<af::AscNodeAttr>();
  af::AscNodeAttr saved_node_attr;
  if (src_asc_node_attr != nullptr) {
    saved_node_attr = *src_asc_node_attr;
  }

  // 在阶段1替换为ArgMaxMultiRPhase1,并处理额外的index输出
  GE_CHK_STATUS_RET(ReplaceArgMaxInPhase1(ctx.impl_graph, reduce_node, phase1_nodes.store_node,
                                            phase1_nodes.workspace_pre_node, index_workspace_pre_node));

  // 设置ArgMax Phase2的所有属性
  ReduceSavedNodeAttrInfo saved_attr_info = {src_asc_node_attr, saved_node_attr};
  ArgMaxIndexNodes index_nodes = {index_workspace_post_node, index_load_node};
  GE_CHK_STATUS_RET(SetupArgMaxPhase2Attrs(saved_attr_info, phase2_nodes, saved_reduce_input_attr,
                                           saved_reduce_output_attr, index_nodes));

  // 复制Phase2的workspace属性,与Phase1的workspace通过相同名称关联到同一buffer
  // 注意:在设置dtype之后调用,避免覆盖正确的dtype
  GE_ASSERT_GRAPH_SUCCESS(DoCopyWorkspaceTensorAttr(phase2_nodes.load_node, phase2_nodes.workspace_post_node));
  GE_ASSERT_GRAPH_SUCCESS(DoCopyWorkspaceTensorAttr(index_load_node, index_workspace_post_node));

  // 连接ArgMax Phase2的所有边
  GE_CHK_STATUS_RET(ConnectArgMaxPhase2Edges(argmax_out_anchor, argmax_out_edges,
                                            phase2_nodes, index_workspace_post_node,
                                            index_load_node));

  return ge::GRAPH_SUCCESS;
}

// R轴分核时,为ArgMax创建额外的workspace和load节点用于index
Status RMulticorePhase2Graph::PartitionByReduce(ascir::ImplGraph &impl_graph,
                                                ReduceType &phase2graph_reduce,
                                                std::vector<af::AscNodePtr> &node_order) {
  af::ascir_op::Workspace workspace_pre((phase2graph.GetName() + "_workspace").c_str());
  af::ascir_op::Workspace workspace_post((phase2graph.GetName() + "_workspace").c_str());
  af::ascir_op::Load load((phase2graph.GetName() + "_load").c_str());
  af::ascir_op::Store store((phase1graph.GetName() + "Store").c_str());

  af::AscNodePtr new_reduce_node;
  std::visit([&new_reduce_node, &impl_graph](auto&& reduce_op) {
    new_reduce_node = impl_graph.AddNode(reduce_op);
  }, phase2graph_reduce);

  auto workspace_pre_node = impl_graph.AddNode(workspace_pre);
  node_order.emplace_back(workspace_pre_node);
  auto workspace_post_node = impl_graph.AddNode(workspace_post);
  auto load_node = impl_graph.AddNode(load);
  GE_ASSERT_NOTNULL(load_node);
  auto store_node = impl_graph.AddNode(store);
  GE_ASSERT_NOTNULL(store_node);

  if (reduce_node->GetType() == "ArgMax") {
    // ArgMax特殊处理:调用专用函数处理双路径(value和index)逻辑
    ReducePhase1Nodes phase1_nodes = {store_node, workspace_pre_node};
    ReducePhase2Nodes phase2_nodes = {workspace_post_node, load_node, new_reduce_node};
    ArgMaxPartitionGraphContext ctx = {impl_graph, node_order, phase2graph.GetName()};
    GE_CHK_STATUS_RET(HandleArgMaxPartition(reduce_node, phase1_nodes, phase2_nodes, ctx));
  } else {
    // 普通reduce算子的处理逻辑
    GE_ASSERT_GRAPH_SUCCESS(DoCopyAscNodeTensorAttr(reduce_node, new_reduce_node));
    GE_ASSERT_GRAPH_SUCCESS(DoCopyAscNodeTensorAttr(reduce_node, store_node));
    GE_ASSERT_GRAPH_SUCCESS(DoCopyWorkspaceTensorAttr(reduce_node, workspace_pre_node));
    GE_ASSERT_GRAPH_SUCCESS(DoCopyWorkspaceTensorAttr(load_node, workspace_post_node));
    for (const auto &reduce_out_anchor : reduce_node->GetAllOutDataAnchors()) {
      GE_ASSERT_NOTNULL(reduce_out_anchor);
      for (const auto &peer_in_anchor : reduce_out_anchor->GetPeerInDataAnchors()) {
        GE_ASSERT_NOTNULL(peer_in_anchor);
        auto reduce_out_node = peer_in_anchor->GetOwnerNode();
        GE_ASSERT_NOTNULL(reduce_out_node);
        GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(reduce_node->GetOutAnchor(reduce_out_anchor->GetIdx()),
                                                    reduce_out_node->GetInAnchor(peer_in_anchor->GetIdx())));
        GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(new_reduce_node->GetOutAnchor(reduce_out_anchor->GetIdx()),
                                                  reduce_out_node->GetInAnchor(peer_in_anchor->GetIdx())));
      }
    }
    // add reduce->store->workspace_pre_node
    GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(reduce_node->GetOutAnchor(0UL),
                                              store_node->GetInAnchor(0UL)));
    GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(store_node->GetOutAnchor(0UL),
                                              workspace_pre_node->GetInAnchor(0UL)));
    // add workspace_post_node->load->new reduce
    GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(workspace_post_node->GetOutAnchor(0UL),
                                              load_node->GetInAnchor(0UL)));
    GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(load_node->GetOutAnchor(0UL),
                                              new_reduce_node->GetInAnchor(0UL)));
  }

  return ge::GRAPH_SUCCESS;
}

Status RMulticorePhase2Graph::SetNodeOrder (std::vector<af::AscNodePtr> &node_order) {
  auto new_reduce_node = phase_graph.FindNode((phase2graph.GetName() + "_" + reduce_node->GetName() + "_reduce").c_str());
  GE_ASSERT_NOTNULL(new_reduce_node);
  std::set<af::NodePtr> visited{new_reduce_node};
  std::list<af::NodePtr> next_nodes{new_reduce_node};
  while (!next_nodes.empty()) {
    const auto node = next_nodes.front();
    GE_ASSERT_NOTNULL(node);
    next_nodes.pop_front();
    if (node->GetOutDataNodes().empty()) {
      node_order.emplace_back(std::dynamic_pointer_cast<af::AscNode>(node));
    }
    for (auto &out_node : node->GetOutDataNodes()) {
      if (visited.find(out_node) == visited.cend()) {
        next_nodes.emplace_back(out_node);
        visited.emplace(out_node);
      }
    }
  }
  return ge::SUCCESS;
}
}