* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file template_argument.cpp
* \brief
*/
#include "utils/tiling/template_argument.h"
#include <cmath>
#include <stdexcept>
#include <iostream>
#include <csignal>
#include <vector>
#include <string>
#include <algorithm>
#include <array>
#include <unordered_set>
#include <unordered_map>
using namespace std;
namespace {
std::string Vec2Str(const std::vector<uint64_t> &vec)
{
std::string resStr = "[";
for (const auto &i : vec) {
resStr += std::to_string(i) + ", ";
}
resStr += "]";
return resStr;
}
}
namespace AscendC {
const std::array<const char*, 7> TPL_TYPE_STR_ARR = {
"DTYPE", "FORMAT", "UINT", "BOOL", "KERNEL_TYPE", "DETERMINISTIC", "SHARED_KERNEL_TYPE"
};
static bool CheckParamStructValid(ParamStruct ¶mStruct)
{
if (paramStruct.paramType >= TPL_TYPE_STR_ARR.size()) {
printf("[ERROR] ASCENDC_TPL_*_%s: %s type value is invalid! Type value should be in [0, 1, 2, 3, 4, 5, 6]\n",
paramStruct.macroType,
paramStruct.name);
return false;
}
if (paramStruct.vals.empty()) {
printf("[ERROR] Values of ASCENDC_TPL_%s_%s: %s is empty!\n",
TPL_TYPE_STR_ARR[paramStruct.paramType], paramStruct.macroType, paramStruct.name);
return false;
}
std::unordered_set<uint64_t> valueSet(paramStruct.vals.begin(), paramStruct.vals.end());
if (paramStruct.vals.size() != valueSet.size()) {
printf("[ERROR] Existing duplicate values in ASCENDC_TPL_%s_%s: %s! Duplicate values: %s\n", TPL_TYPE_STR_ARR[paramStruct.paramType],
paramStruct.macroType, paramStruct.name, Vec2Str(paramStruct.vals).c_str());
return false;
}
return true;
}
static bool ParseTplUintValue(ParamStruct ¶mStruct)
{
if (paramStruct.paramType != ASCENDC_TPL_UINT) {
return CheckParamStructValid(paramStruct);
}
if (paramStruct.vals.size() < 2U) {
printf("[ERROR] Value length of ASCENDC_TPL_UINT_%s: %s is less than 2, parse value failed.\n",
paramStruct.macroType, paramStruct.name);
return false;
}
uint8_t uiFlag = static_cast<uint8_t>(paramStruct.vals[0]);
if (uiFlag == ASCENDC_TPL_UI_LIST) {
paramStruct.vals.erase(paramStruct.vals.begin());
return CheckParamStructValid(paramStruct);
}
if (uiFlag != ASCENDC_TPL_UI_RANGE && uiFlag != ASCENDC_TPL_UI_MIX) {
printf("[ERROR] ASCENDC_TPL_UINT_%s: %s UI flag is invalid,"
"it should be in [UI_RANGE, UI_LIST, UI_MIX]!\n",
paramStruct.macroType, paramStruct.name);
return false;
}
std::vector<uint64_t> extendVal;
size_t rangeNum = paramStruct.vals[1];
if (rangeNum * VAL_PAIR > paramStruct.vals.size() - VAL_START) {
printf("[ERROR] ASCENDC_TPL_UINT_%s: %s range declare exceed actual length!\n", paramStruct.macroType,
paramStruct.name);
return false;
}
size_t elementNum = 0;
for (size_t i = 1; i <= rangeNum; i++) {
elementNum += paramStruct.vals[i * VAL_PAIR + 1] - paramStruct.vals[VAL_PAIR * i] + 1;
}
if (uiFlag == ASCENDC_TPL_UI_MIX) {
elementNum += paramStruct.vals.size() - (VAL_START + rangeNum * VAL_PAIR);
}
extendVal.reserve(elementNum);
for (size_t i = 1; i <= rangeNum; i++) {
for (size_t j = paramStruct.vals[VAL_PAIR * i]; j <= paramStruct.vals[i * VAL_PAIR + 1]; j++) {
extendVal.emplace_back(j);
}
}
if (uiFlag == ASCENDC_TPL_UI_MIX) {
size_t mixStart = VAL_START + rangeNum * VAL_PAIR;
std::copy(paramStruct.vals.begin() + mixStart, paramStruct.vals.end(), std::back_inserter(extendVal));
}
paramStruct.vals = std::move(extendVal);
return CheckParamStructValid(paramStruct);
}
static bool CheckSelectParamValid(const TilingDeclareParams &declareParams, const ParamStruct &selectParam)
{
for (const auto &declareParam : declareParams) {
if (declareParam.name == selectParam.name) {
auto declareType = declareParam.paramType;
auto declareVals = declareParam.vals;
if (declareType != selectParam.paramType) {
printf("[ERROR] ASCENDC_TPL_%s_SEL: %s has different macro type!\n", TPL_TYPE_STR_ARR[declareParam.paramType], selectParam.name);
return false;
}
for (auto val : selectParam.vals) {
if (std::find(declareVals.begin(), declareVals.end(), val) == declareVals.cend()) {
printf("[ERROR] ASCENDC_TPL_%s_SEL %s value %lu does not exist in ASCENDC_TPL_%s_DECL, "
"please check it!\n", TPL_TYPE_STR_ARR[declareParam.paramType], selectParam.name, val, TPL_TYPE_STR_ARR[declareParam.paramType]);
return false;
}
}
return true;
}
}
return false;
}
static bool CheckSelectParamsValid(const TilingDeclareParams &declareParams, const TilingSelectParams &selectParams)
{
for (const auto &selectParamVec : selectParams) {
for (const auto &selectParam : selectParamVec) {
if(selectParam.paramType == ASCENDC_TPL_KERNEL_TYPE || selectParam.paramType == ASCENDC_TPL_DETERMINISTIC) {
continue;
}
if (!CheckSelectParamValid(declareParams, selectParam)) {
return false;
}
}
}
return true;
}
static bool CheckInputTilingParamValid(const std::vector<uint64_t> &tilingParams, TilingSelectParams &selectParamsList,
const TilingDeclareParams &declareParams)
{
if (tilingParams.size() != declareParams.size()) {
printf("[ERROR] Number of arguments in GET_TPL_TILING_KEY is not the same as ASCENDC_TPL_ARGS_DECL.\n");
return false;
}
std::unordered_map<std::string, size_t> nameOrder;
for (size_t i = 0; i < declareParams.size(); ++i) {
nameOrder[declareParams[i].name] = i;
}
for (auto &selectParams : selectParamsList) {
auto newEnd = std::remove_if(selectParams.begin(), selectParams.end(),
[](const ParamStruct ¶m) {
return param.paramType == ASCENDC_TPL_KERNEL_TYPE || param.paramType == ASCENDC_TPL_DETERMINISTIC;
});
selectParams.erase(newEnd, selectParams.end());
std::sort(selectParams.begin(), selectParams.end(), [&nameOrder](const ParamStruct &a, const ParamStruct &b) {
auto itA = nameOrder.find(a.name);
auto itB = nameOrder.find(b.name);
if (itA == nameOrder.cend() || itB == nameOrder.cend()) {
return false;
}
return itA->second < itB->second;
});
bool matchFlag = true;
for (std::size_t i = 0;i < tilingParams.size();i++){
if(std::find(selectParams[i].vals.begin(), selectParams[i].vals.end(), tilingParams[i]) == selectParams[i].vals.cend()){
matchFlag = false;
break;
}
}
if(matchFlag){
return true;
}
}
return false;
}
static bool CheckParamValid(const std::vector<uint64_t> &tilingParams, TilingSelectParams &selectParamsVec,
const TilingDeclareParams &declareParams)
{
if (!CheckSelectParamsValid(declareParams, selectParamsVec)) {
printf("[ERROR] ASCEND_TPL_SEL format is incorrect, please check!\n");
return false;
}
if (!CheckInputTilingParamValid(tilingParams, selectParamsVec, declareParams)) {
printf("[ERROR] GET_TPL_TILING_KEY args:%s cannot be found in ASCENDC_TPL_ARGS_SEL.\n",
Vec2Str(tilingParams).c_str());
return false;
}
return true;
}
static uint64_t EncodeParam(uint64_t val, const ParamStruct ¶mStruct)
{
if (paramStruct.paramType == ASCENDC_TPL_DTYPE || paramStruct.paramType == ASCENDC_TPL_FORMAT ||
paramStruct.paramType == ASCENDC_TPL_BOOL || paramStruct.paramType == ASCENDC_TPL_SHARED_KERNEL_TYPE) {
return static_cast<uint64_t>(val);
} else if (paramStruct.paramType == ASCENDC_TPL_UINT) {
auto iter = std::find(paramStruct.vals.cbegin(), paramStruct.vals.cend(), val);
uint64_t index = iter - paramStruct.vals.cbegin();
return index;
} else {
return INVALID_TILING_KEY;
}
}
uint64_t EncodeTilingKey(TilingDeclareParams declareParams, TilingSelectParams selectParamsVec,
std::vector<uint64_t> tilingParams)
{
if (declareParams.empty() || selectParamsVec.empty()) {
printf("[ERROR] ASCENDC_TPL params is empty!\n");
return INVALID_TILING_KEY;
}
std::unordered_set<std::string> declareParamNameSet;
for (auto &declareParam : declareParams) {
declareParamNameSet.insert(declareParam.name);
if (!ParseTplUintValue(declareParam)) {
printf("[ERROR] ASCENDC_TPL_DECL:%s parses value failed!\n", declareParam.name);
return INVALID_TILING_KEY;
}
}
for (auto &selectParams : selectParamsVec) {
std::unordered_set<std::string> selectParamNameSet{};
for (auto &selectParam : selectParams) {
if(selectParam.paramType == ASCENDC_TPL_KERNEL_TYPE || selectParam.paramType == ASCENDC_TPL_DETERMINISTIC) {
continue;
}
std::string name = selectParam.name;
selectParamNameSet.insert(name);
if (declareParamNameSet.count(name) == 0 || !ParseTplUintValue(selectParam)) {
printf("[ERROR] ASCENDC_TPL_SEL : %s parses value failed!"
" Please check macro define name or values.\n", name.c_str());
return INVALID_TILING_KEY;
}
}
if (declareParamNameSet.size() != selectParamNameSet.size()) {
printf("[ERROR] ASCENDC_TPL_SEL : The parameters in ASCENDC_TPL_ARGS_SEL "
"do not fully include those in ASCENDC_TPL_ARGS_DECL.\n");
return INVALID_TILING_KEY;
}
}
if (!CheckParamValid(tilingParams, selectParamsVec, declareParams)) {
return INVALID_TILING_KEY;
}
uint8_t totalBits = 0;
uint64_t tilingKey = 0;
for (size_t i = 0; i < declareParams.size(); ++i) {
auto bitWidth = declareParams.at(i).bitWidth;
auto paramBit = EncodeParam(tilingParams.at(i), declareParams.at(i));
if (paramBit == INVALID_TILING_KEY) {
return INVALID_TILING_KEY;
}
tilingKey |= (paramBit << totalBits);
totalBits += bitWidth;
}
if (totalBits > MAX_BITS_NUM) {
printf("[ERROR] Tiling Key exceeds maximum 64 bit, please adjust ASCENDC_TPL_ARGS_DECL bitWidth"
"accordingly\n");
return INVALID_TILING_KEY;
}
return tilingKey;
}
}