/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef OPTIMIZE_SCHEDULE_UTILS_H_
#define OPTIMIZE_SCHEDULE_UTILS_H_

#include "ascendc_ir/ascendc_ir_core/ascendc_ir_def.h"
#include "graph/symbolizer/symbolic_utils.h"
#include "asc_graph_utils.h"
#include "ascir_ops_utils.h"
#include "ascgen_log.h"
#include "ascir.h"
#include "ascir_ops.h"
#include "common/platform_context.h"

namespace optimize {
class ScheduleUtils {
 public:
  static af::ComputeType GetComputeType(const af::AscNodePtr &node) {
    return node->attr.api.compute_type;
  }

  // 后端默认采用逆dfs的拓扑排序方式
  static Status TopologicalSorting(af::AscGraph &graph);

  static bool IsElewise(const af::AscNodePtr &node) {
    return node->attr.api.compute_type == af::ComputeType::kComputeElewise;
  }

  static bool IsBroadcast(const af::AscNodePtr &node) {
    return node->attr.api.compute_type == af::ComputeType::kComputeBroadcast;
  }

  static bool IsReduce(const af::AscNodePtr &node) {
    return node->attr.api.compute_type == af::ComputeType::kComputeReduce;
  }

  static bool IsTranspose(const af::AscNodePtr &node) {
    return node->attr.api.compute_type == af::ComputeType::kComputeTranspose;
  }

  static bool IsLoad(const af::AscNodePtr &node) {
    return node->attr.api.compute_type == af::ComputeType::kComputeLoad;
  }

  static bool IsStore(const af::AscNodePtr &node) {
    return node->attr.api.compute_type == af::ComputeType::kComputeStore;
  }

  static bool IsConcat(const af::AscNodePtr &node) {
    return node->attr.api.compute_type == af::ComputeType::kComputeConcat;
  }

  static bool IsSplit(const af::AscNodePtr &node) {
    return node->attr.api.compute_type == af::ComputeType::kComputeSplit;
  }

  static bool IsGather(const af::AscNodePtr &node) {
    return node->attr.api.compute_type == af::ComputeType::kComputeGather;
  }

  static bool IsBuffer(const af::AscNodePtr &node) {
    return node->attr.api.type == af::ApiType::kAPITypeBuffer;
  }

  static bool IsCompute(const af::AscNodePtr &node) {
    return node->attr.api.type == af::ApiType::kAPITypeCompute;
  }

  static bool IsCube(const af::AscNodePtr &node) {
    return node->attr.api.compute_type == af::ComputeType::kComputeCube;
  }

  static bool HasComputeType(const ascir::ImplGraph &impl_graph, const af::ComputeType compute_type);

  static bool IsIOBuffer(const af::NodePtr &node) {
    return af::ops::IsOps<af::ascir_op::Scalar>(node) || IsDataInput(node) ||
           af::ops::IsOps<af::ascir_op::Output>(node);
  }

  static bool IsDataInput(const af::NodePtr &node) {
    return af::ops::IsOps<af::ascir_op::Data>(node) || af::ops::IsOps<af::ascir_op::ScalarData>(node);
  }

  static bool IsDataInput(const af::Node *const node) {
    return af::ops::IsOps<af::ascir_op::Data>(node) || af::ops::IsOps<af::ascir_op::ScalarData>(node);
  }

  static bool IsScalarLikeNode(const af::NodePtr &node) {
    return af::ops::IsOps<af::ascir_op::Scalar>(node) || af::ops::IsOps<af::ascir_op::ScalarData>(node);
  }

  static bool IsScalarLikeNode(const af::Node *const node) {
    return af::ops::IsOps<af::ascir_op::Scalar>(node) || af::ops::IsOps<af::ascir_op::ScalarData>(node);
  }

  static bool IsConstantScalar(const af::Node *const node) {
    return af::ops::IsOps<af::ascir_op::Scalar>(node) || af::ops::IsOps<af::ascir_op::IndexExpr>(node);
  }

  static bool IsRemovePad(const af::NodePtr &node) {
    return af::ops::IsOps<af::ascir_op::RemovePad>(node);
  }

  template <typename T>
  static bool IsNodeSupportDataType(const af::DataType data_type) {
    std::string npu_arch;
    GE_ASSERT_SUCCESS(ge::PlatformContext::GetInstance().GetCurrentPlatformString(npu_arch));
    std::vector exp_dtypes{data_type};
    if (T::InferDataType({data_type}, exp_dtypes, npu_arch) != af::SUCCESS) {
      GELOGD("%s not support dtype=%s", T::Type, af::TypeUtils::DataTypeToSerialString(data_type).c_str());
      return false;
    }
    return true;
  }

  // 获取 npu_arch 并调用 InferDataType
  template <typename OpType>
  static Status CallAscirInferDataType(const std::vector<af::DataType> &input_dtypes,
                                       std::vector<af::DataType> &expect_output_dtypes) {
    std::string npu_arch;
    GE_ASSERT_SUCCESS(ge::PlatformContext::GetInstance().GetCurrentPlatformString(npu_arch));
    return OpType::InferDataType(input_dtypes, expect_output_dtypes, npu_arch);
  }

  static bool GetGatherParams(af::AscGraph &graph, int64_t &attr_axis, int64_t &params_size) {
    bool has_gather = false;
    for (const auto &node : graph.GetAllNodes()) {
      if (!IsGather(node)) {
        continue;
      }
      has_gather = true;
      auto input_axis = node->inputs[1].attr.axis;
      GE_ASSERT_TRUE(!input_axis.empty(), "input attr is invalid.");
      auto output_axis = node->outputs[0].attr.axis;
      for (auto iter = output_axis.begin(); iter != output_axis.end(); ++iter) {
        auto axis = graph.FindAxis(*iter);
        GE_ASSERT_NOTNULL(axis);
        if (axis->id == input_axis[0] ||
            std::find(axis->from.begin(), axis->from.end(), input_axis[0]) != axis->from.end()) {
          attr_axis = std::distance(output_axis.begin(), iter);
          GELOGD("Axis index:[%ld], axis_id:[%ld].", attr_axis, axis->id);
        }
      }
      params_size = node->inputs[0].attr.repeats.size();
      break;
    }
    return has_gather;
  }

  /**
   * 从ComputeGraph上获取AscGraphAttr属性,如果graph或者获取到的ascGraphAttr是空,则抛出异常。
   */
  static af::AscGraphAttr *GetOrCreateGraphAttrsGroup(const af::ComputeGraphPtr &graph) {
    GE_CHECK_NOTNULL_EXEC(graph, return nullptr;);
    auto attr = graph->GetOrCreateAttrsGroup<af::AscGraphAttr>();
    GE_CHECK_NOTNULL_EXEC(attr, return nullptr;);
    return attr;
  }

  /**
   * 给定node,根据其所属Graph的类型,获取循环轴。
   * 若是HintGraph,则获取图属性中的axis信息;
   * 否则ImplGraph,则直接从node的属性中获取axis信息
   */
  static Status GetLoopAxis(const af::AscNode &node, std::vector<int64_t> &axes) {
    auto attr = GetOrCreateGraphAttrsGroup(node.GetOwnerComputeGraph());
    GE_CHECK_NOTNULL(attr, "Get ascgraph type failed, attr is null.");
    if (attr->type == af::AscGraphType::kImplGraph) {
      GELOGD("GetLoopAxis Impl graph, node.name=%s", node.GetNamePtr());
      axes.insert(axes.end(), node.attr.sched.axis.begin(), node.attr.sched.axis.end());
      return af::SUCCESS;
    }
    axes.clear();
    axes.reserve(attr->axis.size());
    for (const auto &graph_axis : attr->axis) {
      GELOGD("[GetLoopAxis] node %s[%s] axis.id=%ld, axis.size=%s", node.GetTypePtr(), node.GetNamePtr(),
             graph_axis->id, graph_axis->size.Str().get());
      axes.push_back(graph_axis->id);
    }
    GELOGD("[GetLoopAxis] attr.axis.size=%zu, ids.size=%zu", attr->axis.size(), axes.size());
    return af::SUCCESS;
  }

  /**
   * 从HintGraph的图属性中,获取repeats,直接获取每个轴的size即可。
   */
  static Status GetLoopRepeats(const af::AscNode &node, std::vector<ascir::SizeExpr> &repeats) {
    auto attr = GetOrCreateGraphAttrsGroup(node.GetOwnerComputeGraph());
    GE_CHECK_NOTNULL(attr, "Get ascgraph type failed, attr is null.");
    repeats.clear();
    repeats.reserve(attr->axis.size());
    for (const auto &axis : attr->axis) {
      GELOGD("[GetLoopRepeats] node %s[%s] axis.id=%ld, axis.size=%s", node.GetTypePtr(), node.GetNamePtr(), axis->id,
             axis->size.Str().get());
      repeats.push_back(axis->size);
    }
    GELOGD("[GetLoopRepeats] attr.axis.size=%zu, repeats.size=%zu", attr->axis.size(), repeats.size());
    return af::SUCCESS;
  }

  /**
   * 从HintGraph的图属性中,获取strides信息。需要根据repeat从后往前计算出来,尾轴的stride=1,然后依次累乘repeat
   */
  static Status GetReduceInputStrides(af::AscNode &node, std::vector<ascir::SizeExpr> &strides) {
    auto attr = GetOrCreateGraphAttrsGroup(node.GetOwnerComputeGraph());
    GE_CHECK_NOTNULL(attr, "Get ascgraph type failed, attr is null.");
    if (attr->type == af::AscGraphType::kImplGraph) {
      strides = node.inputs[0].attr.strides;
      return af::SUCCESS;
    }
    strides.resize(attr->axis.size());
    ascir::SizeExpr basic_stride = af::Symbol(1);
    for (int64_t i = static_cast<int64_t>(attr->axis.size()) - 1; i >= 0; --i) {
      if (af::SymbolicUtils::StaticCheckEq(attr->axis[i]->size, af::sym::kSymbolOne) == af::TriBool::kTrue) {
        strides[i] = af::Symbol(0);
      } else {
        strides[i] = basic_stride;
        basic_stride = basic_stride * attr->axis[i]->size;
      }
      GELOGD("[GetLoopStrides] node %s[%s] axis.id=%ld, axis.strides=%s", node.GetTypePtr(), node.GetNamePtr(),
             attr->axis[i]->id, strides[i].Str().get());
    }
    GELOGD("[GetLoopStrides] attr.axis.size=%zu", attr->axis.size());
    return af::SUCCESS;
  }

  template <typename T>
  static Status GetNodeIrAttrValue(const af::NodePtr &node, const string &attr_name, T &attr_value) {
    auto asc_node = std::dynamic_pointer_cast<af::AscNode>(node);
    GE_ASSERT_NOTNULL(asc_node);
    GE_ASSERT_NOTNULL(asc_node->attr.ir_attr);
    asc_node->attr.ir_attr->GetAttrValue(attr_name, attr_value);  // 允许获取失败
    return af::SUCCESS;
  }

  static Status GetNodeIrAttrIndex(const af::NodePtr &node, int64_t &index) {
    return GetNodeIrAttrValue(node, "index", index);
  }

  static Status GetNodeIrAttrOffset(const af::NodePtr &node, af::Expression &offset) {
    auto asc_node = std::dynamic_pointer_cast<af::AscNode>(node);
    GE_ASSERT_NOTNULL(asc_node);
    GE_ASSERT_NOTNULL(asc_node->attr.ir_attr);
    return asc_node->attr.ir_attr->GetAttrValue("offset", offset);
  }

  static bool IsAxisStrideAllZero(const std::vector<ascir::AxisId> &origin_ids,
                                  const std::vector<ascir::SizeExpr> &axis_strides,
                                  const std::vector<ascir::AxisId> &axis_ids) {
    for (uint64_t i = 0; i < axis_ids.size(); i++) {
      for (uint64_t j = 0; j < origin_ids.size(); j++) {
        if (axis_ids[i] != origin_ids[j]) {
          continue;
        }
        if (af::SymbolicUtils::StaticCheckEq(axis_strides[j], af::ops::Zero) == af::TriBool::kTrue) {
          return true;
        }
      }
    }
    return false;
  }

  static bool IsBroadcastNeedMemUnique(const ascir::NodeView &node, const std::vector<ascir::AxisId> &axis_ids) {
    if (IsBuffer(node) || node->GetInDataNodesSize() == 0) {
      return false;
    }
    if (IsLoad(node)) {
      auto output = node->outputs[0];
      bool is_all_zero = IsAxisStrideAllZero(output.attr.axis, output.attr.strides, axis_ids);
      return is_all_zero;
    }

    for (const auto &in : node->inputs()) {
      GE_ASSERT_NOTNULL(in);
      auto prev_node = std::dynamic_pointer_cast<af::AscNode>(in->anchor.GetOwnerNode());
      if (IsBroadcastNeedMemUnique(prev_node, axis_ids)) {
        return true;
      }
    }
    return false;
  }

  static bool IsNextNodeRemovePad(const ascir::NodeView &node);

  static bool IsContinuesBroadcast(const std::vector<af::Expression> &in_repeats,
                                   const std::vector<af::Expression> &out_repeats);

  static bool IsContinuesStrides(const std::vector<af::Expression> &repeats,
                                 const std::vector<af::Expression> &strides);

  static bool IsContinuesVecStrides(const ascir::NodeView &node);

  static bool IsVectorizedAxisContinuousInGM(const af::AscTensorAttr &output_tensor);

  static bool IsLastAxisSliceLoad(const af::AscNodePtr &node);

  static bool NotNeedAlignVectorStride(const af::AscGraph &graph);

  static bool IsIntervalBroadcast(const std::vector<af::Expression> &in_repeats,
                                  const std::vector<af::Expression> &out_repeats);

  static bool GetTailAxisDataSize(const af::AscNodePtr &node, uint32_t &size);
  static bool IsTailAxisLessThan(const af::AscNodePtr &node, const uint32_t value);
  static bool IsTailAxisAlignedBy(const af::AscNodePtr &node, const uint32_t align_bytes=32);

  static bool IsStaticShape(const ascir::NodeView &node);

  static bool IsStaticGraph(const af::AscGraph &graph);

  static Status GetNonBrcInputTensor(const ascir::NodeView &node, const size_t index,
                                     std::unique_ptr<af::AscTensor> &tensor);

  static bool IsComputeNodes(const ascir::NodeView &node) {
    for (auto &out_node : node->GetOutDataNodes()) {
      auto cur_node = std::dynamic_pointer_cast<af::AscNode>(out_node);
      if (cur_node->attr.api.type == af::ApiType::kAPITypeCompute && !IsBroadcast(cur_node)) {
        return true;
      }
    }
    return false;
  }

  static Status RemoveUnusedAxes(af::AscGraph &graph);

  static void NormalizeAxisIds(const af::AscGraph &graph);

  static Status GetVectorRepeats(const std::vector<af::Expression> &repeats, const std::vector<int64_t> &axis,
                                 const std::vector<int64_t> &vector_axis, std::vector<af::Expression> &vector_repeats);
  static Status GetNodeInputVectorRepeats(const ascir::NodeView &node, std::vector<af::Expression> &vector_repeats);
  static Status GetNodeOutVectorRepeats(const ascir::NodeView &node, std::vector<af::Expression> &vec_repeats);
  static Status GetConcatDim(const af::AscNodePtr &node, size_t &concat_dim);

  static std::string AxesToString(const std::vector<af::AxisPtr> &axes);
  static bool IsNormStruct(const ascir::ImplGraph &implGraph);
  static bool IsReduceArFullLoad(const ascir::ImplGraph &implGraph);
  static bool HasSameInput(const af::AscNodePtr &node);
  static bool IsLastAxisReduce(const ascir::ImplGraph &impl_graph);
  static bool IsScalarBroadcastNode(const ascir::NodeView &node);
  static bool IsScalarBrc(const af::AscNodePtr &node);
  static Status SwapInputIndex(const ascir::NodeView &node, const int32_t idx1, const int32_t idx2);
  static Status GetInputForTranspose(af::AscNode &node, std::vector<ascir::AxisId> &input_axis);
  template<typename T>
  static af::AscNodePtr FindFirstNodeOfType(const af::AscGraph &graph) {
    for (const auto &node : graph.GetAllNodes()) {
      if (af::ops::IsOps<T>(node)) {
        return node;
      }
    }
    return nullptr;
  }
  static Status RemoveNode(const ascir::ImplGraph &impl_graph, const af::AscNodePtr &node,
    const af::OutDataAnchorPtr &pre_out_anchor);
  static Status RemoveNodeDst(const ascir::ImplGraph &impl_graph, const af::AscNodePtr &node,
    const af::InDataAnchorPtr &next_in_anchor);
  static bool FindContinuesBroadcastNode(const ascir::NodeView &node, vector<af::AscNodePtr> &continues_brc_nodes);

  static Status AddRemovePadAfter(af::AscGraph &graph, const af::AscNodePtr &node, af::AscNodePtr &remove_pad_node,
    const int32_t idx = 0);
  static bool IsOutNodeWithMultiInputs(const af::AscNodePtr &node);
  static Status ResolveDiffDim(const af::AscNodePtr &node, size_t &diff_dim, bool &is_first_dim);
  static Status RecalculateStridesFromRepeats(const std::vector<af::Expression> &repeats,
                                       std::vector<af::Expression> &strides);
  static bool IsNeedDiscontinuousAligned(const af::AscTensorAttr &attr);
  static Status ClearAllSizeVar(const af::AscGraph &graph);
  // 判断节点的Micro API是否支持Scalar输入,用于scalar_broadcast优化
  static bool IsMicroApiSupportsScalarInput(const af::AscNodePtr &node);
  static void GenerateStrides(const std::vector<ge::Expression> &repeats, std::vector<ge::Expression> &strides);
};
}  // namespace optimize

#endif  // OPTIMIZE_SCHEDULE_UTILS_H_