* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "optimize/task_generator/concat_group_partitioner.h"
#include <queue>
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "graph/symbolizer/symbolic_utils.h"
#include "graph/utils/graph_utils.h"
#include "ascir_ops.h"
#include "ascir/meta/ascir_ops_utils.h"
#include "optimize/schedule_utils.h"
#include "platform/platform_factory.h"
namespace optimize {
namespace {
constexpr int32_t kAlignment = 32;
constexpr int32_t kConcatAlgTranspose = 0;
constexpr int64_t kVectorBlockSize = 256;
constexpr int64_t kMinGroupNum = 2L;
constexpr int64_t kMinGroupSizeByte = 1024U;
template <typename T1, typename T2>
int64_t CeilDiv(const T1 n1, const T2 n2) {
GE_ASSERT_TRUE(n1 != 0);
GE_ASSERT_TRUE(n2 != 0);
return ((static_cast<int64_t>(n1) - 1) / static_cast<int64_t>(n2)) + 1;
}
}
ConcatGroupPartitioner::ConcatGroupPartitioner(af::AscNodePtr concat_node, size_t concat_dim)
: concat_node_(std::move(concat_node)), concat_dim_(concat_dim) {
}
Status ConcatGroupPartitioner::Initialize() {
const auto backend_spec = BackendSpec::GetInstance();
GE_ASSERT_NOTNULL(backend_spec);
concat_by_transpose_ = (backend_spec->concat_alg == kConcatAlgTranspose);
default_cols_per_group_ = kMaxBlockSize / dtype_size_;
max_input_num_per_group_ = MaxInputNumPerGroup();
if (known_rows_ <= kGroupParallelRowThreshold) {
const auto new_size = static_cast<uint32_t>(CeilDiv(concat_dim_sizes_.size(), kMinGroupNum));
max_input_num_per_group_ = std::min(max_input_num_per_group_, new_size);
}
const auto &output_attr = concat_node_->outputs[0].attr;
if (concat_by_transpose_) {
const auto is_tail_concat = (concat_dim_ == output_attr.repeats.size() - 1UL);
can_use_small_tail_ = is_tail_concat && (dtype_size_ == sizeof(uint16_t) || dtype_size_ == sizeof(uint32_t));
group_type_to_limit_[kGroupTypeSmallTail] = kMaxBlockSizeForSmallTail;
} else {
if (output_cols_ > 0) {
GE_ASSERT_SUCCESS(TryOptimizeGroupSize());
}
is_concat_scalar_ = (total_rows_ == 1);
}
GELOGI("output_repeat = %s, concat_dim = %zu, input_num = %u, max_input_num_per_group = %u",
af::ToString(output_attr.repeats).c_str(), concat_dim_, concat_node_->inputs.Size(), max_input_num_per_group_);
return ge::SUCCESS;
}
Status ConcatGroupPartitioner::PartitionGroups(std::vector<ConcatGroup> &groups) {
GE_ASSERT_SUCCESS(ParseConcatNode());
GE_CHK_BOOL_RET_SPECIAL_STATUS((stride_ == -1), ge::SUCCESS, "contains non-static dim after concat dim");
GE_ASSERT_SUCCESS(Initialize());
const auto &all_in_data_anchors = concat_node_->GetAllInDataAnchorsPtr();
for (size_t i = 0UL; i < all_in_data_anchors.size(); ++i) {
const int64_t size = concat_dim_sizes_[i];
const auto new_group_type = GetGroupType(size);
if (InputHasTransposeOrReduce(i) || (size < 0) || ((size * dtype_size_) > kGroupEltSizeThreshold) || (new_group_type == kGroupTypeNone)) {
if (index_start_ != -1) {
GroupEnd(i);
}
groups_.emplace_back(ConcatGroup{i, i + 1, kGroupTypeNone, -1});
continue;
}
if (index_start_ == -1) {
GroupStart(static_cast<int64_t>(i), new_group_type, size);
continue;
}
if (NeedSubmit(i, size, new_group_type)) {
GroupEnd(i);
GroupStart(static_cast<int64_t>(i), new_group_type, size);
} else {
UpdateStatus(size);
}
}
if (index_start_ != -1) {
GroupEnd(all_in_data_anchors.size());
}
MergeSmallGroups();
ConvertToDefaultIfTooSmall();
MergeSmallGroups();
if ((groups_.size() > 1) && (groups_.size() != concat_node_->inputs.Size())) {
GE_ASSERT_SUCCESS(RecomputeNodesCrossGroups(false, has_recompute_));
GE_ASSERT_SUCCESS(RecomputeNodesCrossGroups(true, has_recompute_));
}
groups = std::move(groups_);
return ge::SUCCESS;
}
bool ConcatGroupPartitioner::HasRecompute() const {
return has_recompute_;
}
bool ConcatGroupPartitioner::NeedSubmit(size_t i, int64_t size, uint32_t new_group_type) {
if (((cur_size_ + size) > GetSizeLimitByGroupType(group_type_)) ||
(i - index_start_ >= max_input_num_per_group_)) {
GELOGD("Size limit(%ld) or number limit(%u) reached, index = %zu, size = %ld, num = %zu",
GetSizeLimitByGroupType(group_type_), max_input_num_per_group_,
i, cur_size_ + size, i - index_start_);
return true;
}
if (new_group_type == group_type_) {
return false;
}
auto merged_group_type = group_type_ & new_group_type;
if (merged_group_type == 0U) {
if ((i - index_start_ == 1) && (group_type_ != kGroupTypeDefault) && (group_type_ != kGroupTypeScalar)) {
GELOGD("group has only one element, convert to %s", GroupTypeToString(new_group_type).c_str());
group_type_ = kGroupTypeDefault;
return false;
}
GELOGD("Group type changed, index = %zu, group_type = [%s], new_group_type = [%s]",
i, GroupTypeToString(group_type_).c_str(), GroupTypeToString(new_group_type).c_str());
return true;
}
if ((merged_group_type == kGroupTypeSmallTail) && (cur_size_ + size) > GetSizeLimitByGroupType(kGroupTypeSmallTail)) {
GELOGD("Size limit(%ld) reached, index = %zu, size = %ld",
GetSizeLimitByGroupType(kGroupTypeSmallTail), i, cur_size_ + size);
return true;
}
group_type_ = merged_group_type;
return false;
}
void ConcatGroupPartitioner::UpdateStatus(int64_t size) {
cur_size_ += size;
if ((group_type_ == kGroupTypeSmallTailAndAligned) && (cur_size_ >= GetSizeLimitByGroupType(kGroupTypeSmallTail))) {
GELOGD("size(%ld) >= size limit(%ld), concat type from [AlignAndSmallTail] to [Aligned]",
cur_size_, GetSizeLimitByGroupType(kGroupTypeSmallTail));
group_type_ = kGroupTypeAligned;
}
}
void ConcatGroupPartitioner::GroupStart(int64_t index_start, uint32_t group_type, int64_t size) {
index_start_ = index_start;
group_type_ = group_type;
cur_size_ = size;
GELOGD("group start, index = %zu, type = %s", index_start, GroupTypeToString(group_type).c_str());
}
void ConcatGroupPartitioner::GroupEnd(size_t index_end) {
GELOGD("group end, start_index = %zu, end_index = %zu, type = [%s], size = %ld",
index_start_, index_end, GroupTypeToString(group_type_).c_str(), cur_size_);
groups_.emplace_back(ConcatGroup{static_cast<size_t>(index_start_), index_end, group_type_, cur_size_});
index_start_ = -1;
group_type_ = -1;
}
int64_t ConcatGroupPartitioner::GetSizeLimitByGroupType(uint32_t group_type) const {
const auto it = group_type_to_limit_.find(group_type);
return (it != group_type_to_limit_.end()) ? it->second : default_cols_per_group_;
}
uint32_t ConcatGroupPartitioner::GetGroupType(int64_t size) const {
if (is_concat_scalar_) {
return (size <= (kVectorBlockSize / dtype_size_)) ? kGroupTypeScalar : kGroupTypeNone;
}
if (use_default_group_) {
return kGroupTypeDefault;
}
bool aligned = ((size * dtype_size_) % kAlignment) == 0;
bool is_small_tail = (can_use_small_tail_ && (size <= kSmallTailLimit));
if (aligned && is_small_tail) {
return kGroupTypeSmallTailAndAligned;
}
if (aligned) {
return kGroupTypeAligned;
}
if (is_small_tail) {
return kGroupTypeSmallTail;
}
return kGroupTypeDefault;
}
void ConcatGroupPartitioner::MergeGroups(std::vector<ConcatGroup>::value_type &lhs_group,
std::vector<ConcatGroup>::value_type &rhs_group) {
rhs_group.start = lhs_group.start;
rhs_group.size += lhs_group.size;
lhs_group.size = 0;
lhs_group.end = lhs_group.start;
}
void ConcatGroupPartitioner::MergeSmallGroups() {
std::vector<ConcatGroup> groups;
for (size_t index = 0UL; index < groups_.size() - 1UL; ++index) {
auto &lhs_group = groups_[index];
auto &rhs_group = groups_[index + 1];
if (CanMerge(lhs_group, rhs_group)) {
if (((lhs_group.group_type != kGroupTypeNone) && (lhs_group.group_type == rhs_group.group_type)) ||
(IsAligned(lhs_group.group_type) && (rhs_group.group_type == kGroupTypeDefault))) {
MergeGroups(lhs_group, rhs_group);
}
}
}
for (const auto &group : groups_) {
if (group.start != group.end) {
groups.emplace_back(group);
}
}
groups_ = std::move(groups);
}
bool ConcatGroupPartitioner::CanMerge(const ConcatGroupPartitioner::ConcatGroup &lhs,
const ConcatGroupPartitioner::ConcatGroup &rhs) const {
auto total_num = (lhs.end - lhs.start) + (rhs.end - rhs.start);
auto any_group_has_single_item = (lhs.end - lhs.start == 1) || (rhs.end - rhs.start == 1);
return (any_group_has_single_item || (total_num <= max_input_num_per_group_)) &&
((lhs.size + rhs.size) <= kMaxBlockSize / dtype_size_);
}
void ConcatGroupPartitioner::ConvertToDefaultIfTooSmall() {
for (auto &group : groups_) {
if ((group.end - group.start == 1) &&
(IsSmallTail(group.group_type) || ((!concat_by_transpose_) && IsAligned(group.group_type)))) {
GELOGD("group start with index = %zu, size = 1, concat type from [%s] to [Default]", group.start,
GroupTypeToString(group.group_type).c_str());
group.group_type = kGroupTypeDefault;
}
}
}
std::string ConcatGroupPartitioner::GroupTypeToString(uint32_t group_type) {
static const std::map<uint32_t, std::string> kGroupTypeToStr {
{ConcatGroupPartitioner::kGroupTypeDefault, "Default"},
{ConcatGroupPartitioner::kGroupTypeAligned, "Aligned"},
{ConcatGroupPartitioner::kGroupTypeSmallTailAndAligned, "AlignAndSmallTail"},
{ConcatGroupPartitioner::kGroupTypeSmallTail, "SmallTail"},
{ConcatGroupPartitioner::kGroupTypeScalar, "Scalar"},
};
return kGroupTypeToStr.at(group_type);
}
bool ConcatGroupPartitioner::IsAligned(uint32_t group_type) {
return (group_type & kGroupTypeAligned) != 0U;
}
bool ConcatGroupPartitioner::IsSmallTail(uint32_t group_type) {
return (group_type & kGroupTypeSmallTail) != 0U;
}
af::Status ConcatGroupPartitioner::RecomputeNodesCrossGroups(bool search_backward, bool &has_recompute) const {
for (const auto &group : groups_) {
std::set<af::InDataAnchorPtr> visited_in_anchors;
std::map<std::string, af::AscNodePtr> name_to_new_node;
for (size_t i = group.start; i < group.end; ++i) {
GELOGD("input[%zu] check recompute start", i);
auto const in_anchor = concat_node_->GetInDataAnchor(static_cast<int32_t>(i));
int32_t depth = 1024;
while (--depth >= 0) {
af::InDataAnchorPtr to_split;
GE_ASSERT_SUCCESS(
FindFirstMultiOutputAnchors(in_anchor, group.end, search_backward, visited_in_anchors, to_split));
if (to_split == nullptr) {
break;
}
GE_ASSERT_SUCCESS(RecomputeInNodes(to_split, i, name_to_new_node));
has_recompute = true;
}
GE_ASSERT_TRUE(depth >= 0);
}
}
return ge::SUCCESS;
}
af::Status ConcatGroupPartitioner::FindFirstMultiOutputAnchors(const af::InDataAnchorPtr &in_anchor, int32_t end_index,
bool search_backward,
std::set<af::InDataAnchorPtr> &visited_in_anchors,
af::InDataAnchorPtr &to_split) const {
std::vector<const af::Node *> root_nodes;
std::queue<af::InDataAnchorPtr> in_anchors;
in_anchors.push(in_anchor);
const auto &concat_dim_size = concat_node_->inputs[in_anchor->GetIdx()].attr.repeats[concat_dim_];
std::set<af::NodePtr> visited;
while (!in_anchors.empty()) {
const auto cur_in_anchor = in_anchors.front();
in_anchors.pop();
auto out_anchor = cur_in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr) {
continue;
}
auto owner_node = out_anchor->GetOwnerNode();
GE_ASSERT_NOTNULL(owner_node);
if (visited_in_anchors.emplace(cur_in_anchor).second) {
std::set<af::NodePtr> out_nodes;
for (const auto &out_node : owner_node->GetOutDataNodes()) {
out_nodes.emplace(out_node);
}
if (out_nodes.size() > 1UL ||
(((*out_nodes.begin())->GetType() == "Concat") && (out_anchor->GetPeerAnchorsSize() > 1UL))) {
bool need_split = false;
GE_ASSERT_SUCCESS(CheckIsAncestorOfConcat(out_anchor, end_index, concat_dim_size, search_backward, need_split));
GELOGD("%s has multi-ref output, end_index = %d, need_split = %d", out_anchor->GetOwnerNode()->GetNamePtr(),
end_index, need_split);
if (need_split) {
to_split = cur_in_anchor;
return ge::SUCCESS;
}
}
}
for (const auto &in_data_anchor : owner_node->GetAllInDataAnchors()) {
if (in_data_anchor != nullptr) {
in_anchors.push(in_data_anchor);
}
}
}
return ge::SUCCESS;
}
af::Status ConcatGroupPartitioner::CheckIsAncestorOfConcat(const af::OutDataAnchorPtr &out_anchor, int32_t start_index,
const af::Expression &concat_dim_size,
bool search_backward,
bool &need_split) const {
std::vector<const af::Node *> nodes;
std::set<const af::Node *> visited;
visited.emplace(concat_node_.get());
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
const auto owner_node = peer_in_anchor->GetOwnerNode().get();
GE_ASSERT_NOTNULL(owner_node);
if ((owner_node == concat_node_.get()) && NeedSplit(peer_in_anchor, start_index, concat_dim_size)) {
need_split = true;
return ge::SUCCESS;
}
if (visited.emplace(owner_node).second) {
nodes.emplace_back(owner_node);
}
}
while (!nodes.empty()) {
const auto cur_node = nodes.back();
nodes.pop_back();
for (const auto &[out_node, in_anchor] : cur_node->GetOutDataNodesAndAnchors()) {
if (out_node == concat_node_) {
if (NeedSplit(in_anchor, start_index, concat_dim_size)) {
need_split = true;
return ge::SUCCESS;
}
} else {
if (visited.emplace(out_node.get()).second) {
nodes.emplace_back(out_node.get());
}
}
}
if (search_backward) {
for (const auto &in_node: cur_node->GetInDataNodes()) {
if (visited.emplace(in_node.get()).second) {
nodes.emplace_back(in_node.get());
}
}
}
}
return ge::SUCCESS;
}
bool ConcatGroupPartitioner::NeedSplit(const af::InDataAnchorPtr &in_anchor, int32_t start_index,
const af::Expression &cur_dim_size) const {
const auto &size = concat_node_->inputs[in_anchor->GetIdx()].attr.repeats[concat_dim_];
const auto cur_group_size = GetGroupSize(static_cast<size_t>(start_index - 1));
const auto next_group_size = GetGroupSize(static_cast<size_t>(in_anchor->GetIdx()));
const auto is_single_group = (cur_group_size == 1) && (next_group_size == 1);
const auto need_split =
((in_anchor->GetIdx() >= start_index) &&
((!is_single_group) || af::SymbolicUtils::StaticCheckEq(size, cur_dim_size) != af::TriBool::kTrue));
GELOGD("start_index = %d, next_index = %d, is_single_group = %d, cur_size = %s, next_size = %s, need_split = %d",
start_index, in_anchor->GetIdx(), static_cast<int32_t>(is_single_group),
af::SymbolicUtils::ToString(cur_dim_size).c_str(), af::SymbolicUtils::ToString(size).c_str(),
static_cast<int32_t>(need_split));
return need_split;
}
size_t ConcatGroupPartitioner::GetGroupSize(size_t index) const {
for (const auto &group : groups_) {
if ((group.start <= index) && (index < group.end)) {
return (group.end - group.start);
}
}
return 0UL;
}
af::Status ConcatGroupPartitioner::RecomputeInNodes(const af::InDataAnchorPtr &in_anchor, size_t index,
std::map<std::string, af::AscNodePtr> &name_to_new_nodes) const {
ascir::ImplGraph owner_graph("");
GE_ASSERT_SUCCESS(af::AscGraphUtils::FromComputeGraph(concat_node_->GetOwnerComputeGraph(), owner_graph));
auto out_anchor = in_anchor->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(out_anchor);
auto asc_node = std::dynamic_pointer_cast<af::AscNode>(out_anchor->GetOwnerNode());
GE_ASSERT_NOTNULL(asc_node);
af::AscNodePtr &dst_new_node = name_to_new_nodes[asc_node->GetName()];
if (dst_new_node == nullptr) {
GELOGD("concat input index = %zu, ancestor node %s multi-ref output, re-compute it", index, asc_node->GetNamePtr());
GE_ASSERT_EQ(asc_node->GetAllOutDataAnchorsSize(), 1U);
const auto &op_desc = af::GraphUtils::CopyOpDesc(asc_node->GetOpDesc(), nullptr);
GE_CHECK_NOTNULL(op_desc);
op_desc->SetName(asc_node->GetName() + "_recompute_" + std::to_string(index));
af::Operator op = af::OpDescUtils::CreateOperatorFromOpDesc(op_desc);
dst_new_node = owner_graph.AddNode(op);
GE_ASSERT_TRUE(af::AscGraph::CopyAscNodeTensorAttr(asc_node, dst_new_node),
"DoCopyAscNodeTensorAttr failed, node = %s[%s]", asc_node->GetNamePtr(), asc_node->GetTypePtr());
for (const auto &in_data_anchor : asc_node->GetAllInDataAnchorsPtr()) {
if (in_data_anchor != nullptr) {
const auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_anchor != nullptr) {
GE_ASSERT_GRAPH_SUCCESS(
af::GraphUtils::AddEdge(peer_out_anchor, dst_new_node->GetInDataAnchor(in_data_anchor->GetIdx())));
}
}
}
}
in_anchor->UnlinkAll();
GE_ASSERT_GRAPH_SUCCESS(af::GraphUtils::AddEdge(dst_new_node->GetOutDataAnchor(0), in_anchor));
return ge::SUCCESS;
}
Status ConcatGroupPartitioner::ParseConcatNode() {
const auto &output_attr = concat_node_->outputs[0].attr;
dtype_size_ = ge::GetSizeByDataType(output_attr.dtype);
GE_ASSERT_TRUE(dtype_size_ > 0, "unsupported dtype: %d", static_cast<int32_t>(output_attr.dtype));
int64_t stride = 1;
for (size_t i = concat_dim_ + 1; i < output_attr.repeats.size(); ++i) {
const auto &expr = output_attr.repeats[i];
GE_CHK_BOOL_RET_SPECIAL_STATUS(!expr.IsConstExpr(), ge::SUCCESS, "contains non-static dim after concat dim");
int64_t value = -1;
GE_ASSERT_TRUE(expr.GetConstValue(value));
GE_ASSERT_TRUE(value >= 0);
stride *= value;
}
stride_ = stride;
int64_t concat_dim_size = -1;
if (output_attr.repeats[concat_dim_].IsConstExpr()) {
GE_ASSERT_TRUE(output_attr.repeats[concat_dim_].GetConstValue(concat_dim_size));
output_cols_ = concat_dim_size * stride_;
}
GELOGD("concat_dim_stride = %ld, concat_dim = %ld", stride_, concat_dim_size);
bool all_known = true;
for (size_t i = 0UL; i < concat_dim_; ++i) {
const auto &expr = output_attr.repeats[i];
if (expr.IsConstExpr()) {
int64_t value = -1;
GE_ASSERT_TRUE(expr.GetConstValue(value) && (value >= 0));
known_rows_ *= value;
}
all_known = all_known && expr.IsConstExpr();
}
total_rows_ = all_known ? known_rows_ : -1;
GELOGD("known rows = %ld, total_rows = %ld", known_rows_, total_rows_);
const auto &all_in_data_anchors = concat_node_->GetAllInDataAnchorsPtr();
for (size_t i = 0UL; i < all_in_data_anchors.size(); ++i) {
const auto &repeats = concat_node_->inputs[i].attr.repeats;
const auto &expr = repeats[concat_dim_];
if (expr.IsConstExpr()) {
int64_t size = -1;
GE_ASSERT_TRUE(expr.GetConstValue(size));
GE_ASSERT_TRUE(size >= 0);
size *= stride_;
concat_dim_sizes_.emplace_back(size);
} else {
concat_dim_sizes_.emplace_back(-1);
}
}
return ge::SUCCESS;
}
Status ConcatGroupPartitioner::TryOptimizeGroupSize() {
GELOGD("all input concat dim is known shape, try to optimize group size");
const auto kMinColsPerGroup = kMinGroupSizeByte / dtype_size_;
use_default_group_ = true;
const auto num_inputs = concat_dim_sizes_.size();
auto num_groups = CeilDiv(concat_dim_sizes_.size(), max_input_num_per_group_);
GE_ASSERT_TRUE(num_groups > 0);
int64_t avg_cols_per_group = output_cols_ / num_groups;
avg_cols_per_group = std::min(avg_cols_per_group, (kMaxBlockSize / dtype_size_));
GELOGD("num_inputs = %zu, max_input_num_per_group = %u, estimated num_groups = %ld, cols_per_group = %ld", num_inputs,
max_input_num_per_group_, num_groups, avg_cols_per_group);
if (avg_cols_per_group <= kMinColsPerGroup) {
avg_cols_per_group = kMinColsPerGroup;
max_input_num_per_group_ = std::max(max_input_num_per_group_, MaxInputNumPerGroup());
GELOGD("group is too small, adjust cols_per_group = %ld, max_input_num_per_group = %u", kMinColsPerGroup,
max_input_num_per_group_);
}
default_cols_per_group_ = avg_cols_per_group;
return ge::SUCCESS;
}
uint32_t ConcatGroupPartitioner::MaxInputNumPerGroup() const {
constexpr uint32_t kLargeInputNum = 512;
constexpr uint32_t kMaxInputNum = 32U;
const auto min_group_size = concat_by_transpose_ ? 32U : 16U;
const uint32_t max_input_num = (concat_dim_sizes_.size() >= kLargeInputNum) ? kMaxInputNum : min_group_size;
return max_input_num;
}
Status ConcatGroupPartitioner::RecomputeDiffAxes() {
const auto num_inputs = concat_node_->inputs.Size();
for (uint32_t i = 0U; i < num_inputs; ++i) {
groups_.emplace_back(ConcatGroup{i, i + 1, kGroupTypeDefault, 0});
}
GE_ASSERT_SUCCESS(RecomputeNodesCrossGroups(false, has_recompute_));
GE_ASSERT_SUCCESS(RecomputeNodesCrossGroups(true, has_recompute_));
return ge::SUCCESS;
}
bool ConcatGroupPartitioner::InputHasTransposeOrReduce(size_t input_index) const {
const auto in_anchor = concat_node_->GetInDataAnchor(static_cast<int32_t>(input_index));
GE_ASSERT_NOTNULL(in_anchor);
auto out_anchor = in_anchor->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(out_anchor);
std::set<const af::Node *> visited;
std::queue<const af::Node *> nodes;
auto owner_node = out_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(owner_node);
nodes.push(owner_node.get());
visited.insert(owner_node.get());
while (!nodes.empty()) {
const auto cur_node = nodes.front();
nodes.pop();
auto asc_node = std::dynamic_pointer_cast<af::AscNode>(const_cast<af::Node *>(cur_node)->shared_from_this());
GE_ASSERT_NOTNULL(asc_node);
if (ScheduleUtils::IsTranspose(asc_node) || ScheduleUtils::IsReduce(asc_node)) {
GELOGI("concat input[%zu] path reaches Transpose/Reduce node %s", input_index, cur_node->GetNamePtr());
return true;
}
for (const auto &in_node : cur_node->GetInDataNodes()) {
if (visited.insert(in_node.get()).second) {
nodes.push(in_node.get());
}
}
}
return false;
}
}