* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "high_perf_tiling_code_gen_impl.h"
#include <regex>
#include "args_manager.h"
#include "solver_pass_manager.h"
#include "common/checker.h"
#include "autofuse_config/auto_fuse_config.h"
namespace att {
namespace {
constexpr ge::char_t kDefaultConfigMaxIterHeader[] = "cfg_iterations = ";
constexpr ge::char_t kDefaultConfigMaxIterValue[] = "100";
}
ge::Status HighPerfTilingCodeGenImpl::GenExternFuncDef() {
GE_ASSERT_SUCCESS(TilingCodeGenImpl::GenExternFuncDef(), "Generate extern func definition failed.");
return ge::SUCCESS;
}
ge::Status HighPerfTilingCodeGenImpl::GenTilingImplPublicFunc() {
GE_ASSERT_SUCCESS(TilingCodeGenImpl::GenTilingImplPublicFunc(), "Generate tiling public func failed.");
GE_ASSERT_SUCCESS(GenVirtualDataTransferFuncs(), "Generate virtual data transfer funcs failed.");
return ge::SUCCESS;
}
ge::Status HighPerfTilingCodeGenImpl::GenToolFuncs() {
GE_ASSERT_SUCCESS(TilingCodeGenImpl::GenToolFuncs(), "Generate tool funcs.");
GE_ASSERT_SUCCESS(TilingCodeGenImpl::GenStructCopyDef(), "Generate struct copy.");
GE_ASSERT_SUCCESS(TilingCodeGenImpl::GenCacheHashMapDef(), "Generate cache hash map.");
return ge::SUCCESS;
}
ge::Status HighPerfTilingCodeGenImpl::GenSolverBaseClass() {
std::vector<ArgsManager> total_models;
for (const auto &model_info_iter : tiling_model_info_) {
ArgsManager args_manager(model_info_iter);
GE_ASSERT_TRUE(args_manager.Process(false), "Args manager process failed.");
total_models.emplace_back(args_manager);
}
std::string basic_solvers_head;
std::string basic_solvers_func;
basic_solvers_head = SolverPassManager::GenCommonBaseClassesHead(total_models);
basic_solvers_func = SolverPassManager::GenCommonBaseClassesFunc(total_models);
std::regex pattern(std::string(kDefaultConfigMaxIterHeader) + std::string(kDefaultConfigMaxIterValue));
std::string result_head = std::regex_replace(
basic_solvers_head, pattern,
kDefaultConfigMaxIterHeader + std::to_string(AutoFuseConfig::GetAttStrategyConfig().max_iter_num));
std::string result_func = std::regex_replace(
basic_solvers_func, pattern,
kDefaultConfigMaxIterHeader + std::to_string(AutoFuseConfig::GetAttStrategyConfig().max_iter_num));
tiling_head_.AddLine(result_head);
tiling_func_.AddLine(result_func);
return ge::SUCCESS;
}
ge::Status HighPerfTilingCodeGenImpl::GenSolverTiling(const ModelInfo &model_info) {
ArgsManager args_manager(model_info);
SolverPassManager solver_pass_manager(args_manager, {args_manager.GetTilingCaseId()}, config_.tiling_data_type_name);
tiling_func_.AddLine(solver_pass_manager.GenClassPass());
return ge::SUCCESS;
}
ge::Status HighPerfTilingCodeGenImpl::GenDoTiling(const ModelInfo &model_info) {
ArgsManager manager(model_info);
GE_ASSERT_TRUE(manager.Process(false), "Args manager process failed.");
SolverPassManager solver_pass_manager(manager, {manager.GetTilingCaseId()}, config_.tiling_data_type_name);
GE_ASSERT_SUCCESS(GenGetSetTilingImpl(model_info), "Gen get set tiling impl failed, group[%s], case[%u,%s].",
model_info.schedule_group_ident.GetItemPrefix().c_str(), model_info.tiling_case_id,
model_info.sub_case_tag.c_str());
return GenDoTilingCommon(model_info, solver_pass_manager.GenFuncPass());
}
}