* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ATT_TILING_CODE_GEN_IMPL_H_
#define ATT_TILING_CODE_GEN_IMPL_H_
#include <string>
#include <set>
#include <memory>
#include <utility>
#include "code_printer.h"
#include "base/model_info.h"
#include "generator_config.h"
#include "tiling_data_gen/tiling_data_generator.h"
#include "extra_info_gen/extra_info_generator.h"
#include "util/duration.h"
#include "gen_model_info/api_tiling_gen/gen_api_tiling.h"
#include "cache/operator_level_cache_gen.h"
#include "cache/group_level_cache_gen.h"
namespace att {
struct GenTilingTailImplExtParams {
std::unordered_map<std::string, std::string> cache_reuse_info = {};
VarRelations var_relations = {};
EnableGroupParallels enable_group_parallels = {};
TensorIdSet workspace_tensor_id_set = {};
GenTilingTailImplExtParams() = default;
GenTilingTailImplExtParams(
std::unordered_map<std::string, std::string> cache_reuse_info_,
VarRelations var_relations_,
EnableGroupParallels enable_group_parallels_,
TensorIdSet workspace_tensor_id_set_
) : cache_reuse_info(std::move(cache_reuse_info_)),
var_relations(std::move(var_relations_)),
enable_group_parallels(std::move(enable_group_parallels_)),
workspace_tensor_id_set(std::move(workspace_tensor_id_set_)) {}
};
class TilingCodeGenImpl {
using AscGraphNamepspaceMap = std::map<size_t, std::map<size_t, std::pair<std::string, std::string>>>;
using FusedGraphNamespaceMap = std::map<size_t, AscGraphNamepspaceMap>;
public:
TilingCodeGenImpl(const std::string &op_name, const TilingCodeGenConfig &config,
const TilingModelInfo &tiling_model_info, const ScoreFuncs &score_funcs, const bool is_uniq_group);
virtual ~TilingCodeGenImpl() = default;
ge::Status GenTilingHead(std::map<std::string, std::string> &tiling_res,
const EnableGroupParallels &enable_group_parallels = {});
ge::Status GenTilingTail(std::map<std::string, std::string> &tiling_res,
GenTilingTailImplExtParams ext_params = {});
ge::Status GenTiling(std::map<std::string, std::string> &tiling_res,
std::unordered_map<std::string, std::string> cache_reuse_info = {},
uint32_t cache_capacity = 0,
const EnableGroupParallels &enable_group_parallels = {});
void SetScheduleResultGroupNums(const std::map<std::pair<size_t, size_t>, size_t> &group_nums) {
schedule_result_group_nums_ = group_nums;
}
uint32_t GetGroupNumForCurrentScheduleResult(const std::pair<size_t, size_t> &schedule_result_key) const;
protected:
ge::Status CheckImplPtr(const std::string &indent);
ge::Status GetReuseVarNames(std::map<std::string, std::string> &var_names_to_reuse_var_name);
ge::Status GenStructCopyDef();
ge::Status GenCacheHashMapDef();
ge::Status GenDurationBeginCode(const TilingFuncDurationType type, const std::string &indent);
ge::Status GenDurationEndCode(const TilingFuncDurationType type, const std::string &indent);
ge::Status ObtainInnerParams(std::map<std::string, std::set<std::string>> &hardware_map,
FusedGraphNamespaceMap &namespace_map);
ge::Status GenGetTilingForAllInitLines(bool pgo = false);
ge::Status GenGetResultSummary(const size_t asc_graph_id);
ge::Status GenGetTilingForScheduleResult();
ge::Status GenGetTilingForAllSchedulesResults(const uint32_t asc_graph_id,
const AscGraphNamepspaceMap &asc_graph_map);
ge::Status GenFusedScheduleResultsGetTilingDefine(const FusedGraphNamespaceMap &namespace_map);
ge::Status GenEnableGroupParallelFunctions(const FusedGraphNamespaceMap &namespace_map);
ge::Status GenEnableGroupParallelInvoke(size_t asc_graph_id, const AscGraphNamepspaceMap &asc_graph_namespace_map);
ge::Status GenEnableGroupParallelPgoInvoke(const std::string &tiling_name, bool is_pointer,
const std::string &indent, std::string &invoke_code);
ge::Status GenPGOFusedScheduleResultsGetTilingDefine(const FusedGraphNamespaceMap &namespace_map);
ge::Status GenPGOByCoreNumFusedScheduleResultsGetTilingDefine(const FusedGraphNamespaceMap &namespace_map);
ge::Status GenPGOByCoreNumSearchTilingKeyCollectTilingData(FusedGraphNamespaceMap namespace_map);
void GenGetScoreFuncs(const size_t asc_graph_id, const AscGraphNamepspaceMap &namespace_map);
ge::Status GenPGOGetTilingForAll();
void GenGetScoreFuncsCalling(const size_t asc_graph_id, const AscGraphNamepspaceMap &namespace_map);
void GenCacheInit();
void GenSetHardwareCodes(const std::string &group_prefix, const std::set<std::string> &hardware_names);
ge::Status GenScheduleGroupDoTiling(std::string &check_cond, const std::string &hardware_param,
const std::string &schedule_result_prefix);
void GenGetScheduleResultTail(const std::map<size_t, std::pair<std::string, std::string>> &graph_info);
void GenUpdateWorkspace(const size_t asc_graph_id, const size_t impl_graph_id);
ge::Status GenDoGroupTilingFunction(
const size_t asc_graph_id,
const size_t impl_graph_id,
const std::map<size_t, std::pair<std::string, std::string>> &graph_info);
ge::Status GenGetScheduleResultPerfAndTail(
const size_t asc_graph_id,
const size_t impl_graph_id,
const std::map<size_t, std::pair<std::string, std::string>> &graph_info);
ge::Status GenGetScheduleResult(
const size_t asc_graph_id, const size_t impl_graph_id,
const std::map<size_t, std::pair<std::string, std::string>> &graph_info,
const std::map<std::string, std::set<std::string>> &hardware_map);
ge::Status GenUpdatePerf(const size_t asc_graph_id, const size_t impl_graph_id,
const std::vector<std::string> &groups_perf,
const std::vector<std::string> &groups_block_num,
const std::vector<std::string> &assign_max_block_num);
void GenGetMaxScoreIndex(const AscGraphNamepspaceMap &namespace_map);
void GenScheduleResultGetTilingCalling(const std::string &index, const std::string &ident = "");
ge::Status GenGetAllSchedulesResults(const AscGraphNamepspaceMap &namespace_map);
void GenPGOUpdateTilingInfo(const size_t asc_graph_id, const size_t impl_graph_id);
void GenFillOtherGroupsGetTiling(const size_t asc_graph_id, const size_t impl_graph_id,
const std::map<size_t, std::pair<std::string, std::string>> &graph_info,
const std::pair<size_t, std::pair<std::string, std::string>> &group_info,
const std::map<std::string, std::set<std::string>> &hardware_map);
ge::Status GenPGOGetScheduleResultPerGroup(const size_t asc_graph_id, const size_t impl_graph_id,
const std::map<size_t, std::pair<std::string, std::string>> &graph_info,
const std::pair<size_t, std::pair<std::string, std::string>> &group_info,
const std::map<std::string, std::set<std::string>> &hardware_map);
ge::Status GenPGOScheduleGroupSearchEntry(const size_t asc_graph_id, const size_t impl_graph_id,
const std::map<size_t, std::pair<std::string, std::string>> &graph_info,
const std::map<std::string, std::set<std::string>> &hardware_map,
const std::pair<size_t, std::pair<std::string, std::string>> &group_info,
const std::string &result_name);
ge::Status GenPGOGetScheduleResult(const size_t asc_graph_id, const size_t impl_graph_id,
const std::map<size_t, std::pair<std::string, std::string>> &graph_info,
const std::map<std::string, std::set<std::string>> &hardware_map);
void GenPGOGetAllSchedulesResults(const size_t asc_graph_id, const AscGraphNamepspaceMap &namespace_map);
void GenPGOByCoreNumGetScheduleResult(const size_t asc_graph_id, const size_t impl_graph_id,
const std::map<size_t, std::pair<std::string, std::string>> &graph_info,
const std::map<std::string, std::set<std::string>> &hardware_map,
const std::map<size_t, std::map<size_t, std::map<std::string, ge::Expression>>> &var_relation);
std::string GenLaunchLikeInputOutputDef(bool is_define = true);
std::string GenInputOutputVoidCasts();
void GenPGOMultiGroupBlockDimList(const FusedGraphNamespaceMap &namespace_map, std::string &block_dim_list_arg);
ge::Status GenCastReuseTilingDataCode(const ReuseScheduleGroupInfo &reuse_info, const ReuseScheduleGroupInfo &info);
bool IsScheduleResultEnableParallel(const size_t asc_graph_id, const size_t impl_graph_id) const;
bool GenUpdateCurPerfAndBlockByGroupIfNeeded(const size_t asc_graph_id, const AscGraphNamepspaceMap &asc_graph_map) const;
bool HitSmallShapePattern(ArgsManager &args_manager) const;
ge::Status GenGetTilingWithCaseId(bool is_tail = false);
std::string GetGetTilingParamDefines(bool use_cache, bool use_workspace, std::string &cache_define_head,
std::string &cache_define_func, std::string &cache_used) const;
void GenGetTilingFunctionSignature(const std::string &workspace_define, const std::string &cache_define_func,
const std::string &cache_define_head);
ge::Status GenGetTilingFunctionBody(bool use_cache, bool is_tail, const std::string &cache_used);
ge::Status GenGetTilingKeyCall(const std::string &cache_used);
ge::Status GenDurationCode(bool is_begin);
ge::Status GenOperatorCacheSaveCode(bool need_operator_cache);
ge::Status ValidateForceTilingCase(const std::map<string, int32_t> &group_tiling_case_ids,
int32_t min_tiling_case_size) const;
ge::Status ValidateSingleModeForceTilingCase(int32_t min_tiling_case_size) const;
ge::Status ValidateGroupModeForceTilingCase(const std::map<string, int32_t> &group_tiling_case_ids) const;
ge::Status GenHardwareSummary(const ModelInfo &model_info);
ge::Status GenHardwareJudge(const ModelInfo &model_info);
ge::Status GenInputSummary(const ModelInfo &model_info);
ge::Status GenCalcScore(const ModelInfo &model_info);
void GenCalcScoreVarsDefine();
ge::Status GenAllSameScoreTilingCases(std::map<std::string, std::vector<const ModelInfo *>> &same_args_name_to_graphs,
const std::vector<std::string> &ordered_assemble_args_name);
void InitTilingUpperBound(const std::vector<Expr> &hardware_args, const ArgsManager &args_manager,
const HardwareDef &hardware_def, std::map<std::string, bool> &visited);
ge::Status GenSmallShapeTiling(const ModelInfo &model_info);
virtual ge::Status GenSolverBaseClass() = 0;
virtual ge::Status GenSolverTiling(const ModelInfo &model_info) = 0;
virtual ge::Status GenDoTilingCommon(const ModelInfo &model_info, const std::pair<std::string, std::string> &codes);
virtual ge::Status GenDoTiling(const ModelInfo &model_info) = 0;
virtual ge::Status GenGetTilingDataFromCopy();
virtual ge::Status GenFindCacheAndSaveCache();
virtual ge::Status GenUpdateBetterTiling();
virtual ge::Status GenSelectBetterTilingBasedOnObjAndUbRatio();
virtual ge::Status GenFindPerfBetterTilingbyCaseId(bool enable_group_parallel_optimize = false,
bool add_core_num_param = false, uint32_t group_num = 1);
virtual ge::Status GenSearchAllTilingbyCaseId();
virtual ge::Status GenGetTilingKey();
virtual ge::Status GenPGOSearchTilingKey();
void GenPGOSearchTilingKeyUniqGroupBatch();
virtual ge::Status ValidateSingleResultAndGroup();
virtual ge::Status GenGetTilingbyCaseId();
virtual ge::Status GenPGODefaultTiling();
virtual ge::Status GenPGOTilingCase(const ModelInfo& model_info);
virtual ge::Status GenPGOGetTilingbyCaseId();
virtual ge::Status GenerateInputParamsAndTiling();
virtual ge::Status GenPGOByCoreNumSearchTilingKeySingleGroup();
virtual ge::Status GenPGOByCoreNumSearchTilingKey();
virtual ge::Status GenPGOByCoreNumTilingForAll();
void GenPGOByCoreNumDoTiling(const std::pair<size_t, std::pair<std::string, std::string>> &group_info,
const uint32_t group_index, const size_t asc_graph_id, const size_t impl_graph_id);
void GenPGOByCoreNumGetAllSchedulesResults(const size_t asc_graph_id, const AscGraphNamepspaceMap &namespace_map);
virtual ge::Status GenExtraParamCode(const ModelInfo &model_info, std::string &pass_code);
virtual ge::Status GenGetSetTilingImpl(const ModelInfo &model_info);
virtual ge::Status GenExternFuncDef();
virtual ge::Status GenMacroInclude();
virtual ge::Status GenToolFuncs();
virtual ge::Status GenTilingImplPublicFunc();
ge::Status GenVirtualDataTransferFuncs();
virtual ge::Status GenTilingCaseImpl(const ModelInfo &model_info);
void GenInductorExecutePGOSolver(const ModelInfo &model_info);
virtual ge::Status GenPreTiling(const ModelInfo &model_info);
virtual ge::Status GenDoApiTiling(const ModelInfo &model_info);
virtual ge::Status GenExtraEvalFunc(const ModelInfo &model_info);
virtual ge::Status GenExtraTilingData(const ModelInfo &model_info);
virtual ge::Status GenExtraSummaryInfo(const ModelInfo &model_info, const ArgsManager &args_manager, std::string &case_info_str);
virtual ge::Status GenPipeTypeObj(const ModelInfo &model_info);
virtual ge::Status GenMemoryParamCode(const ModelInfo &model_info);
virtual ge::Status GenExtraTilingFuncImpl(const ModelInfo &model_info);
virtual ge::Status GenExtraTilingFuncInvoke(const ModelInfo &model_info);
virtual ge::Status GenHardwareCons(const ModelInfo &model_info);
virtual ge::Status GenGetObj(const ModelInfo &model_info);
void GenArrangeBlockOffsetsDeclarations(const FusedGraphNamespaceMap &namespace_map);
void GenDoGroupTilingGetTilingCalls(const std::map<size_t, std::pair<std::string, std::string>> &graph_info);
void GenDoGroupTilingFailureHandler(const std::map<size_t, std::pair<std::string, std::string>> &graph_info);
void GenGroupParallelFirstTiling(const size_t impl_graph_id);
void GenGroupParallelSecondTiling(const size_t impl_graph_id,
const std::map<size_t, std::pair<std::string, std::string>> &graph_info);
void GenGroupParallelFirstTilingDecls(const std::map<size_t, std::pair<std::string, std::string>> &graph_info);
ge::Status GenSingleGroupScheduleResult(
const size_t asc_graph_id, const size_t impl_graph_id,
const std::map<size_t, std::pair<std::string, std::string>> &graph_info,
const std::map<std::string, std::set<std::string>> &hardware_map);
static std::string GenPerfUpdateCode(const std::vector<std::string> &groups_perf,
const std::vector<std::string> &groups_block_num,
const std::string &indent);
void GenBestPerfUpdateCode(const size_t asc_graph_id, const size_t impl_graph_id,
const std::vector<std::string> &assign_max_block_num,
const std::string &indent);
ge::Status GenConflictGroupHelpers(const size_t asc_graph_id,
const size_t impl_graph_id,
const std::map<size_t, std::pair<std::string, std::string>> &graph_info);
ge::Status GenConflictGroupHelper(const ModelInfo &model_info,
const std::string &group_item_prefix);
std::string GenConflictGroupInvoke(const size_t asc_graph_id,
const size_t impl_graph_id,
size_t group_id,
const std::string &group_item_prefix) const;
std::pair<std::string, bool> GenConflictExprContextCode(const ModelInfo &model_info,
const ge::Expression &expr,
std::set<std::string> &declared_symbols) const;
static std::string GenMixedPerfUpdateCode(const std::vector<std::string> &groups_perf,
const std::vector<std::string> &groups_block_num,
const std::vector<std::string> &groups_conflict_flags,
const std::string &indent);
ge::CodePrinter tiling_data_;
ge::CodePrinter tiling_func_;
ge::CodePrinter tiling_head_;
std::string op_name_;
TilingCodeGenConfig config_;
ExtraInfoConfig extra_info_config_;
TilingDataGenerator tiling_data_manager_;
ExtraInfoGenerator extra_info_generator_;
const TilingModelInfo &tiling_model_info_;
bool is_uniq_group_{true};
bool hardware_has_ub_{false};
ScoreFuncs score_funcs_;
std::unordered_map<std::string, std::string> cache_reuse_info_{};
VarRelations var_relations_{};
EnableGroupParallels enable_group_parallels_{};
TensorIdSet workspace_tensor_id_set_{};
uint32_t cache_capacity_{0};
bool with_reuse_info_{false};
std::string arrange_code_;
std::map<std::pair<size_t, size_t>, size_t> schedule_result_group_nums_;
std::unique_ptr<cache::OperatorLevelCacheGen> operator_level_cache_gen_;
std::unique_ptr<cache::GroupLevelCacheGen> group_level_cache_gen_;
private:
static bool IsConflictCacheLineConfig(const CacheLineConfig &cfg);
ge::Status GenExpressionMacro();
ge::Status GetRelatedHardware(std::map<std::string, std::string> &hardware_info);
ge::Status GenDurationCommonCode();
ge::Status GenDurationPrintCode(const std::string &indent);
ge::Status GenDurationClearCode(const std::string &indent);
static std::string GenPerformanceAdjustmentCode(bool enable_group_parallel_optimize, bool add_core_num_param,
uint32_t group_num, bool is_uniq_group);
static std::string GenLogOutputCodeWithUb(const bool is_uniq_group);
ge::Status GenFindPerfBetterTilingbyCaseIdWithUb(bool enable_group_parallel_optimize, bool add_core_num_param,
uint32_t group_num, bool is_uniq_group);
ge::Status GenFindPerfBetterTilingbyCaseIdWithoutUb(bool enable_group_parallel_optimize, bool add_core_num_param,
uint32_t group_num);
ge::Status GenProtectedVars();
ge::Status GenBaseTilingData(std::map<std::string, std::string> &type_name_to_definition);
ge::Status GenHeaderCodesHead();
ge::Status GenHeaderCodesTail();
ge::Status GenHeaderCodesBody();
ge::Status GenHeaderCodesSummaryBody();
ge::Status GenHeaderInclude();
ge::Status GenHeaderVarsDef();
ge::Status GenScheduleGroupTilingHead();
ge::Status GenScheduleGroupTilingTail();
ge::Status GenGetTiling();
ge::Status GenTilingImplBaseClass();
ge::Status GenCommonFrameWork();
ge::Status GenCommonStruct();
ge::Status GenEvalFunc(const ModelInfo &model_info);
ge::Status GenTilingSummary(const ModelInfo &model_info);
ge::Status GenPostTiling(const ModelInfo &model_info);
ge::Status GenImplPtr();
ge::Status GenGetPerf();
ge::Status GenGetSummary();
ge::Status GenReuseGroupTilingWrapperGetTiling(
const std::string &cur_prefix, const std::string &reuse_prefix, const ReuseScheduleGroupInfo &reuse_info,
std::map<ScheduleGroupIdent, ReuseScheduleGroupInfo>::const_iterator iter);
ge::Status GenReuseGroupTilingWrapperGetPerf(
const std::string &cur_prefix, const std::string &reuse_prefix, const ReuseScheduleGroupInfo &reuse_info,
std::map<ScheduleGroupIdent, ReuseScheduleGroupInfo>::const_iterator iter);
ge::Status GenReuseGroupTilingWrapperGetSummary(
const std::string &cur_prefix, const std::string &reuse_prefix, const ReuseScheduleGroupInfo &reuse_info,
std::map<ScheduleGroupIdent, ReuseScheduleGroupInfo>::const_iterator iter);
ge::Status GenReuseGroupTilingWrapper(std::map<std::string, std::string> &tiling_res);
ge::Status GenPGOReuseGroupTilingWrapper();
ge::Status GenTilingKeyFunc();
void GenTilingHeadMultiGroup();
size_t CollectInputVarsSize() const;
ge::Status GenGetTilingImpl();
ge::Status GenIsStaticShape();
ge::Status GenTilingFuncCallEntrance();
ge::Status GenGeneralTiling(const ModelInfo &model_info);
ge::Status GenVariableAnnotation(const ArgsManager &args_manager);
ge::Status GenGroupCacheLookupCode();
ge::Status GenTemplateIterationLogic();
ge::Status GenOpLog(const std::string &indent, const std::string &log);
ge::Status GenOpLog(const std::string &indent, const std::string &uniq_log, const std::string &sched_log);
};
using TilingCodeGenImplPtr = std::shared_ptr<TilingCodeGenImpl>;
}
#endif