* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "optimize.h"
#include <queue>
#include "attr_utils.h"
#include "ascir_ops.h"
#include "ascir_ops_utils.h"
#include "ascir_utils.h"
#include "autoschedule/autoschedule.h"
#include "graph_properties_cache.h"
#include "fused_graph/fused_graph_unfolder.h"
#include "fused_graph/fused_graph_modifier.h"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "graph/utils/graph_utils.h"
#include "task_generator/schedule_task_generator.h"
#include "fusion/autofuse_attrs.h"
#include "buffer_allocate/buf_que_allocator.h"
#include "ascgraph_info_complete.h"
#include "schedule_utils.h"
#include "common_utils.h"
#include "node_utils.h"
#include "optimize/graph_pass/pass_runner_handler.h"
#include "graph/symbolizer/symbolic_utils.h"
#include "optimize/graph_completeness/dtype_consistency.h"
#include "pre_process/pre_process.h"
using namespace ascir;
using namespace optimize;
using namespace af::ascir_op;
using namespace af::ops;
namespace {
const char *const kAttrAscGraph = "ascgraph";
constexpr int64_t kInvalidNodeId = -1;
constexpr size_t kMaxGraphNameLength = 60UL;
std::string TruncateGraphName(const std::string &name) {
if (name.length() <= kMaxGraphNameLength) {
return name;
}
return name.substr(0, kMaxGraphNameLength);
}
std::string GenerateIndexedGraphName(const std::string &original_name,
size_t result_idx,
size_t group_idx,
size_t impl_idx) {
std::ostringstream oss;
oss << original_name << "_S" << result_idx << "G" << group_idx << "C" << impl_idx;
return oss.str();
}
Status FinalizeIndexedGraphs(std::vector<ascir::ScheduledResult> &scheduled_results) {
for (size_t result_idx = 0UL; result_idx < scheduled_results.size(); ++result_idx) {
auto &scheduled_result = scheduled_results[result_idx];
for (size_t group_idx = 0UL; group_idx < scheduled_result.schedule_groups.size(); ++group_idx) {
auto &schedule_group = scheduled_result.schedule_groups[group_idx];
for (size_t impl_idx = 0UL; impl_idx < schedule_group.impl_graphs.size(); ++impl_idx) {
auto &impl_graph = schedule_group.impl_graphs[impl_idx];
std::string old_name = impl_graph.GetName();
std::string new_name = GenerateIndexedGraphName(old_name, result_idx, group_idx, impl_idx);
auto compute_graph = af::AscGraphUtils::GetComputeGraph(impl_graph);
GE_ASSERT_NOTNULL(compute_graph);
compute_graph->SetName(new_name);
GELOGD("Rename graph: [%s] -> [%s]", old_name.c_str(), new_name.c_str());
auto node = schedule_group.graph_name_to_score_funcs.extract(old_name);
if (!node.empty()) {
node.key() = std::move(new_name);
schedule_group.graph_name_to_score_funcs.insert(std::move(node));
GELOGD("Update score func key: [%s] -> [%s]", old_name.c_str(), node.key().c_str());
}
}
}
}
return af::SUCCESS;
}
bool IsAxisContinuous(const af::AscGraph &graph, const int64_t pre_id_idx, const int64_t post_id_idx) {
for (const auto &node : graph.GetAllNodes()) {
if (ScheduleUtils::IsBuffer(node)) {
continue;
}
auto pre_id = node->attr.sched.axis[pre_id_idx];
auto post_id = node->attr.sched.axis[post_id_idx];
for (auto &out_tensor : node->outputs()) {
auto pre_iter = std::find(out_tensor->attr.axis.begin(), out_tensor->attr.axis.end(), pre_id);
auto post_iter = std::find(out_tensor->attr.axis.begin(), out_tensor->attr.axis.end(), post_id);
if ((pre_iter == out_tensor->attr.axis.end()) || (post_iter == out_tensor->attr.axis.end())) {
return false;
}
auto pre_idx = std::distance(out_tensor->attr.axis.begin(), pre_iter);
auto post_idx = std::distance(out_tensor->attr.axis.begin(), post_iter);
auto post_stride = out_tensor->attr.strides[post_idx] * out_tensor->attr.repeats[post_idx];
if ((pre_idx + 1 != post_idx) ||
(af::SymbolicUtils::StaticCheckEq(out_tensor->attr.strides[pre_idx], post_stride) != af::TriBool::kTrue)) {
return false;
}
}
}
return true;
}
std::vector<std::vector<int64_t>> MergeContinuousPairs(const std::vector<std::pair<int64_t, int64_t>> &potential_axis) {
std::vector<std::vector<int64_t>> continuous_ids;
if (potential_axis.empty()) {
return continuous_ids;
}
std::vector<int64_t> current_chain;
current_chain.push_back(potential_axis[0].first);
current_chain.push_back(potential_axis[0].second);
for (size_t i = 1UL; i < potential_axis.size(); ++i) {
const auto &cur_pair = potential_axis[i];
if (current_chain.back() == cur_pair.first) {
current_chain.push_back(cur_pair.second);
} else {
continuous_ids.push_back(current_chain);
current_chain.clear();
current_chain.push_back(cur_pair.first);
current_chain.push_back(cur_pair.second);
}
}
continuous_ids.push_back(current_chain);
return continuous_ids;
}
std::unordered_set<size_t> IdentifyZeroStrideAxisIndices(const ascir::ImplGraph &owner_graph) {
std::vector<bool> is_zero_stride_axis;
bool include_reduce = false;
for (const auto &node : owner_graph.GetAllNodes()) {
GE_ASSERT_NOTNULL(node);
if (!ScheduleUtils::IsBuffer(node) && is_zero_stride_axis.empty()) {
is_zero_stride_axis.resize(node->attr.sched.axis.size(), true);
}
if (ScheduleUtils::IsReduce(node)) {
include_reduce = true;
}
}
for (const auto &node : owner_graph.GetAllNodes()) {
GE_ASSERT_NOTNULL(node);
if (ScheduleUtils::IsBuffer(node)) {
continue;
}
bool keep_first_axis = false;
(void) af::AttrUtils::GetBool(node->GetOpDesc(), "_keep_first_axis", keep_first_axis);
keep_first_axis = keep_first_axis || include_reduce;
const auto &loop_axes = node->attr.sched.axis;
for (size_t loop_idx = 0UL; loop_idx < loop_axes.size(); ++loop_idx) {
bool has_non_zero_stride = false;
for (const auto &output : node->outputs()) {
auto iter = std::find(output->attr.axis.begin(), output->attr.axis.end(), loop_axes[loop_idx]);
GE_ASSERT_TRUE(iter != output->attr.axis.end());
auto axis_index = static_cast<size_t>(std::distance(output->attr.axis.begin(), iter));
GE_ASSERT_TRUE(axis_index < output->attr.strides.size(), "axis index [%zu] is out of range, max_size:[%zu].",
axis_index, output->attr.strides.size());
if (af::SymbolicUtils::StaticCheckEq(output->attr.strides[axis_index], af::sym::kSymbolZero) !=
af::TriBool::kTrue || (keep_first_axis && (axis_index == 0))) {
has_non_zero_stride = true;
break;
}
}
if (has_non_zero_stride) {
is_zero_stride_axis[loop_idx] = false;
}
}
}
std::unordered_set<size_t> zero_stride_axis_indices;
for (size_t i = 0UL; i < is_zero_stride_axis.size(); ++i) {
if (is_zero_stride_axis[i]) {
zero_stride_axis_indices.emplace(i);
}
}
if (zero_stride_axis_indices.size() == is_zero_stride_axis.size()) {
return {};
}
return zero_stride_axis_indices;
}
Status GetDirectFatherNode(const af::AscGraph &impl_graph, std::map<int64_t, int64_t> &dir_father_nodes) {
for (const auto &node : impl_graph.GetAllNodes()) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
for (const auto &out_node : node->GetOutAllNodes()) {
GE_CHECK_NOTNULL(out_node);
GE_CHECK_NOTNULL(out_node->GetOpDesc());
if (af::ops::IsOps<Store>(out_node) || af::ops::IsOps<Output>(out_node)) {
dir_father_nodes.insert(std::make_pair(node->GetOpDesc()->GetId(), node->GetOpDesc()->GetId()));
continue;
}
dir_father_nodes.insert(std::make_pair(node->GetOpDesc()->GetId(), out_node->GetOpDesc()->GetId()));
}
}
return af::SUCCESS;
}
int64_t FindRoot(std::map<int64_t, int64_t> &dir_father_nodes, const int64_t node_id) {
if (node_id == dir_father_nodes[node_id]) {
return node_id;
}
return dir_father_nodes[node_id] = FindRoot(dir_father_nodes, dir_father_nodes[node_id]);
}
bool HasSameComputeNode(std::map<int64_t, int64_t> &dir_father_nodes, int64_t node0_id, int64_t node1_id) {
node0_id = FindRoot(dir_father_nodes, node0_id);
node1_id = FindRoot(dir_father_nodes, node1_id);
return node0_id == node1_id;
}
Status GetMapFromQueId2LoadId(const af::AscGraph &impl_graph,
std::map<int64_t, vector<int64_t>> &loads_with_same_queid) {
for (const auto &node : impl_graph.GetAllNodes()) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
if (!af::ops::IsOps<Load>(node)) {
continue;
}
if (loads_with_same_queid.find(node->outputs[0].attr.que.id) != loads_with_same_queid.end()) {
loads_with_same_queid[node->outputs[0].attr.que.id].emplace_back(node->GetOpDesc()->GetId());
continue;
}
vector<int64_t> temp_load_ids = {node->GetOpDesc()->GetId()};
loads_with_same_queid.insert(std::make_pair(node->outputs[0].attr.que.id, temp_load_ids));
}
return af::SUCCESS;
}
Status SearchNodesNeedForward(const af::AscGraph &impl_graph, std::map<int64_t, int64_t> &need_forward_nodes_id,
int64_t &first_load_id) {
std::map<int64_t, int64_t> dir_father_nodes;
GE_ASSERT_SUCCESS(GetDirectFatherNode(impl_graph, dir_father_nodes));
constexpr size_t load_thresh = 2UL;
size_t num_of_load_need_adjust = 0UL;
size_t index_data_and_load = 0UL;
std::map<int64_t, vector<int64_t>> loads_with_same_queid;
GE_ASSERT_SUCCESS(GetMapFromQueId2LoadId(impl_graph, loads_with_same_queid));
for (const auto &node : impl_graph.GetAllNodes()) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
if (num_of_load_need_adjust >= load_thresh) {
GELOGD("The number of loads that need to be brought forward is %zu, which reaches threshold %zu.",
num_of_load_need_adjust, load_thresh);
break;
}
if (first_load_id == kInvalidNodeId && af::ops::IsOps<Load>(node)) {
first_load_id = node->GetOpDesc()->GetId();
num_of_load_need_adjust++;
continue;
}
if (first_load_id != kInvalidNodeId && af::ops::IsOps<Load>(node) && node->GetOpDesc()->GetId() > first_load_id &&
loads_with_same_queid[node->outputs[0].attr.que.id].size() == 1UL &&
HasSameComputeNode(dir_father_nodes, node->GetOpDesc()->GetId(), first_load_id)) {
for (const auto &in_node : node->GetInAllNodes()) {
GE_CHECK_NOTNULL(in_node);
GE_CHECK_NOTNULL(in_node->GetOpDesc());
if (ScheduleUtils::IsBuffer(std::dynamic_pointer_cast<af::AscNode>(in_node)) &&
in_node->GetOpDesc()->GetId() > first_load_id) {
GELOGD("Input node %s is after first load, needs to be advanced.", in_node->GetNamePtr());
need_forward_nodes_id.insert(std::make_pair(in_node->GetOpDesc()->GetId(), index_data_and_load++));
}
}
if (ScheduleUtils::IsOutNodeWithMultiInputs(node)) {
continue;
}
GELOGD("Node %s needs to be advanced.", node->GetNamePtr());
need_forward_nodes_id.insert(std::make_pair(node->GetOpDesc()->GetId(), index_data_and_load++));
num_of_load_need_adjust++;
}
}
return af::SUCCESS;
}
Status DoSeqAdjustment(const af::AscGraph &impl_graph, const std::map<int64_t, int64_t> &need_forward_nodes_id,
const int64_t &first_load_id) {
const size_t need_forward_nodes_num = need_forward_nodes_id.size();
if (need_forward_nodes_num <= 0UL || first_load_id == kInvalidNodeId) {
return af::SUCCESS;
}
int64_t start_index_other_nodes = first_load_id + static_cast<int64_t>(need_forward_nodes_num) + 1;
for (const auto &node : impl_graph.GetAllNodes()) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
if (node->GetOpDesc()->GetId() <= first_load_id) {
continue;
}
if (node->GetOpDesc()->GetId() > need_forward_nodes_id.rbegin()->first) {
GELOGD("No need to adjust node after node %" PRId64 ".", need_forward_nodes_id.rbegin()->first);
break;
}
if (need_forward_nodes_id.find(node->GetOpDesc()->GetId()) != need_forward_nodes_id.end()) {
GELOGD("Move node with id %" PRId64 " and name %s forward.", node->GetOpDesc()->GetId(),
node->GetNamePtr());
node->GetOpDesc()->SetId(first_load_id + need_forward_nodes_id.at(node->GetOpDesc()->GetId()) + 1);
continue;
}
node->GetOpDesc()->SetId(start_index_other_nodes++);
}
return af::SUCCESS;
}
std::string RegisterScoreFuncInScheduleGroup(const autoschedule::AutoScheduleOutput &schedule_output,
ScheduleGroup &schedule_group, const bool should_skip_registry = true) {
if (schedule_output.score_func.empty()) {
return "";
}
const std::string graph_name = schedule_output.scheduled_graph.GetName();
if (should_skip_registry) {
GELOGD("Not a valid case, skip register template score func of graph [%s].", graph_name.c_str());
return "";
}
schedule_group.graph_name_to_score_funcs[graph_name] = schedule_output.score_func;
GELOGD("The score func of template [%s] is [%s].", graph_name.c_str(), schedule_output.score_func.c_str());
return graph_name;
}
af::Status CopyImplGraphs(const std::vector<autoschedule::AutoScheduleOutput> &schedule_outputs,
std::vector<ascir::ScheduledResult> &scheduled_results_cur) {
for (size_t i = 0UL; i < scheduled_results_cur.size(); ++i) {
auto &cur_result = scheduled_results_cur[i];
ScheduleGroup cur_group;
cur_group.impl_graphs.reserve(schedule_outputs.size());
for (const auto &result: schedule_outputs) {
ascir::ImplGraph copied_graph(result.scheduled_graph.GetName().c_str());
GE_ASSERT_TRUE(copied_graph.CopyFrom(result.scheduled_graph));
cur_group.impl_graphs.push_back(std::move(copied_graph));
RegisterScoreFuncInScheduleGroup(result, cur_group);
}
cur_result.schedule_groups.emplace_back(std::move(cur_group));
}
return af::SUCCESS;
}
bool CanDoReMergeAxis(const af::AscGraph &impl_graph) {
GraphPropertiesCache cache(impl_graph);
return !cache.HasGather() && !cache.HasReduce() && !cache.HasCube();
}
void FilterComplexTilingDataScoreFuncs(std::vector<::ascir::ScheduledResult> &scheduled_results,
const std::set<std::string> &scored_graph_names) {
if (scored_graph_names.empty()) {
return;
}
if (scheduled_results.size() == 1UL && scheduled_results[0].schedule_groups.size() == 1UL) {
GELOGD("Autoschedule score func for graph simple tiling data: %zu results, %zu groups", scheduled_results.size(),
scheduled_results[0].schedule_groups.size());
return;
}
for (auto &result : scheduled_results) {
for (auto &group : result.schedule_groups) {
auto it = group.graph_name_to_score_funcs.begin();
while (it != group.graph_name_to_score_funcs.end()) {
if (scored_graph_names.count(it->first) == 0) {
++it;
continue;
}
GELOGD("Clear autoschedule score func for graph [%s] in complex tiling data: %zu results, %zu groups",
it->first.c_str(), scheduled_results.size(), result.schedule_groups.size());
it = group.graph_name_to_score_funcs.erase(it);
}
}
}
}
Status CheckGraphValidity(const af::AscGraph &graph) {
for (const auto &node: graph.GetAllNodes()) {
if (ScheduleUtils::IsBuffer(node)) {
continue;
}
for (size_t i = 0UL; i < node->outputs().size(); ++i) {
const auto &tensor = node->outputs[i];
const auto &axis = tensor.attr.axis;
const auto &repeats = tensor.attr.repeats;
const auto &strides = tensor.attr.strides;
GE_ASSERT_TRUE(axis.size() == repeats.size() && repeats.size() == strides.size(),
"Output tensor[%zu] of node [%s] has mismatched sizes: axis=%zu, repeat=%zu, stride=%zu.", i,
node->GetNamePtr(), axis.size(), repeats.size(), strides.size());
}
}
return af::SUCCESS;
}
bool IsRemovableDanglingNode(const af::NodePtr &node) {
if (node == nullptr) {
return false;
}
if (IsOps<Data>(node) || IsOps<ScalarData>(node) || IsOps<Output>(node) || IsOps<Workspace>(node)) {
return false;
}
return node->GetOutDataNodesSize() == 0U && node->GetOutControlNodesSize() == 0U;
}
bool IsDanglingInputNode(const af::NodePtr &node) {
if (node == nullptr || (!IsOps<Data>(node) && !IsOps<ScalarData>(node))) {
return false;
}
return node->GetOutDataNodesSize() == 0U && node->GetOutControlNodesSize() == 0U;
}
Status GetOutputNodes(const af::ComputeGraphPtr &graph, std::vector<af::NodePtr> &output_nodes) {
GE_ASSERT_NOTNULL(graph);
output_nodes.clear();
for (const auto &node : graph->GetDirectNode()) {
GE_ASSERT_NOTNULL(node);
if (IsOps<Output>(node) || (IsOps<Workspace>(node) && node->GetOutDataNodesSize() == 0U)) {
output_nodes.push_back(node);
}
}
return af::SUCCESS;
}
Status LinkDanglingInputNodesToOutput(const af::ComputeGraphPtr &graph) {
GE_ASSERT_NOTNULL(graph);
std::vector<af::NodePtr> output_nodes;
GE_CHK_STATUS_RET(GetOutputNodes(graph, output_nodes), "Get output nodes failed, graph:[%s].",
graph->GetName().c_str());
if (output_nodes.empty()) {
GELOGW("Graph [%s] does not contain valid output nodes, skip linking dangling input nodes.",
graph->GetName().c_str());
return af::SUCCESS;
}
const auto &first_output_node = output_nodes.front();
for (const auto &node : graph->GetDirectNode()) {
GE_ASSERT_NOTNULL(node);
if (!IsDanglingInputNode(node)) {
continue;
}
GE_ASSERT_GRAPH_SUCCESS(af::GraphUtils::AddEdge(node->GetOutControlAnchor(),
first_output_node->GetInControlAnchor()));
GELOGD("Add extra control edge between input node[%s] and output node[%s] in graph[%s].", node->GetNamePtr(),
first_output_node->GetNamePtr(), graph->GetName().c_str());
}
return af::SUCCESS;
}
Status RemoveDanglingNodes(const af::ComputeGraphPtr &graph) {
GE_ASSERT_NOTNULL(graph);
size_t total_removed = 0UL;
bool has_removed = false;
do {
has_removed = false;
std::vector<af::NodePtr> nodes_to_remove;
for (const auto &node : graph->GetDirectNode()) {
GE_ASSERT_NOTNULL(node);
if (IsRemovableDanglingNode(node)) {
nodes_to_remove.push_back(node);
}
}
for (const auto &node : nodes_to_remove) {
GELOGW("Remove dangling node [%s], type [%s], graph [%s].", node->GetNamePtr(), node->GetTypePtr(),
graph->GetName().c_str());
GE_ASSERT_GRAPH_SUCCESS(graph->RemoveNode(node));
++total_removed;
has_removed = true;
}
} while (has_removed);
GELOGI("Remove dangling nodes end, graph [%s], removed node num [%zu].", graph->GetName().c_str(), total_removed);
return af::SUCCESS;
}
Status RemoveDanglingNodes(af::AscGraph &graph) {
auto compute_graph = af::AscGraphUtils::GetComputeGraph(graph);
GE_ASSERT_NOTNULL(compute_graph);
GE_CHK_STATUS_RET(RemoveDanglingNodes(compute_graph), "Remove dangling nodes failed, graph:[%s].",
graph.GetName().c_str());
GE_CHK_STATUS_RET(LinkDanglingInputNodesToOutput(compute_graph), "Link dangling input nodes failed, graph:[%s].",
graph.GetName().c_str());
GE_ASSERT_GRAPH_SUCCESS(ScheduleUtils::TopologicalSorting(graph));
return af::SUCCESS;
}
}
Optimizer::Optimizer(const OptimizerOptions &options) : options_(options) {}
Status Optimizer::Optimize(const af::ComputeGraphPtr &fused_graph,
ascir::FusedScheduledResult &fused_scheduled_result) {
GELOGI("Fused graph optimize in, graph_name:[%s].", fused_graph->GetName().c_str());
ascir::utils::FusedGraphNameGuard guard(fused_graph->GetName());
ascir::utils::DumpComputeGraph(fused_graph, "BaseFusedGraph");
if (options_.graph_type == GraphType::kFusedAscBackend) {
return OptimizeFusedAscBackend(fused_graph, fused_scheduled_result);
}
std::map<af::Node *, af::AscGraph> asc_backend_to_ascgraph;
SizeVarSet original_var_set;
for (auto &node : fused_graph->GetDirectNodePtr()) {
GE_ASSERT_NOTNULL(node);
if (node->GetType() == kAscGraphNodeType) {
const std::string *serialized_ascgraph = af::AttrUtils::GetStr(node->GetOpDescBarePtr(), kAttrAscGraph);
GE_ASSERT_NOTNULL(serialized_ascgraph, "Failed to get serialized ascgraph attr from node:[%s].",
node->GetNamePtr());
std::string graph_name = node->GetName() + "_ascgraph";
af::AscGraph ascgraph(graph_name.c_str());
GE_CHK_STATUS_RET(af::AscGraphUtils::DeserializeFromReadable(*serialized_ascgraph, ascgraph),
"DeserializeFromBinary failed, graph:[%s].", fused_graph->GetName().c_str());
ascgraph.SetGraphType(af::AscGraphType::kImplGraph);
GE_CHK_STATUS_RET(AscGraphInfoComplete::CompleteApiInfo(ascgraph), "CompleteApiInfo failed");
AscGraphInfoComplete::AppendOriginalSizeVar(ascgraph, original_var_set);
ascir::utils::DumpGraph(ascgraph, "AfterDeserialize");
asc_backend_to_ascgraph.emplace(node, ascgraph);
}
}
GE_ASSERT_TRUE(!asc_backend_to_ascgraph.empty(), "The fused graph [%s] is invalid, which has none AscBackend node.",
fused_graph->GetName().c_str());
af::AscGraph hint_graph(fused_graph->GetName().c_str());
if (asc_backend_to_ascgraph.size() > 1UL) {
GE_CHK_STATUS_RET(FusedGraphUnfolder::UnfoldFusedGraph(fused_graph, asc_backend_to_ascgraph, hint_graph),
"Failed to unfold graph[%s].", fused_graph->GetName().c_str());
} else {
hint_graph = asc_backend_to_ascgraph.begin()->second;
}
auto owner_graph = af::AscGraphUtils::GetComputeGraph(hint_graph);
GE_ASSERT_NOTNULL(owner_graph);
owner_graph->SetName(ascgen_utils::GenValidName(fused_graph->GetName()));
GE_ASSERT_SUCCESS(Optimize(hint_graph, fused_scheduled_result), "optimize failed, graph:[%s].",
hint_graph.GetName().c_str());
fused_scheduled_result.fused_graph_name = fused_graph->GetName().c_str();
fused_scheduled_result.origin_vars.assign(original_var_set.begin(), original_var_set.end());
return af::SUCCESS;
}
Status Optimizer::OptimizeFusedAscBackend(const af::ComputeGraphPtr &fused_graph,
ascir::FusedScheduledResult &fused_scheduled_result) const {
std::map<af::Node *, af::AscGraph> asc_backend_to_ascgraph;
SizeVarSet original_var_set;
for (auto &node : fused_graph->GetDirectNodePtr()) {
GE_ASSERT_NOTNULL(node);
if (node->GetType() == kAscBackendType) {
const auto fuse_attr = node->GetOpDesc()->GetAttrsGroup<af::AutoFuseAttrs>();
GE_ASSERT_NOTNULL(fuse_attr, "Node %s has no AutoFuseAttrs", node->GetName().c_str());
auto fuse_asc_graph = fuse_attr->GetAscGraph();
GE_ASSERT_NOTNULL(fuse_asc_graph, "Cannot get ascgraph from ascbc node:[%s].", node->GetNamePtr());
ascir::utils::DumpGraph(*fuse_asc_graph, "AutoFuseBeforeRemoveDanglingNodes");
GE_CHK_STATUS_RET(RemoveDanglingNodes(*fuse_asc_graph), "Remove dangling nodes failed, graph:[%s].",
fuse_asc_graph->GetName().c_str());
::ascir::utils::DumpGraph(*fuse_asc_graph, "AutoFuseBeforeOptimize");
AscGraphInfoComplete::AppendOriginalSizeVar(*fuse_asc_graph, original_var_set);
asc_backend_to_ascgraph.emplace(node, *fuse_asc_graph);
}
}
GE_ASSERT_TRUE(!asc_backend_to_ascgraph.empty(), "The fused graph [%s] is invalid, which has none AscBackend node.",
fused_graph->GetName().c_str());
GE_ASSERT_SUCCESS(FusedGraphModifier::SubgraphConnectionsToWorkspace(fused_graph, asc_backend_to_ascgraph),
"Failed to add workspace between ascgraphs.");
fused_scheduled_result.fused_graph_name = fused_graph->GetName().c_str();
for (auto &node : fused_graph->GetDirectNodePtr()) {
if (node->GetType() == kAscBackendType) {
auto it = asc_backend_to_ascgraph.find(node);
GE_ASSERT_TRUE(it != asc_backend_to_ascgraph.end());
std::vector<::ascir::ScheduledResult> sub_results;
GE_ASSERT_SUCCESS(OptimizeForHintGraph(it->second, sub_results), "optimize failed, graph:[%s].",
it->second.GetName().c_str());
fused_scheduled_result.node_idx_to_scheduled_results.emplace_back(std::move(sub_results));
}
}
fused_scheduled_result.origin_vars.assign(original_var_set.begin(), original_var_set.end());
GE_CHK_STATUS_RET(BufQueAllocator().AllocBufQue(fused_scheduled_result));
GELOGI("AllocBufQue end");
for (auto &scheduled_results : fused_scheduled_result.node_idx_to_scheduled_results) {
for (auto &result : scheduled_results) {
GE_ASSERT_SUCCESS(FusedGraphModifier::ChangeStartingOutputToWorkspace(result.schedule_groups),
"Change starting output to workspace failed.");
}
}
ascir::utils::DumpScheduleResult(fused_scheduled_result, "AutoFuseAfterOptimize");
return af::SUCCESS;
}
Status Optimizer::BufQueAlloc(const ascir::HintGraph &graph, ascir::ImplGraph &impl_graph) const {
(void)graph;
FusedScheduledResult fused_scheduled_result;
fused_scheduled_result.node_idx_to_scheduled_results.resize(1UL);
fused_scheduled_result.node_idx_to_scheduled_results[0UL].resize(1UL);
fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups.resize(1UL);
fused_scheduled_result.node_idx_to_scheduled_results[0UL][0UL].schedule_groups[0].impl_graphs.emplace_back(
impl_graph);
GE_CHK_STATUS_RET(BufQueAllocator().AllocBufQue(fused_scheduled_result), "AllocBufQue failed");
return af::SUCCESS;
}
Status Optimizer::BufQueAlloc(const ascir::HintGraph &graph, std::vector<ascir::ImplGraph> &impl_graphs) const {
for (auto &impl_graph : impl_graphs) {
GE_CHK_STATUS_RET(this->BufQueAlloc(graph, impl_graph), "AllocBufQue failed");
}
return af::SUCCESS;
}
Status Optimizer::GraphPass(ascir::ImplGraph &impl_graph) const {
return autoschedule::PassRunnerHandler::RunPasses(impl_graph);
}
Status Optimizer::GetNonContinuousAxisPairBySpecialRule(ascir::ImplGraph &impl_graph,
std::set<std::pair<int64_t, int64_t>> &non_continuous_pair) {
for (const auto &node : impl_graph.GetAllNodes()) {
if (ScheduleUtils::IsConcat(node) || (ScheduleUtils::IsSplit(node))) {
const std::vector<af::Expression> &input_repeats = node->inputs[0].attr.repeats;
const std::vector<af::Expression> &output_repeats = node->outputs[0].attr.repeats;
GE_ASSERT_TRUE((input_repeats.size() == output_repeats.size()),
"The output dim cnt [%zu] of concat mismatch with input dim cnt [%zu].", output_repeats.size(),
input_repeats.size());
af::Expression pre_size = af::sym::kSymbolOne;
uint32_t concat_dim{0U};
for (uint32_t i = 0U; i < input_repeats.size(); ++i) {
if (af::SymbolicUtils::StaticCheckEq(input_repeats[i], output_repeats[i]) != af::TriBool::kTrue) {
concat_dim = i;
break;
}
pre_size = pre_size * input_repeats[i];
}
if ((concat_dim > 0U) &&
(af::SymbolicUtils::StaticCheckEq(pre_size, af::sym::kSymbolOne) != af::TriBool::kTrue )) {
non_continuous_pair.emplace(concat_dim - 1, concat_dim);
}
bool no_merge_first_axis = false;
(void) af::AttrUtils::GetBool(node->GetOpDesc(), "_keep_first_axis", no_merge_first_axis);
if (no_merge_first_axis) {
non_continuous_pair.emplace(0, 1);
}
}
if (ScheduleUtils::IsLoad(node)) {
auto strides = node->outputs[0U].attr.strides;
for (int64_t i = static_cast<int64_t>(strides.size() - 1); i >= 0; --i) {
if (af::SymbolicUtils::StaticCheckEq(strides[i], af::sym::kSymbolZero) == af::TriBool::kTrue) {
continue;
}
if (af::SymbolicUtils::StaticCheckEq(strides[i], af::sym::kSymbolOne) == af::TriBool::kTrue ) {
break;
} else {
GELOGD("Node [%s] is last axis load, axis:[%ld].", node->GetNamePtr(), i);
non_continuous_pair.emplace(i - 1, i);
break;
}
}
continue;
}
if (ScheduleUtils::IsGather(node)) {
int64_t attr_axis = -1;
ScheduleUtils::GetNodeIrAttrValue(node, "axis", attr_axis);
if (static_cast<size_t>(attr_axis) == node->inputs[0].attr.repeats.size() - 1 && attr_axis != 0) {
non_continuous_pair.emplace(attr_axis - 1, attr_axis);
}
if (static_cast<size_t>(attr_axis) != node->inputs[0].attr.repeats.size() - 1) {
auto indices = node->inputs[1].attr.repeats;
attr_axis = attr_axis + indices.size() - 1;
non_continuous_pair.emplace(attr_axis, attr_axis + 1);
}
}
}
return af::SUCCESS;
}
Status Optimizer::RemoveAllZeroStrideLoopAxis(ascir::ImplGraph &owner_graph) {
if (ScheduleUtils::HasComputeType(owner_graph, af::ComputeType::kComputeGather)) {
return af::SUCCESS;
}
std::unordered_set<size_t> zero_stride_axis_indices = IdentifyZeroStrideAxisIndices(owner_graph);
if (zero_stride_axis_indices.empty()) {
return af::SUCCESS;
}
for (const auto &node : owner_graph.GetAllNodes()) {
GE_ASSERT_NOTNULL(node);
if (ScheduleUtils::IsBuffer(node)) {
continue;
}
std::vector<int64_t> valid_axis_ids;
const auto &original_axis_ids = node->attr.sched.axis;
for (size_t i = 0UL; i < original_axis_ids.size(); ++i) {
if (zero_stride_axis_indices.count(i) == 0UL) {
valid_axis_ids.push_back(original_axis_ids[i]);
}
}
node->attr.sched.axis = valid_axis_ids;
for (const auto &output : node->outputs()) {
std::vector<int64_t> new_axis_ids;
std::vector<af::Expression> new_repeats;
std::vector<af::Expression> new_strides;
for (size_t i = 0UL; i < output->attr.axis.size(); ++i) {
auto axis_id = output->attr.axis[i];
if (std::find(valid_axis_ids.begin(), valid_axis_ids.end(), axis_id) != valid_axis_ids.end()) {
new_axis_ids.push_back(axis_id);
GE_ASSERT_TRUE(i < output->attr.strides.size());
GE_ASSERT_TRUE(i < output->attr.repeats.size());
new_strides.push_back(output->attr.strides[i]);
new_repeats.push_back(output->attr.repeats[i]);
}
}
output->attr.axis = new_axis_ids;
output->attr.strides = new_strides;
output->attr.repeats = new_repeats;
}
}
return af::SUCCESS;
}
Status Optimizer::MergeContinuousAxis(ascir::ImplGraph &impl_graph, ascir::CubeTemplateType cube_type) {
auto all_axis = impl_graph.GetAllAxis();
if (all_axis.size() <= 1UL) {
return af::SUCCESS;
}
std::vector<std::pair<int64_t, int64_t>> potential_axis_idx;
for (const auto &node : impl_graph.GetAllNodes()) {
if (!ScheduleUtils::IsBuffer(node)) {
auto axis_ids = node->attr.sched.axis;
potential_axis_idx.reserve(axis_ids.size() - 1);
for (size_t i = 0UL; i < axis_ids.size() - 1; ++i) {
potential_axis_idx.emplace_back(i, i + 1);
}
break;
}
}
std::set<std::pair<int64_t, int64_t>> non_continuous_pair;
GE_ASSERT_SUCCESS(GetNonContinuousAxisPairBySpecialRule(impl_graph, non_continuous_pair));
potential_axis_idx.erase(
std::remove_if(potential_axis_idx.begin(), potential_axis_idx.end(),
[&non_continuous_pair](const auto &pair) { return non_continuous_pair.count(pair) > 0; }),
potential_axis_idx.end());
for (auto it = potential_axis_idx.rbegin(); it != potential_axis_idx.rend();) {
if (!IsAxisContinuous(impl_graph, it->first, it->second)) {
auto normal_it = it.base();
++it;
potential_axis_idx.erase(normal_it - 1);
} else {
++it;
}
}
std::vector<std::vector<int64_t>> merged_axis_indexes = MergeContinuousPairs(potential_axis_idx);
std::map<std::vector<af::AxisId>, af::AxisId> from_id_to_merged_id;
std::vector<af::AxisPtr> new_merged_axes;
for (auto node : impl_graph.GetAllNodes()) {
if (ScheduleUtils::IsBuffer(node)) {
continue;
}
std::vector<af::AxisId> node_merged_id;
for (const auto &from_idx : merged_axis_indexes) {
std::vector<af::AxisId> from_ids;
for (const int64_t index : from_idx) {
from_ids.emplace_back(node->attr.sched.axis[index]);
}
auto iter = from_id_to_merged_id.find(from_ids);
if (iter == from_id_to_merged_id.end()) {
auto merged_axis = impl_graph.MergeAxis(from_ids);
new_merged_axes.push_back(merged_axis);
node_merged_id.push_back(merged_axis->id);
from_id_to_merged_id[from_ids] = merged_axis->id;
} else {
node_merged_id.push_back(iter->second);
}
}
for (auto axis_id : node_merged_id) {
GELOGD("Apply merged axis id [%ld] to node:[%s].", axis_id, node->GetNamePtr());
GE_ASSERT_TRUE(impl_graph.ApplyMerge(node, axis_id), "Failed to apply merged axis id %ld to node:[%s].", axis_id,
node->GetNamePtr());
}
}
int64_t attr_axis = -1;
int64_t param_size = -1;
bool has_gather = ScheduleUtils::GetGatherParams(impl_graph, attr_axis, param_size);
if ((!has_gather) && (cube_type != ascir::CubeTemplateType::kUBFuse)) {
for (const auto &axis : new_merged_axes) {
axis->type = af::Axis::Type::kAxisTypeOriginal;
axis->from.clear();
}
if (cube_type == ascir::CubeTemplateType::kUBFuse) {
return af::GRAPH_SUCCESS;
}
GE_ASSERT_SUCCESS(ScheduleUtils::RemoveUnusedAxes(impl_graph), "Failed to remove unused axes");
}
return af::SUCCESS;
}
Status Optimizer::OptimizeForHintGraph(af::AscGraph &hint_graph,
std::vector<ascir::ScheduledResult> &scheduled_results) const {
ScheduleUtils::NormalizeAxisIds(hint_graph);
hint_graph.SetGraphType(af::AscGraphType::kImplGraph);
ascir::utils::DumpPyCode(hint_graph);
GE_CHK_STATUS_RET(AscGraphInfoComplete::CompleteApiInfo(hint_graph), "CompleteApiInfo failed");
GE_ASSERT_SUCCESS(CheckGraphValidity(hint_graph),
"Check graph validity failed, graph:[%s].", hint_graph.GetName().c_str());
if (options_.graph_type == GraphType::kFusedAscBackend) {
GE_CHK_STATUS_RET(af::pre_process::PreProcess::Run(hint_graph), "Pre process failed");
ascir::utils::DumpGraph(hint_graph, "AfterPreProcess");
}
std::string base_graph_name = TruncateGraphName(hint_graph.GetName());
ascir::ImplGraph optimize_graph(base_graph_name.c_str());
GE_ASSERT_TRUE(optimize_graph.CopyFrom(hint_graph));
GE_CHK_STATUS_RET(DtypeConsistency::EnsureDtypeConsistency(optimize_graph), "Failed to ensure dtype consistency");
ascir::utils::DumpGraph(optimize_graph, "AfterDtypeConsistency");
GE_CHK_STATUS_RET(GraphPass(optimize_graph), "Run graph passes failed");
GE_CHK_STATUS_RET(AscGraphInfoComplete::CompleteApiInfo(optimize_graph), "CompleteApiInfo failed");
GE_ASSERT_GRAPH_SUCCESS(ScheduleUtils::TopologicalSorting(optimize_graph));
utils::DumpGraph(optimize_graph, "AfterGraphPass");
GE_ASSERT_SUCCESS(RemoveAllZeroStrideLoopAxis(optimize_graph), "Remove All zero stride axis failed.");
if (!ScheduleUtils::HasComputeType(optimize_graph, af::ComputeType::kComputeCube)) {
GE_ASSERT_SUCCESS(MergeContinuousAxis(optimize_graph), "Merge continuous axes failed.");
}
ascir::utils::DumpGraph(optimize_graph, "AfterMergeAxis");
std::vector<ScheduleTask> schedule_tasks;
GE_CHK_STATUS_RET(ScheduleTaskGenerator::GenerateTasks(optimize_graph, schedule_tasks, options_),
"Generate tasks failed");
for (size_t i = 0U; i < schedule_tasks.size(); ++i) {
GE_CHK_STATUS_RET(AutoScheduler(hint_graph, schedule_tasks[i], scheduled_results), "AutoScheduler task[%zu] failed",
i);
GELOGI("AutoScheduler task[%zu] end", i);
}
GE_ASSERT_SUCCESS(FinalizeIndexedGraphs(scheduled_results));
return af::SUCCESS;
}
* 用于调整load执行顺序,具体步骤如下:
* 1. 获取节点之间的连边关系
* 2. 按序遍历所有节点,从第二个load开始,先判断其是否存在内存共复用,若存在则跳过;
* 若不存在则采用并查集判定其与首load是否具有公共计算节点,若有则前移;若无则不作处理;
* 直到达到最大可调序load个数或遍历结束即退出
* 3. 对load前移后的图进行reorder
* @param impl_graph
* @return
*/
Status Optimizer::LoadOpSeqAdjust(const af::AscGraph &impl_graph) {
std::map<int64_t, int64_t> need_forward_nodes_id;
int64_t first_load_id = kInvalidNodeId;
GE_ASSERT_SUCCESS(SearchNodesNeedForward(impl_graph, need_forward_nodes_id, first_load_id));
GE_ASSERT_SUCCESS(DoSeqAdjustment(impl_graph, need_forward_nodes_id, first_load_id));
const auto &compute_graph = af::AscGraphUtils::GetComputeGraph(impl_graph);
GE_ASSERT_NOTNULL(compute_graph);
compute_graph->ReorderByNodeId();
return af::SUCCESS;
}
Status Optimizer::Optimize(af::AscGraph &hint_graph, FusedScheduledResult &fused_scheduled_result) {
ascir::utils::DumpGraph(hint_graph, "AutoFuseBeforeRemoveDanglingNodes");
GE_CHK_STATUS_RET(RemoveDanglingNodes(hint_graph), "Remove dangling nodes failed, graph:[%s].",
hint_graph.GetName().c_str());
ascir::utils::DumpGraph(hint_graph, "AutoFuseBeforeOptimize");
fused_scheduled_result.node_idx_to_scheduled_results.resize(1UL);
SizeVarSet original_var_set;
AscGraphInfoComplete::AppendOriginalSizeVar(hint_graph, original_var_set);
GE_ASSERT_SUCCESS(OptimizeForHintGraph(hint_graph, fused_scheduled_result.node_idx_to_scheduled_results[0UL]),
"Failed to optimize for graph:[%s].", hint_graph.GetName().c_str());
GE_CHK_STATUS_RET(BufQueAllocator().AllocBufQue(fused_scheduled_result));
if (options_.graph_type == GraphType::kAscGraph) {
fused_scheduled_result.fused_graph_name = hint_graph.GetName().c_str();
fused_scheduled_result.origin_vars.assign(original_var_set.begin(), original_var_set.end());
}
GELOGI("AllocBufQue end");
TryEnableGroupParallel(fused_scheduled_result);
ExecSeqAdvancedOfLoad(fused_scheduled_result);
ascir::utils::DumpScheduleResult(fused_scheduled_result, "AutoFuseAfterOptimize");
return af::SUCCESS;
}
bool Optimizer::IsReduceFirstStage(size_t index, ScheduleTask &schedule_task) {
if (schedule_task.reduce_type != ReduceTemplateType::kRCore) {
return false;
}
auto iter = schedule_task.groups_relations_in.find(index);
if (iter != schedule_task.groups_relations_in.end()) {
return true;
}
return false;
}
void Optimizer::RefreshGroupRelation(size_t index, std::map<std::string, af::Expression> &var_relations,
ScheduleTask &schedule_task, ScheduledResult &schedule_result) const {
auto iter = schedule_task.groups_relations_in.find(index);
if (iter == schedule_task.groups_relations_in.end()) {
return;
}
for (auto dst : iter->second) {
schedule_result.var_relations[dst][index] = var_relations;
}
}
static Status ProcessCubeSchedules(std::vector<ascir::ScheduledResult> &scheduled_results_cur, ascir::ImplGraph &grouped_graph) {
ascir::Graph optimize_graph(ascgen_utils::GenValidName(grouped_graph.GetName()).c_str());
GE_ASSERT_TRUE(optimize_graph.CopyFrom(grouped_graph));
ScheduleGroup schedule_group{{optimize_graph}, {}};
std::for_each(scheduled_results_cur.begin(), scheduled_results_cur.end(),
[&schedule_group](ascir::ScheduledResult &res) { res.schedule_groups.emplace_back(schedule_group); });
return af::SUCCESS;
}
static Status ProcessNonReduceSchedules(const std::vector<autoschedule::AutoScheduleOutput> &schedule_outputs,
std::vector<ascir::ScheduledResult> &scheduled_results_cur,
ScheduleTask &schedule_task,
std::set<std::string> &scored_graph_names) {
ScheduleGroup schedule_group;
schedule_group.impl_graphs.reserve(schedule_outputs.size());
for (const auto &schedule_output : schedule_outputs) {
schedule_group.impl_graphs.emplace_back(schedule_output.scheduled_graph);
std::string autoschedule_graph_name =
RegisterScoreFuncInScheduleGroup(schedule_output, schedule_group, schedule_task.has_load_store_conversion);
scored_graph_names.insert(autoschedule_graph_name);
}
for (auto &res : scheduled_results_cur) {
res.schedule_groups.emplace_back(schedule_group);
}
return af::SUCCESS;
}
Status Optimizer::InitializeScheduledResults(std::vector<ascir::ScheduledResult> &scheduled_results_cur,
ScheduleTask &schedule_task) {
::ascir::ScheduledResult schedule_result;
schedule_result.score_func = schedule_task.score_func.c_str();
schedule_result.is_reduce_mem_reuse = schedule_task.reduce_type == ReduceTemplateType::kAllLoad;
schedule_result.cube_type = schedule_task.cube_type;
scheduled_results_cur.emplace_back(schedule_result);
return af::SUCCESS;
}
Status Optimizer::AutoScheduler([[maybe_unused]]const HintGraph &hint_graph, ScheduleTask &schedule_task,
std::vector<ascir::ScheduledResult> &scheduled_results) const {
size_t index = 0UL;
std::vector<ascir::ScheduledResult> scheduled_results_cur;
GE_ASSERT_SUCCESS(InitializeScheduledResults(scheduled_results_cur, schedule_task));
std::set<std::string> scored_graph_names;
for (auto &grouped_graph : schedule_task.grouped_graphs) {
GE_CHK_STATUS_RET(AscGraphInfoComplete::CompleteApiInfo(grouped_graph), "CompleteApiInfo failed");
if (CanDoReMergeAxis(grouped_graph)) {
GE_ASSERT_SUCCESS(RemoveAllZeroStrideLoopAxis(grouped_graph), "Remove All zero stride axis failed.");
GE_ASSERT_SUCCESS(MergeContinuousAxis(grouped_graph, schedule_task.cube_type),
"Merge continuous axes failed.");
}
GE_ASSERT_SUCCESS(ScheduleUtils::ClearAllSizeVar(grouped_graph));
SizeVarSet original_var_set;
AscGraphInfoComplete::AppendOriginalSizeVar(grouped_graph, original_var_set);
for (const auto &exp: original_var_set) {
GE_ASSERT_GRAPH_SUCCESS(grouped_graph.CreateSizeVar(exp));
}
ascir::utils::DumpGraph(grouped_graph, "BeforeAutoSchedule");
GELOGI("AutoScheduler start: %s", grouped_graph.GetName().c_str());
if (ScheduleUtils::HasComputeType(grouped_graph, af::ComputeType::kComputeCube)) {
GE_ASSERT_SUCCESS(ProcessCubeSchedules(scheduled_results_cur, grouped_graph));
continue;
}
bool is_reduce_first_stage = IsReduceFirstStage(index, schedule_task);
std::vector<autoschedule::AutoScheduleOutput> schedule_outputs;
auto scheduler = autoschedule::AutoSchedule(grouped_graph, schedule_outputs, is_reduce_first_stage,
schedule_task.reduce_type, schedule_task.cube_type);
GE_CHK_STATUS_RET(scheduler.DoAutoSchedule(), "Failed to do schedule, graph:[%s].",
grouped_graph.GetName().c_str());
GE_ASSERT_TRUE(!schedule_outputs.empty(), "Failed to gen tiling case for graph:[%s].",
grouped_graph.GetName().c_str());
GELOGI("AutoScheduler end: %s, number of tiling cases = %zu", grouped_graph.GetName().c_str(),
schedule_outputs.size());
if (is_reduce_first_stage) {
std::vector<ascir::ScheduledResult> scheduled_results_tmp;
for (auto &schedule_output : schedule_outputs) {
ScheduleGroup schedule_group = {{schedule_output.scheduled_graph}, {}};
for (auto &d : scheduled_results_cur) {
d.schedule_groups.emplace_back(schedule_group);
RefreshGroupRelation(index, schedule_output.var_relations_, schedule_task, d);
scheduled_results_tmp.emplace_back(d);
d.schedule_groups.pop_back();
}
RegisterScoreFuncInScheduleGroup(schedule_output, schedule_group);
}
scheduled_results_tmp.swap(scheduled_results_cur);
} else {
if (schedule_task.reduce_type == ReduceTemplateType::kRCore) {
GE_ASSERT_SUCCESS(CopyImplGraphs(schedule_outputs, scheduled_results_cur));
} else {
GE_ASSERT_SUCCESS(
ProcessNonReduceSchedules(schedule_outputs, scheduled_results_cur, schedule_task, scored_graph_names));
}
}
index++;
}
scheduled_results.insert(scheduled_results.end(), scheduled_results_cur.begin(), scheduled_results_cur.end());
FilterComplexTilingDataScoreFuncs(scheduled_results, scored_graph_names);
return af::SUCCESS;
}
void Optimizer::TryEnableGroupParallel(FusedScheduledResult &fused_scheduled_result) {
if (fused_scheduled_result.node_idx_to_scheduled_results.size() == 1UL && fused_scheduled_result.workspace_nodes.empty()) {
for (auto &schedule_result : fused_scheduled_result.node_idx_to_scheduled_results.front()) {
if (schedule_result.cube_type == CubeTemplateType::kCommon) {
schedule_result.enable_group_parallel = false;
return;
}
schedule_result.enable_group_parallel =
schedule_result.schedule_groups.size() > 1UL && schedule_result.var_relations.empty();
}
}
}
void Optimizer::ExecSeqAdvancedOfLoad(const FusedScheduledResult &fused_scheduled_result) {
for (auto &scheduled_results : fused_scheduled_result.node_idx_to_scheduled_results) {
for (auto &schedule_result : scheduled_results) {
if (schedule_result.cube_type == ascir::CubeTemplateType::kUBFuse) {
continue;
}
for (auto &schedule_group : schedule_result.schedule_groups) {
for (auto &impl_graph : schedule_group.impl_graphs) {
LoadOpSeqAdjust(impl_graph);
}
}
}
}
}