* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <queue>
#include "graph/utils/graph_utils.h"
#include "graph/symbolizer/symbolic_utils.h"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "ascir_ops.h"
#include "autoschedule/axis_group.h"
#include "ascir_utils.h"
#include "node_utils.h"
#include "reduce_schedule_case_generator.h"
#include "register/op_def_factory.h"
#include "base/err_msg.h"
#include "graph/symbolizer/symbolic.h"
namespace optimize {
namespace {
size_t TWO = 2;
size_t kMaxFullLoadAxisSize = 3UL;
size_t NODE_COUNT_AFTER_REDUCE = 4UL;
std::string GetNewNodeName(const af::AscNodePtr &src_node, const af::AscNodePtr &dst_node,
const std::string &type, int32_t idx) {
return src_node->GetName() + "_to_" + dst_node->GetName() + "_" + type + "_" + to_string(idx);
}
Status DoCopyAscNodeTensorAttr(const af::AscNodePtr &src_node, af::AscNodePtr &dst_node) {
auto op_desc = dst_node->GetOpDesc();
auto dst_asc_node_attr = op_desc->GetOrCreateAttrsGroup<af::AscNodeAttr>();
auto src_asc_node_attr = src_node->GetOpDesc()->GetOrCreateAttrsGroup<af::AscNodeAttr>();
if (src_asc_node_attr != nullptr && dst_asc_node_attr != nullptr) {
dst_asc_node_attr->sched = src_asc_node_attr->sched;
if (src_asc_node_attr->ir_attr) {
dst_asc_node_attr->ir_attr = src_asc_node_attr->ir_attr->Clone();
}
}
for (size_t i = 0U; i < src_node->outputs().size(); i++) {
GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(i));
auto tensor_attr_group = op_desc->MutableOutputDesc(i)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
GE_ASSERT_NOTNULL(tensor_attr_group);
*tensor_attr_group = src_node->outputs[i].attr;
}
return ge::SUCCESS;
}
Status DoCopyWorkspaceTensorAttr(const af::AscNodePtr &src_node, af::AscNodePtr &workspace_node) {
GE_ASSERT_NOTNULL(src_node);
GE_ASSERT_NOTNULL(workspace_node);
GE_ASSERT_TRUE(!src_node->outputs().empty());
GE_ASSERT_TRUE(!workspace_node->outputs().empty());
workspace_node->outputs[0].attr.dtype = src_node->outputs[0].attr.dtype;
return ge::SUCCESS;
}
const std::unordered_map<std::string, std::function<ReduceType(const char*)>> reducers = {
{"Max", [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::Max>{}, n}; }},
{"ArgMax", [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::ArgMaxMultiRPhase2>{}, n}; }},
{"ArgMaxMultiRPhase1", [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::ArgMaxMultiRPhase2>{}, n}; }},
{"ArgMaxMultiRPhase2", [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::ArgMaxMultiRPhase2>{}, n}; }},
{"Sum", [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::Sum>{}, n}; }},
{"Mean", [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::Sum>{}, n}; }},
{"Min", [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::Min>{}, n}; }},
{"Prod", [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::Prod>{}, n}; }},
{"Any", [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::Any>{}, n}; }},
{"All", [](const char* n) { return ReduceType{std::in_place_type_t<af::ascir_op::All>{}, n}; }}
};
bool IsNotPartitionReduce(const af::AscNodePtr &reduce_node, size_t threshold) {
std::queue<af::NodePtr> node_queue;
size_t node_count = 0UL;
std::unordered_set<const af::Node*> visited;
visited.insert(reduce_node.get());
for (const auto &reduce_out_node : reduce_node->GetOutNodes()) {
node_queue.emplace(reduce_out_node);
visited.insert(reduce_out_node.get());
}
while (!node_queue.empty()) {
const auto current_node = node_queue.front();
node_queue.pop();
if (current_node->GetInDataNodesSize() > 1UL) {
for (const auto &in_current_node : current_node->GetInDataNodes()) {
if (visited.find(in_current_node.get()) == visited.end()) {
GELOGW("Node [%s] has multiple inputs, with input node [%s] not being a post-reduction node.",
current_node->GetNamePtr(), in_current_node->GetNamePtr());
return false;
}
}
}
node_count += 1UL;
if (node_count > threshold) {
GELOGW("The total count of nodes after the reduce node[%s](including the store node) is above the threshold[%zu].",
reduce_node->GetNamePtr(), threshold);
return false;
}
const auto &asc_current_node = std::dynamic_pointer_cast<af::AscNode>(current_node);
GE_ASSERT_NOTNULL(asc_current_node);
if (af::ops::IsOps<af::ascir_op::Store>(asc_current_node) ||
asc_current_node->attr.api.type == af::ApiType::kAPITypeBuffer) {
continue;
}
if (!ScheduleUtils::IsElewise(asc_current_node)) {
GELOGW("The node[%s] after the reduce node[%s] is not elewise type.",
asc_current_node->GetNamePtr(), reduce_node->GetNamePtr());
return false;
}
if (node_count > threshold - 1UL) {
GELOGW("The count of nodes after the reduce node[%s] is above the threshold[%zu].",
reduce_node->GetNamePtr(), threshold - 1UL);
return false;
}
for (const auto &next_node : current_node->GetOutAllNodes()) {
if (visited.find(next_node.get()) == visited.end()) {
visited.insert(next_node.get());
node_queue.emplace(next_node);
}
}
}
return true;
}
}
Status ReducePartitionCaseGenerator::GeneratorGeneralTask(ascir::HintGraph &optimize_graph,
std::vector<ScheduleTask> &tasks) {
std::vector<ascir::ImplGraph> optimize_graphs;
std::vector<std::string> score_funcs;
GE_CHK_STATUS_RET(GenerateGeneralCase(optimize_graph, optimize_graphs, score_funcs), "GenerateScheduleCases failed");
score_funcs.resize(optimize_graphs.size());
for (size_t i = 0U; i < optimize_graphs.size(); ++i) {
const auto &graph = optimize_graphs[i];
ScheduleTask task{graph, {}, score_funcs[i], {}, ReduceTemplateType::kCommon};
GE_CHK_STATUS_RET(ScheduleGroupGraphPartitioner::PartitionByConnectivity(graph, task.grouped_graphs, node_order_),
"Failed to partition graph");
tasks.emplace_back(std::move(task));
}
return ge::GRAPH_SUCCESS;
}
Status ReducePartitionCaseGenerator::GeneratorAllLoadTask(ascir::HintGraph &optimize_graph,
std::vector<ScheduleTask> &tasks) {
if (!CanReduceFuse(optimize_graph)) {
return ge::GRAPH_SUCCESS;
}
std::vector<ascir::ImplGraph> optimize_graphs;
std::vector<std::string> score_funcs;
GE_CHK_STATUS_RET(GenerateAllLoadCase(optimize_graph, optimize_graphs, score_funcs), "GenerateScheduleCases failed");
score_funcs.resize(optimize_graphs.size());
for (size_t i = 0U; i < optimize_graphs.size(); ++i) {
const auto &graph = optimize_graphs[i];
ScheduleTask task{graph, {}, score_funcs[i], {}, ReduceTemplateType::kAllLoad};
GE_CHK_STATUS_RET(ScheduleGroupGraphPartitioner::PartitionByConnectivity(graph, task.grouped_graphs, node_order_),
"Failed to partition graph");
tasks.emplace_back(std::move(task));
}
return ge::GRAPH_SUCCESS;
}
Status ReducePartitionCaseGenerator::GeneratorRCoreTask(ascir::HintGraph &optimize_graph,
std::vector<ScheduleTask> &tasks) const {
std::vector<ScheduleTask> new_tasks;
for (const auto &task : tasks) {
if (task.reduce_type != ReduceTemplateType::kCommon) {
continue;
}
std::vector<::ascir::ImplGraph> new_task_grouped_graphs;
std::map<size_t, std::vector<size_t>> map;
size_t phase_2_graph_size = 0;
for (size_t i = 0; i < task.grouped_graphs.size(); i++) {
GE_ASSERT_TRUE(IsGroupGraphLegal(task.grouped_graphs[i]));
if (!HasReduce(task.grouped_graphs[i])) {
::ascir::ImplGraph graph((task.grouped_graphs[i].GetName() + "_r_multicore").c_str());
graph.CopyFrom(task.grouped_graphs[i]);
new_task_grouped_graphs.emplace_back(std::move(graph));
continue;
}
ascir::ImplGraph phase_graph((task.grouped_graphs[i].GetName() + "_r_multicore_phase_graph").c_str());
phase_graph.CopyFrom(task.grouped_graphs[i]);
af::AscNodePtr reduce_node;
for (auto node : phase_graph.GetAllNodes()) {
if (ScheduleUtils::IsReduce(node)) {
reduce_node = node;
break;
}
}
GE_CHECK_NOTNULL(reduce_node);
GE_ASSERT_TRUE(IsNotPartitionReduce(reduce_node, NODE_COUNT_AFTER_REDUCE));
ascir::ImplGraph phase_1_graph((task.grouped_graphs[i].GetName() + "_r_multicore_phase_1_graph").c_str());
ascir::ImplGraph phase_2_graph((task.grouped_graphs[i].GetName() + "_r_multicore_phase_2_graph").c_str());
GE_CHK_STATUS_RET(RMulticorePhase2Graph(phase_2_graph, phase_1_graph, phase_graph, reduce_node).Construct());
ascir::utils::DumpGraph(phase_2_graph, "phase2graph_construct");
new_task_grouped_graphs.emplace_back(std::move(phase_1_graph));
new_task_grouped_graphs.emplace_back(std::move(phase_2_graph));
map[i + phase_2_graph_size] = {i + phase_2_graph_size + 1};
phase_2_graph_size++;
}
if (phase_2_graph_size == 0) {
continue;
}
ScheduleTask new_task{optimize_graph, new_task_grouped_graphs, task.score_func, map, ReduceTemplateType::kRCore};
new_tasks.push_back(new_task);
}
tasks.insert(tasks.end(), new_tasks.begin(), new_tasks.end());
return ge::GRAPH_SUCCESS;
}
Status ReducePartitionCaseGenerator::GeneratorTask(ascir::HintGraph &optimize_graph, std::vector<ScheduleTask> &tasks,
const OptimizerOptions &options) {
bool is_norm_like_reduce = autoschedule::IsNormLikeReduceGraph(optimize_graph);
(void)options;
if (is_norm_like_reduce) {
GELOGI("Graph %s satisfies norm-like reduce conditions, only generate AllLoad tasks", optimize_graph.GetName().c_str());
GE_CHK_STATUS_RET(GeneratorAllLoadTask(optimize_graph, tasks));
} else {
GELOGI("Graph %s does not satisfy norm-like reduce conditions, use general strategy", optimize_graph.GetName().c_str());
GE_CHK_STATUS_RET(GeneratorGeneralTask(optimize_graph, tasks));
GE_CHK_STATUS_RET(GeneratorRCoreTask(optimize_graph, tasks));
GE_CHK_STATUS_RET(GeneratorAllLoadTask(optimize_graph, tasks));
}
return ge::GRAPH_SUCCESS;
}
Status ReducePartitionCaseGenerator::Generate([[maybe_unused]] ascir::HintGraph &graph,
[[maybe_unused]] std::vector<ascir::ImplGraph> &graphs,
[[maybe_unused]] std::vector<std::string> &score_functions) {
return ge::GRAPH_SUCCESS;
}
Status ReducePartitionCaseGenerator::GenerateGeneralCase(ascir::HintGraph &graph,
std::vector<ascir::ImplGraph> &graphs,
std::vector<std::string> &score_functions) {
if (!HasReduce(graph)) {
return ge::GRAPH_SUCCESS;
}
ascir::ImplGraph optimize_graph(graph.GetName().c_str());
optimize_graph.CopyFrom(graph);
std::vector<std::pair<af::AscNodePtr, af::AscNodePtr>> loop_start_end;
for (auto node : optimize_graph.GetAllNodes()) {
if (node->GetOutDataNodes().empty()) {
node_order_.emplace_back(node);
}
if (node->GetOutNodes().size() <= 1UL) {
continue;
}
std::vector<af::AscNodePtr> loop_ends;
FindNormLoop(node, loop_ends);
for (const auto &end : loop_ends) {
loop_start_end.emplace_back(node, end);
}
}
std::sort(loop_start_end.begin(), loop_start_end.end(), [](
const std::pair<af::AscNodePtr, af::AscNodePtr> &lhs, const std::pair<af::AscNodePtr, af::AscNodePtr> &rhs) {
return lhs.second->GetOpDescBarePtr()->GetId() < rhs.second->GetOpDescBarePtr()->GetId();
});
GE_CHK_STATUS_RET(ReducePartitionPostFusion(optimize_graph));
GE_CHK_STATUS_RET(PartitionNorm(optimize_graph, loop_start_end));
GE_CHK_STATUS_RET(ReducePartitionMultipleCitations(optimize_graph));
if (partition_) {
std::sort(node_order_.begin(), node_order_.end(), [](const af::AscNodePtr &lhs, af::AscNodePtr &rhs) {
return lhs->GetOpDescBarePtr()->GetId() < rhs->GetOpDescBarePtr()->GetId();
});
ascir::utils::DumpGraph(graph, "before_partition");
ascir::utils::DumpGraph(optimize_graph, "after_partition");
graphs.emplace_back(optimize_graph);
score_functions.resize(graphs.size());
} else {
node_order_.clear();
graphs.emplace_back(graph);
}
return ge::GRAPH_SUCCESS;
}
Status ReducePartitionCaseGenerator::GenerateAllLoadCase(ascir::HintGraph &graph,
std::vector<ascir::ImplGraph> &graphs,
const std::vector<std::string> &score_functions) {
(void)score_functions;
if (!HasReduce(graph)) {
return ge::GRAPH_SUCCESS;
}
node_order_.clear();
graphs.emplace_back(graph);
return ge::GRAPH_SUCCESS;
}
Status ReducePartitionCaseGenerator::ReducePartitionMultipleCitations(ascir::ImplGraph &impl_graph) {
if (IsGroupGraphLegal(impl_graph)) {
return ge::GRAPH_SUCCESS;
}
std::vector<af::AscNodePtr> multi_output_nodes;
for (auto node : impl_graph.GetAllNodes()) {
if (node->GetOutNodes().size() > 1UL) {
multi_output_nodes.emplace_back(node);
}
}
std::sort(multi_output_nodes.begin(), multi_output_nodes.end(), [](const af::AscNodePtr &lhs, af::AscNodePtr &rhs) {
return lhs->GetOpDescBarePtr()->GetId() > rhs->GetOpDescBarePtr()->GetId();
});
for (auto node : multi_output_nodes) {
std::set<af::AscNodePtr> reduce_nodes;
for (const auto &output_node : node->GetOutNodes()) {
af::AscNodePtr out_asc_node = std::dynamic_pointer_cast<af::AscNode>(output_node);
if (af::AscNodePtr reduce_node = nullptr; FindOutputReduce(out_asc_node, reduce_node)) {
if (!reduce_nodes.empty() && reduce_nodes.find(reduce_node) == reduce_nodes.end()) {
PartitionByNode(node, out_asc_node, impl_graph);
}
reduce_nodes.emplace(reduce_node);
}
}
}
return ge::GRAPH_SUCCESS;
}
bool ReducePartitionCaseGenerator::FindOutputReduce(const af::AscNodePtr &node, af::AscNodePtr &reduce_node) {
if (ScheduleUtils::IsReduce(node)) {
reduce_node = node;
return true;
}
bool output_has_reduce = false;
if (node->GetOutNodes().empty()) {
return output_has_reduce;
}
for (const auto &output_node : node->GetOutNodes()) {
auto output_asc_node = std::dynamic_pointer_cast<af::AscNode>(output_node);
output_has_reduce = output_has_reduce || FindOutputReduce(output_asc_node, reduce_node);
}
return output_has_reduce;
}
Status ReducePartitionCaseGenerator::PartitionReduce(af::AscNodePtr &src_node, ascir::ImplGraph &impl_graph) {
partition_ = true;
node_order_.emplace_back(src_node);
af::ascir_op::Workspace workspace_pre((src_node->GetName() + "_Workspace").c_str());
af::ascir_op::Workspace workspace_post((src_node->GetName() + "_Workspace").c_str());
af::ascir_op::Load load((src_node->GetName() + "_Load").c_str());
af::ascir_op::Store store((src_node->GetName() + "_Store").c_str());
auto workspace_pre_node = impl_graph.AddNode(workspace_pre);
auto workspace_post_node = impl_graph.AddNode(workspace_post);
auto load_node = impl_graph.AddNode(load);
auto store_node = impl_graph.AddNode(store);
GE_CHK_STATUS_RET(DoCopyAscNodeTensorAttr(src_node, load_node));
GE_CHK_STATUS_RET(DoCopyAscNodeTensorAttr(src_node, store_node));
GE_CHK_STATUS_RET(DoCopyWorkspaceTensorAttr(store_node, workspace_pre_node));
GE_CHK_STATUS_RET(DoCopyWorkspaceTensorAttr(load_node, workspace_post_node));
for (const auto &out_anchor : src_node->GetAllOutDataAnchors()) {
GE_CHK_BOOL_EXEC(out_anchor != nullptr,
REPORT_INNER_ERR_MSG("E18888", "out data anchor is null, node:%s.", src_node->GetName().c_str());
return ge::GRAPH_FAILED, "[Check][Param] Out data anchor is null, node:%s",
src_node->GetName().c_str());
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_CHECK_NOTNULL(peer_in_anchor);
auto dst_node = peer_in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node, "peer node is null, src node: %s", src_node->GetNamePtr());
GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(src_node->GetOutAnchor(out_anchor->GetIdx()),
dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(load_node->GetOutAnchor(out_anchor->GetIdx()),
dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
}
}
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(src_node->GetOutAnchor(0UL), store_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(store_node->GetOutAnchor(0UL), workspace_pre_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(workspace_post_node->GetOutAnchor(0UL), load_node->GetInAnchor(0UL)));
return ge::GRAPH_SUCCESS;
}
Status ReducePartitionCaseGenerator::ReducePartitionPostFusion(ascir::ImplGraph &impl_graph) {
for (auto node : impl_graph.GetAllNodes()) {
if (ScheduleUtils::IsReduce(node)) {
if (IsNotPartitionReduce(node, NODE_COUNT_AFTER_REDUCE)) {
continue;
}
GE_CHK_STATUS_RET(PartitionReduce(node, impl_graph));
}
}
return ge::GRAPH_SUCCESS;
}
Status ReducePartitionCaseGenerator::PartitionByNode(af::AscNodePtr &src_node, af::AscNodePtr &dst_node,
ascir::ImplGraph &impl_graph) {
partition_ = true;
node_order_.emplace_back(src_node);
if (ScheduleUtils::IsLoad(src_node)) {
return PartitionLoad(src_node, dst_node, impl_graph);
}
if (ScheduleUtils::IsScalarLikeNode(src_node)) {
return PartitionScalar(src_node, dst_node, impl_graph);
};
for (const auto &out_anchor : src_node->GetAllOutDataAnchors()) {
GE_CHK_BOOL_EXEC(out_anchor != nullptr,
REPORT_INNER_ERR_MSG("E18888", "out data anchor is null, node:%s.", src_node->GetName().c_str());
return ge::GRAPH_FAILED, "[Check][Param] Out data anchor is null, node:%s",
src_node->GetName().c_str());
af::ascir_op::Workspace workspace_pre(GetNewNodeName(src_node, dst_node, "Workspace", out_anchor->GetIdx()).c_str());
af::ascir_op::Workspace workspace_post(GetNewNodeName(src_node, dst_node, "Workspace", out_anchor->GetIdx()).c_str());
af::ascir_op::Load load(GetNewNodeName(src_node, dst_node, "Load", out_anchor->GetIdx()).c_str());
af::ascir_op::Store store(GetNewNodeName(src_node, dst_node, "Store", out_anchor->GetIdx()).c_str());
auto workspace_pre_node = impl_graph.AddNode(workspace_pre);
auto workspace_post_node = impl_graph.AddNode(workspace_post);
auto load_node = impl_graph.AddNode(load);
auto store_node = impl_graph.AddNode(store);
GE_CHK_STATUS_RET(DoCopyAscNodeTensorAttr(src_node, load_node));
GE_CHK_STATUS_RET(DoCopyAscNodeTensorAttr(src_node, store_node));
GE_CHK_STATUS_RET(DoCopyWorkspaceTensorAttr(store_node, workspace_pre_node));
GE_CHK_STATUS_RET(DoCopyWorkspaceTensorAttr(load_node, workspace_post_node));
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_CHECK_NOTNULL(peer_in_anchor);
GE_CHK_BOOL_EXEC(peer_in_anchor->GetOwnerNodeBarePtr() != nullptr,
REPORT_INNER_ERR_MSG("E18888", "Peer in node:%s is null", src_node->GetName().c_str());
return ge::GRAPH_FAILED, "Peer in node:%s is null", src_node->GetName().c_str());
if (peer_in_anchor->GetOwnerNodeBarePtr()->GetName() == dst_node->GetName()) {
GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(src_node->GetOutAnchor(out_anchor->GetIdx()),
dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(load_node->GetOutAnchor(0UL),
dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
}
}
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(src_node->GetOutAnchor(out_anchor->GetIdx()),
store_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(store_node->GetOutAnchor(0UL),
workspace_pre_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(workspace_post_node->GetOutAnchor(0UL),
load_node->GetInAnchor(0UL)));
}
return ge::GRAPH_SUCCESS;
}
Status ReducePartitionCaseGenerator::PartitionLoad(af::AscNodePtr &src_node, af::AscNodePtr &dst_node,
ascir::ImplGraph &impl_graph) {
auto load_input_node = src_node->GetInNodes().at(0UL);
auto load_input_asc_node = std::dynamic_pointer_cast<af::AscNode>(load_input_node);
GE_ASSERT_TRUE(af::ops::IsOps<af::ascir_op::Data>(load_input_asc_node) || af::ops::IsOps<af::ascir_op::Workspace>(load_input_asc_node));
af::ascir_op::Load load(("copy_from_" + src_node->GetName()).c_str());
af::AscNodePtr new_load_input_node;
if (af::ops::IsOps<af::ascir_op::Data>(load_input_asc_node)) {
af::ascir_op::Data data(("copy_from_" + load_input_asc_node->GetName()).c_str());
new_load_input_node = impl_graph.AddNode(data);
} else {
af::ascir_op::Workspace workspace(("copy_from_" + load_input_asc_node->GetName()).c_str());
new_load_input_node = impl_graph.AddNode(workspace);
}
auto load_node = impl_graph.AddNode(load);
DoCopyAscNodeTensorAttr(load_input_asc_node, new_load_input_node);
DoCopyAscNodeTensorAttr(src_node, load_node);
for (const auto &out_anchor : src_node->GetAllOutDataAnchors()) {
GE_CHK_BOOL_EXEC(out_anchor != nullptr,
REPORT_INNER_ERR_MSG("E18888", "out data anchor is null, node:%s.", src_node->GetName().c_str());
return ge::GRAPH_FAILED, "[Check][Param] Out data anchor is null, node:%s", src_node->GetName().c_str());
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_CHECK_NOTNULL(peer_in_anchor);
GE_CHK_BOOL_EXEC(peer_in_anchor->GetOwnerNodeBarePtr() != nullptr,
REPORT_INNER_ERR_MSG("E18888", "Peer in node:%s is null", src_node->GetName().c_str());
return ge::GRAPH_FAILED, "Peer in node:%s is null", src_node->GetName().c_str());
if (peer_in_anchor->GetOwnerNodeBarePtr()->GetName() == dst_node->GetName()) {
GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(src_node->GetOutAnchor(out_anchor->GetIdx()),
dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(new_load_input_node->GetOutAnchor(0UL),
load_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(load_node->GetOutAnchor(0UL),
dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
return ge::GRAPH_SUCCESS;
}
}
}
return ge::GRAPH_SUCCESS;
}
Status ReducePartitionCaseGenerator::PartitionScalar(af::AscNodePtr &src_node, af::AscNodePtr &dst_node,
ascir::ImplGraph &impl_graph) {
af::AscNodePtr scalar_node;
if (af::ops::IsOps<af::ascir_op::ScalarData>(src_node)) {
af::ascir_op::ScalarData scalar_data(("copy_from_" + src_node->GetName()).c_str());
scalar_node = impl_graph.AddNode(scalar_data);
} else {
af::ascir_op::Scalar scalar(("copy_from_" + src_node->GetName()).c_str());
scalar_node = impl_graph.AddNode(scalar);
}
DoCopyAscNodeTensorAttr(src_node, scalar_node);
for (const auto &out_anchor : src_node->GetAllOutDataAnchors()) {
GE_CHK_BOOL_EXEC(out_anchor != nullptr,
REPORT_INNER_ERR_MSG("E18888", "out data anchor is null, node:%s.", src_node->GetName().c_str());
return ge::GRAPH_FAILED, "[Check][Param] Out data anchor is null, node:%s", src_node->GetName().c_str());
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_CHECK_NOTNULL(peer_in_anchor);
GE_CHK_BOOL_EXEC(peer_in_anchor->GetOwnerNodeBarePtr() != nullptr,
REPORT_INNER_ERR_MSG("E18888", "Peer in node:%s is null", src_node->GetName().c_str());
return ge::GRAPH_FAILED, "Peer in node:%s is null", src_node->GetName().c_str());
if (peer_in_anchor->GetOwnerNodeBarePtr()->GetName() == dst_node->GetName()) {
GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(src_node->GetOutAnchor(out_anchor->GetIdx()),
dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(scalar_node->GetOutAnchor(0UL),
dst_node->GetInAnchor(peer_in_anchor->GetIdx())));
return ge::GRAPH_SUCCESS;
}
}
}
return ge::GRAPH_SUCCESS;
}
bool ReducePartitionCaseGenerator::IsInputNodePartitioned(const std::shared_ptr<af::Node>& start,
const std::shared_ptr<af::Node>& node) {
if (node == start) {
return false;
}
bool partitioned = true;
if (node->GetInNodes().empty()) {
return partitioned;
}
for (const auto &in_node : node->GetInNodes()) {
partitioned = partitioned && IsInputNodePartitioned(start, in_node);
}
return partitioned;
}
Status ReducePartitionCaseGenerator::FindNormLoop(const af::AscNodePtr &start, std::vector<af::AscNodePtr> &ends) {
std::set<af::NodePtr> visited{start};
std::list<af::NodePtr> next_nodes{start};
while (!next_nodes.empty()) {
const auto node = next_nodes.front();
next_nodes.pop_front();
for (auto &out_node : node->GetOutDataNodes()) {
if (visited.find(out_node) == visited.cend()) {
next_nodes.emplace_back(out_node);
visited.emplace(out_node);
} else {
auto asc_node = std::dynamic_pointer_cast<af::AscNode>(out_node);
ends.emplace_back(asc_node);
}
}
}
ends.erase(
std::remove_if(ends.begin(), ends.end(),
[start, this](const af::AscNodePtr &end) { return !IsNorm(start, end); }), ends.end()
);
return ge::GRAPH_SUCCESS;
}
void ReducePartitionCaseGenerator::FindAllPath(const af::AscNodePtr& start, const af::AscNodePtr& end,
std::vector<af::AscNodePtr> &path,
std::vector<std::vector<af::AscNodePtr>> &all_paths) {
path.emplace_back(start);
if (start == end) {
all_paths.emplace_back(path);
}
for (const auto &out_node : start->GetOutNodes()) {
const auto asc_node = std::dynamic_pointer_cast<af::AscNode>(out_node);
FindAllPath(asc_node, end, path, all_paths);
}
path.pop_back();
}
bool ReducePartitionCaseGenerator::IsNorm(const af::AscNodePtr &start, const af::AscNodePtr &end) {
std::vector<af::AscNodePtr> path;
std::vector<std::vector<af::AscNodePtr>> all_paths;
FindAllPath(start, end, path, all_paths);
bool is_norm = false;
std::set<af::AscNodePtr> end_in_nodes;
for (const auto &path_temp : all_paths) {
for (const auto &node : path_temp) {
if (ScheduleUtils::IsReduce(node)) {
is_norm = true;
break;
}
}
end_in_nodes.insert(path_temp.at(path_temp.size() - TWO));
}
return is_norm && end_in_nodes.size() > 1;
}
Status ReducePartitionCaseGenerator::PartitionNorm(ascir::ImplGraph &impl_graph, std::vector<std::pair<af::AscNodePtr,
af::AscNodePtr>> &loop_start_end) {
for (auto loop : loop_start_end) {
for (auto &in_node : loop.second->GetInNodes()) {
if(IsInputNodePartitioned(loop.first, in_node)) {
continue;
}
af::AscNodePtr src_node = std::dynamic_pointer_cast<af::AscNode>(in_node);
GE_CHK_STATUS_RET(PartitionByNode(src_node, loop.second, impl_graph));
}
}
return ge::GRAPH_SUCCESS;
}
bool ReducePartitionCaseGenerator::HasReduce(const ascir::ImplGraph &impl_graph) {
for (const auto &node : impl_graph.GetAllNodes()) {
if (ScheduleUtils::IsReduce(node)) {
return true;
}
}
return false;
}
bool ReducePartitionCaseGenerator::CanReduceFuse(const ascir::ImplGraph &impl_graph) {
std::vector<ascir::SizeExpr> temp_strides;
for (const auto &node : impl_graph.GetAllNodes()) {
if (!ScheduleUtils::IsReduce(node)) {
continue;
}
std::vector<ascir::SizeExpr> input_repeats = node->inputs[0].attr.repeats;
std::vector<ascir::SizeExpr> output_repeats = node->outputs[0].attr.repeats;
GE_ASSERT_TRUE(input_repeats.size() == output_repeats.size());
if (output_repeats.empty() || (output_repeats.size() > kMaxFullLoadAxisSize)) {
return false;
}
if (af::SymbolicUtils::StaticCheckEq(input_repeats[0], output_repeats[0]) != af::TriBool::kTrue) {
return false;
}
}
return true;
}
bool ReducePartitionCaseGenerator::IsGroupGraphLegal(const ascir::ImplGraph &impl_graph) {
int reduce_count = 0;
for (const auto &node : impl_graph.GetAllNodes()) {
if (ScheduleUtils::IsReduce(node)) {
reduce_count += 1;
}
}
return reduce_count <= 1;
}
Status RMulticorePhase2Graph::Construct() {
GE_ASSERT_TRUE(reducers.find(reduce_node->GetType()) != reducers.end());
ReduceType phase2graph_reduce = reducers.find(reduce_node->GetType())->
second((phase2graph.GetName() + "_" + reduce_node->GetName() + "_reduce").c_str());
std::visit([](auto&& reduce_op) {
reduce_op.attr.sched.axis = {0, 1};
}, phase2graph_reduce);
af::AscTensorDataType argmax_input_dtype;
if (reduce_node->GetType() == "ArgMax") {
GE_ASSERT_TRUE(!reduce_node->inputs().empty(), "ArgMax node should have at least 1 input");
argmax_input_dtype = reduce_node->inputs[0].attr.dtype;
}
GE_CHK_STATUS_RET(CompletePhaseGraph(phase2graph_reduce));
GE_CHK_STATUS_RET(CreateVarAxis());
auto workspace_node = phase2graph.FindNode((phase2graph.GetName() + "_workspace").c_str());
GE_ASSERT_NOTNULL(workspace_node);
workspace_node->attr.sched.axis = {0, 1};
af::AscTensorDataType workspace_dtype;
if (reduce_node->GetType() == "ArgMax") {
workspace_dtype = argmax_input_dtype;
} else {
workspace_dtype = reduce_node->outputs[0].attr.dtype;
}
GE_CHK_STATUS_RET(CompleteNodeAttr(workspace_node, true, workspace_dtype));
auto load_node = phase2graph.FindNode((phase2graph.GetName() + "_load").c_str());
GE_ASSERT_NOTNULL(load_node);
load_node->attr.sched.axis = {0, 1};
GE_CHK_STATUS_RET(CompleteNodeAttr(load_node, true, workspace_dtype));
GE_CHK_STATUS_RET(SetupArgMaxIndexNodes(reduce_node, phase2graph));
auto reduce_node_parse2graph = phase2graph.FindNode((phase2graph.GetName() + "_" + reduce_node->GetName() + "_reduce").c_str());
GE_ASSERT_NOTNULL(reduce_node_parse2graph);
std::set<af::NodePtr> visited{reduce_node_parse2graph};
std::list<af::NodePtr> next_nodes{reduce_node_parse2graph};
while (!next_nodes.empty()) {
const auto node = next_nodes.front();
next_nodes.pop_front();
auto asc_node = std::dynamic_pointer_cast<af::AscNode>(node);
asc_node->attr.sched.axis = {0, 1};
GE_CHK_STATUS_RET(CompleteNodeAttr(asc_node, false, asc_node->outputs[0].attr.dtype));
for (auto &out_node : node->GetOutDataNodes()) {
if (visited.find(out_node) == visited.cend()) {
next_nodes.emplace_back(out_node);
visited.emplace(out_node);
}
}
}
return ge::GRAPH_SUCCESS;
}
Status RMulticorePhase2Graph::CreateVarAxis() {
auto compute_graph = af::AscGraphUtils::GetComputeGraph(phase2graph);
GE_ASSERT_NOTNULL(compute_graph);
auto attr = compute_graph->GetOrCreateAttrsGroup<af::AscGraphAttr>();
GE_ASSERT_NOTNULL(attr);
attr->axis.clear();
attr->size_vars.clear();
Rm_org_size = phase2graph.CreateSizeVar("Rm_org_size");
A_org_size = phase2graph.CreateSizeVar("A_org_size");
phase2graph.CreateAxis("Rm", Rm_org_size);
phase2graph.CreateAxis("A", A_org_size);
return ge::GRAPH_SUCCESS;
}
Status RMulticorePhase2Graph::CompleteNodeAttr(af::AscNodePtr &node, bool before_reduce,
const af::AscTensorDataType& data_type) {
node->outputs[0].attr.dtype = data_type;
node->outputs[0].attr.axis = {0, 1};
if (before_reduce) {
node->outputs[0].attr.strides = {A_org_size, af::ops::One};
node->outputs[0].attr.repeats = {Rm_org_size, A_org_size};
} else {
node->outputs[0].attr.strides = {af::ops::Zero, af::ops::One};
node->outputs[0].attr.repeats = {af::ops::One, A_org_size};
}
return ge::GRAPH_SUCCESS;
}
Status RMulticorePhase2Graph::CompleteNodeAttrBeforeReduce(af::AscNodePtr &node) {
node->outputs[0].attr.axis = {0, 1};
node->outputs[0].attr.strides = {A_org_size, af::ops::One};
node->outputs[0].attr.repeats = {Rm_org_size, A_org_size};
return ge::GRAPH_SUCCESS;
}
Status RMulticorePhase2Graph::SetupArgMaxIndexNodes(const af::AscNodePtr &reduce_node,
ascir::ImplGraph &phase2graph) {
if (reduce_node->GetType() == "ArgMax") {
auto workspace_index_node = phase2graph.FindNode((phase2graph.GetName() + "_workspace_index").c_str());
if (workspace_index_node != nullptr) {
workspace_index_node->attr.sched.axis = {0, 1};
GE_CHK_STATUS_RET(CompleteNodeAttrBeforeReduce(workspace_index_node));
}
auto index_load_node = phase2graph.FindNode((phase2graph.GetName() + "_load_index").c_str());
if (index_load_node != nullptr) {
index_load_node->attr.sched.axis = {0, 1};
GE_CHK_STATUS_RET(CompleteNodeAttrBeforeReduce(index_load_node));
}
}
return ge::GRAPH_SUCCESS;
}
Status RMulticorePhase2Graph::CompletePhaseGraph(ReduceType &phase2graph_reduce) {
std::vector<af::AscNodePtr> node_order;
GE_ASSERT_GRAPH_SUCCESS(PartitionByReduce(phase_graph, phase2graph_reduce, node_order));
GE_ASSERT_GRAPH_SUCCESS(SetNodeOrder(node_order));
std::vector<::ascir::ImplGraph> sub_optimize_graphs;
GE_ASSERT_GRAPH_SUCCESS(ScheduleGroupGraphPartitioner::PartitionByConnectivity(phase_graph, sub_optimize_graphs, node_order));
GE_ASSERT_EQ(sub_optimize_graphs.size(), 2UL);
phase1graph.CopyFrom(sub_optimize_graphs[0]);
phase2graph.CopyFrom(sub_optimize_graphs[1]);
return ge::GRAPH_SUCCESS;
}
static Status SetupArgMaxPhase1OutputAttrs(const af::AscNodePtr &phase1_node,
const af::AscTensorAttr &input_attr,
const af::AscTensorAttr &output_attr) {
{
auto op_desc = phase1_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(0));
auto tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
GE_ASSERT_NOTNULL(tensor_attr_group);
*tensor_attr_group = output_attr;
tensor_attr_group->dtype = input_attr.dtype;
}
{
auto op_desc = phase1_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(1));
auto tensor_attr_group = op_desc->MutableOutputDesc(1)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
GE_ASSERT_NOTNULL(tensor_attr_group);
*tensor_attr_group = output_attr;
tensor_attr_group->dtype = ge::DT_INT64;
}
return ge::GRAPH_SUCCESS;
}
static void CopyNodeLevelAttrs(const af::AscNodePtr &dst_node, const af::AscNodePtr &src_node) {
auto dst_op_desc = dst_node->GetOpDesc();
auto dst_asc_node_attr = dst_op_desc->GetOrCreateAttrsGroup<af::AscNodeAttr>();
auto src_op_desc = src_node->GetOpDesc();
auto src_asc_node_attr = src_op_desc->GetOrCreateAttrsGroup<af::AscNodeAttr>();
if (src_asc_node_attr != nullptr && dst_asc_node_attr != nullptr) {
dst_asc_node_attr->sched = src_asc_node_attr->sched;
if (src_asc_node_attr->ir_attr) {
dst_asc_node_attr->ir_attr = src_asc_node_attr->ir_attr->Clone();
}
}
}
Status ReplaceArgMaxInPhase1(ascir::ImplGraph &phase_graph,
const af::AscNodePtr &argmax_node,
af::AscNodePtr &store_node,
af::AscNodePtr &workspace_pre_node,
af::AscNodePtr &workspace_pre_index_node_out) {
af::ascir_op::ArgMaxMultiRPhase1 argmax_phase1((argmax_node->GetName() + "_phase1").c_str());
auto new_argmax_phase1_node = phase_graph.AddNode(argmax_phase1);
CopyNodeLevelAttrs(new_argmax_phase1_node, argmax_node);
const auto &saved_input_attr = argmax_node->inputs[0].attr;
const auto &saved_output_attr = argmax_node->outputs[0].attr;
(void)SetupArgMaxPhase1OutputAttrs(new_argmax_phase1_node, saved_input_attr, saved_output_attr);
for (const auto &in_anchor : argmax_node->GetAllInDataAnchors()) {
GE_ASSERT_NOTNULL(in_anchor);
auto src_anchor = in_anchor->GetPeerOutAnchor();
if (src_anchor == nullptr) {
continue;
}
GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(src_anchor, in_anchor));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(src_anchor, new_argmax_phase1_node->GetInAnchor(in_anchor->GetIdx())));
}
af::ascir_op::Store store_index((argmax_node->GetName() + "_Store_index").c_str());
auto new_store_index_node = phase_graph.AddNode(store_index);
{
CopyNodeLevelAttrs(store_node, argmax_node);
auto op_desc = store_node->GetOpDesc();
auto tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
GE_ASSERT_NOTNULL(tensor_attr_group);
*tensor_attr_group = saved_output_attr;
tensor_attr_group->dtype = saved_input_attr.dtype;
}
{
CopyNodeLevelAttrs(new_store_index_node, argmax_node);
auto op_desc = new_store_index_node->GetOpDesc();
auto tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
GE_ASSERT_NOTNULL(tensor_attr_group);
*tensor_attr_group = saved_output_attr;
}
GE_ASSERT_GRAPH_SUCCESS(DoCopyWorkspaceTensorAttr(store_node, workspace_pre_node));
GE_ASSERT_GRAPH_SUCCESS(DoCopyWorkspaceTensorAttr(new_store_index_node, workspace_pre_index_node_out));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(new_argmax_phase1_node->GetOutAnchor(0UL),
store_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(store_node->GetOutAnchor(0UL),
workspace_pre_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(new_argmax_phase1_node->GetOutAnchor(1UL),
new_store_index_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(new_store_index_node->GetOutAnchor(0UL),
workspace_pre_index_node_out->GetInAnchor(0UL)));
auto compute_graph = af::AscGraphUtils::GetComputeGraph(phase_graph);
GE_ASSERT_GRAPH_SUCCESS(af::GraphUtils::RemoveNodeWithoutRelink(compute_graph, argmax_node));
return ge::GRAPH_SUCCESS;
}
struct ReduceSavedNodeAttrInfo {
const af::AscNodeAttr *attr_ptr;
af::AscNodeAttr attr_value;
};
struct ArgMaxIndexNodes {
af::AscNodePtr index_workspace_post_node;
af::AscNodePtr index_load_node;
};
struct ReducePhase1Nodes {
af::AscNodePtr store_node;
af::AscNodePtr workspace_pre_node;
};
struct ReducePhase2Nodes {
af::AscNodePtr workspace_post_node;
af::AscNodePtr load_node;
af::AscNodePtr new_reduce_node;
};
struct ArgMaxPartitionGraphContext {
ascir::ImplGraph &impl_graph;
std::vector<af::AscNodePtr> &node_order;
std::string graph_name;
};
static Status SetupArgMaxPhase2Attrs(const ReduceSavedNodeAttrInfo &saved_attr,
const ReducePhase2Nodes &phase2_nodes,
const af::AscTensorAttr &reduce_input_attr,
const af::AscTensorAttr &reduce_output_attr,
const ArgMaxIndexNodes &index_nodes) {
if (saved_attr.attr_ptr != nullptr) {
auto op_desc = phase2_nodes.new_reduce_node->GetOpDesc();
auto dst_asc_node_attr = op_desc->GetOrCreateAttrsGroup<af::AscNodeAttr>();
if (dst_asc_node_attr != nullptr) {
dst_asc_node_attr->sched = saved_attr.attr_value.sched;
if (saved_attr.attr_value.ir_attr) {
dst_asc_node_attr->ir_attr = saved_attr.attr_value.ir_attr->Clone();
}
}
}
{
auto op_desc = phase2_nodes.new_reduce_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(0));
auto tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
GE_ASSERT_NOTNULL(tensor_attr_group);
*tensor_attr_group = reduce_output_attr;
tensor_attr_group->dtype = ge::DT_INT64;
}
{
auto op_desc = phase2_nodes.workspace_post_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(0));
auto tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
GE_ASSERT_NOTNULL(tensor_attr_group);
tensor_attr_group->dtype = reduce_input_attr.dtype;
op_desc = phase2_nodes.load_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(0));
tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
GE_ASSERT_NOTNULL(tensor_attr_group);
tensor_attr_group->dtype = reduce_input_attr.dtype;
}
{
auto op_desc = index_nodes.index_workspace_post_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(0));
auto tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
GE_ASSERT_NOTNULL(tensor_attr_group);
tensor_attr_group->dtype = ge::DT_INT64;
op_desc = index_nodes.index_load_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc->MutableOutputDesc(0));
tensor_attr_group = op_desc->MutableOutputDesc(0)->GetOrCreateAttrsGroup<af::AscTensorAttr>();
GE_ASSERT_NOTNULL(tensor_attr_group);
tensor_attr_group->dtype = ge::DT_INT64;
}
return ge::GRAPH_SUCCESS;
}
static Status ConnectArgMaxPhase2Edges(af::OutDataAnchorPtr argmax_out_anchor,
const std::vector<af::InDataAnchorPtr> &argmax_out_edges,
const ReducePhase2Nodes &phase2_nodes,
const af::AscNodePtr &index_workspace_post_node,
const af::AscNodePtr &index_load_node) {
for (const auto &peer_in_anchor : argmax_out_edges) {
GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(argmax_out_anchor, peer_in_anchor));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(phase2_nodes.new_reduce_node->GetOutAnchor(0UL),
peer_in_anchor));
}
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(phase2_nodes.workspace_post_node->GetOutAnchor(0UL),
phase2_nodes.load_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(phase2_nodes.load_node->GetOutAnchor(0UL),
phase2_nodes.new_reduce_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(index_workspace_post_node->GetOutAnchor(0UL),
index_load_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(index_load_node->GetOutAnchor(0UL),
phase2_nodes.new_reduce_node->GetInAnchor(1UL)));
return ge::GRAPH_SUCCESS;
}
static Status HandleArgMaxPartition(const af::AscNodePtr &reduce_node,
ReducePhase1Nodes &phase1_nodes,
ReducePhase2Nodes &phase2_nodes,
const ArgMaxPartitionGraphContext &ctx) {
af::ascir_op::Workspace index_workspace_pre((ctx.graph_name + "_workspace_index").c_str());
af::ascir_op::Workspace index_workspace_post((ctx.graph_name + "_workspace_index").c_str());
af::AscNodePtr index_workspace_pre_node = ctx.impl_graph.AddNode(index_workspace_pre);
af::AscNodePtr index_workspace_post_node = ctx.impl_graph.AddNode(index_workspace_post);
ctx.node_order.emplace_back(index_workspace_pre_node);
af::ascir_op::Load index_load((ctx.graph_name + "_load_index").c_str());
af::AscNodePtr index_load_node = ctx.impl_graph.AddNode(index_load);
const auto &saved_reduce_input_attr = reduce_node->inputs[0].attr;
const auto &saved_reduce_output_attr = reduce_node->outputs[0].attr;
af::OutDataAnchorPtr argmax_out_anchor = reduce_node->GetOutDataAnchor(0UL);
std::vector<af::InDataAnchorPtr> argmax_out_edges;
if (argmax_out_anchor != nullptr) {
for (const auto &peer_in_anchor : argmax_out_anchor->GetPeerInDataAnchors()) {
GE_ASSERT_NOTNULL(peer_in_anchor);
argmax_out_edges.push_back(peer_in_anchor);
}
}
auto src_asc_node_attr = reduce_node->GetOpDesc()->GetOrCreateAttrsGroup<af::AscNodeAttr>();
af::AscNodeAttr saved_node_attr;
if (src_asc_node_attr != nullptr) {
saved_node_attr = *src_asc_node_attr;
}
GE_CHK_STATUS_RET(ReplaceArgMaxInPhase1(ctx.impl_graph, reduce_node, phase1_nodes.store_node,
phase1_nodes.workspace_pre_node, index_workspace_pre_node));
ReduceSavedNodeAttrInfo saved_attr_info = {src_asc_node_attr, saved_node_attr};
ArgMaxIndexNodes index_nodes = {index_workspace_post_node, index_load_node};
GE_CHK_STATUS_RET(SetupArgMaxPhase2Attrs(saved_attr_info, phase2_nodes, saved_reduce_input_attr,
saved_reduce_output_attr, index_nodes));
GE_ASSERT_GRAPH_SUCCESS(DoCopyWorkspaceTensorAttr(phase2_nodes.load_node, phase2_nodes.workspace_post_node));
GE_ASSERT_GRAPH_SUCCESS(DoCopyWorkspaceTensorAttr(index_load_node, index_workspace_post_node));
GE_CHK_STATUS_RET(ConnectArgMaxPhase2Edges(argmax_out_anchor, argmax_out_edges,
phase2_nodes, index_workspace_post_node,
index_load_node));
return ge::GRAPH_SUCCESS;
}
Status RMulticorePhase2Graph::PartitionByReduce(ascir::ImplGraph &impl_graph,
ReduceType &phase2graph_reduce,
std::vector<af::AscNodePtr> &node_order) {
af::ascir_op::Workspace workspace_pre((phase2graph.GetName() + "_workspace").c_str());
af::ascir_op::Workspace workspace_post((phase2graph.GetName() + "_workspace").c_str());
af::ascir_op::Load load((phase2graph.GetName() + "_load").c_str());
af::ascir_op::Store store((phase1graph.GetName() + "Store").c_str());
af::AscNodePtr new_reduce_node;
std::visit([&new_reduce_node, &impl_graph](auto&& reduce_op) {
new_reduce_node = impl_graph.AddNode(reduce_op);
}, phase2graph_reduce);
auto workspace_pre_node = impl_graph.AddNode(workspace_pre);
node_order.emplace_back(workspace_pre_node);
auto workspace_post_node = impl_graph.AddNode(workspace_post);
auto load_node = impl_graph.AddNode(load);
GE_ASSERT_NOTNULL(load_node);
auto store_node = impl_graph.AddNode(store);
GE_ASSERT_NOTNULL(store_node);
if (reduce_node->GetType() == "ArgMax") {
ReducePhase1Nodes phase1_nodes = {store_node, workspace_pre_node};
ReducePhase2Nodes phase2_nodes = {workspace_post_node, load_node, new_reduce_node};
ArgMaxPartitionGraphContext ctx = {impl_graph, node_order, phase2graph.GetName()};
GE_CHK_STATUS_RET(HandleArgMaxPartition(reduce_node, phase1_nodes, phase2_nodes, ctx));
} else {
GE_ASSERT_GRAPH_SUCCESS(DoCopyAscNodeTensorAttr(reduce_node, new_reduce_node));
GE_ASSERT_GRAPH_SUCCESS(DoCopyAscNodeTensorAttr(reduce_node, store_node));
GE_ASSERT_GRAPH_SUCCESS(DoCopyWorkspaceTensorAttr(reduce_node, workspace_pre_node));
GE_ASSERT_GRAPH_SUCCESS(DoCopyWorkspaceTensorAttr(load_node, workspace_post_node));
for (const auto &reduce_out_anchor : reduce_node->GetAllOutDataAnchors()) {
GE_ASSERT_NOTNULL(reduce_out_anchor);
for (const auto &peer_in_anchor : reduce_out_anchor->GetPeerInDataAnchors()) {
GE_ASSERT_NOTNULL(peer_in_anchor);
auto reduce_out_node = peer_in_anchor->GetOwnerNode();
GE_ASSERT_NOTNULL(reduce_out_node);
GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(reduce_node->GetOutAnchor(reduce_out_anchor->GetIdx()),
reduce_out_node->GetInAnchor(peer_in_anchor->GetIdx())));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(new_reduce_node->GetOutAnchor(reduce_out_anchor->GetIdx()),
reduce_out_node->GetInAnchor(peer_in_anchor->GetIdx())));
}
}
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(reduce_node->GetOutAnchor(0UL),
store_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(store_node->GetOutAnchor(0UL),
workspace_pre_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(workspace_post_node->GetOutAnchor(0UL),
load_node->GetInAnchor(0UL)));
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(load_node->GetOutAnchor(0UL),
new_reduce_node->GetInAnchor(0UL)));
}
return ge::GRAPH_SUCCESS;
}
Status RMulticorePhase2Graph::SetNodeOrder (std::vector<af::AscNodePtr> &node_order) {
auto new_reduce_node = phase_graph.FindNode((phase2graph.GetName() + "_" + reduce_node->GetName() + "_reduce").c_str());
GE_ASSERT_NOTNULL(new_reduce_node);
std::set<af::NodePtr> visited{new_reduce_node};
std::list<af::NodePtr> next_nodes{new_reduce_node};
while (!next_nodes.empty()) {
const auto node = next_nodes.front();
GE_ASSERT_NOTNULL(node);
next_nodes.pop_front();
if (node->GetOutDataNodes().empty()) {
node_order.emplace_back(std::dynamic_pointer_cast<af::AscNode>(node));
}
for (auto &out_node : node->GetOutDataNodes()) {
if (visited.find(out_node) == visited.cend()) {
next_nodes.emplace_back(out_node);
visited.emplace(out_node);
}
}
}
return ge::SUCCESS;
}
}