* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "base_template_generator.h"
namespace optimize {
std::vector<autoschedule::AutoScheduleOutput> BaseTemplateGenerator::GetBasedCasesByGenMode(
const GenerationMode mode, const std::vector<autoschedule::AutoScheduleOutput> &tiling_cases,
const std::vector<autoschedule::AutoScheduleOutput> &generated_cases) {
if (mode == GenerationMode::kBaseCase) {
return tiling_cases;
}
if (mode == GenerationMode::kAppendCase) {
std::vector<autoschedule::AutoScheduleOutput> merged;
merged.reserve(tiling_cases.size() + generated_cases.size());
merged.insert(merged.end(), tiling_cases.begin(), tiling_cases.end());
merged.insert(merged.end(), generated_cases.begin(), generated_cases.end());
return merged;
}
GELOGW("Unknown generation mode: %u", mode);
return {};
}
af::Status BaseTemplateGenerator::Generate(BaseTemplate &strategy, const af::AscGraph &origin_graph,
const std::vector<autoschedule::AutoScheduleOutput> &based_cases,
std::vector<autoschedule::AutoScheduleOutput> &generated_cases,
std::unordered_set<std::string> &drop_case_names) {
for (const auto &based_case : based_cases) {
autoschedule::AutoScheduleOutput generated_output(strategy.GenName(based_case.scheduled_graph.GetName()).c_str());
GE_ASSERT_TRUE(generated_output.scheduled_graph.CopyFrom(based_case.scheduled_graph));
generated_output.var_relations_ = based_case.var_relations_;
if (strategy.Generate(origin_graph, based_case.scheduled_graph, generated_output.scheduled_graph) != ge::SUCCESS) {
GELOGD("Generate template failed, %s.", generated_output.scheduled_graph.GetName().c_str());
continue;
}
const auto score_func = strategy.GetScoreFunc(origin_graph, generated_output.scheduled_graph);
if (!score_func.empty()) {
generated_output.score_func = score_func;
}
GELOGD("Generate template success, %s.", generated_output.scheduled_graph.GetName().c_str());
generated_cases.push_back(generated_output);
if (strategy.NeedDropBasedCase(origin_graph, based_case.scheduled_graph, generated_output.scheduled_graph)) {
GELOGD("New template is better than original general template, drop it, %s.",
based_case.scheduled_graph.GetName().c_str());
drop_case_names.emplace(based_case.scheduled_graph.GetName());
}
}
return ge::SUCCESS;
}
af::Status BaseTemplateGenerator::GenerateTemplates(const af::AscGraph &origin_graph,
std::vector<autoschedule::AutoScheduleOutput> &tiling_cases) {
if (strategies_.empty()) {
GELOGD("No template strategies found.");
return ge::SUCCESS;
}
std::vector<autoschedule::AutoScheduleOutput> generated_cases;
std::unordered_set<std::string> drop_case_names;
for (const auto &strategy : strategies_) {
GE_CHECK_NOTNULL(strategy);
const auto &based_cases = GetBasedCasesByGenMode(strategy->GetGenerationMode(), tiling_cases, generated_cases);
GE_ASSERT_SUCCESS(Generate(*strategy, origin_graph, based_cases, generated_cases, drop_case_names));
}
if (generated_cases.empty() && drop_case_names.empty()) {
return ge::SUCCESS;
}
tiling_cases.insert(tiling_cases.end(), generated_cases.begin(), generated_cases.end());
if (!drop_case_names.empty()) {
std::vector<autoschedule::AutoScheduleOutput> reserved_cases;
reserved_cases.reserve(tiling_cases.size());
for (const auto &c : tiling_cases) {
if (drop_case_names.count(c.scheduled_graph.GetName()) == 0UL) {
reserved_cases.push_back(c);
}
}
tiling_cases.swap(reserved_cases);
}
return ge::SUCCESS;
}
}