/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "pre_process/improve_precision.h"

#include <atomic>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "ascir_ops.h"
#include "schedule_utils.h"
#include "graph_properties_cache.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "graph/ascendc_ir/ascendc_ir_core/ascendc_ir_def.h"
#include "common/platform_context.h"
#include "pre_process/pre_process_config.h"
#include "ascgen_log.h"


namespace af::pre_process {
namespace {
using af::AscGraph;
using af::AscGraphUtils;
using af::AscNodeAttr;
using af::AscTensorAttr;
using af::ComputeGraphPtr;
using ge::DataType;
using ge::DT_BF16;
using ge::DT_FLOAT;
using ge::DT_FLOAT16;
using ge::DT_INT4;
using ge::DT_INT8;
using ge::DT_UINT8;
using af::GeTensorDescPtr;
using af::GraphUtils;
using af::NodePtr;
using af::NodeUtils;
using af::OpDescBuilder;
using af::OpDescUtils;
using af::TypeUtils;

bool IsLowPrecisionDataType(DataType dtype) { return dtype == DT_FLOAT16 || dtype == DT_BF16; }

bool IsHighPrecisionDataType(DataType dtype) { return dtype == DT_FLOAT; }

bool IsFloatDataType(DataType dtype) { return (IsLowPrecisionDataType(dtype) || IsHighPrecisionDataType(dtype)); }

bool IsUltraLowPrecisionDataType(DataType dtype) {
  return dtype == DT_INT4 || dtype == DT_INT8 || dtype == DT_UINT8;
}

bool IsFloatToUltraLowPrecision(DataType peer_output_dtype, DataType output_dtype) {
  return IsFloatDataType(peer_output_dtype) && IsUltraLowPrecisionDataType(output_dtype);
}

bool IsUltraLowToLowPrecision(DataType peer_output_dtype, DataType output_dtype) {
  return IsUltraLowPrecisionDataType(peer_output_dtype) && IsLowPrecisionDataType(output_dtype);
}

bool IsNodeTypeInPeerInNodes(const std::string &node_type, const std::vector<NodePtr> &peer_in_nodes) {
  for (const auto &peer_in_node: peer_in_nodes) {
    if (peer_in_node->GetType() == node_type) {
      return true;
    }
  }
  return false;
}

Status GetPeerOutNode(const NodePtr &node, NodePtr &peer_out_node, int32_t idx) {
  const auto in_anchor = node->GetInDataAnchor(idx);
  GE_ASSERT_NOTNULL(in_anchor);
  const auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
  GE_ASSERT_NOTNULL(peer_out_anchor);
  peer_out_node = peer_out_anchor->GetOwnerNode();
  GE_ASSERT_NOTNULL(peer_out_node);
  return ge::SUCCESS;
}

Status GetPeerOutNodes(const NodePtr &node, std::vector<NodePtr> &peer_out_nodes) {
  auto size = static_cast<int32_t>(node->GetAllInDataAnchorsSize());
  for (auto i = 0; i < size; i++) {
    const auto in_anchor = node->GetInDataAnchor(i);
    GE_ASSERT_NOTNULL(in_anchor);
    const auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
    if (peer_out_anchor == nullptr) {
      continue;
    }
    auto peer_out_node = peer_out_anchor->GetOwnerNode();
    GE_ASSERT_NOTNULL(peer_out_node);
    peer_out_nodes.push_back(peer_out_node);
  }
  return ge::SUCCESS;
}

Status GetPeerInNodes(const NodePtr &node, std::vector<NodePtr> &peer_in_nodes, int32_t out_data_idx) {
  const auto out_anchor = node->GetOutDataAnchor(out_data_idx);
  GE_ASSERT_NOTNULL(out_anchor);
  for (const auto &peer_in_anchor: out_anchor->GetPeerInDataAnchorsPtr()) {
    GE_ASSERT_NOTNULL(peer_in_anchor);
    const auto peer_in_node = peer_in_anchor->GetOwnerNode();
    GE_ASSERT_NOTNULL(peer_in_node);
    peer_in_nodes.push_back(peer_in_node);
  }
  return ge::SUCCESS;
}

Status GetOutputTensorDesc(const NodePtr &node, GeTensorDescPtr &output_tensor_desc) {
  const auto op_desc = node->GetOpDesc();
  GE_ASSERT_NOTNULL(op_desc);
  output_tensor_desc = op_desc->MutableOutputDesc(0);
  GE_ASSERT_NOTNULL(output_tensor_desc);
  return ge::SUCCESS;
}

Status DelNode(AscGraph &asc_graph, const NodePtr &node) {
  const auto in_data_anchor = node->GetInDataAnchor(0);
  GE_ASSERT_NOTNULL(in_data_anchor);
  const auto out_data_anchor = node->GetOutDataAnchor(0);
  GE_ASSERT_NOTNULL(out_data_anchor);
  const auto src_anchor = in_data_anchor->GetPeerOutAnchor();
  GE_ASSERT_NOTNULL(src_anchor);
  GE_ASSERT_GRAPH_SUCCESS(GraphUtils::RemoveEdge(src_anchor, in_data_anchor));
  for (const auto &dst_anchor: out_data_anchor->GetPeerInDataAnchors()) {
    GE_ASSERT_NOTNULL(dst_anchor);
    GE_ASSERT_GRAPH_SUCCESS(GraphUtils::RemoveEdge(out_data_anchor, dst_anchor));
    GE_ASSERT_GRAPH_SUCCESS(GraphUtils::AddEdge(src_anchor, dst_anchor));
  }
  GELOGD("Remove node: %s(%s) from asc_graph:%s.", node->GetName().c_str(), node->GetType().c_str(),
         asc_graph.GetName().c_str());
  GE_ASSERT_GRAPH_SUCCESS(GraphUtils::RemoveJustNode(AscGraphUtils::GetComputeGraph(asc_graph), node));
  NodeUtils::UnlinkAll(*node);
  return ge::SUCCESS;
}

Status UpdateTopoId(AscGraph &asc_graph, const NodePtr &node, int64_t topo_id_increment) {
  const auto &op_desc = node->GetOpDesc();
  GE_ASSERT_NOTNULL(op_desc);
  auto topo_id = op_desc->GetId();
  auto compute_graph = AscGraphUtils::GetComputeGraph(asc_graph);
  GE_ASSERT_NOTNULL(compute_graph);
  for (const auto &n: compute_graph->GetAllNodes()) {
    const auto &n_desc = n->GetOpDesc();
    GE_ASSERT_NOTNULL(n_desc);
    if (n_desc->GetId() > topo_id) {
      n_desc->SetId(n_desc->GetId() + topo_id_increment);
    }
  }
  return ge::SUCCESS;
}

Status FromDtypeToOtherDtype(const NodePtr &node, DataType s_dtype, DataType d_dtype) {
  const auto node_opdesc = node->GetOpDesc();
  GE_ASSERT_NOTNULL(node_opdesc);
  const auto node_output_desc_size = node_opdesc->GetAllOutputsDescSize();
  for (auto i = 0U; i < node_output_desc_size; i++) {
    const auto output_tensor_desc = node_opdesc->MutableOutputDesc(i);
    GE_ASSERT_NOTNULL(output_tensor_desc);
    if (output_tensor_desc->GetDataType() == s_dtype) {
      output_tensor_desc->SetDataType(d_dtype);
    }
  }
  return ge::SUCCESS;
}

void TopologicalSorting(const ComputeGraphPtr &graph) {
  graph->TopologicalSorting(
    [](const af::NodePtr &a, const af::NodePtr &b) { return a->GetOpDesc()->GetId() < b->GetOpDesc()->GetId(); });
}

bool CheckCastDtype(DataType input_dtype, DataType output_dtype) {
  std::vector<DataType> input_dtypes = {input_dtype};
  std::vector<DataType> expect_output_dtypes = {output_dtype};
  return optimize::ScheduleUtils::CallAscirInferDataType<af::ascir_op::Cast>(input_dtypes, expect_output_dtypes) == ge::SUCCESS;
}

std::atomic<int64_t> g_unique_number{0};

int64_t GenUniqueNumber() { return g_unique_number.fetch_add(1); }

void ResetUniqueNumber() { g_unique_number.store(0); }

// ====================== Blacklist ======================

const std::unordered_map<std::string, std::string> kBlackList1 = {
  {af::ascir_op::Data::Type, af::ascir_op::Data::Type},
  {af::ascir_op::Load::Type, af::ascir_op::Load::Type},
  {af::ascir_op::Scalar::Type, af::ascir_op::Scalar::Type},
  {af::ascir_op::Store::Type, af::ascir_op::Store::Type},
  {af::ascir_op::Output::Type, af::ascir_op::Output::Type},
  {af::ascir_op::Broadcast::Type, af::ascir_op::Broadcast::Type},
  {af::ascir_op::Transpose::Type, af::ascir_op::Transpose::Type},
  {af::ascir_op::Concat::Type, af::ascir_op::Concat::Type},
  {af::ascir_op::Gather::Type, af::ascir_op::Gather::Type},
  {"Slice", "Slice"}
};

bool IsInBlackList1(const NodePtr &node) { return kBlackList1.find(node->GetType()) != kBlackList1.end(); }

bool IsInBlackList2(const NodePtr &node, const std::unordered_set<std::string> &blacklist2) {
  return blacklist2.find(node->GetType()) != blacklist2.end();
}

Status CheckNodeDtype(const NodePtr &node) {
  std::vector<NodePtr> peer_out_nodes;
  GE_ASSERT_SUCCESS(GetPeerOutNodes(node, peer_out_nodes));
  std::vector<DataType> input_dtypes;
  for (const auto &peer_out_node: peer_out_nodes) {
    GeTensorDescPtr peer_output_tensor_desc;
    GE_ASSERT_SUCCESS(GetOutputTensorDesc(peer_out_node, peer_output_tensor_desc));
    input_dtypes.push_back(peer_output_tensor_desc->GetDataType());
  }
  GeTensorDescPtr output_tensor_desc;
  GE_ASSERT_SUCCESS(GetOutputTensorDesc(node, output_tensor_desc));
  std::vector<DataType> expect_output_dtypes = {output_tensor_desc->GetDataType()};
  std::string npu_arch;
  GE_ASSERT_SUCCESS(ge::PlatformContext::GetInstance().GetCurrentPlatformString(npu_arch));
  if (af::ascir::CommonInferDtype(node->GetType(), input_dtypes, expect_output_dtypes, npu_arch) != ge::SUCCESS) {
    GELOGE(ge::FAILED,
           "Node %s(%s) with dtype(%s) is not supported. "
           "Do not configure it in autofuse_enhance_precision_blacklist",
           node->GetName().c_str(), node->GetType().c_str(),
           TypeUtils::DataTypeToSerialString(output_tensor_desc->GetDataType()).c_str());
    return ge::FAILED;
  }
  return ge::SUCCESS;
}

const std::unordered_map<std::string, std::string> kTypeToGroup = {
  {af::ascir_op::Cast::Type, af::ascir_op::Cast::Type},
  {af::ascir_op::Load::Type, af::ascir_op::Load::Type},
  {af::ascir_op::Gather::Type, af::ascir_op::Gather::Type},
  {af::ascir_op::Scalar::Type, af::ascir_op::Scalar::Type},
  {af::ascir_op::Store::Type, af::ascir_op::Store::Type}
};

bool ShouldDeleteCastNode(DataType peer_output_dtype, DataType output_dtype) {
  return IsFloatDataType(output_dtype) && IsFloatDataType(peer_output_dtype);
}

bool ShouldChangeDataType(const NodePtr &node, const std::vector<NodePtr> &peer_in_nodes, DataType peer_output_dtype,
                          DataType output_dtype) {
  (void) node;
  if (IsNodeTypeInPeerInNodes(af::ascir_op::Store::Type, peer_in_nodes)) {
    return false;
  }
  if (IsUltraLowToLowPrecision(peer_output_dtype, output_dtype) && !CheckCastDtype(peer_output_dtype, DT_FLOAT)) {
    return false;
  }
  return (output_dtype == DT_FLOAT16 || output_dtype == DT_BF16);
}

bool IsFloatToUltraLowNeedInsertCast(const NodePtr &peer_out_node, DataType peer_output_dtype, DataType output_dtype) {
  if (!IsFloatToUltraLowPrecision(peer_output_dtype, output_dtype)) {
    return false;
  }
  if (CheckCastDtype(DT_FLOAT, output_dtype)) {
    return false;
  }
  const auto &type = peer_out_node->GetType();
  if ((type == af::ascir_op::Load::Type || type == af::ascir_op::Gather::Type ||
       type == af::ascir_op::Cast::Type) &&
      IsLowPrecisionDataType(peer_output_dtype)) {
    return false;
  }
  return true;
}

NodePtr BuildCastNode(AscGraph &asc_graph, const NodePtr &ref_node) {
  OpDescBuilder builder("Cast_" + ref_node->GetName() + "_" + std::to_string(GenUniqueNumber()),
                        af::ascir_op::Cast::Type);
  builder.AddInput("x");
  builder.AddOutput("y");
  auto cast_op_desc = builder.Build();
  GE_ASSERT_NOTNULL(cast_op_desc);
  cast_op_desc->AppendIrInput("x", af::kIrInputRequired);
  cast_op_desc->AppendIrOutput("y", af::kIrOutputRequired);
  auto op = std::make_shared<af::Operator>(OpDescUtils::CreateOperatorFromOpDesc(cast_op_desc));
  GE_ASSERT_NOTNULL(op);
  return asc_graph.AddNode(*op);
}

Status WireCastBeforeInput(AscGraph &asc_graph, const NodePtr &target, NodePtr &cast_node,
                           int32_t input_idx) {
  cast_node = BuildCastNode(asc_graph, target);
  GE_ASSERT_NOTNULL(cast_node);
  GE_ASSERT_SUCCESS(cast_node->SetOwnerComputeGraph(AscGraphUtils::GetComputeGraph(asc_graph)));
  GE_ASSERT_GRAPH_SUCCESS(GraphUtils::ReplaceNodeDataAnchors(cast_node, target, {input_idx}, {}));
  GE_ASSERT_GRAPH_SUCCESS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), target->GetInDataAnchor(input_idx)));
  return ge::SUCCESS;
}

Status TransferNodeAttrs(const NodePtr &src_node, const NodePtr &dst_node) {
  GE_ASSERT_NOTNULL(src_node->GetOpDesc());
  const auto src_attr = src_node->GetOpDesc()->GetAttrsGroup<AscNodeAttr>();
  GE_ASSERT_NOTNULL(src_attr);
  GE_ASSERT_NOTNULL(dst_node->GetOpDesc());
  auto dst_attr = dst_node->GetOpDesc()->GetOrCreateAttrsGroup<AscNodeAttr>();
  GE_ASSERT_NOTNULL(dst_attr);
  dst_attr->sched.axis = src_attr->sched.axis;
  dst_node->GetOpDesc()->SetId(src_node->GetOpDesc()->GetId());
  src_node->GetOpDesc()->SetId(src_node->GetOpDesc()->GetId() + 1);
  return ge::SUCCESS;
}

Status ConfigureCastTensor(const GeTensorDescPtr &src_tensor_desc, const NodePtr &cast_node,
                           const NodePtr &next_node, bool is_increase) {
  // Determine output dtype first
  const auto c_opdesc = cast_node->GetOpDesc();
  GE_ASSERT_NOTNULL(c_opdesc);
  auto c_out_desc = c_opdesc->MutableOutputDesc(0);
  GE_ASSERT_NOTNULL(c_out_desc);
  if (is_increase) {
    c_out_desc->SetDataType(DT_FLOAT);
  } else {
    const auto next_desc = next_node->GetOpDesc();
    GE_ASSERT_NOTNULL(next_desc);
    auto next_out = next_desc->MutableOutputDesc(0);
    GE_ASSERT_NOTNULL(next_out);
    c_out_desc->SetDataType(IsLowPrecisionDataType(next_out->GetDataType())
                                ? next_out->GetDataType() : DT_FLOAT16);
  }
  // Copy tensor attrs from source
  auto c_o_attr = c_out_desc->GetOrCreateAttrsGroup<AscTensorAttr>();
  GE_ASSERT_NOTNULL(c_o_attr);
  GE_ASSERT_NOTNULL(src_tensor_desc);
  const auto src_attr = src_tensor_desc->GetAttrsGroup<AscTensorAttr>();
  GE_ASSERT_NOTNULL(src_attr);
  c_o_attr->axis = src_attr->axis;
  c_o_attr->repeats = src_attr->repeats;
  c_o_attr->strides = src_attr->strides;
  return ge::SUCCESS;
}

// ====================== Per-type processing ======================
Status CastNodeProc(AscGraph &asc_graph, const NodePtr &node) {
  NodePtr peer_out_node;
  GE_ASSERT_SUCCESS(GetPeerOutNode(node, peer_out_node, 0));
  std::vector<NodePtr> peer_in_nodes;
  GE_ASSERT_SUCCESS(GetPeerInNodes(node, peer_in_nodes, 0));

  GeTensorDescPtr peer_output_tensor_desc;
  GE_ASSERT_SUCCESS(GetOutputTensorDesc(peer_out_node, peer_output_tensor_desc));
  GeTensorDescPtr output_tensor_desc;
  GE_ASSERT_SUCCESS(GetOutputTensorDesc(node, output_tensor_desc));
  const auto peer_output_dtype = peer_output_tensor_desc->GetDataType();
  const auto output_dtype = output_tensor_desc->GetDataType();
  if (ShouldDeleteCastNode(peer_output_dtype, output_dtype)) {
    GE_ASSERT_SUCCESS(DelNode(asc_graph, node));
    return ge::SUCCESS;
  }

  if (IsFloatToUltraLowNeedInsertCast(peer_out_node, peer_output_dtype, output_dtype)) {
    GE_ASSERT_SUCCESS(UpdateTopoId(asc_graph, node, 1));
    NodePtr c_node = nullptr;
    GE_ASSERT_SUCCESS(WireCastBeforeInput(asc_graph, node, c_node, 0));
    NodePtr peer_out_of_cast;
    GE_ASSERT_SUCCESS(GetPeerOutNode(c_node, peer_out_of_cast, 0));
    GeTensorDescPtr peer_tensor_desc;
    GE_ASSERT_SUCCESS(GetOutputTensorDesc(peer_out_of_cast, peer_tensor_desc));
    GE_ASSERT_SUCCESS(ConfigureCastTensor(peer_tensor_desc, c_node, node, false));
    GE_ASSERT_SUCCESS(TransferNodeAttrs(node, c_node));
    return ge::SUCCESS;
  }

  if (ShouldChangeDataType(node, peer_in_nodes, peer_output_dtype, output_dtype)) {
    output_tensor_desc->SetDataType(DT_FLOAT);
    return ge::SUCCESS;
  }

  return ge::SUCCESS;
}

Status IsNeedInsertCastAfterLoad(const NodePtr &node, bool &is_need_insert_cast) {
  const auto node_opdesc = node->GetOpDesc();
  GE_ASSERT_NOTNULL(node_opdesc);
  const auto output_tensor_desc = node_opdesc->MutableOutputDesc(0);
  GE_ASSERT_NOTNULL(output_tensor_desc);
  std::vector<NodePtr> peer_in_nodes;
  GE_ASSERT_SUCCESS(GetPeerInNodes(node, peer_in_nodes, 0));
  if (IsNodeTypeInPeerInNodes(af::ascir_op::Cast::Type, peer_in_nodes) ||
      IsNodeTypeInPeerInNodes(af::ascir_op::Store::Type, peer_in_nodes) ||
      !(output_tensor_desc->GetDataType() == DT_FLOAT16 || output_tensor_desc->GetDataType() == DT_BF16)) {
    return ge::SUCCESS;
  }
  is_need_insert_cast = true;
  return ge::SUCCESS;
}

Status InsertCastToIncreasePrecision(AscGraph &asc_graph, const NodePtr &load_node, bool is_need_insert_cast) {
  if (!is_need_insert_cast) {
    return ge::SUCCESS;
  }
  GE_ASSERT_SUCCESS(UpdateTopoId(asc_graph, load_node, 1));
  auto c_node = BuildCastNode(asc_graph, load_node);
  GE_ASSERT_NOTNULL(c_node);
  GE_ASSERT_GRAPH_SUCCESS(GraphUtils::ReplaceNodeDataAnchors(c_node, load_node, {}, {0}));
  GE_ASSERT_GRAPH_SUCCESS(GraphUtils::AddEdge(load_node->GetOutDataAnchor(0), c_node->GetInDataAnchor(0)));
  const auto c_opdesc = c_node->GetOpDesc();
  GE_ASSERT_NOTNULL(c_opdesc);
  const auto c_output_tensor_desc = c_opdesc->MutableOutputDesc(0);
  GE_ASSERT_NOTNULL(c_output_tensor_desc);
  c_output_tensor_desc->SetDataType(DT_FLOAT);
  const auto c_o_attr = c_output_tensor_desc->GetOrCreateAttrsGroup<AscTensorAttr>();
  GE_ASSERT_NOTNULL(c_o_attr);
  const auto load_opdesc = load_node->GetOpDesc();
  GE_ASSERT_NOTNULL(load_opdesc);
  const auto load_output_tensor_desc = load_opdesc->MutableOutputDesc(0);
  GE_ASSERT_NOTNULL(load_output_tensor_desc);
  const auto load_attr = load_output_tensor_desc->GetAttrsGroup<AscTensorAttr>();
  GE_ASSERT_NOTNULL(load_attr);
  c_o_attr->axis = load_attr->axis;
  c_o_attr->repeats = load_attr->repeats;
  c_o_attr->strides = load_attr->strides;
  const auto c_node_attr = c_opdesc->GetOrCreateAttrsGroup<AscNodeAttr>();
  GE_ASSERT_NOTNULL(c_node_attr);
  const auto load_node_attr = load_opdesc->GetAttrsGroup<AscNodeAttr>();
  GE_ASSERT_NOTNULL(load_node_attr);
  c_node_attr->sched.axis = load_node_attr->sched.axis;
  c_opdesc->SetId(load_opdesc->GetId() + 1);
  return ge::SUCCESS;
}

Status IsNeedInsertCastBeforeOther(const NodePtr &other_node, bool &need_insert, std::vector<int32_t> &input_idxs) {
  std::vector<NodePtr> peer_out_nodes;
  GE_ASSERT_SUCCESS(GetPeerOutNodes(other_node, peer_out_nodes));
  GeTensorDescPtr peer_output_tensor_desc;
  for (auto idx = 0U; idx < peer_out_nodes.size(); idx++) {
    const auto &peer_out_node = peer_out_nodes[idx];
    GE_ASSERT_SUCCESS(GetOutputTensorDesc(peer_out_node, peer_output_tensor_desc));
    const auto &type = peer_out_node->GetType();
    if (type == af::ascir_op::Cast::Type || type == af::ascir_op::Load::Type || type == af::ascir_op::Gather::Type) {
      if (IsLowPrecisionDataType(peer_output_tensor_desc->GetDataType())) {
        need_insert = true;
        input_idxs.push_back(static_cast<int32_t>(idx));
      }
    }
  }
  return ge::SUCCESS;
}

Status InsertCastBeforeNode(AscGraph &asc_graph, const NodePtr &other_node, bool is_need_insert_cast,
                            bool is_increase_precision, const std::vector<int32_t> &input_idxs) {
  if (!is_need_insert_cast) {
    return ge::SUCCESS;
  }
  for (auto input_idx: input_idxs) {
    GE_ASSERT_SUCCESS(UpdateTopoId(asc_graph, other_node, 1));
    NodePtr c_node = nullptr;
    GE_ASSERT_SUCCESS(WireCastBeforeInput(asc_graph, other_node, c_node, input_idx));
    NodePtr peer_out_node;
    GE_ASSERT_SUCCESS(GetPeerOutNode(c_node, peer_out_node, 0));
    GeTensorDescPtr peer_tensor_desc;
    GE_ASSERT_SUCCESS(GetOutputTensorDesc(peer_out_node, peer_tensor_desc));
    GE_ASSERT_SUCCESS(ConfigureCastTensor(peer_tensor_desc, c_node, other_node, is_increase_precision));
    GE_ASSERT_SUCCESS(TransferNodeAttrs(other_node, c_node));
  }
  return ge::SUCCESS;
}

Status IsNeedInsertCastBeforeStore(const NodePtr &store_node, bool &need_insert, bool &is_increase_precision) {
  NodePtr peer_out_node;
  GE_ASSERT_SUCCESS(GetPeerOutNode(store_node, peer_out_node, 0));
  GeTensorDescPtr peer_output_tensor_desc;
  GE_ASSERT_SUCCESS(GetOutputTensorDesc(peer_out_node, peer_output_tensor_desc));
  GeTensorDescPtr store_output_tensor_desc;
  GE_ASSERT_SUCCESS(GetOutputTensorDesc(store_node, store_output_tensor_desc));
  is_increase_precision = IsHighPrecisionDataType(store_output_tensor_desc->GetDataType());
  if (peer_output_tensor_desc->GetDataType() == store_output_tensor_desc->GetDataType()) {
    need_insert = false;
    return ge::SUCCESS;
  }
  need_insert = true;
  return ge::SUCCESS;
}
using TypeToNodesMap = std::unordered_map<std::string, std::vector<NodePtr>>;

bool ShouldSkipGraph(optimize::GraphPropertiesCache &cache, const AscGraph &asc_graph) {
  if (cache.HasCube() || cache.HasConcat() || cache.HasSplit()) {
    GELOGI("Graph %s is cube/concat/split type, skip precision improvement.", asc_graph.GetName().c_str());
    return true;
  }
  return false;
}

Status IsAllNodesInBlacklist(const AscGraph &asc_graph, bool &result) {
  const auto &blacklist2 = PreProcessConfig::Instance().GetImprovePrecisionBlacklist();
  constexpr char kAllNodesType[] = "all";
  const bool has_all = (blacklist2.find(kAllNodesType) != blacklist2.end());
  result = true;
  for (const auto &node : AscGraphUtils::GetComputeGraph(asc_graph)->GetAllNodes()) {
    if (node->GetType() == af::ascir_op::Output::Type || node->GetType() == af::ascir_op::Data::Type) {
      continue;
    }
    if (has_all) {
      if (CheckNodeDtype(node) != ge::SUCCESS) { result = false; return ge::SUCCESS; }
    } else if (IsInBlackList1(node)) {
      continue;
    } else if (IsInBlackList2(node, blacklist2)) {
      GE_ASSERT_SUCCESS(CheckNodeDtype(node));
    } else {
      result = false; return ge::SUCCESS;
    }
  }
  GELOGI("All nodes in graph %s are in the blacklist, skip precision improvement.", asc_graph.GetName().c_str());
  return ge::SUCCESS;
}

TypeToNodesMap GroupNodesByType(const AscGraph &asc_graph) {
  TypeToNodesMap type_to_nodes;
  for (const auto &node : AscGraphUtils::GetComputeGraph(asc_graph)->GetAllNodes()) {
    if (node->GetType() == af::ascir_op::Output::Type || node->GetType() == af::ascir_op::Data::Type) {
      continue;
    }
    auto it = kTypeToGroup.find(node->GetType());
    const std::string &group_name = (it != kTypeToGroup.end()) ? it->second : "Other";
    type_to_nodes[group_name].push_back(node);
  }
  return type_to_nodes;
}

Status ProcessLoadGatherNodes(AscGraph &asc_graph, const std::vector<NodePtr> &nodes) {
  for (const auto &node : nodes) {
    bool is_need = false;
    GE_ASSERT_SUCCESS(IsNeedInsertCastAfterLoad(node, is_need));
    GE_ASSERT_SUCCESS(InsertCastToIncreasePrecision(asc_graph, node, is_need));
  }
  return ge::SUCCESS;
}

Status ProcessOtherComputeNodes(AscGraph &asc_graph, const std::vector<NodePtr> &nodes) {
  for (const auto &node : nodes) {
    bool is_need = false;
    std::vector<int32_t> input_idxs;
    GE_ASSERT_SUCCESS(IsNeedInsertCastBeforeOther(node, is_need, input_idxs));
    GE_ASSERT_SUCCESS(InsertCastBeforeNode(asc_graph, node, is_need, true, input_idxs));
    GE_ASSERT_SUCCESS(FromDtypeToOtherDtype(node, DT_BF16, DT_FLOAT));
    GE_ASSERT_SUCCESS(FromDtypeToOtherDtype(node, DT_FLOAT16, DT_FLOAT));
  }
  return ge::SUCCESS;
}

Status ProcessStoreNodes(AscGraph &asc_graph, const std::vector<NodePtr> &nodes) {
  for (const auto &node : nodes) {
    bool is_need = false;
    bool is_increase = true;
    GE_ASSERT_SUCCESS(IsNeedInsertCastBeforeStore(node, is_need, is_increase));
    GE_ASSERT_SUCCESS(InsertCastBeforeNode(asc_graph, node, is_need, is_increase, {0}));
  }
  return ge::SUCCESS;
}

Status ProcessNodeGroups(AscGraph &asc_graph, TypeToNodesMap &type_to_nodes) {
  for (const auto &node : type_to_nodes[af::ascir_op::Cast::Type]) {
    GE_ASSERT_SUCCESS(CastNodeProc(asc_graph, node));
  }
  GE_ASSERT_SUCCESS(ProcessLoadGatherNodes(asc_graph, type_to_nodes[af::ascir_op::Load::Type]));
  GE_ASSERT_SUCCESS(ProcessLoadGatherNodes(asc_graph, type_to_nodes[af::ascir_op::Gather::Type]));
  for (const auto &node : type_to_nodes[af::ascir_op::Scalar::Type]) {
    GE_ASSERT_SUCCESS(FromDtypeToOtherDtype(node, DT_BF16, DT_FLOAT));
    GE_ASSERT_SUCCESS(FromDtypeToOtherDtype(node, DT_FLOAT16, DT_FLOAT));
  }
  GE_ASSERT_SUCCESS(ProcessOtherComputeNodes(asc_graph, type_to_nodes["Other"]));
  GE_ASSERT_SUCCESS(ProcessStoreNodes(asc_graph, type_to_nodes[af::ascir_op::Store::Type]));
  return ge::SUCCESS;
}
} // namespace

ge::Status ImprovePrecisionForAscGraph(AscGraph &asc_graph) {
  ResetUniqueNumber();
  optimize::GraphPropertiesCache cache(asc_graph);
  if (ShouldSkipGraph(cache, asc_graph)) { return ge::SUCCESS; }
  bool all_in_blacklist = false;
  GE_ASSERT_SUCCESS(IsAllNodesInBlacklist(asc_graph, all_in_blacklist));
  if (all_in_blacklist) { return ge::SUCCESS; }
  auto type_to_nodes = GroupNodesByType(asc_graph);
  GE_ASSERT_SUCCESS(ProcessNodeGroups(asc_graph, type_to_nodes));
  TopologicalSorting(AscGraphUtils::GetComputeGraph(asc_graph));
  return ge::SUCCESS;
}
} // namespace af::pre_process