* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef AIR_CXX_COMPILER_GRAPH_EAGER_STYLE_EAGER_STYLE_GRAPH_BUILDER_GENERATOR_UTILS_H_
#define AIR_CXX_COMPILER_GRAPH_EAGER_STYLE_EAGER_STYLE_GRAPH_BUILDER_GENERATOR_UTILS_H_
#include <sstream>
#include <chrono>
#include <fstream>
#include <vector>
#include "graph/op_desc.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/op_type_utils.h"
#include "history/history_registry_types.h"
namespace ge {
namespace es {
bool IsOpSkip(const std::string &op_type);
bool IsOpSupport(const OpDescPtr &op, const char **reason);
bool IsOpExclude(const string &op_type, vector<std::string> &exlude_ops);
bool IsOpInputsAllOptional(const std::vector<std::pair<std::string, IrInputType>> &input_infos);
void WriteOut(const char *file_path, const std::stringstream &ss);
class NotSupportException : public std::exception {
public:
explicit NotSupportException(std::string msg) : msg_(std::move(msg)) {}
const char *what() const noexcept override {
return msg_.c_str();
}
private:
std::string msg_;
};
inline void GenCopyright(std::stringstream &ss, bool is_py = false) {
int32_t born_year = 2025;
std::string year_str = std::to_string(born_year);
if (is_py) {
ss << R"(# Copyright (c) )" << year_str << R"( Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ===================================================================================================================
""" This file is GENERATED by bin/gen_esb, do not edit it manually"""
)";
} else {
ss << R"(/* Copyright (c) )" << year_str << R"( Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
* ===================================================================================================================*/
/*********************************************************************************************************************
This file is GENERATED by bin/gen_esb, do not edit it manually
*********************************************************************************************************************/
)";
}
}
struct TypeInfo {
bool is_list_type;
const char *av_type;
const char *ir_type;
const char *c_api_type;
const char *cpp_api_type;
};
const TypeInfo &GetTypeInfoByAvType(const std::string &av_type);
struct IrAttrInfo {
std::string name;
bool is_required;
const TypeInfo &type_info;
std::string GetRequiredString() const {
return is_required ? "ge::es::CompliantNodeBuilder::kEsAttrRequired" : "ge::es::CompliantNodeBuilder::kEsAttrOptional";
}
};
std::string GetDefaultValueString(const OpDescPtr &op_desc, const std::string &attr_name, const std::string &av_type);
inline IrAttrInfo GetIrAttrInfoForName(const OpDescPtr &op_desc,
const std::map<AscendString, AscendString> &ir_names_to_type,
const std::string &ir_name) {
const std::string av_type = ir_names_to_type.at(AscendString(ir_name.c_str())).GetString();
const bool is_required = op_desc->GetRequiredAttrWithType().count(ir_name) > 0;
if (av_type.compare("VT_TENSOR") == 0) {
ge::ConstGeTensorPtr default_value;
ge::AttrUtils::GetTensor(op_desc, ir_name, default_value);
if (!is_required && default_value->IsTensorDataValid()) {
throw NotSupportException("Only support 'Tensor()' as default value for Tensor attr, which current op's attr is not using.");
}
}
return {ir_name, is_required, GetTypeInfoByAvType(av_type)};
}
inline std::vector<IrAttrInfo> GetAllIrAttrsNamesAndTypeInOrder(const OpDescPtr &op_desc) {
const auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc);
std::map<AscendString, AscendString> ir_names_to_type;
if (op.GetAllIrAttrNamesAndTypes(ir_names_to_type) != GRAPH_SUCCESS) {
throw std::runtime_error("Failed to get ir names and types");
}
std::vector<IrAttrInfo> req_attrs;
std::vector<IrAttrInfo> optional_attrs;
for (const auto &ir_name : op_desc->GetIrAttrNames()) {
IrAttrInfo info = GetIrAttrInfoForName(op_desc, ir_names_to_type, ir_name);
if (info.is_required) {
req_attrs.push_back(std::move(info));
} else {
optional_attrs.push_back(std::move(info));
}
}
std::vector<IrAttrInfo> ir_name_dts = std::move(req_attrs);
for (auto &op_attr : optional_attrs) {
ir_name_dts.emplace_back(std::move(op_attr));
}
return ir_name_dts;
}
inline std::vector<IrAttrInfo> GetAllIrAttrsNamesAndTypeInIrOrder(const OpDescPtr &op_desc) {
const auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc);
std::map<AscendString, AscendString> ir_names_to_type;
if (op.GetAllIrAttrNamesAndTypes(ir_names_to_type) != GRAPH_SUCCESS) {
throw std::runtime_error("Failed to get ir names and types");
}
std::vector<IrAttrInfo> result;
result.reserve(op_desc->GetIrAttrNames().size());
for (const auto &ir_name : op_desc->GetIrAttrNames()) {
result.push_back(GetIrAttrInfoForName(op_desc, ir_names_to_type, ir_name));
}
return result;
}
inline bool IsDupNameInInputs(const std::string &name, const OpDescPtr &op_desc) {
std::unordered_set<std::string> input_names;
input_names.reserve(op_desc->GetIrInputs().size());
for (const auto &ir_input : op_desc->GetIrInputs()) {
input_names.insert(ir_input.first);
}
return input_names.count(name) > 0;
}
inline bool IsDupNameInInputs(const std::string &name, const std::vector<ge::es::history::IrInput> &inputs) {
for (const auto &input : inputs) {
if (input.name == name) {
return true;
}
}
return false;
}
enum class OutputType { kNoOutput, kOneOutput, kMultiOutput, kDynamicOutput };
enum class GenLanType { GenPy, GenCpp };
inline OutputType GetOutputType(const OpDescPtr &op) {
for (const auto &ir_output : op->GetIrOutputs()) {
if (ir_output.second == kIrOutputDynamic) {
return OutputType::kDynamicOutput;
}
}
const auto ir_output_num = op->GetIrOutputs().size();
if (ir_output_num == 0) {
return OutputType::kNoOutput;
} else if (ir_output_num == 1) {
return OutputType::kOneOutput;
} else {
return OutputType::kMultiOutput;
}
}
inline bool IsKeyword(const std::string &word) {
static const std::unordered_set<std::string> keywords = {"alignas",
"alignof",
"and",
"and_eq",
"asm",
"atomic_cancel",
"atomic_commit",
"atomic_noexcept",
"auto",
"bitand",
"bitor",
"bool",
"break",
"case",
"catch",
"char",
"char16_t",
"char32_t",
"class",
"compl",
"concept",
"const",
"consteval",
"constexpr",
"constinit",
"const_cast",
"continue",
"co_await",
"co_return",
"co_yield",
"decltype",
"default",
"delete",
"do",
"double",
"dynamic_cast",
"else",
"enum",
"explicit",
"export",
"extern",
"false",
"float",
"for",
"friend",
"goto",
"if",
"inline",
"int",
"long",
"mutable",
"namespace",
"new",
"noexcept",
"not",
"not_eq",
"nullptr",
"operator",
"or",
"or_eq",
"private",
"protected",
"public",
"register",
"reinterpret_cast",
"requires",
"return",
"short",
"signed",
"sizeof",
"static",
"static_assert",
"static_cast",
"struct",
"switch",
"synchronized",
"template",
"this",
"thread_local",
"throw",
"true",
"try",
"typedef",
"typeid",
"typename",
"union",
"unsigned",
"using",
"virtual",
"void",
"volatile",
"wchar_t",
"while",
"xor",
"xor_eq"};
return keywords.find(word) != keywords.end();
}
inline bool IsPyKeyword(const std::string &word) {
static const std::unordered_set<std::string> keywords = {
"False", "None", "True", "and", "as", "assert", "async", "await", "break",
"class", "continue", "def", "del", "elif", "else", "except", "finally", "for",
"from", "global", "if", "import", "in", "is", "lambda", "nonlocal", "not",
"or", "pass", "raise", "return", "try", "while", "with", "yield"};
return keywords.find(word) != keywords.end();
}
inline std::string InName(const std::string &name, GenLanType type = GenLanType::GenCpp) {
if (type == GenLanType::GenPy) {
return IsPyKeyword(name) ? "in_" + name : name;
}
return IsKeyword(name) ? "in_" + name : name;
}
inline std::string OutName(const std::string &name, const OpDescPtr &op_desc, GenLanType type = GenLanType::GenCpp) {
if (IsDupNameInInputs(name, op_desc)) {
return "ref_" + InName(name, type);
}
if (type == GenLanType::GenPy) {
return IsPyKeyword(name) ? "out_" + name : name;
}
return IsKeyword(name) ? "out_" + name : name;
}
inline std::string SubgraphName(const std::string &name) {
if (IsKeyword(name)) {
return "subgraph_" + name;
} else {
return name;
}
}
inline std::string DynamicSubgraphVectorName(const std::string &name) {
return "dynamic_" + SubgraphName(name);
}
inline std::string AttrName(const std::string &name, const OpDescPtr &op_desc, GenLanType type = GenLanType::GenCpp) {
if (IsDupNameInInputs(name, op_desc)) {
return "attr_" + InName(name, type);
}
if (type == GenLanType::GenPy) {
return IsPyKeyword(name) ? "attr_" + name : name;
}
return IsKeyword(name) ? "attr_" + name : name;
}
inline std::string AttrName(const std::string &name, const std::vector<ge::es::history::IrInput> &inputs) {
if (IsDupNameInInputs(name, inputs)) {
return "attr_" + InName(name, GenLanType::GenCpp);
}
return IsKeyword(name) ? "attr_" + name : name;
}
void GenCommentsIfNeeded(const OpDescPtr &op, std::stringstream &h_stream, bool support_tensor_like = false);
void GenTensorLikeInputComments(const OpDescPtr &op, std::stringstream &h_stream);
std::string MakeGuardFromString(const std::string &input, const std::string &prefix, const std::string &suffix);
std::string MakeGuardFromOp(const std::string &op_type, const std::string &external_defined_prefix = "",
const std::string &suffix = "_H_");
std::string MakeGuardFromModule(const std::string &module_name, const std::string &external_defined_prefix = "",
const std::string &suffix = "_OPS_H_");
template <typename ContentMap>
void WritePerOpFiles(const std::string &output_dir, const std::vector<std::string> &op_types,
const ContentMap &contents, const std::function<std::string(const std::string &)> &get_filename) {
for (const auto &op_type : op_types) {
if (contents.find(op_type) != contents.end()) {
const std::string file_path = output_dir + get_filename(op_type);
std::ofstream file(file_path);
file << contents.at(op_type).str();
file.close();
}
}
}
template <typename ContentMap>
void WritePerUtilFiles(const std::string &output_dir, const std::vector<std::string> &util_names,
const ContentMap &contents, const std::function<std::string(const std::string &)> &get_filename) {
for (const auto &util_name : util_names) {
if (contents.find(util_name) != contents.end()) {
const std::string file_path = output_dir + get_filename(util_name);
std::ofstream file(file_path);
file << contents.at(util_name).str();
file.close();
}
}
}
}
}
#endif