* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "concat_schedule_case_generator.h"
#include <queue>
#include "graph/ascendc_ir/ascendc_ir_core/ascendc_ir.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "ascir/meta/ascir_utils.h"
#include "ascir/meta/ascir_ops_utils.h"
#include "optimize/schedule_utils.h"
#include "optimize/task_generator/concat_group_partitioner.h"
#include "optimize/task_generator/concat_score_function_generator.h"
#include "optimize/task_generator/concat_inputs_unification_pass.h"
#include "platform/platform_factory.h"
#include "util/mem_utils.h"
#include "task_generator/cast_optimization_pass.h"
namespace optimize {
namespace {
constexpr uint32_t kMaxInputNum = 48U;
constexpr size_t kTemplateSizeAll = 3UL;
constexpr int32_t kConcatAlgTranspose = 0;
constexpr int64_t kSmallDimSizeThreshold = 256;
void CollectInAndOutNodes(const af::NodePtr &node, std::set<af::Node *> &visited_nodes,
std::queue<af::NodePtr> &nodes) {
for (const auto &in_node : node->GetInDataNodes()) {
if (visited_nodes.emplace(in_node.get()).second) {
nodes.emplace(in_node);
}
}
for (const auto &out_node : node->GetOutDataNodes()) {
if (visited_nodes.emplace(out_node.get()).second) {
nodes.emplace(out_node);
}
}
}
}
Status ConcatFusionCaseGenerator::AddTemplatesForFirstDimConcat(const af::AscNodePtr &concat_node,
ascir::HintGraph &graph,
std::vector<ascir::ImplGraph> &graphs) {
const bool is_one_axis = (concat_node->outputs[0].attr.repeats.size() == 1UL);
if ((concat_node->inputs.Size() != 1U) && (!support_small_tail_) && (is_one_axis || (concat_dim_ > 0))) {
if (is_one_axis) {
GE_ASSERT_SUCCESS(InsertAxis(graph), "Failed to insert axis for graph:[%s].", graph.GetName().c_str());
concat_dim_ = 1;
GE_ASSERT_SUCCESS(AddTemplateIfCanFitInOneKernel(concat_node, graph, graphs));
}
GE_ASSERT_SUCCESS(AddTemplateForSplitConcat(graph, graphs));
GE_ASSERT_SUCCESS(MarkNoMergeFirstAxis(graphs));
return af::SUCCESS;
}
GE_CHK_STATUS_RET(Prepare(concat_node, concat_dim_), "Prepare failed");
GE_CHK_STATUS_RET(ConvertConcatToStores(graph, concat_node), "ConvertConcatToStores failed");
graphs.emplace_back(graph);
GELOGI("concat on first dim, num_inputs = %u, 1 template was generated", concat_node->inputs.Size());
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::Generate(ascir::HintGraph &graph,
std::vector<ascir::ImplGraph> &graphs,
std::vector<std::string> &score_functions) {
bool has_unsupported_op = false;
const auto concat_nodes = FindConcatNodes(graph, &has_unsupported_op);
if (concat_nodes.empty()) {
return af::SUCCESS;
}
auto &concat_node = concat_nodes.front();
bool is_first_dim = false;
GE_CHK_STATUS_RET(ScheduleUtils::ResolveDiffDim(concat_node, concat_dim_, is_first_dim), "ResolveConcatDim failed");
GE_ASSERT_SUCCESS(AddExtraShapeEnv(concat_node, concat_dim_));
GE_ASSERT_SUCCESS(SplitDataForDifferentConcatDim(graph),
"Failed to split data for graph:[%s].", graph.GetName().c_str());
const auto backend_spec = BackendSpec::GetInstance();
GE_ASSERT_NOTNULL(backend_spec);
support_small_tail_ = backend_spec->concat_alg == kConcatAlgTranspose;
if (is_first_dim) {
return AddTemplatesForFirstDimConcat(concat_node, graph, graphs);
}
if ((!has_unsupported_op) && (concat_node->inputs.Size() <= kMaxInputNum)) {
graphs.emplace_back(graph);
if (support_small_tail_ && ascir::utils::UseSmallTailConcatApi(*concat_node)) {
GELOGI("match small tail pattern, 1 template was generated");
return af::SUCCESS;
}
}
GE_ASSERT_SUCCESS(AddTemplateForSplitConcat(graph, graphs));
if ((!has_unsupported_op) && NeedDynSmallTailTemplate(concat_node)) {
GE_ASSERT_SUCCESS(AddTemplateForSmallTail(graph, graphs));
}
GE_ASSERT_SUCCESS(RunCastOptimizationPass(graphs));
if (!support_small_tail_) {
GE_ASSERT_SUCCESS(ConcatInputUnificationPass::Run(graphs));
}
GE_ASSERT_SUCCESS(GenerateScoreFunctions(graphs, concat_dim_, score_functions));
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::EliminateConcat(ascir::HintGraph &graph, const af::AscNodePtr &concat_node) {
bool is_first_dim = false;
GE_CHK_STATUS_RET(ScheduleUtils::ResolveDiffDim(concat_node, concat_dim_, is_first_dim), "ResolveConcatDim failed");
GE_CHK_STATUS_RET(Prepare(concat_node, concat_dim_), "Prepare failed");
GE_CHK_STATUS_RET(ConvertConcatToStores(graph, concat_node), "ConvertConcatToStores failed");
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::AddTemplateForSplitConcat(const ascir::HintGraph &graph, std::vector<ascir::ImplGraph> &graphs) {
ascir::ImplGraph optimized_graph((graph.GetName() + "_group_concat").c_str());
GE_ASSERT_TRUE(optimized_graph.CopyFrom(graph));
auto concat_node = FindConcatNodes(optimized_graph).front();
GE_CHK_STATUS_RET(Prepare(concat_node, concat_dim_), "Prepare failed");
split_concat_ = false;
GE_CHK_STATUS_RET(SplitConcats(optimized_graph, concat_node, split_concat_), "SplitConcats failed");
GELOGI("Concat on non-first dim, split concat into groups templates generated, split = %d",
static_cast<int32_t>(split_concat_));
if (split_concat_) {
if (concat_node->outputs[0].attr.repeats[concat_dim_].IsConstExpr() && (!KeepOriginGraph(concat_node))) {
graphs.clear();
}
} else {
GE_CHK_STATUS_RET(ConvertConcatToStores(optimized_graph, concat_node), "ConvertConcatToStores failed");
}
graphs.emplace_back(optimized_graph);
return af::SUCCESS;
}
bool ConcatFusionCaseGenerator::NeedDynSmallTailTemplate(const af::AscNodePtr &concat_node) const {
const auto dtype_size = GetSizeByDataType(concat_node->outputs[0].attr.dtype);
GE_WARN_ASSERT(dtype_size > 0);
return support_small_tail_ &&
((dtype_size == sizeof(uint16_t)) || (dtype_size == sizeof(uint32_t))) &&
(concat_node->inputs.Size() <= kMaxInputNum) &&
(!concat_node->outputs[0].attr.strides[concat_dim_ - 1].IsConstExpr());
}
Status ConcatFusionCaseGenerator::AddTemplateForSmallTail(const ascir::HintGraph &graph,
std::vector<ascir::ImplGraph> &graphs) {
GELOGI("exists dynamic dim after concat_dim, generate force small tail template");
ascir::ImplGraph force_small_tail_graph((graph.GetName() + "_force_small_tail").c_str());
GE_ASSERT_TRUE(force_small_tail_graph.CopyFrom(graph));
auto force_small_tail_node = FindConcatNodes(force_small_tail_graph).front();
GE_ASSERT_TRUE(af::AttrUtils::SetBool(force_small_tail_node->GetOpDesc(), "_concat_small_tail", true));
graphs.emplace_back(force_small_tail_graph);
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::GenerateScoreFunctions(const std::vector<ascir::ImplGraph> &graphs,
size_t concat_dim,
std::vector<std::string> &score_functions) const {
GE_CHK_BOOL_RET_STATUS_NOLOG((graphs.size() > 1U), af::SUCCESS);
if (support_small_tail_) {
if (!has_recompute_) {
score_functions.resize(graphs.size());
const auto concat_node = FindConcatNodes(graphs.front()).front();
GE_CHK_STATUS_RET(
ConcatScoreFunctionGenerator(graphs.front(), concat_node, concat_dim).Generate(score_functions.front()),
"Failed to generate score func for ub_concat");
if (graphs.size() == kTemplateSizeAll) {
GE_CHK_STATUS_RET(ConcatScoreFunctionGenerator(graphs.back(), concat_node, concat_dim)
.GenerateForCheckSmallTail(score_functions.back()),
"Failed to generate score func for small_tail_concat");
}
}
} else {
if (!split_concat_) {
const auto concat_node = FindConcatNodes(graphs.front()).front();
constexpr uint32_t kMaxNumInputs = 16;
if ((!has_recompute_) && (concat_node->inputs.Size() <= kMaxNumInputs) && IsSmallBlock(concat_node, concat_dim_)) {
score_functions.resize(graphs.size());
ConcatScoreFunctionGenerator::GenerateScoreOne(score_functions.front());
}
}
}
return af::SUCCESS;
}
std::vector<af::AscNodePtr> ConcatFusionCaseGenerator::FindConcatNodes(const ascir::HintGraph &owner_graph,
bool *has_unsupported_op) {
if (has_unsupported_op != nullptr) {
*has_unsupported_op = false;
}
std::vector<af::AscNodePtr> concat_nodes;
for (const auto &node : owner_graph.GetAllNodes()) {
if (af::ops::IsOps<af::ascir_op::Concat>(node)) {
concat_nodes.emplace_back(node);
} else if (has_unsupported_op != nullptr && (!*has_unsupported_op)) {
const auto asc_node = std::dynamic_pointer_cast<af::AscNode>(node);
if ((asc_node != nullptr) && (ScheduleUtils::IsTranspose(asc_node) || ScheduleUtils::IsReduce(asc_node))) {
*has_unsupported_op = true;
GELOGI("graph contains Transpose/Reduce node %s", node->GetNamePtr());
}
} else {
}
}
return concat_nodes;
}
Status ConcatFusionCaseGenerator::ConvertSingleInput(ascir::HintGraph &owner_graph, const af::AscNodePtr &concat_node,
size_t in_index, size_t group_idx,
ConcatDimAxisMap &repeat_to_axis_id) {
const auto &all_in_data_anchors = concat_node->GetAllInDataAnchors();
const auto &loop_axis = owner_graph.GetAllAxis();
const auto &concat_in_anchor = all_in_data_anchors.at(in_index);
auto input_repeat = concat_node->inputs[in_index].attr.repeats[concat_dim_];
if (repeat_to_axis_id.find(input_repeat) == repeat_to_axis_id.end()) {
auto new_axis =
owner_graph.CreateAxis(loop_axis[concat_dim_]->name + "_" + std::to_string(group_idx), input_repeat);
repeat_to_axis_id[input_repeat] = new_axis.id;
GELOGD("Create axis [%s, %ld] for input repeat=[%s].", new_axis.name.c_str(), new_axis.id,
af::SymbolicUtils::ToString(input_repeat).c_str());
}
auto replace_axis = owner_graph.FindAxis(repeat_to_axis_id[input_repeat]);
GE_ASSERT_NOTNULL(replace_axis);
return ReplaceWithStore(concat_node, concat_in_anchor, *replace_axis);
}
Status ConcatFusionCaseGenerator::ConvertConcatToStores(ascir::HintGraph &owner_graph,
const af::AscNodePtr &concat_node) {
ConcatGroupPartitioner partitioner(concat_node, concat_dim_);
GE_ASSERT_SUCCESS(partitioner.RecomputeDiffAxes());
has_recompute_ = partitioner.HasRecompute();
GE_ASSERT_SUCCESS(PrepareForModifyingGraph(concat_node));
ConcatDimAxisMap repeat_to_axis_id;
const auto all_in_data_anchors_count = concat_node->GetAllInDataAnchors().size();
for (size_t i = 0UL; i < all_in_data_anchors_count; ++i) {
const auto in_index = all_in_data_anchors_count - i - 1UL;
GE_ASSERT_SUCCESS(ConvertSingleInput(owner_graph, concat_node, in_index, i, repeat_to_axis_id),
"ProcessSingleInput failed in ConvertConcatToStores");
}
GE_CHK_STATUS_RET(RemoveUnusedNodes(concat_node, post_concat_nodes_), "RemoveUnusedNodes failed");
GE_ASSERT_GRAPH_SUCCESS(ScheduleUtils::TopologicalSorting(owner_graph));
ascir::utils::DumpGraph(owner_graph, "AfterConvertConcatToStore");
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::SplitConcats(ascir::HintGraph &owner_graph, const af::AscNodePtr &concat_node,
bool &split) {
std::vector<ConcatGroupPartitioner::ConcatGroup> groups;
ConcatGroupPartitioner partitioner(concat_node, concat_dim_);
GE_ASSERT_SUCCESS(partitioner.PartitionGroups(groups));
if ((groups.size() <= 1U) || (groups.size() == concat_node->inputs.Size())) {
return af::SUCCESS;
}
has_recompute_ = partitioner.HasRecompute();
GE_ASSERT_SUCCESS(PrepareForModifyingGraph(concat_node));
for (size_t i = 0U; i < groups.size(); ++i) {
const auto &group = groups[i];
GELOGI("group[%zu] start = %ld, end = %ld", i, group.start, group.end);
}
ConcatDimAxisMap repeat_to_axis_id;
auto loop_axis = owner_graph.GetAllAxis();
for (size_t i = 0U; i < groups.size(); ++i) {
const auto &group = groups[groups.size() - i - 1];
if (group.end - group.start == 1) {
GE_ASSERT_SUCCESS(ConvertSingleInput(owner_graph, concat_node, group.start, i, repeat_to_axis_id),
"ProcessSingleInput failed in ConvertConcatToStores");
} else {
GE_CHK_STATUS_RET(ReplaceWithConcat(owner_graph, concat_node, group.start, group.end));
}
}
GE_CHK_STATUS_RET(RemoveUnusedNodes(concat_node, post_concat_nodes_), "RemoveUnusedNodes failed");
GE_ASSERT_GRAPH_SUCCESS(ScheduleUtils::TopologicalSorting(owner_graph));
ascir::utils::DumpGraph(owner_graph, "AfterSplitConcat");
split = true;
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::Prepare(const af::AscNodePtr &concat_node, size_t concat_dim) {
af::Expression dim_offset = af::ops::Zero;
for (const auto &in_anchor : concat_node->GetAllInDataAnchorsPtr()) {
GE_ASSERT_NOTNULL(in_anchor);
concat_dim_offsets_.emplace_back(dim_offset);
dim_offset = dim_offset + concat_node->inputs[static_cast<uint32_t>(in_anchor->GetIdx())].attr.repeats[concat_dim];
}
concat_axis_id_ = concat_node->attr.sched.axis[concat_dim];
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::PropagateAxisChanges(af::Node *start_node,
const std::vector<ascir::AxisId> &new_axis_ids) {
std::set<af::Node *> visited_nodes;
std::queue<af::Node *> node_queue;
visited_nodes.emplace(start_node);
node_queue.emplace(start_node);
while (!node_queue.empty()) {
const auto curr_node = dynamic_cast<af::AscNode *>(node_queue.front());
node_queue.pop();
GE_ASSERT_NOTNULL(curr_node);
if (curr_node->attr.api.type != af::ApiType::kAPITypeBuffer) {
curr_node->attr.sched.axis = new_axis_ids;
for (const auto &out_tensor: curr_node->outputs()) {
out_tensor->attr.axis = new_axis_ids;
}
}
for (const auto &out_node : curr_node->GetOutDataNodes()) {
if (visited_nodes.count(out_node.get()) == 0UL) {
visited_nodes.emplace(out_node.get());
node_queue.emplace(out_node.get());
}
}
for (const auto &in_node : curr_node->GetInDataNodes()) {
if (visited_nodes.count(in_node.get()) == 0UL) {
visited_nodes.emplace(in_node.get());
node_queue.emplace(in_node.get());
}
}
}
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::ReplaceWithStore(const af::AscNodePtr &concat_node,
const af::InDataAnchorPtr &concat_in_anchor,
const af::Axis &replace_axis) {
const auto in_index = concat_in_anchor->GetIdx();
const auto &src_out_anchor = concat_in_anchor->GetPeerOutAnchor();
std::vector<ascir::AxisId> new_axis_ids = concat_node->attr.sched.axis;
new_axis_ids[concat_dim_] = replace_axis.id;
GE_CHK_STATUS_RET(PropagateAxisChanges(concat_node.get(), new_axis_ids),
"PropagateAxisChanges failed in ReplaceWithStore");
GE_ASSERT_NOTNULL(src_out_anchor);
concat_in_anchor->UnlinkAll();
std::vector<af::InDataAnchorPtr> dst_in_anchors;
auto src_node = dynamic_cast<af::AscNode *>(src_out_anchor->GetOwnerNodeBarePtr());
GE_ASSERT_NOTNULL(src_node);
std::unordered_map<std::string, af::NodePtr> name_to_new_node;
GE_ASSERT_SUCCESS(
CloneNonConcatNodes(replace_axis, in_index, dst_in_anchors, new_axis_ids, name_to_new_node));
for (const auto &peer_in_anchor : dst_in_anchors) {
GE_ASSERT_GRAPH_SUCCESS(af::GraphUtils::AddEdge(src_out_anchor, peer_in_anchor));
GE_ASSERT_SUCCESS(ReconnectIfShareSameAncestor(name_to_new_node, peer_in_anchor));
}
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::ReplaceWithConcat(ascir::ImplGraph &owner_graph, const af::AscNodePtr &concat_node,
size_t start, size_t end) {
auto suffix = "_" + std::to_string(start) + "_" + std::to_string(end);
af::ascir_op::Concat concat_op((concat_node->GetName() + suffix).c_str());
SetConcatOpAttr(concat_op, concat_node, concat_dim_, start, end);
auto new_concat_node = owner_graph.AddNode(concat_op);
GE_ASSERT_NOTNULL(new_concat_node);
GELOGD("split concat [%zu, %zu), output repeats = %s", start, end,
af::ToString(new_concat_node->outputs[0].attr.repeats).c_str());
for (size_t i = end; i > start; --i) {
auto concat_in_anchor = concat_node->GetInDataAnchor(static_cast<int32_t>(i) - 1);
GE_CHECK_NOTNULL(concat_in_anchor);
auto peer_out_anchor = concat_in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);
GE_CHK_STATUS_RET(af::GraphUtils::RemoveEdge(peer_out_anchor, concat_in_anchor), "Failed to RemoveEdge");
GE_CHK_STATUS_RET(af::GraphUtils::AddEdge(peer_out_anchor, new_concat_node->GetInDataAnchor(i - start - 1)),
"Failed to AddEdge");
}
const auto &output_repeats = new_concat_node->outputs[0].attr.repeats;
ascir::Axis concat_axis = *(owner_graph.GetAllAxis().at(concat_axis_id_));
auto new_concat_axis =
owner_graph.CreateAxis(concat_axis.name + "_ss_" + std::to_string(start), output_repeats[concat_dim_]);
std::vector<ascir::AxisId> new_axis_ids = concat_node->attr.sched.axis;
new_axis_ids[concat_dim_] = new_concat_axis.id;
GE_CHK_STATUS_RET(PropagateAxisChanges(concat_node.get(), new_axis_ids),
"PropagateAxisChanges failed in ReplaceWithStore");
GELOGD("New axis %s, size = %s", new_concat_axis.name.c_str(),
af::SymbolicUtils::ToString(new_concat_axis.size).c_str());
std::vector<af::InDataAnchorPtr> dst_in_anchors;
GE_ASSERT_SUCCESS(ReplaceAxis(new_concat_node, concat_dim_, new_concat_axis, new_axis_ids));
std::unordered_map<std::string, af::NodePtr> name_to_new_node;
GE_ASSERT_SUCCESS(
CloneNonConcatNodes(new_concat_axis, start, dst_in_anchors, new_axis_ids, name_to_new_node));
for (const auto &in_anchor : new_concat_node->GetAllInDataAnchors()) {
GE_ASSERT_SUCCESS(ReconnectIfShareSameAncestor(name_to_new_node, in_anchor));
}
for (const auto &peer_in_anchor : dst_in_anchors) {
GE_ASSERT_GRAPH_SUCCESS(af::GraphUtils::AddEdge(new_concat_node->GetOutDataAnchor(0), peer_in_anchor));
}
return af::SUCCESS;
}
af::Status ConcatFusionCaseGenerator::SetConcatOpAttr(af::ascir_op::Concat &concat_op,
const af::AscNodePtr &concat_node,
size_t concat_dim,
size_t start,
size_t end) {
GE_ASSERT_TRUE(end <= concat_node->inputs.Size());
auto repeats = concat_node->inputs[start].attr.repeats;
for (size_t i = start + 1U; i < end; ++i) {
repeats[concat_dim] = repeats[concat_dim] + concat_node->inputs[i].attr.repeats[concat_dim];
}
af::Expression stride = af::sym::kSymbolOne;
std::vector<af::Expression> strides(repeats.size(), af::sym::kSymbolOne);
for (auto i = static_cast<int32_t>(repeats.size() - 1); i >= 0; --i) {
strides[i] = stride;
stride = (stride * repeats[i]);
}
const auto &concat_output_tensor_attr = concat_node->outputs[0].attr;
concat_op.attr = concat_node->attr;
concat_op.y.dtype = concat_output_tensor_attr.dtype;
*concat_op.y.repeats = repeats;
*concat_op.y.strides = strides;
*concat_op.y.axis = concat_output_tensor_attr.axis;
concat_op.DynamicInputRegister("x", end - start);
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::RemoveUnusedNodes(const af::AscNodePtr &concat_node,
const std::vector<af::AscNodePtr> &nodes) {
auto owner_compute_graph = concat_node->GetOwnerComputeGraph();
GE_ASSERT_NOTNULL(owner_compute_graph);
GE_CHK_STATUS_RET(owner_compute_graph->RemoveNode(concat_node), "Failed to remote node: %s",
concat_node->GetNamePtr());
for (const auto &node : nodes) {
GE_CHK_STATUS_RET(owner_compute_graph->RemoveNode(node), "Failed to remote node: %s",
node->GetNamePtr());
}
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::SplitDataForDifferentConcatDim(ascir::ImplGraph &owner_graph) {
for (const auto &node : owner_graph.GetAllNodes()) {
GE_ASSERT_NOTNULL(node);
if (!ScheduleUtils::IsDataInput(node)) {
continue;
}
auto output_anchor = node->GetOutDataAnchor(0);
GE_ASSERT_NOTNULL(output_anchor);
auto peer_in_anchors = output_anchor->GetPeerInDataAnchors();
for (size_t idx = 1UL; idx < peer_in_anchors.size(); ++idx) {
std::string node_name = node->GetName() + std::string("_") + std::to_string(idx);
af::AscNodePtr data_node;
if (af::ops::IsOps<af::ascir_op::ScalarData>(node)) {
af::ascir_op::ScalarData scalar_data(node_name.c_str());
data_node = owner_graph.AddNode(scalar_data);
} else {
af::ascir_op::Data data(node_name.c_str());
data_node = owner_graph.AddNode(data);
}
GE_ASSERT_NOTNULL(data_node);
data_node->attr = node->attr;
data_node->outputs[0].attr = node->outputs[0].attr;
auto new_out_anchor = data_node->GetOutDataAnchor(0);
GE_ASSERT_NOTNULL(new_out_anchor);
GE_ASSERT_SUCCESS(af::GraphUtils::RemoveEdge(output_anchor, peer_in_anchors.at(idx)));
GE_ASSERT_SUCCESS(af::GraphUtils::AddEdge(new_out_anchor, peer_in_anchors.at(idx)));
}
}
return af::SUCCESS;
}
af::Status ConcatFusionCaseGenerator::CollectBackwardNodes(const af::NodePtr &concat_node,
std::vector<af::AscNodePtr> &nodes) {
std::set<af::Node *> visited_nodes{concat_node.get()};
std::queue<af::NodePtr> next_nodes;
for (const auto &out_data_node : concat_node->GetOutDataNodes()) {
if (visited_nodes.emplace(out_data_node.get()).second) {
next_nodes.push(out_data_node);
}
}
while (!next_nodes.empty()) {
auto &top = next_nodes.front();
auto asc_node = std::dynamic_pointer_cast<af::AscNode>(top);
GE_ASSERT_NOTNULL(asc_node);
nodes.emplace_back(asc_node);
CollectInAndOutNodes(top, visited_nodes, next_nodes);
next_nodes.pop();
}
std::sort(nodes.begin(), nodes.end(), [](const af::AscNodePtr &lhs, const af::AscNodePtr &rhs) -> bool {
return lhs->GetOpDesc()->GetId() < rhs->GetOpDesc()->GetId();
});
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::CollectReachableLoadNodes(const af::NodePtr &concat_node,
std::set<af::AscNodePtr> &nodes) {
std::set<af::Node *> visited_nodes{concat_node.get()};
std::queue<af::NodePtr> next_nodes;
for (const auto &in_data_node : concat_node->GetInDataNodes()) {
if (visited_nodes.emplace(in_data_node.get()).second) {
next_nodes.push(in_data_node);
}
}
while (!next_nodes.empty()) {
auto &top = next_nodes.front();
auto asc_node = std::dynamic_pointer_cast<af::AscNode>(top);
GE_ASSERT_NOTNULL(asc_node);
if (af::ops::IsOps<af::ascir_op::Load>(asc_node)) {
nodes.emplace(asc_node);
}
CollectInAndOutNodes(top, visited_nodes, next_nodes);
next_nodes.pop();
}
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::CloneNonConcatNodes(const af::Axis &new_axis,
size_t index,
std::vector<af::InDataAnchorPtr> &in_anchors,
const std::vector<ascir::AxisId> &new_axis_ids,
std::unordered_map<std::string, af::NodePtr> &name_to_new_node) {
GE_ASSERT_TRUE(!post_concat_nodes_.empty());
ascir::ImplGraph owner_graph("owner_graph");
GE_ASSERT_SUCCESS(af::AscGraphUtils::FromComputeGraph(post_concat_nodes_.front()->GetOwnerComputeGraph(), owner_graph));
std::string suffix;
if (index != 0UL) {
suffix = "_split_" + std::to_string(index);
}
std::unordered_map<std::string, af::NodePtr> all_new_nodes;
for (const auto &asc_node : post_concat_nodes_) {
const auto &op_desc = af::GraphUtils::CopyOpDesc(asc_node->GetOpDesc(), nullptr);
GE_CHECK_NOTNULL(op_desc);
op_desc->SetName(asc_node->GetName() + suffix);
af::Operator op = af::OpDescUtils::CreateOperatorFromOpDesc(op_desc);
auto dst_new_node = owner_graph.AddNode(op);
all_new_nodes[dst_new_node->GetName()] = dst_new_node;
name_to_new_node[asc_node->GetName()] = dst_new_node;
GE_ASSERT_TRUE(af::AscGraph::CopyAscNodeTensorAttr(asc_node, dst_new_node),
"DoCopyAscNodeTensorAttr failed, node = %s[%s]",
asc_node->GetNamePtr(), asc_node->GetTypePtr());
if (dst_new_node->GetType() == af::ascir_op::Store::Type) {
const auto offset = concat_dim_offsets_[index] * dst_new_node->outputs[0].attr.strides[concat_dim_];
const auto ir_attr = dst_new_node->attr.ir_attr->DownCastTo<af::ascir_op::Store::AscStoreIrAttrDef>();
GE_ASSERT_NOTNULL(ir_attr);
GE_CHK_STATUS_RET(ir_attr->SetOffset(offset), "Failed to set offset to %s", dst_new_node->GetNamePtr());
GELOGI("Store node: %s added, offset = %s", dst_new_node->GetName().c_str(), offset.Serialize().get());
} else if ((dst_new_node->GetType() == af::ascir_op::Load::Type) &&
(reachable_load_nodes_.find(asc_node) == reachable_load_nodes_.end())) {
const auto offset = concat_dim_offsets_[index] * dst_new_node->outputs[0].attr.strides[concat_dim_];
const auto ir_attr = dst_new_node->attr.ir_attr->DownCastTo<af::ascir_op::Load::AscLoadIrAttrDef>();
GE_ASSERT_NOTNULL(ir_attr);
GE_CHK_STATUS_RET(ir_attr->SetOffset(offset), "Failed to set offset to %s", dst_new_node->GetNamePtr());
GELOGI("Load node: %s added, offset = %s", dst_new_node->GetName().c_str(), offset.Serialize().get());
} else {
}
if (const auto it = out_node_name_to_indices_.find(asc_node->GetName()); it != out_node_name_to_indices_.cend()) {
for (const auto in_anchor_index : it->second) {
in_anchors.emplace_back(dst_new_node->GetInDataAnchor(in_anchor_index));
}
}
if (!ScheduleUtils::IsBuffer(dst_new_node)) {
GE_ASSERT_SUCCESS(ReplaceAxis(dst_new_node, concat_dim_, new_axis, new_axis_ids));
}
}
for (const auto &src_node : post_concat_nodes_) {
GE_CHK_STATUS_RET(af::GraphUtils::RelinkGraphEdges(src_node, suffix, all_new_nodes), "RelinkGraphEdges failed");
}
return af::SUCCESS;
}
af::Status ConcatFusionCaseGenerator::ReplaceAxis(const af::AscNodePtr &node, size_t axis_index,
const af::Axis &to_axis,
const std::vector<ascir::AxisId> &new_axis_ids) {
node->attr.sched.axis = new_axis_ids;
for (uint32_t i = 0U; i < node->outputs().size(); ++i) {
node->outputs[i].attr.axis = new_axis_ids;
GE_ASSERT_SUCCESS(UpdateRepeatAndStrides(node, axis_index, to_axis.size, node->outputs[i].attr),
"Failed to update repeat and strides for outputs[%u], node = %s(%s)", i, node->GetNamePtr(),
node->GetTypePtr());
}
GELOGD("Replace axis for node: %s(%s) success", node->GetNamePtr(), node->GetTypePtr());
return af::SUCCESS;
}
af::Status ConcatFusionCaseGenerator::UpdateRepeatAndStrides(const af::AscNodePtr &node,
size_t axis_index,
const af::Expression &axis_size,
af::AscTensorAttr &tensor_attr) {
auto &repeats = tensor_attr.repeats;
auto &strides = tensor_attr.strides;
GELOGD("before update, repeats = %s, strides = %s", af::ToString(repeats).c_str(), af::ToString(strides).c_str());
GE_ASSERT_TRUE(repeats.size() == strides.size());
GE_ASSERT_TRUE(axis_index < repeats.size(), "axis_index = %zu, out of range [0, %zu)", axis_index, repeats.size());
if (af::SymbolicUtils::StaticCheckEq(repeats[axis_index], af::ops::One) == af::TriBool::kTrue) {
return af::SUCCESS;
}
repeats[axis_index] = axis_size;
if (af::ops::IsOps<af::ascir_op::Load>(node) || af::ops::IsOps<af::ascir_op::Store>(node)) {
if (af::SymbolicUtils::StaticCheckEq(axis_size, af::ops::One) == af::TriBool::kTrue) {
node->outputs[0].attr.strides[axis_index] = af::ops::Zero;
}
return af::SUCCESS;
}
af::Expression stride = af::ops::One;
for (auto i = static_cast<int32_t>(repeats.size() - 1); i >= 0; --i) {
if (i != static_cast<int32_t>(repeats.size() - 1UL)) {
stride = stride * repeats[i + 1];
}
if (af::SymbolicUtils::StaticCheckEq(repeats[i], af::ops::One) == af::TriBool::kTrue) {
strides[i] = af::ops::Zero;
} else {
strides[i] = stride;
}
}
GELOGD("after update, repeats = %s, strides = %s", af::ToString(repeats).c_str(), af::ToString(strides).c_str());
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::InsertAxis(const ascir::ImplGraph &optimized_graph) {
const auto graph_attr =
af::AscGraphUtils::GetComputeGraph(optimized_graph)->GetOrCreateAttrsGroup<af::AscGraphAttr>();
GE_ASSERT_NOTNULL(graph_attr);
GELOGD("before: axes = %s", ScheduleUtils::AxesToString(graph_attr->axis).c_str());
const auto src_axes = graph_attr->axis;
std::vector<af::AxisPtr> new_axes;
for (const auto &src_axis : src_axes) {
std::shared_ptr<af::Axis> new_axis = af::MakeShared<af::Axis>();
GE_CHECK_NOTNULL(new_axis, "create axis failed");
new_axis->id = src_axis->id;
new_axis->name = src_axis->name;
new_axis->type = src_axis->type;
new_axis->size = src_axis->size;
new_axes.push_back(std::move(new_axis));
}
auto new_axis_id = static_cast<int64_t>(new_axes.size());
std::shared_ptr<af::Axis> const_axis = af::MakeShared<af::Axis>();
GE_CHECK_NOTNULL(const_axis, "create axis failed");
const_axis->id = static_cast<int64_t>(new_axes.size());
const_axis->name = "axis_1d";
const_axis->type = af::Axis::kAxisTypeOriginal;
const_axis->size = af::ops::One;
new_axes.push_back(std::move(const_axis));
graph_attr->axis = std::move(new_axes);
GELOGD("after: axes = %s", ScheduleUtils::AxesToString(graph_attr->axis).c_str());
for (const auto &node : optimized_graph.GetAllNodes()) {
if (ScheduleUtils::IsIOBuffer(node)) {
continue;
}
auto cur_axis_ids = node->attr.sched.axis;
node->attr.sched.axis.insert(node->attr.sched.axis.begin(), new_axis_id);
for (const auto output_attr : node->outputs()) {
output_attr->attr.axis.insert(output_attr->attr.axis.begin(), new_axis_id);
if (output_attr->attr.strides[0UL] == 0) {
output_attr->attr.strides.insert(output_attr->attr.strides.begin(), af::ops::One);
} else {
output_attr->attr.strides.insert(output_attr->attr.strides.begin(),
af::sym::Mul(output_attr->attr.repeats[0UL], output_attr->attr.strides[0UL]));
}
output_attr->attr.repeats.insert(output_attr->attr.repeats.begin(), af::ops::One);
}
}
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::AddTemplateIfCanFitInOneKernel(const af::AscNodePtr &concat_node,
ascir::HintGraph &graph,
std::vector<ascir::ImplGraph> &graphs) {
GE_CHK_BOOL_RET_SPECIAL_STATUS(concat_node->inputs.Size() > kMaxInputNum, af::SUCCESS,
"input num(%u) > max_input_num(%u), cannot fit in one kernel",
concat_node->inputs.Size(), kMaxInputNum);
const auto data_type_size = af::GetSizeByDataType(concat_node->outputs[0].attr.dtype);
GE_ASSERT_TRUE(data_type_size > 0);
const auto max_dim_size = kSmallDimSizeThreshold / data_type_size;
for (uint32_t i = 0U; i < concat_node->inputs.Size(); ++i) {
const auto dim_expr = concat_node->inputs[i].attr.repeats.back();
GE_CHK_BOOL_RET_SPECIAL_STATUS((!dim_expr.IsConstExpr()), af::SUCCESS, "input[%zu] is not known shape: %s", i,
af::ToString(concat_node->inputs[i].attr.repeats).c_str());
int64_t dim_size = -1;
GE_ASSERT_TRUE(dim_expr.GetConstValue(dim_size));
GE_CHK_BOOL_RET_SPECIAL_STATUS((dim_size > max_dim_size), af::SUCCESS,
"input[%zu] dim_size(%ld) is larger than threshold(%ld)", i, dim_size, max_dim_size);
}
graphs.emplace_back(graph);
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::MarkNoMergeFirstAxis(const std::vector<ascir::ImplGraph> &graphs) {
for (const auto &graph : graphs) {
for (const auto &node : graph.GetAllNodes()) {
if (af::ops::IsOps<af::ascir_op::Concat>(node)) {
GE_ASSERT_TRUE(af::AttrUtils::SetBool(node->GetOpDesc(), "_keep_first_axis", true));
}
}
}
return af::SUCCESS;
}
bool ConcatFusionCaseGenerator::KeepOriginGraph(const af::AscNodePtr &concat_node) const {
bool keep_origin = false;
if (!support_small_tail_) {
constexpr uint32_t kMaxNumInputs = 16;
keep_origin =
has_recompute_ || (concat_node->inputs.Size() <= kMaxNumInputs) || IsSmallBlock(concat_node, concat_dim_);
}
return keep_origin;
}
bool ConcatFusionCaseGenerator::IsSmallBlock(const af::AscNodePtr &concat_node, size_t concat_dim) {
constexpr int64_t kMaxOutputBlockSize = 16 * 1024;
const auto dtype_size = GetSizeByDataType(concat_node->outputs[0].attr.dtype);
GE_WARN_ASSERT(dtype_size > 0);
const auto &output_repeats = concat_node->outputs[0].attr.repeats;
int64_t output_size = dtype_size;
for (size_t i = concat_dim; i < output_repeats.size(); ++i) {
auto &dim_expr = output_repeats[i];
if (!dim_expr.IsConstExpr()) {
output_size = -1;
break;
}
int64_t dim_size = -1;
GE_WARN_ASSERT(dim_expr.GetConstValue(dim_size));
output_size *= dim_size;
}
const bool is_small_block = ((output_size >= 0) && (output_size < kMaxOutputBlockSize));
GELOGI("output shape = %s, concat_dim = %zu, is_small_block = %d", af::ToString(output_repeats).c_str(), concat_dim,
static_cast<int32_t>(is_small_block));
return is_small_block;
}
Status ConcatFusionCaseGenerator::ReconnectIfShareSameAncestor(
const std::unordered_map<std::string, af::NodePtr> &name_to_node, const af::InDataAnchorPtr &in_anchor) {
auto src_anchor = in_anchor->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(src_anchor);
auto src_node = src_anchor->GetOwnerNode();
GE_ASSERT_NOTNULL(src_node);
const auto &it = name_to_node.find(src_node->GetName());
if (it != name_to_node.end()) {
in_anchor->UnlinkAll();
GE_ASSERT_GRAPH_SUCCESS(in_anchor->LinkFrom(it->second->GetOutDataAnchor(src_anchor->GetIdx())));
}
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::AddExtraShapeEnv(const af::AscNodePtr &concat_node, size_t concat_dim) {
const auto &output_repeats = concat_node->outputs[0].attr.repeats;
for (uint32_t k = 1; k < concat_node->inputs.Size(); ++k) {
const auto &input_repeats = concat_node->inputs[k].attr.repeats;
auto input_axis_size = af::ops::One;
auto output_axis_size = af::ops::One;
for (size_t i = output_repeats.size() - 1; i > concat_dim; --i) {
input_axis_size = input_axis_size * input_repeats[i];
output_axis_size = output_axis_size * output_repeats[i];
GE_LOGW_IF(!EXPECT_SYMBOL_EQ(input_axis_size, output_axis_size),
"expect axis eq failed, concat_dim = %zu, cur_dim = %zu, input_axis_size = %s, output_axis_size = %s",
concat_dim, i, input_axis_size.Str().get(), output_axis_size.Str().get());
}
}
for (uint32_t k = 1; k < concat_node->inputs.Size(); ++k) {
const auto &input_repeats = concat_node->inputs[k].attr.repeats;
auto input_axis_size = af::ops::One;
af::Expression output_axis_size = af::ops::One;
for (size_t i = concat_dim + 1; i < output_repeats.size(); ++i) {
input_axis_size = input_axis_size * input_repeats[i];
output_axis_size = output_axis_size * output_repeats[i];
GE_LOGW_IF(!EXPECT_SYMBOL_EQ(input_axis_size, output_axis_size),
"expect axis eq failed, concat_dim = %zu, cur_dim = %zu, input_axis_size = %s, output_axis_size = %s",
concat_dim, i, input_axis_size.Str().get(), output_axis_size.Str().get());
}
}
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::PrepareForModifyingGraph(const af::AscNodePtr &concat_node) {
GE_ASSERT_SUCCESS(CollectBackwardNodes(concat_node, post_concat_nodes_));
GE_ASSERT_SUCCESS(CollectReachableLoadNodes(concat_node, reachable_load_nodes_));
for (const auto &in_anchor_and_node : af::NodeUtils::GetOutDataNodesWithAnchorByIndex(*concat_node, 0)) {
out_node_name_to_indices_[in_anchor_and_node.second->GetName()].emplace_back(in_anchor_and_node.first->GetIdx());
}
return af::SUCCESS;
}
Status ConcatFusionCaseGenerator::RunCastOptimizationPass(std::vector<ascir::ImplGraph> &graphs) {
const auto backend_spec = BackendSpec::GetInstance();
GE_ASSERT_NOTNULL(backend_spec);
for (auto &graph : graphs) {
GE_ASSERT_SUCCESS(af::optimize::CastOptimizationPass::Run(graph, backend_spec->concat_alg));
}
return af::SUCCESS;
}
}