/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "ascgraph_info_complete.h"
#include <map>
#include <queue>
#include "ascir_ops.h"
#include "ascendc_ir_def.h"
#include "graph/symbolizer/symbolic.h"
#include "graph/attribute_group/attr_group_shape_env.h"
#include "ascir_ops_utils.h"

using namespace af::ascir_op;

namespace optimize {
namespace {
static Status GetNodeIrAttrOffset(const af::NodePtr &node, af::Expression &offset) {
  auto asc_node = std::dynamic_pointer_cast<af::AscNode>(node);
  GE_ASSERT_NOTNULL(asc_node);
  GE_ASSERT_NOTNULL(asc_node->attr.ir_attr);
  return asc_node->attr.ir_attr->GetAttrValue("offset", offset);
}

void InsertFreeSymbolsIntoVarSet(const af::Expression &exp, SizeVarSet &size_vars) {
  std::vector<af::Expression> free_symbols = exp.FreeSymbols();
  size_vars.insert(free_symbols.begin(), free_symbols.end());
}

void CompleteDataApiInfo(af::AscNodePtr &node) {
  node->attr.api.type = af::ApiType::kAPITypeBuffer;
  node->attr.api.unit = af::ComputeUnit::kUnitNone;
}

void CompleteLoadApiInfo(af::AscNodePtr &node) {
  node->attr.api.type = af::ApiType::kAPITypeCompute;
  node->attr.api.unit = af::ComputeUnit::kUnitMTE2;
}

void CompleteStoreApiInfo(af::AscNodePtr &node) {
  node->attr.api.type = af::ApiType::kAPITypeCompute;
  node->attr.api.unit = af::ComputeUnit::kUnitMTE2;
}

void CompleteElewiseApiInfo(af::AscNodePtr &node) {
  node->attr.api.type = af::ApiType::kAPITypeCompute;
  node->attr.api.unit = af::ComputeUnit::kUnitVector;
}

void CompleteBroadcastApiInfo(af::AscNodePtr &node) {
  node->attr.api.type = af::ApiType::kAPITypeCompute;
  node->attr.api.unit = af::ComputeUnit::kUnitVector;
}

void CompleteReduceApiInfo(af::AscNodePtr &node) {
  node->attr.api.type = af::ApiType::kAPITypeCompute;
  node->attr.api.unit = af::ComputeUnit::kUnitVector;
}

void CompleteConcatApiInfo(af::AscNodePtr &node) {
  node->attr.api.type = af::ApiType::kAPITypeCompute;
  node->attr.api.unit = af::ComputeUnit::kUnitVector;
}

void CompleteSplitApiInfo(af::AscNodePtr &node) {
  node->attr.api.type = af::ApiType::kAPITypeCompute;
  node->attr.api.unit = af::ComputeUnit::kUnitVector;
}

void CompleteGatherApiInfo(af::AscNodePtr &node) {
  node->attr.api.type = af::ApiType::kAPITypeCompute;
  node->attr.api.unit = af::ComputeUnit::kUnitMTE2;
}

void CompleteCubeApiInfo(af::AscNodePtr &node) {
  node->attr.api.type = af::ApiType::kAPITypeCompute;
  node->attr.api.unit = af::ComputeUnit::kUnitCube;
}
}  // namespace

using CompleteApiInfoFunc = std::function<void(af::AscNodePtr &)>;
struct Completer {
  CompleteApiInfoFunc complete_api_info;
};

void CompleteTransposeApiInfo(af::AscNodePtr &node) {
  node->attr.api.type = af::ApiType::kAPITypeCompute;
  node->attr.api.unit = af::ComputeUnit::kUnitVector;
}

static const std::map<std::string, af::ComputeType> kOpTypeToComputeType = {
    {Workspace::Type, af::ComputeType::kComputeInvalid},      {Data::Type, af::ComputeType::kComputeInvalid},
    {Scalar::Type, af::ComputeType::kComputeInvalid},         {Output::Type, af::ComputeType::kComputeInvalid},
    {IndexExpr::Type, af::ComputeType::kComputeInvalid},

    {Load::Type, af::ComputeType::kComputeLoad},           {Store::Type, af::ComputeType::kComputeStore},

    {Sum::Type, af::ComputeType::kComputeReduce},          {Max::Type, af::ComputeType::kComputeReduce},
    {ArgMax::Type, af::ComputeType::kComputeReduce},       {Mean::Type, af::ComputeType::kComputeReduce},         {Min::Type, af::ComputeType::kComputeReduce},
    {Prod::Type, af::ComputeType::kComputeReduce},         {All::Type, af::ComputeType::kComputeReduce},
    {Any::Type, af::ComputeType::kComputeReduce},

    {Broadcast::Type, af::ComputeType::kComputeBroadcast},
    {RemovePad::Type, af::ComputeType::kComputeElewise},
    {Pad::Type, af::ComputeType::kComputeElewise},
    {Round::Type, af::ComputeType::kComputeElewise},

    {Cast::Type, af::ComputeType::kComputeElewise},        {Abs::Type, af::ComputeType::kComputeElewise},
    {Neg::Type, af::ComputeType::kComputeElewise},         {Exp::Type, af::ComputeType::kComputeElewise},
    {Sqrt::Type, af::ComputeType::kComputeElewise},        {Rsqrt::Type, af::ComputeType::kComputeElewise},
    {Relu::Type, af::ComputeType::kComputeElewise},        {Reciprocal::Type, af::ComputeType::kComputeElewise},
    {Erf::Type, af::ComputeType::kComputeElewise},         {Sign::Type, af::ComputeType::kComputeElewise},
    {Tanh::Type, af::ComputeType::kComputeElewise},        {Isnan::Type, af::ComputeType::kComputeElewise},
    {IsFinite::Type, af::ComputeType::kComputeElewise},    {Ln::Type, af::ComputeType::kComputeElewise},
    {LogicalNot::Type, af::ComputeType::kComputeElewise},

    {Add::Type, af::ComputeType::kComputeElewise},         {Sub::Type, af::ComputeType::kComputeElewise},
    {Mul::Type, af::ComputeType::kComputeElewise},         {Div::Type, af::ComputeType::kComputeElewise},
    {TrueDiv::Type, af::ComputeType::kComputeElewise},     {Minimum::Type, af::ComputeType::kComputeElewise},
    {Maximum::Type, af::ComputeType::kComputeElewise},     {LogicalOr::Type, af::ComputeType::kComputeElewise},
    {LogicalAnd::Type, af::ComputeType::kComputeElewise},

    {Ge::Type, af::ComputeType::kComputeElewise},          {Eq::Type, af::ComputeType::kComputeElewise},
    {Ne::Type, af::ComputeType::kComputeElewise},          {Gt::Type, af::ComputeType::kComputeElewise},
    {Le::Type, af::ComputeType::kComputeElewise},          {Lt::Type, af::ComputeType::kComputeElewise},
    {Broadcast::Type, af::ComputeType::kComputeElewise},   {Sigmoid::Type, af::ComputeType::kComputeElewise},
    {Concat::Type, af::ComputeType::kComputeConcat},       {Gather::Type, af::ComputeType::kComputeGather},

    {Where::Type, af::ComputeType::kComputeElewise},       {Select::Type, af::ComputeType::kComputeElewise},
    {ClipByValue::Type, af::ComputeType::kComputeElewise}, {Pow::Type, af::ComputeType::kComputeElewise},
    {Transpose::Type, af::ComputeType::kComputeTranspose},
    {BitwiseAnd::Type, af::ComputeType::kComputeElewise},  {LeakyRelu::Type, af::ComputeType::kComputeElewise},
    {FloorDiv::Type, af::ComputeType::kComputeElewise},    {Gelu::Type, af::ComputeType::kComputeElewise},
    {Axpy::Type, af::ComputeType::kComputeElewise},
    {Split::Type, af::ComputeType::kComputeSplit},
    {MatMul::Type, af::ComputeType::kComputeCube},         {MatMulBias::Type, af::ComputeType::kComputeCube},
    {MatMulOffset::Type, af::ComputeType::kComputeCube},   {MatMulOffsetBias::Type, af::ComputeType::kComputeCube},
    {BatchMatMul::Type, af::ComputeType::kComputeCube},    {BatchMatMulBias::Type, af::ComputeType::kComputeCube},
    {BatchMatMulOffset::Type, af::ComputeType::kComputeCube},
    {BatchMatMulOffsetBias::Type, af::ComputeType::kComputeCube},
    {Conv2D::Type, af::ComputeType::kComputeCube},
    {Conv2DBias::Type, af::ComputeType::kComputeCube},
    {Conv2DOffset::Type, af::ComputeType::kComputeCube},
    {Conv2DOffsetBias::Type, af::ComputeType::kComputeCube},
};

static const std::map<af::ComputeType, Completer> kComputeTypeToCompleter = {
    {af::ComputeType::kComputeInvalid, {&CompleteDataApiInfo}},
    {af::ComputeType::kComputeLoad, {&CompleteLoadApiInfo}},
    {af::ComputeType::kComputeStore, {&CompleteStoreApiInfo}},
    {af::ComputeType::kComputeReduce, {&CompleteReduceApiInfo}},
    {af::ComputeType::kComputeBroadcast, {&CompleteBroadcastApiInfo}},
    {af::ComputeType::kComputeElewise, {&CompleteElewiseApiInfo}},
    {af::ComputeType::kComputeConcat, {&CompleteConcatApiInfo}},
    {af::ComputeType::kComputeGather, {&CompleteGatherApiInfo}},
    {af::ComputeType::kComputeTranspose, {&CompleteTransposeApiInfo}},
    {af::ComputeType::kComputeSplit, {&CompleteSplitApiInfo}},
    {af::ComputeType::kComputeCube, {&CompleteCubeApiInfo}},
};

Status AscGraphInfoComplete::CompleteApiInfo(const af::AscGraph &optimize_graph) {
  for (auto node : optimize_graph.GetAllNodes()) {
    auto node_compute_type = &node->attr.api.compute_type;
    if (*node_compute_type >= af::ComputeType::kComputeInvalid) {
      auto item = kOpTypeToComputeType.find(node->GetType());
      if (item != kOpTypeToComputeType.end()) {
        *node_compute_type = item->second;
      }
    }
    auto it = kComputeTypeToCompleter.find(*node_compute_type);
    GE_ASSERT_TRUE((it != kComputeTypeToCompleter.end()), "CompleteApiInfo unsupported node name:[%s], type: [%s].",
                   node->GetNamePtr(), node->GetTypePtr());
    it->second.complete_api_info(node);
  }
  return ge::SUCCESS;
}

void AscGraphInfoComplete::AppendOriginalSizeVar(const af::AscGraph &graph, SizeVarSet &size_vars) {
  auto axes = graph.GetAllAxis();
  for (const auto &axis : axes) {
    InsertFreeSymbolsIntoVarSet(axis->size, size_vars);
  }
  auto all_nodes = graph.GetAllNodes();
  for (const auto &node : all_nodes) {
    if (!af::ops::IsOps<Nddma>(node) && !af::ops::IsOps<Store>(node) && !af::ops::IsOps<Load>(node) &&
        !af::ops::IsOps<Gather>(node)) {
      continue;
    }

    af::Expression cur_load_offset;
    if (GetNodeIrAttrOffset(node, cur_load_offset) == ge::SUCCESS) {
      InsertFreeSymbolsIntoVarSet(cur_load_offset, size_vars);
    }

    if (af::ops::IsOps<Gather>(node)) {
      for (const auto &exp : node->inputs[0].attr.repeats) {
        InsertFreeSymbolsIntoVarSet(exp, size_vars);
      }
    }

    for (const auto &exp : node->outputs[0].attr.repeats) {
      InsertFreeSymbolsIntoVarSet(exp, size_vars);
    }
    for (const auto &exp : node->outputs[0].attr.strides) {
      InsertFreeSymbolsIntoVarSet(exp, size_vars);
    }
  }
}
}  // namespace optimize