/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef AIR_CXX_COMPILER_GRAPH_EAGER_STYLE_EAGER_STYLE_GRAPH_BUILDER_GENERATOR_UTILS_H_
#define AIR_CXX_COMPILER_GRAPH_EAGER_STYLE_EAGER_STYLE_GRAPH_BUILDER_GENERATOR_UTILS_H_
#include <sstream>
#include <chrono>
#include <fstream>
#include <vector>
#include "graph/op_desc.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/op_type_utils.h"
#include "history/history_registry_types.h"

namespace ge {
namespace es {
bool IsOpSkip(const std::string &op_type);
bool IsOpSupport(const OpDescPtr &op, const char **reason);
bool IsOpExclude(const string &op_type, vector<std::string> &exlude_ops);
bool IsOpInputsAllOptional(const std::vector<std::pair<std::string, IrInputType>> &input_infos);
void WriteOut(const char *file_path, const std::stringstream &ss);
class NotSupportException : public std::exception {
 public:
  explicit NotSupportException(std::string msg) : msg_(std::move(msg)) {}
  const char *what() const noexcept override {
    return msg_.c_str();
  }

 private:
  std::string msg_;
};

inline void GenCopyright(std::stringstream &ss, bool is_py = false) {
  int32_t born_year = 2025;
  std::string year_str = std::to_string(born_year);

  if (is_py) {
    ss << R"(# Copyright (c) )" << year_str << R"( Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ===================================================================================================================
""" This file is GENERATED by bin/gen_esb, do not edit it manually"""
)";
  } else {
    ss << R"(/* Copyright (c) )" << year_str << R"( Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 * ===================================================================================================================*/

/*********************************************************************************************************************
 This file is GENERATED by bin/gen_esb, do not edit it manually
*********************************************************************************************************************/
)";
  }
}
struct TypeInfo {
  bool is_list_type;

  // `AnyValue::ValueType` 的字符串形式,看`ge_attr_value.cc`中的`kAttrTypesMap`
  const char *av_type;
  const char *ir_type;
  const char *c_api_type;
  const char *cpp_api_type;
};
const TypeInfo &GetTypeInfoByAvType(const std::string &av_type);
struct IrAttrInfo {
  std::string name;
  bool is_required;
  const TypeInfo &type_info;

  std::string GetRequiredString() const {
    return is_required ? "ge::es::CompliantNodeBuilder::kEsAttrRequired" : "ge::es::CompliantNodeBuilder::kEsAttrOptional";
  }
};
std::string GetDefaultValueString(const OpDescPtr &op_desc, const std::string &attr_name, const std::string &av_type);
inline IrAttrInfo GetIrAttrInfoForName(const OpDescPtr &op_desc,
                                       const std::map<AscendString, AscendString> &ir_names_to_type,
                                       const std::string &ir_name) {
  const std::string av_type = ir_names_to_type.at(AscendString(ir_name.c_str())).GetString();
  const bool is_required = op_desc->GetRequiredAttrWithType().count(ir_name) > 0;
  if (av_type.compare("VT_TENSOR") == 0) {
    ge::ConstGeTensorPtr default_value;
    ge::AttrUtils::GetTensor(op_desc, ir_name, default_value);
    if (!is_required && default_value->IsTensorDataValid()) {
      throw NotSupportException("Only support 'Tensor()' as default value for Tensor attr, which current op's attr is not using.");
    }
  }
  return {ir_name, is_required, GetTypeInfoByAvType(av_type)};
}

// 返回属性信息:先 required 后 optional,各自内部与 GetIrAttrNames() 顺序一致
inline std::vector<IrAttrInfo> GetAllIrAttrsNamesAndTypeInOrder(const OpDescPtr &op_desc) {
  const auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc);

  std::map<AscendString, AscendString> ir_names_to_type;
  if (op.GetAllIrAttrNamesAndTypes(ir_names_to_type) != GRAPH_SUCCESS) {
    throw std::runtime_error("Failed to get ir names and types");
  }

  std::vector<IrAttrInfo> req_attrs;
  std::vector<IrAttrInfo> optional_attrs;
  for (const auto &ir_name : op_desc->GetIrAttrNames()) {
    IrAttrInfo info = GetIrAttrInfoForName(op_desc, ir_names_to_type, ir_name);
    if (info.is_required) {
      req_attrs.push_back(std::move(info));
    } else {
      optional_attrs.push_back(std::move(info));
    }
  }

  std::vector<IrAttrInfo> ir_name_dts = std::move(req_attrs);
  for (auto &op_attr : optional_attrs) {
    ir_name_dts.emplace_back(std::move(op_attr));
  }
  return ir_name_dts;
}

// 按 GetIrAttrNames() 的 IR 顺序返回属性信息
inline std::vector<IrAttrInfo> GetAllIrAttrsNamesAndTypeInIrOrder(const OpDescPtr &op_desc) {
  const auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc);
  std::map<AscendString, AscendString> ir_names_to_type;
  if (op.GetAllIrAttrNamesAndTypes(ir_names_to_type) != GRAPH_SUCCESS) {
    throw std::runtime_error("Failed to get ir names and types");
  }

  std::vector<IrAttrInfo> result;
  result.reserve(op_desc->GetIrAttrNames().size());
  for (const auto &ir_name : op_desc->GetIrAttrNames()) {
    result.push_back(GetIrAttrInfoForName(op_desc, ir_names_to_type, ir_name));
  }
  return result;
}

inline bool IsDupNameInInputs(const std::string &name, const OpDescPtr &op_desc) {
  std::unordered_set<std::string> input_names;
  input_names.reserve(op_desc->GetIrInputs().size());
  for (const auto &ir_input : op_desc->GetIrInputs()) {
    input_names.insert(ir_input.first);
  }

  return input_names.count(name) > 0;
}

inline bool IsDupNameInInputs(const std::string &name, const std::vector<ge::es::history::IrInput> &inputs) {
  for (const auto &input : inputs) {
    if (input.name == name) {
      return true;
    }
  }
  return false;
}

enum class OutputType { kNoOutput, kOneOutput, kMultiOutput, kDynamicOutput };
enum class GenLanType { GenPy, GenCpp };
inline OutputType GetOutputType(const OpDescPtr &op) {
  for (const auto &ir_output : op->GetIrOutputs()) {
    if (ir_output.second == kIrOutputDynamic) {
      return OutputType::kDynamicOutput;
    }
  }
  const auto ir_output_num = op->GetIrOutputs().size();
  if (ir_output_num == 0) {
    return OutputType::kNoOutput;
  } else if (ir_output_num == 1) {
    return OutputType::kOneOutput;
  } else {
    return OutputType::kMultiOutput;
  }
}
inline bool IsKeyword(const std::string &word) {
  // C++11关键字集合
  static const std::unordered_set<std::string> keywords = {"alignas",
                                                           "alignof",
                                                           "and",
                                                           "and_eq",
                                                           "asm",
                                                           "atomic_cancel",
                                                           "atomic_commit",
                                                           "atomic_noexcept",
                                                           "auto",
                                                           "bitand",
                                                           "bitor",
                                                           "bool",
                                                           "break",
                                                           "case",
                                                           "catch",
                                                           "char",
                                                           "char16_t",
                                                           "char32_t",
                                                           "class",
                                                           "compl",
                                                           "concept",
                                                           "const",
                                                           "consteval",
                                                           "constexpr",
                                                           "constinit",
                                                           "const_cast",
                                                           "continue",
                                                           "co_await",
                                                           "co_return",
                                                           "co_yield",
                                                           "decltype",
                                                           "default",
                                                           "delete",
                                                           "do",
                                                           "double",
                                                           "dynamic_cast",
                                                           "else",
                                                           "enum",
                                                           "explicit",
                                                           "export",
                                                           "extern",
                                                           "false",
                                                           "float",
                                                           "for",
                                                           "friend",
                                                           "goto",
                                                           "if",
                                                           "inline",
                                                           "int",
                                                           "long",
                                                           "mutable",
                                                           "namespace",
                                                           "new",
                                                           "noexcept",
                                                           "not",
                                                           "not_eq",
                                                           "nullptr",
                                                           "operator",
                                                           "or",
                                                           "or_eq",
                                                           "private",
                                                           "protected",
                                                           "public",
                                                           "register",
                                                           "reinterpret_cast",
                                                           "requires",
                                                           "return",
                                                           "short",
                                                           "signed",
                                                           "sizeof",
                                                           "static",
                                                           "static_assert",
                                                           "static_cast",
                                                           "struct",
                                                           "switch",
                                                           "synchronized",
                                                           "template",
                                                           "this",
                                                           "thread_local",
                                                           "throw",
                                                           "true",
                                                           "try",
                                                           "typedef",
                                                           "typeid",
                                                           "typename",
                                                           "union",
                                                           "unsigned",
                                                           "using",
                                                           "virtual",
                                                           "void",
                                                           "volatile",
                                                           "wchar_t",
                                                           "while",
                                                           "xor",
                                                           "xor_eq"};

  return keywords.find(word) != keywords.end();
}

inline bool IsPyKeyword(const std::string &word) {
  // Python 关键字集合
  static const std::unordered_set<std::string> keywords = {
      "False", "None",     "True",  "and",    "as",   "assert", "async",  "await",    "break",
      "class", "continue", "def",   "del",    "elif", "else",   "except", "finally",  "for",
      "from",  "global",   "if",    "import", "in",   "is",     "lambda", "nonlocal", "not",
      "or",    "pass",     "raise", "return", "try",  "while",  "with",   "yield"};

  return keywords.find(word) != keywords.end();
}

// 为输入名字添加前缀,避免IR定义中使用了关键字
inline std::string InName(const std::string &name, GenLanType type = GenLanType::GenCpp) {
  if (type == GenLanType::GenPy) {
    return IsPyKeyword(name) ? "in_" + name : name;
  }
  return IsKeyword(name) ? "in_" + name : name;
}

inline std::string OutName(const std::string &name, const OpDescPtr &op_desc, GenLanType type = GenLanType::GenCpp) {
  if (IsDupNameInInputs(name, op_desc)) {  // 若输出与输入名称相同,则参照输入名规则并添加前缀"ref_"
    return "ref_" + InName(name, type);
  }
  if (type == GenLanType::GenPy) {
    return IsPyKeyword(name) ? "out_" + name : name;
  }
  return IsKeyword(name) ? "out_" + name : name;
}
inline std::string SubgraphName(const std::string &name) {
  if (IsKeyword(name)) {
    return "subgraph_" + name;
  } else {
    return name;
  }
}
inline std::string DynamicSubgraphVectorName(const std::string &name) {
  return "dynamic_" + SubgraphName(name);
}
inline std::string AttrName(const std::string &name, const OpDescPtr &op_desc, GenLanType type = GenLanType::GenCpp) {
  if (IsDupNameInInputs(name, op_desc)) {  // 若属性与输入名称相同,则参照输入名规则并添加前缀"attr_"
    return "attr_" + InName(name, type);
  }
  if (type == GenLanType::GenPy) {
    return IsPyKeyword(name) ? "attr_" + name : name;
  } 
  return IsKeyword(name) ? "attr_" + name : name;
}

// 仅基于 IR 原型中的输入列表生成属性名
inline std::string AttrName(const std::string &name, const std::vector<ge::es::history::IrInput> &inputs) {
  if (IsDupNameInInputs(name, inputs)) {
    // 若属性与输入名称相同,则参照输入名规则并添加前缀"attr_"
    return "attr_" + InName(name, GenLanType::GenCpp);
  }
  return IsKeyword(name) ? "attr_" + name : name;
}

// 依照情况生成comments
void GenCommentsIfNeeded(const OpDescPtr &op, std::stringstream &h_stream, bool support_tensor_like = false);

// 生成TensorLike input comments
void GenTensorLikeInputComments(const OpDescPtr &op, std::stringstream &h_stream);

// 通用的保护宏生成函数
std::string MakeGuardFromString(const std::string &input, const std::string &prefix, const std::string &suffix);

// 为算子生成保护宏
std::string MakeGuardFromOp(const std::string &op_type, const std::string &external_defined_prefix = "",
                            const std::string &suffix = "_H_");

// 为模块生成保护宏
std::string MakeGuardFromModule(const std::string &module_name, const std::string &external_defined_prefix = "",
                                const std::string &suffix = "_OPS_H_");

// 通用Op文件写入辅助函数
template <typename ContentMap>
void WritePerOpFiles(const std::string &output_dir, const std::vector<std::string> &op_types,
                     const ContentMap &contents, const std::function<std::string(const std::string &)> &get_filename) {
  for (const auto &op_type : op_types) {
    if (contents.find(op_type) != contents.end()) {
      const std::string file_path = output_dir + get_filename(op_type);
      std::ofstream file(file_path);
      file << contents.at(op_type).str();
      file.close();
    }
  }
}

// 通用Utils写入辅助函数
template <typename ContentMap>
void WritePerUtilFiles(const std::string &output_dir, const std::vector<std::string> &util_names,
                     const ContentMap &contents, const std::function<std::string(const std::string &)> &get_filename) {
  for (const auto &util_name : util_names) {
    if (contents.find(util_name) != contents.end()) {
      const std::string file_path = output_dir + get_filename(util_name);
      std::ofstream file(file_path);
      file << contents.at(util_name).str();
      file.close();
    }
  }
}
}  // namespace es
}  // namespace ge
#endif  // AIR_CXX_COMPILER_GRAPH_EAGER_STYLE_EAGER_STYLE_GRAPH_BUILDER_GENERATOR_UTILS_H_