* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gen_model_info.h"
#include <fstream>
#include "common/util/mem_utils.h"
#include "common/autofuse_config/auto_fuse_config_utils.h"
#include "graph/utils/file_utils.h"
#include "expr_gen/generate_tiling_expr.h"
#include "tiling_data_gen/tiling_data_generator.h"
#include "parser/ascend_graph_parser.h"
#include "pass/pass_mgr.h"
#include "nlohmann/json.hpp"
#include "util/base_types_printer.h"
#include "util/duration.h"
#include "api_tiling_gen/gen_api_tiling.h"
#include "base/att_const_values.h"
#include "util/thread_local_context.h"
#include "gen_model_info/expr_gen/arg_list_reorder.h"
#include "reuse_group_utils/reuse_group_utils.h"
#include "schedule_result.h"
#include "ascgraph_info_complete.h"
namespace att {
namespace {
constexpr uint32_t kPathMax = 4096U;
const std::string kFileSeperator = "/";
const std::string kModelInfoFileName = "model_info.json";
constexpr uint32_t kConstType = 1U;
constexpr uint32_t kVarType = 2U;
constexpr uint32_t kDefaultAlignValue = 1U;
const std::string kModelInfoFilePath = "./";
}
ge::Status GenerateModelInfo(const af::AscGraph &graph, ModelInfo &model_info, TuningSpacePtr &tuning_space,
const uint32_t tiling_case_id) {
GELOGI("[DFX]Begin to generate model info for graph %s of tiling case id %u", graph.GetName().c_str(), tiling_case_id);
DURATION_GUARD(DurationType::DURATION_GEN_MODEL_INFO);
model_info.tiling_case_id = tiling_case_id;
GE_ASSERT_NOTNULL(tuning_space, "Create tuning space failed.");
att::AscendGraphParser ascend_graph_parser(tuning_space);
GE_ASSERT_SUCCESS(ascend_graph_parser.GraphParser(graph), "Get tuning space failed.");
att::GenerateTilingExpr tiling_expr(tuning_space);
GE_ASSERT_SUCCESS(tiling_expr.Generate(model_info), "Get basic expr constraint failed.");
ATTConfig att_config;
std::vector<PassFunc> pass_funcs;
ATTPassMgr::Instance().GetPassList(pass_funcs);
for (auto &pass_func : pass_funcs) {
GE_ASSERT_NOTNULL(pass_func, "Get pass func failed.");
std::map<std::string, std::string> config_str;
GE_ASSERT_TRUE(pass_func(tuning_space, config_str), "Run pass failed.");
for (auto &config : config_str) {
GELOGD("Exe pass: config[%s], value[%s].", config.first.c_str(), config.second.c_str());
}
for (auto &config : config_str) {
att_config.config_names.emplace_back(config.first);
att_config.config_value.insert(config);
}
}
GELOGI("[DFX]End to generate model info for graph %s of tiling case id %u", graph.GetName().c_str(), tiling_case_id);
return ge::SUCCESS;
}
ge::Status CheckKeyValid(const std::vector<af::AscGraph> &graph_list) {
if (graph_list.size() > 0) {
bool has_set_key = (graph_list[0].GetTilingKey() >= 0);
for (auto &graph : graph_list) {
GE_ASSERT_TRUE(has_set_key == (graph.GetTilingKey() >= 0),
"If the user has set tiling_case_id for a graph, the tiling_case_id of all graphs must be set.");
}
}
return ge::SUCCESS;
}
ge::Status GenerateModelInfo(const std::vector<af::AscGraph> &graph_list, std::vector<ModelInfo> &model_info_list) {
GE_ASSERT_SUCCESS(CheckKeyValid(graph_list));
uint32_t tiling_key = 0U;
for (auto &graph : graph_list) {
tiling_key = graph.GetTilingKey() >= 0 ? graph.GetTilingKey() : tiling_key;
ModelInfo model_info;
TuningSpacePtr tuning_space = af::MakeShared<TuningSpace>();
GE_ASSERT_NOTNULL(tuning_space, "Make tuning space failed.");
GE_ASSERT_SUCCESS(GenerateModelInfo(graph, model_info, tuning_space, tiling_key), "General model info failed.");
model_info_list.emplace_back(model_info);
tiling_key++;
}
GE_ASSERT_TRUE(!model_info_list.empty(), "No graph need convert.");
return ge::SUCCESS;
}
void to_json(nlohmann::json &j, const SymInfoPtr &arg) {
GE_CHECK_NOTNULL_JUST_RETURN(arg);
auto type = 0U;
uint32_t const_value = 0U;
auto const_arg = std::dynamic_pointer_cast<SymConstInfo>(arg);
if (const_arg != nullptr) {
type = kConstType;
const_value = const_arg->const_value;
}
Expr align = af::Symbol(kDefaultAlignValue);
std::vector<HardwareDef> related_scope;
Expr max_value = af::sym::kSymbolZero;
auto var_arg = std::dynamic_pointer_cast<SymVarInfo>(arg);
std::pair<int64_t, int64_t> value_range;
std::string symbol_expr;
if (var_arg != nullptr) {
type = kVarType;
max_value = var_arg->max_value;
related_scope = var_arg->related_scope;
align = var_arg->align;
value_range = var_arg->value_range;
symbol_expr = (!IsValid(arg->symbol_expr)) ? "" : Str(arg->symbol_expr);
}
const std::string max_value_str = (!IsValid(max_value)) ? "null" : Str(max_value);
j = nlohmann::json{
{"type", (type == kConstType) ? "const" : "var"},
{"symbol_expr", symbol_expr},
{"const_value", const_value},
{"expr_max_value", max_value_str},
{"related_scope", related_scope},
{"align", align},
{"min_value", value_range.first},
{"max_value", value_range.second},
};
}
void to_json(nlohmann::json &j, const std::shared_ptr<AttAxis> &arg) {
GE_CHECK_NOTNULL_JUST_RETURN(arg);
std::string orig_axis;
for (auto axis : arg->orig_axis) {
orig_axis += axis->name;
orig_axis += ",";
}
std::string from_axis;
for (auto axis : arg->from_axis) {
from_axis += axis->name;
from_axis += ",";
}
j = nlohmann::json{
{"name", arg->name},
{"axis_pos", arg->axis_pos},
{"bind_multicore", arg->bind_multicore},
{"is_last", arg->is_last},
{"is_node_innerest_dim", arg->is_node_innerest_dim},
{"size", arg->size},
{"orig_axis", orig_axis},
{"from_axis", from_axis},
};
}
void to_json(nlohmann::json &j, const ATTConfig &info) {
j = nlohmann::json{
{"config_inputs", info.config_names},
{"align_config", info.config_value},
};
}
void to_json(nlohmann::json &j, const HardwareDef &type) {
j = nlohmann::json{
{HardwareType2Str.at(type)},
};
}
void to_json(nlohmann::json &j, const AxisPosition &type) {
j = nlohmann::json{
{AxisType2Str.at(type)},
};
}
void to_json(nlohmann::json &j, const PipeType &type) {
j = nlohmann::json{
{PipeType2Str.at(type)},
};
}
void to_json(nlohmann::json &j, const ModelInfo &info) {
j = nlohmann::json{
{"tiling_case_id", info.tiling_case_id},
{"hardware_cons", info.hardware_cons},
{"eq_exprs", info.eq_exprs},
{"leq_exprs", info.leq_exprs},
{"objects", info.objects},
{"arg_list", info.arg_list},
};
}
std::string GetRealPath(const std::string &path) {
if (path.empty() || (path.size() >= kPathMax)) {
GELOGW("Path is size[%zu] exception.", path.size());
return "";
}
std::string root_path = af::RealPath(path.c_str());
GE_ASSERT_TRUE(!root_path.empty(), "Invalid path: %s", path.c_str());
if (!kFileSeperator.empty() && (root_path[root_path.size() - 1U] != kFileSeperator.back())) {
root_path += kFileSeperator;
}
return root_path;
}
void DumpModelInfo(const std::vector<ModelInfo> &model_info_list, std::string dump_dir = "") {
nlohmann::json j = nlohmann::json{
{"model_info", model_info_list},
};
std::ofstream file_stream;
std::string model_file = kModelInfoFileName;
std::string output_path = dump_dir;
if (output_path.empty()) {
output_path = kModelInfoFilePath;
}
model_file = GetRealPath(output_path) + model_file;
ge::char_t realpath_file[PATH_MAX] = {0x00};
auto ret = realpath(model_file.c_str(), realpath_file);
if (ret == nullptr) {
GELOGD("Model file path [%s] unfound.", model_file.c_str());
return;
}
file_stream.open(realpath_file, std::ios::out | std::ios::trunc);
if (file_stream.is_open()) {
file_stream << j.dump();
file_stream.close();
}
}
void ProcessTilingR(std::vector<ModelInfo> &model_info_list, const ModelInfo &model_info,
std::vector<AttAxisPtr> &tiling_R_arg_list) {
if (tiling_R_arg_list.empty()) {
return;
}
ModelInfo model_info_tiling_R = model_info;
model_info_tiling_R.arg_list = tiling_R_arg_list;
model_info_tiling_R.sub_case_tag = "R";
model_info_list.emplace_back(model_info_tiling_R);
}
void ProcessGraphOriginalSizeVar(const af::AscGraph &graph, ModelInfo &model_info) {
optimize::SizeVarSet var_set;
model_info.sizes.clear();
optimize::AscGraphInfoComplete::AppendOriginalSizeVar(graph, var_set);
for (const auto &var : var_set) {
model_info.sizes.emplace_back(var);
}
}
inline bool IsAxesReorderAlgorithm() {
const auto res = AutoFuseConfig::MutableAttStrategyConfig().Init();
if (res == ge::SUCCESS) {
const auto &att_config = AutoFuseConfig::GetAttStrategyConfig();
if (att_config.set_env_tiling_algorithm) {
return att_config.tiling_algorithm == "AxesReorder";
}
}
return true;
}
ge::Status GenerateModelInfo(const std::vector<af::AscGraph> &graph_list, std::vector<ModelInfo> &model_info_list,
const std::map<std::string, std::string> &options,
bool enable_group_parallel) {
GE_ASSERT_SUCCESS(CheckKeyValid(graph_list));
uint32_t tiling_key = 0U;
for (auto &graph : graph_list) {
tiling_key = graph.GetTilingKey() >= 0 ? graph.GetTilingKey() : tiling_key;
ModelInfo model_info;
model_info.graph_name = graph.GetName();
model_info.tiling_case_id = tiling_key;
model_info.enable_group_parallel = enable_group_parallel;
std::vector<AttAxisPtr> tiling_R_arg_list;
TuningSpacePtr tuning_space = af::MakeShared<TuningSpace>();
GE_ASSERT_NOTNULL(tuning_space, "Make tuning space failed.");
tuning_space->cache_line_config = &model_info.cache_line_config;
GetThreadLocalContext().SetOption(options);
GE_ASSERT_SUCCESS(GenerateModelInfo(graph, model_info, tuning_space, tiling_key), "General model info failed.");
if (IsAxesReorderAlgorithm()) {
ArgListReorder arg_list_reorder(tuning_space);
GE_ASSERT_SUCCESS(arg_list_reorder.SortArgList(model_info.arg_list, tiling_R_arg_list), "Sort arg list failed.");
}
const auto iter = options.find(kTilingDataTypeName);
const std::string tiling_data_type = (iter != options.cend()) ? iter->second : "";
ApiTilingParams params{graph, tiling_data_type};
GE_ASSERT_SUCCESS(GetApiTilingInfo(model_info.tiling_case_id, params, model_info.node_name_to_api_code),
"Generate api tiling info failed.");
ProcessGraphOriginalSizeVar(graph, model_info);
ProcessTilingR(model_info_list, model_info, tiling_R_arg_list);
model_info_list.emplace_back(model_info);
tiling_key++;
}
GE_ASSERT_TRUE(!model_info_list.empty(), "No graph need convert.");
if (options.find(kDumpDebugInfo) != options.end()) {
DumpModelInfo(model_info_list, options.at(kDumpDebugInfo).back() == '/' ? options.at(kDumpDebugInfo)
: options.at(kDumpDebugInfo) + "/");
}
return ge::SUCCESS;
}
ge::Status MakeJsonIner(const std::vector<ModelInfo> &model_info_list, std::string &json_info) {
nlohmann::json j = nlohmann::json{
{"model_info", model_info_list},
};
json_info = j.dump();
DumpModelInfo(model_info_list);
return ge::SUCCESS;
}
ge::Status MakeJson(std::vector<ModelInfo> &model_info_list, std::string &json_info) {
GE_ASSERT_SUCCESS(MakeJsonIner(model_info_list, json_info), "Make json failed.");
return ge::SUCCESS;
}
ge::Status GetAllSubImplGraphs(const ascir::FusedScheduledResult &schedule_results,
std::vector<std::vector<std::vector<std::vector<af::AscGraph>>>> &all_graphs,
std::map<std::string, std::string> &all_graph_score_funcs) {
bool has_none_graph = true;
for (const auto &asc_graph : schedule_results.node_idx_to_scheduled_results) {
std::vector<std::vector<std::vector<af::AscGraph>>> cur_asc_graphs;
for (const auto &schedule_result : asc_graph) {
std::vector<std::vector<af::AscGraph>> schedule_result_graphs;
for (const auto &schedule_group : schedule_result.schedule_groups) {
schedule_result_graphs.emplace_back(schedule_group.impl_graphs);
all_graph_score_funcs.insert(schedule_group.graph_name_to_score_funcs.begin(),
schedule_group.graph_name_to_score_funcs.end());
if (!schedule_group.impl_graphs.empty()) {
has_none_graph = false;
}
}
cur_asc_graphs.emplace_back(schedule_result_graphs);
}
all_graphs.emplace_back(cur_asc_graphs);
}
GE_ASSERT_TRUE(!has_none_graph);
return ge::SUCCESS;
}
namespace {
ge::Status ProcessAndSetScheduleGroupInfo(
const std::vector<std::vector<af::AscGraph>> &schedule_groups,
const std::map<std::string, std::string> &all_graph_score_funcs,
const ascir::FusedScheduledResult &schedule_results,
const std::map<std::string, std::string> &options,
att::ParsedScheduleResult &out_schedule_groups,
size_t asc_graph_id, size_t impl_graph_id) {
for (size_t schedule_group_id = 0UL; schedule_group_id < schedule_groups.size(); schedule_group_id++) {
auto &model_info_list = out_schedule_groups.groups_tiling_model_info[schedule_group_id];
GELOGI(
"[DFX]Begin to gen model info for asc graph %zu, schedule result %zu, schedule group %zu, tiling_case size "
"%zu, graph name %s.",
asc_graph_id, impl_graph_id, schedule_group_id, schedule_groups[schedule_group_id].size(),
!schedule_groups[schedule_group_id].empty() ? schedule_groups[schedule_group_id][0].GetName().c_str()
: "null");
GE_ASSERT_SUCCESS(GenerateModelInfo(schedule_groups[schedule_group_id], model_info_list, options,
out_schedule_groups.enable_group_parallel),
"Get model info failed, impl graph id = %ld, group id = %ld.", impl_graph_id,
schedule_group_id);
for (auto &model_info : model_info_list) {
model_info.schedule_group_ident.asc_graph_id = asc_graph_id;
model_info.schedule_group_ident.impl_graph_id = impl_graph_id;
model_info.schedule_group_ident.group_id = schedule_group_id;
model_info.input_nodes = schedule_results.input_nodes;
model_info.output_nodes = schedule_results.output_nodes;
auto it = all_graph_score_funcs.find(model_info.graph_name);
if (it != all_graph_score_funcs.end()) {
model_info.score_func = it->second;
}
}
const auto &ident = model_info_list[0].schedule_group_ident;
GELOGI("[DFX]End to gen model info for %s tiling_case size %zu, graph name %s.", ident.GetItemPrefix().c_str(),
schedule_groups[schedule_group_id].size(),
!schedule_groups[schedule_group_id].empty() ? schedule_groups[schedule_group_id][0].GetName().c_str()
: "null");
GE_ASSERT_SUCCESS(ReuseGroupUtils::InitReuseScheduleGroup(ident, model_info_list),
"Init reuse schedule group failed, impl_graph_id[%zu], group_ident[%s].", impl_graph_id,
ident.GetItemPrefix().c_str());
}
return ge::SUCCESS;
}
}
ge::Status GetModelInfoMap(const ascir::FusedScheduledResult &schedule_results,
const std::map<std::string, std::string> &options,
FusedParsedScheduleResult &out_all_model_infos) {
std::vector<std::vector<std::vector<std::vector<af::AscGraph>>>> all_graphs_lists;
std::map<std::string, std::string> all_graph_score_funcs;
GE_ASSERT_SUCCESS(GetAllSubImplGraphs(schedule_results, all_graphs_lists, all_graph_score_funcs));
for (size_t asc_graph_id = 0UL; asc_graph_id < all_graphs_lists.size(); asc_graph_id++) {
auto &out_asc_graph_model_infos = out_all_model_infos[asc_graph_id];
auto &asc_graph_list = all_graphs_lists[asc_graph_id];
for (size_t impl_graph_id = 0UL; impl_graph_id < asc_graph_list.size(); impl_graph_id++) {
if (asc_graph_list[impl_graph_id].empty() || asc_graph_list[impl_graph_id][0].empty()) {
continue;
}
auto &schedule_groups = asc_graph_list[impl_graph_id];
auto &out_schedule_groups = out_asc_graph_model_infos[impl_graph_id];
const auto &schedule_result = schedule_results.node_idx_to_scheduled_results[asc_graph_id][impl_graph_id];
GELOGD(
"out_schedule_groups input values: score_func=%s, enable_group_parallel=%d, asc_graph_id=%zu, impl_graph_id=%zu",
schedule_result.score_func.GetString(), schedule_result.enable_group_parallel, asc_graph_id, impl_graph_id);
out_schedule_groups.score_func = schedule_result.score_func.GetString();
out_schedule_groups.enable_group_parallel = schedule_result.enable_group_parallel;
out_schedule_groups.asc_graph_id = asc_graph_id;
out_schedule_groups.impl_graph_id = impl_graph_id;
out_schedule_groups.var_relations = schedule_results.node_idx_to_scheduled_results[asc_graph_id][impl_graph_id].var_relations;
GE_ASSERT_SUCCESS(ProcessAndSetScheduleGroupInfo(schedule_groups, all_graph_score_funcs, schedule_results,
options, out_schedule_groups, asc_graph_id, impl_graph_id));
}
}
GE_ASSERT_SUCCESS(ReuseGroupUtils::MergeAllReusableGroups(all_graphs_lists, out_all_model_infos));
return ge::SUCCESS;
}
}