* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file op_cfg_generator.cpp
* \brief
*/
#include <fstream>
#include <sstream>
#include <cctype>
#include <unordered_map>
#include <limits.h>
#include "ascendc_tool_log.h"
#include "op_cfg_generator.h"
#include "op_build_params.h"
#include "op_dtype_name_utils.h"
namespace ops {
std::string CfgGenerator::GetDataTypeName(const ge::DataType& type) const { return FindCfgDataTypeName(type); }
std::string CfgGenerator::GetParamTypeName(uint32_t paramType) const
{
std::vector<std::string> ptstrs = {"ignore", "optional", "required", "dynamic", "optional"};
if (paramType >= ptstrs.size()) {
return std::string("unknow");
}
return ptstrs[paramType];
}
void CfgGenerator::GetParamFormats(std::vector<ge::Format>& formats, std::string& fmtstr) const
{
fmtstr = "";
for (auto fmt : formats) {
fmtstr += std::string(ge::GetFormatName(fmt)) + ",";
}
fmtstr.resize(fmtstr.size() - 1);
}
void CfgGenerator::GetParamDataTypes(std::vector<ge::DataType>& types, std::string& tpstr) const
{
tpstr = "";
for (auto type : types) {
tpstr += GetDataTypeName(type) + ",";
}
tpstr.resize(tpstr.size() - 1);
}
bool inline CfgDtypeFormatCheck(
const std::string& paramterName, const std::string& firstArgName, const std::string& secondArgName,
const size_t firstArgSize, const size_t secondArgSize)
{
if (firstArgSize > 0 && secondArgSize > 0 && firstArgSize != secondArgSize) {
ASCENDLOGE(
"InitValue parameter :%s %s and %s size do not match, %s size is %zu, %s size is %zu", paramterName.c_str(),
firstArgName.c_str(), secondArgName.c_str(), firstArgName.c_str(), firstArgSize, secondArgName.c_str(),
secondArgSize);
Generator::SetErrorMessage("CfgDtypeFormatCheck: InitValue size do not match dtype, View plog for details");
return false;
}
return true;
}
void GenSingleInitValueTypeAndValue(std::ofstream& outfile, const ScalarVar& scalar)
{
switch (scalar.scalar_type) {
case ScalarType::UINT64:
outfile << "\"type\": \"uint64\", \"value\": " << std::to_string(scalar.scalar_num.value_u64);
break;
case ScalarType::INT64:
outfile << "\"type\": \"int64\", \"value\": " << std::to_string(scalar.scalar_num.value_i64);
break;
case ScalarType::UINT32:
outfile << "\"type\": \"uint32\", \"value\": " << std::to_string(scalar.scalar_num.value_u64);
break;
case ScalarType::INT32:
outfile << "\"type\": \"int32\", \"value\": " << std::to_string(scalar.scalar_num.value_i64);
break;
case ScalarType::UINT16:
outfile << "\"type\": \"uint16\", \"value\": " << std::to_string(scalar.scalar_num.value_u64);
break;
case ScalarType::INT16:
outfile << "\"type\": \"int16\", \"value\": " << std::to_string(scalar.scalar_num.value_i64);
break;
case ScalarType::UINT8:
outfile << "\"type\": \"uint8\", \"value\": " << std::to_string(scalar.scalar_num.value_u64);
break;
case ScalarType::INT8:
outfile << "\"type\": \"int8\", \"value\": " << std::to_string(scalar.scalar_num.value_i64);
break;
case ScalarType::FLOAT32:
outfile << "\"type\": \"float32\", \"value\": " << std::to_string(scalar.scalar_num.value_f32);
break;
case ScalarType::FLOAT16:
outfile << "\"type\": \"float16\", \"value\": " << std::to_string(scalar.scalar_num.value_f32);
break;
default:
ASCENDLOGE("InitValue(ScalarVar) use unsupport type, please check!");
Generator::SetErrorMessage(
"GenSingleInitValueTypeAndValue: InitValue(ScalarVar) use unsupport type, View plog for details");
break;
}
return;
}
void CfgGenerator::GenVectorInitValue(
std::ofstream& outfile, OpParamDef& def, const std::vector<ge::DataType>& dataTypeVec,
const std::string& typeName) const
{
auto& scalarVec = def.GetInitValueList();
if (!CfgDtypeFormatCheck(
def.GetParamName().GetString(), "initValue", typeName, scalarVec.size(), dataTypeVec.size())) {
return;
}
std::unordered_map<ge::DataType, ScalarVar> typeToScalars;
for (size_t scalarIndex = 0; scalarIndex < scalarVec.size(); ++scalarIndex) {
if (typeToScalars.find(dataTypeVec[scalarIndex]) != typeToScalars.end()) {
auto& judgeScalar = typeToScalars[dataTypeVec[scalarIndex]];
if (judgeScalar == scalarVec[scalarIndex]) {
continue;
}
ASCENDLOGE(
"InitValue(std::vector<ScalarVar>) should ensure that same type has the same initValue "
"Scalar, paramName: %s, dtype %s has more than two different initValue",
def.GetParamName().GetString(), GetDataTypeName(dataTypeVec[scalarIndex]).c_str());
Generator::SetErrorMessage(
"GenVectorInitValue: initValue of the same type are not the same, View plog for details");
return;
} else {
typeToScalars.emplace(dataTypeVec[scalarIndex], scalarVec[scalarIndex]);
ASCENDLOGI(
"InitValue push back, dtype: %s initValue type: %u, initValue num: %lu\n",
GetDataTypeName(dataTypeVec[scalarIndex]).c_str(),
static_cast<uint32_t>(scalarVec[scalarIndex].scalar_type), scalarVec[scalarIndex].scalar_num.value_u64);
}
}
bool commmaFlag = true;
for (auto& typeToScalar : typeToScalars) {
if (commmaFlag) {
commmaFlag = false;
} else {
outfile << ", ";
}
outfile << "\"" << GetDataTypeName(typeToScalar.first) << "\": { ";
GenSingleInitValueTypeAndValue(outfile, typeToScalar.second);
outfile << " }";
}
}
void CfgGenerator::GenInitValue(
std::ofstream& outfile, const std::string& type, const size_t ind, OpParamDef& def) const
{
bool hasGenInitValue = false;
if (def.GetInitValueType() != InitValueType::INIT_VALUE_DEFAULT) {
switch (def.GetInitValueType()) {
case InitValueType::INIT_VALUE_UINT64_T:
hasGenInitValue = true;
if (def.GetInitValue().value_u64 != 0) {
ASCENDLOGW(
"Parameter: %s initValue is %lu, when the parameter type is not uint64_t, undefined "
"behavior may occur.",
def.GetParamName().GetString(), def.GetInitValue().value_u64);
Generator::SetErrorMessage(
"GenInitValue: The parameter type of initvalue is not uint64_t, View plog for details");
}
ASCENDLOGI(
"The initial value of the paramete: %s is set to %lu by InitValue(uint64_t).",
def.GetParamName().GetString(), def.GetInitValue().value_u64);
outfile << type << ind << ".initValue=" << std::to_string(def.GetInitValue().value_u64) << std::endl;
break;
default:
ASCENDLOGE("InitValue(uint64_t) only support uint64_t now, please check!");
Generator::SetErrorMessage(
"GenInitValue: InitValue(uint64_t) only support uint64_t now, View plog for details");
break;
}
}
if (def.GetInitValueList().size() > 0) {
if (hasGenInitValue) {
ASCENDLOGW(
"Parameter: %s set initvalue in different ways at the same time results in undefined behavior.",
def.GetParamName().GetString());
Generator::SetErrorMessage("GenInitValue: Set initvalue in different ways at the same time results in "
"undefined behavior, View plog for details");
return;
}
if (def.GetInitValueList().size() == 1) {
ASCENDLOGI(
"The initValue of the paramete: %s is set by InitValue(ScalarVar).", def.GetParamName().GetString());
auto scalar = def.GetInitValueList()[0];
outfile << type << ind << ".initValue={ \"is_list\" : false, ";
GenSingleInitValueTypeAndValue(outfile, scalar);
outfile << "}" << std::endl;
} else {
ASCENDLOGI(
"The initValue of the parameter: %s is set by InitValue(std::vector<ScalarVar>).",
def.GetParamName().GetString());
outfile << type << ind << ".initValue={ \"is_list\" : true, ";
if (def.IsDtype()) {
GenVectorInitValue(outfile, def, def.GetOriginDataTypes(), "type");
} else {
GenVectorInitValue(outfile, def, def.GetDataTypesList(), "typeList");
}
outfile << "}" << std::endl;
}
}
return;
}
void CfgGenerator::GenParamInfo(std::ofstream& outfile, std::vector<OpParamDef>& param, bool isOutput) const
{
std::string type = (isOutput ? "output" : "input");
std::string tpstr;
std::string fmtstr;
size_t ind = 0;
for (auto def : param) {
outfile << type << ind << ".name=" << def.GetParamName().GetString() << std::endl;
if (def.GetDataTypes().size() > 0) {
GetParamDataTypes(def.GetDataTypes(), tpstr);
outfile << type << ind << ".dtype=" << tpstr << std::endl;
}
if (def.IsSetDtypeForBin()) {
GetParamDataTypes(def.GetDataTypesForBin(), tpstr);
outfile << type << ind << ".for_bin_dtype=" << tpstr << std::endl;
}
if (def.GetFormats().size() > 0) {
GetParamFormats(def.GetFormats(), fmtstr);
outfile << type << ind << ".format=" << fmtstr << std::endl;
}
if (def.IsSetFormatForBin()) {
GetParamFormats(def.GetFormatsForBin(), tpstr);
outfile << type << ind << ".for_bin_format=" << tpstr << std::endl;
}
if (def.GetUnknownShapeFormats().size() > 0) {
GetParamFormats(def.GetUnknownShapeFormats(), fmtstr);
outfile << type << ind << ".unknownshape_format=" << fmtstr << std::endl;
}
outfile << type << ind << ".shape=all" << std::endl;
outfile << type << ind << ".paramType=" << GetParamTypeName(def.GetParamType()) << std::endl;
if (def.GetParamType() == VIRTUAL) {
outfile << type << ind << ".virtual=true" << std::endl;
}
if (def.GetValueDepend().GetLength() > 0) {
outfile << type << ind << ".valueDepend=" << def.GetValueDepend().GetString() << std::endl;
}
GenInitValue(outfile, type, ind, def);
if (def.IsOutputShapeDependOnCompute() == true) {
outfile << type << ind << ".outputShapeDependOnCompute=true" << std::endl;
}
ind++;
}
}
void CfgGenerator::GenAttrInfo(std::ofstream& outfile, std::vector<OpAttrDef>& attrs) const
{
std::string attrList = "";
if (attrs.size() <= 0) {
return;
}
for (auto attr : attrs) {
attrList += std::string(attr.GetName().GetString()) + ",";
}
attrList.resize(attrList.size() - 1);
outfile << "attr.list=" << attrList << std::endl;
for (auto attr : attrs) {
outfile << "attr_" << attr.GetName().GetString() << ".type=" << attr.GetCfgDataType().GetString() << std::endl;
outfile << "attr_" << attr.GetName().GetString() << ".value=all" << std::endl;
outfile << "attr_" << attr.GetName().GetString()
<< ".paramType=" << (attr.IsRequired() ? "required" : "optional") << std::endl;
if (!attr.IsRequired()) {
outfile << "attr_" << attr.GetName().GetString()
<< ".defaultValue=" << attr.GetAttrDefaultVal("[]").GetString() << std::endl;
}
}
}
void CfgGenerator::GenImplFile(std::ofstream& outfile, std::string& opType, OpAICoreConfig& aicoreConfig) const
{
std::string snakeName = ConvertToSnakeCase(opType);
std::map<ge::AscendString, ge::AscendString>& cfgInfo = aicoreConfig.GetCfgInfo();
auto it = cfgInfo.find("opFile.value");
if (it == cfgInfo.cend()) {
outfile << "opFile.value=" << snakeName << std::endl;
}
it = cfgInfo.find("opInterface.value");
if (it == cfgInfo.cend()) {
outfile << "opInterface.value=" << snakeName << std::endl;
}
}
void CfgGenerator::GenExtendInfo(std::ofstream& outfile, OpAICoreConfig& aicoreConfig, const bool enableFallBack) const
{
bool hasSetAclnnSupport = false;
bool hasSetPrebuildPattern = false;
bool hasSetCoreType = false;
bool hasSetJitCompile = false;
std::string sourceFlag = "";
const char* sourcePackage = std::getenv("ENABLE_SOURCE_PACKAGE");
if (sourcePackage != nullptr && strlen(sourcePackage) != 0) {
sourceFlag = std::string(sourcePackage);
}
std::map<ge::AscendString, ge::AscendString>& cfgInfo = aicoreConfig.GetCfgInfo();
for (auto& key : aicoreConfig.GetCfgKeys()) {
outfile << key.GetString() << "=" << cfgInfo[key].GetString() << std::endl;
if (std::string(key.GetString()) == "aclnnSupport.value") {
hasSetAclnnSupport = true;
}
if (std::string(key.GetString()) == "prebuildPattern.value") {
hasSetPrebuildPattern = true;
}
if (std::string(key.GetString()) == "coreType.value") {
hasSetCoreType = true;
}
if (std::string(key.GetString()) == "jitCompile.flag") {
hasSetCoreType = true;
}
}
if (enableFallBack && !hasSetAclnnSupport) {
outfile << "aclnnSupport.value=support_aclnn" << std::endl;
}
if (!hasSetPrebuildPattern) {
outfile << "prebuildPattern.value=Opaque" << std::endl;
}
if (!hasSetCoreType) {
outfile << "coreType.value=AiCore" << std::endl;
}
if (sourceFlag == "False" && !hasSetJitCompile) {
outfile << "jitCompile.flag=static_false,dynamic_false" << std::endl;
}
}
void CfgGenerator::GenMC2Info(std::ofstream& outfile, std::vector<ge::AscendString>& mc2Grps) const
{
uint32_t index = 0;
std::string mc2Cfg = "";
for (; index < mc2Grps.size(); index++) {
auto name = std::string("mc2_context_") + static_cast<char>(0x30 + index);
mc2Cfg += name + ",";
}
mc2Cfg.resize(mc2Cfg.size() - 1);
outfile << "mc2.ctx=" << mc2Cfg << std::endl;
}
void Split(const std::string& str, const char delimiter, std::vector<std::string>& result)
{
std::stringstream ss(str);
std::string tmp;
while (std::getline(ss, tmp, delimiter)) {
result.push_back(tmp);
}
}
void CfgGenerator::ParseSingleComputeUnitOfOp(
OpDef& opsDef, std::string& opType, OpAICoreConfig& aicoreConfig, bool enableFallBack, std::ofstream& outfile) const
{
outfile << "[" << opsDef.GetOpType().GetString() << "]" << std::endl;
std::vector<OpParamDef> inputs = opsDef.GetMergeInputs(aicoreConfig);
std::vector<OpParamDef> outputs = opsDef.GetMergeOutputs(aicoreConfig);
GenParamInfo(outfile, inputs, false);
GenParamInfo(outfile, outputs, true);
GenAttrInfo(outfile, opsDef.GetAttrs());
GenExtendInfo(outfile, aicoreConfig, enableFallBack);
GenImplFile(outfile, opType, aicoreConfig);
std::vector<ge::AscendString> mc2Grps = opsDef.MC2().GetHcclGroups();
if (mc2Grps.size() != 0) {
GenMC2Info(outfile, mc2Grps);
}
}
void CfgGenerator::GetOutFilePtr(
std::string& genPath, std::string& socVer, std::ofstream& outfile, const std::string resolvedGenPath,
std::map<std::string, std::string>& cfgFileStreams) const
{
std::string cfgFileName = genPath + "/aic-" + socVer + "-ops-info.ini";
std::string realCfgFile = std::string(resolvedGenPath) + "/aic-" + socVer + "-ops-info.ini";
auto iter = cfgFileStreams.find(cfgFileName);
if (iter == cfgFileStreams.end()) {
cfgFileStreams.emplace(cfgFileName, cfgFileName);
outfile.open(realCfgFile, std::ofstream::out | std::ofstream::trunc);
} else {
outfile.open(realCfgFile, std::ofstream::out | std::ofstream::app);
}
}
void CfgGenerator::GenAllOpCfgWithoutComputeUint(
const std::vector<std::string>& ops, std::string& genPath, const std::string resolvedGenPath,
std::map<std::string, std::string>& cfgFileStreams) const
{
for (auto op : ops) {
OpDef opsDef = OpDefFactory::OpDefCreate(op.c_str());
bool enableFallBack = opsDef.IsEnableFallBack();
for (auto& aicoreItem : opsDef.AICore().GetAICoreConfigs()) {
std::string socVer = aicoreItem.first.GetString();
OpAICoreConfig aicoreConfig = aicoreItem.second;
std::ofstream outfile;
GetOutFilePtr(genPath, socVer, outfile, resolvedGenPath, cfgFileStreams);
ParseSingleComputeUnitOfOp(opsDef, op, aicoreConfig, enableFallBack, outfile);
outfile.close();
}
}
}
opbuild::Status CfgGenerator::GenerateCode(void)
{
std::string genPath;
ASCENDLOGI("Cfg GenerateCode called!");
Generator::GetGenPath(genPath);
char resolvedGenPath[PATH_MAX] = {0};
if (realpath(genPath.c_str(), resolvedGenPath) == nullptr) {
ASCENDLOGE("Generate Path %s is invalid!", genPath.c_str());
std::cerr << "Path: " << genPath << " is not valid!" << std::endl;
return opbuild::OPBUILD_FAILED;
}
std::vector<std::string> ops = this->GetAllOp();
std::map<std::string, std::string> cfgFileStreams;
auto computeUnitCfg = opbuild::Params::GetInstance().Optional("compute_unit");
ASCENDLOGI("Compute unit cfg is %s ", computeUnitCfg.c_str());
if (computeUnitCfg.size() == 0) {
GenAllOpCfgWithoutComputeUint(ops, genPath, std::string(resolvedGenPath), cfgFileStreams);
ASCENDLOGI("Cfg GenerateCode end!");
return opbuild::OPBUILD_SUCCESS;
}
std::vector<std::string> computeUnits;
Split(computeUnitCfg, ';', computeUnits);
for (auto op : ops) {
OpDef opsDef = OpDefFactory::OpDefCreate(op.c_str());
bool enableFallBack = opsDef.IsEnableFallBack();
auto allAICoreConfig = opsDef.AICore().GetAICoreConfigs();
for (uint32_t i = 0; i < computeUnits.size(); ++i) {
auto socVer = computeUnits[i];
if (VALID_SOC_SET.find(socVer) == VALID_SOC_SET.end()) {
ASCENDLOGW("Invlid soc version %s\n of ASCEND_COMPUTE_UNIT", socVer.c_str());
continue;
}
ge::AscendString curCompUnit(socVer.c_str());
auto cfgIter = allAICoreConfig.find(curCompUnit);
OpAICoreConfig aicoreConfig;
if (cfgIter == allAICoreConfig.end()) {
aicoreConfig = OpAICoreConfig(socVer.c_str());
} else {
aicoreConfig = cfgIter->second;
}
std::ofstream outfile;
GetOutFilePtr(genPath, socVer, outfile, std::string(resolvedGenPath), cfgFileStreams);
ParseSingleComputeUnitOfOp(opsDef, op, aicoreConfig, enableFallBack, outfile);
outfile.close();
}
}
ASCENDLOGI("Cfg Generator of compute unit config %s end", computeUnitCfg.c_str());
return opbuild::OPBUILD_SUCCESS;
}
CfgGenerator::CfgGenerator(std::vector<std::string>& ops) : Generator(ops) { ASCENDLOGI("Stub Generator construct!"); }
static opbuild::Status CfgGeneratorBuilder(std::vector<std::string>& ops)
{
CfgGenerator g(ops);
return g.GenerateCode();
}
static void AddCfgGenerator(void) __attribute__((constructor));
void AddCfgGenerator(void) { GeneratorFactory::AddBuilder("cfg", CfgGeneratorBuilder); }
}