* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "tiling_group.h"
#include "ascir_ops_utils.h"
#include "ascgraph_info_complete.h"
#include "ascir_utils.h"
#include "schedule_utils.h"
#include "graph/symbolizer/symbolic_utils.h"
namespace optimize::autoschedule {
namespace {
constexpr int64_t kDefaultGroup = -1;
void PrintGroup(std::stringstream &ss, const std::string &name, const std::vector<af::AxisId> &group) {
ss << name << "[";
for (auto axis_id : group) {
ss << axis_id << ",";
}
ss << "],";
}
}
std::string AxisGroup::ToString() const {
std::stringstream ss;
PrintGroup(ss, "XGroup:", x_group);
PrintGroup(ss, "YGroup:", y_group);
PrintGroup(ss, "RGroup:", r_group);
PrintGroup(ss, "NGroup:", n_group);
ss << "Order:[";
for (auto &axis : axes_order) {
ss << axis << ",";
}
ss << "]";
return ss.str();
}
bool AxisGroup::IsEmpty() const {
return x_group.empty() && y_group.empty() && r_group.empty() && n_group.empty();
}
extern "C" {
int32_t GenAscGraphAxisGroup(const af::AscGraph &graph, AxisGroup &axes_group) {
GELOGD("Enter [GenAscGraphAxisGroup], graph name: %s, axis size: %zu", graph.GetName().c_str(),
graph.GetAllAxis().size());
for (auto &axis : graph.GetAllAxis()) {
GE_CHECK_NOTNULL(axis);
GELOGD("[GenAscGraphAxisGroup] AscGraph axis.name=%s, axis.id=%ld, axis.size=%s", axis->name.c_str(), axis->id,
axis->size.Str().get());
}
GE_CHK_STATUS_RET(AscGraphInfoComplete::CompleteApiInfo(graph));
GE_CHK_STATUS_RET(TilingGroup::GenTilingGroup(graph, axes_group), "Gen axis map failed, graph[%s] may be invalid.",
graph.GetName().c_str());
GELOGD("Finish [GenAscGraphAxisGroup], graph name: %s", graph.GetName().c_str());
return ge::SUCCESS;
}
bool CanMergeAxisGroup(const AxisGroup &lhs, const AxisGroup &rhs, AxisGroup &merged_group, bool is_ge_call) {
merged_group = lhs;
AxisGroup new_rhs = rhs;
auto ret = TilingGroup::MergeAxesGroup(merged_group, new_rhs, true, is_ge_call);
GELOGI("Merged axis group: %s to target: %s, res:[%d].", rhs.ToString().c_str(), lhs.ToString().c_str(), ret);
return ret;
}
static constexpr size_t kMaxFullLoadAxisSizeForNorm = 3UL;
static bool CalculateRAxisTotalSize(const af::AscTensorAttr &input_attr,
const af::AscTensorAttr &output_attr,
int64_t &r_axis_total_size,
int64_t &a_axis_total_size) {
r_axis_total_size = 1;
a_axis_total_size = 1;
if (output_attr.repeats.empty() || output_attr.repeats.size() > kMaxFullLoadAxisSizeForNorm) {
GELOGD("Output repeats size %zu exceeds max full load axis size %zu",
output_attr.repeats.size(), kMaxFullLoadAxisSizeForNorm);
return false;
}
if (af::SymbolicUtils::StaticCheckEq(input_attr.repeats[0], output_attr.repeats[0]) != af::TriBool::kTrue) {
GELOGD("First axis of input and output not equal, not AR/ARA pattern");
return false;
}
for (size_t i = 0; i < input_attr.repeats.size(); ++i) {
if (!input_attr.repeats[i].IsConstExpr() || !output_attr.repeats[i].IsConstExpr()) {
return false;
}
int64_t input_size = 0;
int64_t output_size = 0;
if (!input_attr.repeats[i].GetConstValue(input_size) || !output_attr.repeats[i].GetConstValue(output_size)) {
return false;
}
if (input_size > output_size) {
r_axis_total_size *= input_size;
} else {
a_axis_total_size *= output_size;
}
}
return true;
}
static bool IsStaticShape(const af::AscNodePtr &node) {
if (node == nullptr || node->outputs().empty()) {
return false;
}
for (const auto &node_out : node->outputs()) {
if (node_out->attr.repeats.empty()) {
return false;
}
for (const auto &repeat : node_out->attr.repeats) {
if (!repeat.IsConstExpr()) {
return false;
}
}
}
return true;
}
static bool IsStaticGraph(const af::AscGraph &graph) {
for (const auto &node : graph.GetAllNodes()) {
if (!ScheduleUtils::IsLoad(node)) {
continue;
}
if (!IsStaticShape(node)) {
return false;
}
}
GELOGD("Graph[%s] is static", graph.GetName().c_str());
return true;
}
static bool CheckReduceNodeNormLike(const af::AscNodePtr &asc_node) {
constexpr int64_t kThresholdTR = 32;
constexpr int64_t kThresholdTA = 128;
GELOGD("Found reduce node: %s", asc_node->GetNamePtr());
if (asc_node->inputs().empty() || asc_node->outputs().empty()) {
GELOGW("Reduce node %s has no inputs or outputs", asc_node->GetNamePtr());
return false;
}
const af::AscTensorAttr *input_attr_ptr = nullptr;
af::OutDataAnchorPtr current_out_anchor = nullptr;
if (asc_node->GetAllInDataAnchorsSize() > 0 && asc_node->GetInDataAnchor(0) != nullptr) {
current_out_anchor = asc_node->GetInDataAnchor(0)->GetPeerOutAnchor();
}
int32_t traverse_depth = 0;
const int32_t max_traverse_depth = 100;
while (current_out_anchor != nullptr && traverse_depth++ < max_traverse_depth) {
auto current_tensor_attr = af::AscTensorAttr::GetTensorAttrPtr(*current_out_anchor);
auto current_node = std::dynamic_pointer_cast<af::AscNode>(current_out_anchor->GetOwnerNode());
bool is_load = (current_node != nullptr && ScheduleUtils::IsLoad(current_node));
bool has_valid_attr = (current_tensor_attr != nullptr && !current_tensor_attr->repeats.empty());
if (has_valid_attr || is_load) {
if (has_valid_attr) {
input_attr_ptr = current_tensor_attr;
}
break;
}
if (current_node == nullptr || current_node->GetAllInDataAnchorsSize() == 0 ||
current_node->GetInDataAnchor(0) == nullptr) {
return false;
}
auto next_in_anchor = current_node->GetInDataAnchor(0);
current_out_anchor = next_in_anchor->GetPeerOutAnchor();
}
const auto &output = asc_node->outputs[0];
if (input_attr_ptr == nullptr || input_attr_ptr->repeats.empty() || input_attr_ptr->repeats.size() != output.attr.repeats.size()) {
return false;
}
int64_t r_axis_total_size = 1;
int64_t a_axis_total_size = 1;
if (!CalculateRAxisTotalSize(*input_attr_ptr, output.attr, r_axis_total_size, a_axis_total_size)) {
GELOGD("Reduce node %s: failed to calculate R/A axis size (non-const shape)", asc_node->GetNamePtr());
return false;
}
if (r_axis_total_size > kThresholdTR || a_axis_total_size < kThresholdTA) {
return false;
}
GELOGI("Reduce node %s passed check: R_axis=%ld (threshold=%ld), A_axis=%ld (threshold=%ld)",
asc_node->GetNamePtr(), r_axis_total_size, kThresholdTR, a_axis_total_size, kThresholdTA);
return true;
}
bool IsNormLikeReduceGraph(const af::AscGraph &graph) {
if (!IsStaticGraph(graph)) {
GELOGI("AscGraph is not static shape, return false for IsNormLikeReduceGraph");
return false;
}
bool has_reduce = false;
for (const auto &asc_node : graph.GetAllNodes()) {
if (!ScheduleUtils::IsReduce(asc_node)) {
continue;
}
has_reduce = true;
if (!CheckReduceNodeNormLike(asc_node)) {
return false;
}
}
if (!has_reduce) {
GELOGD("AscGraph has no reduce node, return false for IsNormLikeReduceGraph");
return false;
}
GELOGD("AscGraph passed IsNormLikeReduceGraph check");
return true;
}
}
static bool CheckYAndYR(const std::vector<ascir::AxisId> &cur_y_group, const std::vector<ascir::AxisId> &new_y_group,
const std::vector<ascir::AxisId> &new_r_group, const std::vector<size_t> &new_axes_order,
const bool is_canfuse_call) {
if (cur_y_group.size() != (new_y_group.size() + new_r_group.size())) {
return false;
}
if (is_canfuse_call) {
std::set<int64_t> l_g(cur_y_group.begin(), cur_y_group.end());
std::set<int64_t> r_g(new_y_group.begin(), new_y_group.end());
r_g.insert(new_r_group.begin(), new_r_group.end());
return l_g == r_g;
}
for (size_t i = 0UL; i < new_y_group.size(); i++) {
if (cur_y_group[new_axes_order[i]] != new_y_group[i]) {
return false;
}
}
size_t y_size = new_y_group.size();
for (size_t i = 0UL; i < new_r_group.size(); i++) {
if (cur_y_group[new_axes_order[y_size + i]] != new_r_group[i]) {
return false;
}
}
return true;
}
static bool CheckYAndR(const std::vector<ascir::AxisId> &cur_y_group, const std::vector<ascir::AxisId> &new_r_group,
const bool is_canfuse_call) {
if (is_canfuse_call) {
std::set<int64_t> l_g(cur_y_group.begin(), cur_y_group.end());
std::set<int64_t> r_g(new_r_group.begin(), new_r_group.end());
return l_g == r_g;
}
return cur_y_group == new_r_group;
}
static bool MergeYAndY(AxisGroup &lhs_group, AxisGroup &rhs_group, const bool is_canfuse_call, const bool is_ge_call) {
(void)is_ge_call;
if (lhs_group.y_group.size() != rhs_group.y_group.size()) {
return false;
}
if (is_canfuse_call) {
std::set<int64_t> lhs(lhs_group.y_group.begin(), lhs_group.y_group.end());
std::set<int64_t> rhs(rhs_group.y_group.begin(), rhs_group.y_group.end());
return lhs == rhs;
}
for (size_t i = 0UL; i < lhs_group.y_group.size(); i++) {
if (lhs_group.y_group[i] != rhs_group.y_group[i]) {
return false;
}
}
return true;
}
static bool MergeYAndR(AxisGroup &lhs_group, AxisGroup &rhs_group, const bool is_canfuse_call, const bool is_ge_call) {
(void)is_ge_call;
if (!CheckYAndR(lhs_group.y_group, rhs_group.r_group, is_canfuse_call)) {
return false;
}
lhs_group = rhs_group;
return true;
}
static bool MergeRAndY(AxisGroup &lhs_group, AxisGroup &rhs_group, const bool is_canfuse_call, const bool is_ge_call) {
(void)is_ge_call;
return CheckYAndR(rhs_group.y_group, lhs_group.r_group, is_canfuse_call);
}
static bool MergeYAndXY(AxisGroup &lhs_group, AxisGroup &rhs_group, const bool is_canfuse_call, const bool is_ge_call) {
(void)is_canfuse_call;
(void)is_ge_call;
std::set<int64_t> lhs(lhs_group.y_group.begin(), lhs_group.y_group.end());
std::set<int64_t> rhs(rhs_group.y_group.begin(), rhs_group.y_group.end());
rhs.insert(rhs_group.x_group.begin(), rhs_group.x_group.end());
if (lhs != rhs) {
return false;
}
lhs_group = rhs_group;
return true;
}
static bool MergeXYAndY(AxisGroup &lhs_group, AxisGroup &rhs_group, const bool is_canfuse_call, const bool is_ge_call) {
(void)is_canfuse_call;
(void)is_ge_call;
std::set<int64_t> lhs(lhs_group.y_group.begin(), lhs_group.y_group.end());
lhs.insert(lhs_group.x_group.begin(), lhs_group.x_group.end());
std::set<int64_t> rhs(rhs_group.y_group.begin(), rhs_group.y_group.end());
if (lhs != rhs) {
return false;
}
return true;
}
static bool MergeYAndYR(AxisGroup &lhs_group, AxisGroup &rhs_group, const bool is_canfuse_call, const bool is_ge_call) {
(void)is_ge_call;
if (!CheckYAndYR(lhs_group.y_group, rhs_group.y_group, rhs_group.r_group, rhs_group.axes_order, is_canfuse_call)) {
return false;
}
lhs_group = rhs_group;
return true;
}
static bool MergeYRAndY(AxisGroup &lhs_group, AxisGroup &rhs_group, const bool is_canfuse_call, const bool is_ge_call) {
(void)is_ge_call;
return CheckYAndYR(rhs_group.y_group, lhs_group.y_group, lhs_group.r_group, lhs_group.axes_order, is_canfuse_call);
}
static bool MergeYRAndYR(AxisGroup &lhs_group, AxisGroup &rhs_group, const bool is_canfuse_call, const bool is_ge_call) {
(void)is_ge_call;
if (is_canfuse_call) {
std::set<int64_t> l_y(lhs_group.y_group.begin(), lhs_group.y_group.end());
std::set<int64_t> r_y(rhs_group.y_group.begin(), rhs_group.y_group.end());
std::set<int64_t> l_r(lhs_group.r_group.begin(), lhs_group.r_group.end());
std::set<int64_t> r_r(rhs_group.r_group.begin(), rhs_group.r_group.end());
return (l_y == r_y) && (l_r == r_r);
}
return (lhs_group == rhs_group);
}
GroupType TilingGroup::GetGroupType(const AxisGroup &axes_group) {
GroupType type = GroupType::GROUP_INVALID;
if (!axes_group.x_group.empty()) {
type = static_cast<GroupType>(type | GroupType::GROUP_X);
}
if (!axes_group.y_group.empty()) {
type = static_cast<GroupType>(type | GroupType::GROUP_Y);
}
if (!axes_group.r_group.empty()) {
type = static_cast<GroupType>(type | GroupType::GROUP_R);
}
return type;
}
static void RemoveNGroupAxisInXYRGroup(const std::set<af::AxisId> &n_groups, AxisGroup &single_node_axes_group) {
if (n_groups.empty()) {
return;
}
single_node_axes_group.n_group.assign(n_groups.begin(), n_groups.end());
std::vector<size_t> new_axes_order;
size_t idx = 0UL;
for (auto iter = single_node_axes_group.x_group.begin(); iter != single_node_axes_group.x_group.end();) {
if (n_groups.count(*iter) > 0UL) {
iter = single_node_axes_group.x_group.erase(iter);
} else {
++iter;
new_axes_order.push_back(single_node_axes_group.axes_order[idx]);
}
++idx;
}
for (auto iter = single_node_axes_group.y_group.begin(); iter != single_node_axes_group.y_group.end();) {
if (n_groups.count(*iter) > 0UL) {
iter = single_node_axes_group.y_group.erase(iter);
} else {
++iter;
new_axes_order.push_back(single_node_axes_group.axes_order[idx]);
}
++idx;
}
for (auto iter = single_node_axes_group.r_group.begin(); iter != single_node_axes_group.r_group.end();) {
if (n_groups.count(*iter) > 0UL) {
iter = single_node_axes_group.r_group.erase(iter);
} else {
++iter;
new_axes_order.push_back(single_node_axes_group.axes_order[idx]);
}
++idx;
}
single_node_axes_group.axes_order = new_axes_order;
}
static void MergeNGroup(AxisGroup &lhs_group, AxisGroup &rhs_group) {
std::set<af::AxisId> n_groupset(lhs_group.n_group.begin(), lhs_group.n_group.end());
n_groupset.insert(rhs_group.n_group.begin(), rhs_group.n_group.end());
RemoveNGroupAxisInXYRGroup(n_groupset, lhs_group);
RemoveNGroupAxisInXYRGroup(n_groupset, rhs_group);
}
bool TilingGroup::MergeAxesGroup(AxisGroup &target, AxisGroup &src, const bool is_canfuse_call, const bool is_ge_call) {
MergeNGroup(target, src);
if (target.x_group.empty() && target.y_group.empty() && target.r_group.empty()) {
target = src;
return true;
}
static std::map<std::pair<GroupType, GroupType>, AxisGroupMergeFunc> type_to_merge_func = {
{{GroupType::GROUP_Y, GroupType::GROUP_Y}, MergeYAndY},
{{GroupType::GROUP_Y, GroupType::GROUP_YR}, MergeYAndYR},
{{GroupType::GROUP_YR, GroupType::GROUP_Y}, MergeYRAndY},
{{GroupType::GROUP_YR, GroupType::GROUP_YR}, MergeYRAndYR},
{{GroupType::GROUP_Y, GroupType::GROUP_XY}, MergeYAndXY},
{{GroupType::GROUP_XY, GroupType::GROUP_Y}, MergeXYAndY},
{{GroupType::GROUP_Y, GroupType::GROUP_R}, MergeYAndR},
{{GroupType::GROUP_R, GroupType::GROUP_Y}, MergeRAndY},
};
auto merge_type = std::make_pair(GetGroupType(target), GetGroupType(src));
auto iter = type_to_merge_func.find(merge_type);
if (iter == type_to_merge_func.end()) {
return false;
}
return iter->second(target, src, is_canfuse_call, is_ge_call);
}
Status TilingGroup::GenTilingGroup(const ascir::ImplGraph &impl_graph, AxisGroup &tiling_group, bool is_reduce_fullload) {
std::vector<std::pair<std::string, AxisGroup>> node_name_to_tiling_group;
std::set<af::AxisId> n_groupset;
for (const auto &node : impl_graph.GetAllNodes()) {
if (ScheduleUtils::IsBuffer(node)) {
continue;
}
AxisGroup single_node_axes_group;
GE_CHK_STATUS_RET(GenAxisGroupForSingleNode(*node, single_node_axes_group, is_reduce_fullload));
n_groupset.insert(single_node_axes_group.n_group.begin(), single_node_axes_group.n_group.end());
node_name_to_tiling_group.emplace_back(node->GetName(), single_node_axes_group);
GELOGD("GenTilingGroup node: %s, group: %s.", node->GetName().c_str(), single_node_axes_group.ToString().c_str());
}
tiling_group.n_group.assign(n_groupset.begin(), n_groupset.end());
for (auto &iter : node_name_to_tiling_group) {
GE_ASSERT_TRUE(MergeAxesGroup(tiling_group, iter.second),
"Merged axis group: %s to target: %s failed, the graph [%s] cannot be fused.",
iter.second.ToString().c_str(), tiling_group.ToString().c_str(), impl_graph.GetName().c_str());
}
return ge::SUCCESS;
}
Status TilingGroup::GenElewiseTilingGroup(af::AscNode &node, AxisGroup &axes_group) {
GE_CHK_STATUS_RET(ScheduleUtils::GetLoopAxis(node, axes_group.y_group));
axes_group.axes_order.reserve(axes_group.y_group.size());
for (size_t i = 0UL; i < axes_group.y_group.size(); ++i) {
axes_group.axes_order.push_back(i);
}
return ge::SUCCESS;
}
std::vector<af::AxisId> CalcReduceAxes(const std::vector<af::Expression>& src_strides,
const std::vector<af::Expression>& dst_strides,
const std::vector<ascir::AxisId>& axes) {
GE_ASSERT_TRUE((src_strides.size() == dst_strides.size()),
"The output dim cnt [%zu] of reduce mismatch with input dim cnt [%zu].", dst_strides.size(),
src_strides.size());
GE_ASSERT_TRUE((src_strides.size() == axes.size()),
"The input dim cnt [%zu] of reduce mismatch with input dim cnt [%zu].", src_strides.size(),
axes.size());
std::vector<ascir::AxisId> reduce_axes;
for (size_t i = 0; i < src_strides.size(); ++i) {
if (af::SymbolicUtils::StaticCheckEq(src_strides[i], dst_strides[i]) != af::TriBool::kTrue &&
af::SymbolicUtils::StaticCheckEq(dst_strides[i], af::sym::kSymbolZero) == af::TriBool::kTrue) {
reduce_axes.push_back(axes[i]);
}
}
return reduce_axes;
}
Status TilingGroup::GenReduceTilingGroup(af::AscNode &node, AxisGroup &axes_group) {
std::vector<ascir::AxisId> axes;
GE_CHK_STATUS_RET(ScheduleUtils::GetLoopAxis(node, axes), "Get loop axis failed.");
axes_group.axes_order.resize(axes.size());
std::vector<ascir::SizeExpr> src_strides;
GE_CHK_STATUS_RET(ScheduleUtils::GetReduceInputStrides(node, src_strides), "Get loop strides failed.");
axes_group.r_group = CalcReduceAxes(src_strides, node.outputs[0].attr.strides, axes);
int64_t y_order_index = 0;
int64_t r_order_index = axes.size() - axes_group.r_group.size();
for (size_t i = 0; i < axes.size(); ++i) {
if (std::find(axes_group.r_group.begin(), axes_group.r_group.end(), axes[i]) == axes_group.r_group.end()) {
axes_group.y_group.push_back(axes[i]);
axes_group.axes_order[y_order_index++] = i;
} else {
axes_group.axes_order[r_order_index++] = i;
}
}
return ge::SUCCESS;
}
Status TilingGroup::GenReduceTilingGroupFullLoad(af::AscNode &node, AxisGroup &axes_group) {
std::vector<ascir::AxisId> axes;
GE_CHK_STATUS_RET(ScheduleUtils::GetLoopAxis(node, axes), "Get loop axis failed.");
axes_group.axes_order.resize(axes.size());
std::vector<ascir::SizeExpr> src_strides;
GE_CHK_STATUS_RET(ScheduleUtils::GetReduceInputStrides(node, src_strides), "Get loop strides failed.");
axes_group.n_group = CalcReduceAxes(src_strides, node.outputs[0].attr.strides, axes);
int64_t y_order_index = 0;
int64_t r_order_index = axes.size() - axes_group.n_group.size();
for (size_t i = 0; i < axes.size(); ++i) {
if (std::find(axes_group.n_group.begin(), axes_group.n_group.end(), axes[i]) == axes_group.n_group.end()) {
axes_group.y_group.push_back(axes[i]);
axes_group.axes_order[y_order_index++] = i;
} else {
axes_group.axes_order[r_order_index++] = i;
}
}
return ge::SUCCESS;
}
void PlaceRemainingAxis(const int64_t index, std::set<int64_t> &remaining_axis, AxisGroup &axes_group,
vector<ascir::AxisId> &output_axis) {
for (int64_t j = index; j >= 0; --j) {
if (remaining_axis.find(output_axis[j]) == remaining_axis.end()) {
continue;
}
axes_group.y_group.emplace(axes_group.y_group.begin(), output_axis[j]);
}
}
Status TilingGroup::GenTransposeTilingGroup(af::AscNode &node, AxisGroup &axes_group) {
std::vector<ascir::AxisId> input_axis;
GE_CHK_STATUS_RET(ScheduleUtils::GetInputForTranspose(node, input_axis), "Get Transpose loop axis failed.");
std::vector<ascir::AxisId> &output_axis = node.outputs[0].attr.axis;
GE_ASSERT_TRUE((input_axis.size() == output_axis.size()),
"The output dim cnt [%zu] of Transpose mismatch with input dim cnt [%zu].", output_axis.size(),
input_axis.size());
std::set<int64_t> remaining_axis(input_axis.begin(), input_axis.end());
GELOGD("GenTransposeTilingGroup input_axis %s, output_axis %s", af::ViewMemberToString(input_axis).c_str(),
af::ViewMemberToString(output_axis).c_str());
int64_t i = static_cast<int64_t>(input_axis.size()) - 1;
for (; i >= 0; --i) {
if (input_axis[i] == output_axis[i]) {
axes_group.n_group.emplace(axes_group.n_group.begin(), input_axis[i]);
remaining_axis.erase(input_axis[i]);
} else {
break;
}
}
for (; i >= 0; --i) {
if (std::find(axes_group.y_group.begin(), axes_group.y_group.end(), input_axis[i]) == axes_group.y_group.end()) {
axes_group.x_group.emplace(axes_group.x_group.begin(), input_axis[i]);
remaining_axis.erase(input_axis[i]);
}
if (std::find(axes_group.x_group.begin(), axes_group.x_group.end(), output_axis[i]) == axes_group.x_group.end()) {
axes_group.y_group.emplace(axes_group.y_group.begin(), output_axis[i]);
remaining_axis.erase(output_axis[i]);
}
if (std::find(axes_group.y_group.begin(), axes_group.y_group.end(), input_axis[i]) != axes_group.y_group.end() &&
std::find(axes_group.x_group.begin(), axes_group.x_group.end(), output_axis[i]) != axes_group.x_group.end()) {
PlaceRemainingAxis(i, remaining_axis, axes_group, output_axis);
break;
}
}
for (int64_t axis : axes_group.x_group) {
axes_group.axes_order.emplace_back(axis);
}
for (int64_t axis : axes_group.y_group) {
axes_group.axes_order.emplace_back(axis);
}
GELOGD("GenTransposeTilingGroup TilingGroup %s", axes_group.ToString().c_str());
return ge::SUCCESS;
}
Status TilingGroup::GenConcatTilingGroup(af::AscNode &node, AxisGroup &axes_group) {
std::vector<ascir::AxisId> axes;
GE_CHK_STATUS_RET(ScheduleUtils::GetLoopAxis(node, axes), "Get loop axis failed.");
const std::vector<ascir::SizeExpr> &input_repeats = node.inputs[0].attr.repeats;
const std::vector<ascir::SizeExpr> &output_repeats = node.outputs[0].attr.repeats;
GE_ASSERT_TRUE((input_repeats.size() == output_repeats.size()),
"The output dim cnt [%zu] of concat mismatch with input dim cnt [%zu].", output_repeats.size(),
input_repeats.size());
size_t concat_dim{0UL};
for (size_t i = 0UL; i < input_repeats.size(); ++i) {
if (af::SymbolicUtils::StaticCheckEq(input_repeats[i], output_repeats[i]) != af::TriBool::kTrue) {
GELOGD("Concat node [%s], input_repeats[%zu]=%s, output_repeats[%zu]=%s", node.GetNamePtr(), i,
input_repeats[i].Str().get(), i, output_repeats[i].Str().get());
concat_dim = i;
break;
}
}
GELOGD("Concat node [%s], concat_dim is %zu", node.GetNamePtr(), concat_dim);
axes_group.axes_order.reserve(axes.size());
for (size_t i = 0UL; i < axes.size(); ++i) {
if (i < concat_dim) {
axes_group.y_group.push_back(axes[i]);
axes_group.axes_order.push_back(i);
} else {
axes_group.n_group.push_back(axes[i]);
}
}
for (auto &input_view : node.inputs()) {
if (input_view->attr.axis.size() > concat_dim) {
axes_group.n_group.push_back(input_view->attr.axis[concat_dim]);
}
}
return ge::SUCCESS;
}
Status TilingGroup::GenSplitTilingGroup(af::AscNode &node, AxisGroup &axes_group) {
std::vector<ascir::AxisId> axes;
GE_CHK_STATUS_RET(ScheduleUtils::GetLoopAxis(node, axes), "Get loop axis failed.");
const std::vector<ascir::SizeExpr> &input_repeats = node.inputs[0].attr.repeats;
const std::vector<ascir::SizeExpr> &output_repeats = node.outputs[0].attr.repeats;
GE_ASSERT_TRUE((input_repeats.size() == output_repeats.size()),
"The output dim cnt [%zu] of split mismatch with input dim cnt [%zu].", output_repeats.size(),
input_repeats.size());
ascir::SizeExpr pre_size = af::ops::One;
size_t split_dim{0UL};
for (size_t i = 0UL; i < input_repeats.size(); ++i) {
if (af::SymbolicUtils::StaticCheckEq(input_repeats[i], output_repeats[i]) != af::TriBool::kTrue) {
GELOGD("Split node [%s], input_repeats[%zu]=%s, output_repeats[%zu]=%s, pre_size=%s.", node.GetNamePtr(), i,
input_repeats[i].Str().get(), i, output_repeats[i].Str().get(), pre_size.Str().get());
split_dim = i;
break;
}
pre_size = pre_size * input_repeats[i];
}
GELOGD("Split node [%s], split_dim is %zu", node.GetNamePtr(), split_dim);
if (split_dim == 0UL || (af::SymbolicUtils::StaticCheckEq(pre_size, af::ops::One) == af::TriBool::kTrue)) {
GE_CHK_STATUS_RET(ScheduleUtils::GetLoopAxis(node, axes_group.y_group));
axes_group.axes_order.reserve(axes_group.y_group.size());
for (size_t i = 0UL; i < axes_group.y_group.size(); ++i) {
axes_group.axes_order.push_back(i);
}
return ge::SUCCESS;
}
axes_group.axes_order.reserve(axes.size());
for (size_t i = 0UL; i < axes.size(); ++i) {
if (i < split_dim) {
axes_group.y_group.push_back(axes[i]);
axes_group.axes_order.push_back(i);
} else {
axes_group.n_group.push_back(axes[i]);
}
}
for (auto &input_view : node.outputs()) {
if (input_view->attr.axis.size() > split_dim) {
axes_group.n_group.push_back(input_view->attr.axis[split_dim]);
}
}
return ge::SUCCESS;
}
void TilingGroup::NormGroup(AxisGroup &group) {
if (group.x_group.empty()) {
group.x_group.push_back(kDefaultGroup);
}
if (group.r_group.empty()) {
group.r_group.push_back(kDefaultGroup);
}
}
Status TilingGroup::GenAxisGroupForSingleNode(af::AscNode &node, AxisGroup &axes_group, bool is_reduce_ar_fullLoad) {
static std::map<ascir::ComputeType, AxisGroupGenFunc> compute_type_to_group_gen_func = {
{af::ComputeType::kComputeElewise, TilingGroup::GenElewiseTilingGroup},
{af::ComputeType::kComputeBroadcast, TilingGroup::GenElewiseTilingGroup},
{af::ComputeType::kComputeGather, TilingGroup::GenElewiseTilingGroup},
{af::ComputeType::kComputeLoad, TilingGroup::GenElewiseTilingGroup},
{af::ComputeType::kComputeStore, TilingGroup::GenElewiseTilingGroup},
{af::ComputeType::kComputeReduce, TilingGroup::GenReduceTilingGroup},
{af::ComputeType::kComputeConcat, TilingGroup::GenConcatTilingGroup},
{af::ComputeType::kComputeTranspose, TilingGroup::GenTransposeTilingGroup},
{af::ComputeType::kComputeSplit, TilingGroup::GenSplitTilingGroup},
{af::ComputeType::kComputeCube, TilingGroup::GenElewiseTilingGroup},
};
if (node.attr.api.compute_type == af::ComputeType::kComputeReduce && is_reduce_ar_fullLoad) {
GE_CHK_STATUS_RET(GenReduceTilingGroupFullLoad(node, axes_group), "Gen tiling case failed, compute type [%u].",
static_cast<uint32_t>(node.attr.api.compute_type));
return ge::SUCCESS;
}
auto iter = compute_type_to_group_gen_func.find(node.attr.api.compute_type);
GE_ASSERT_TRUE(iter != compute_type_to_group_gen_func.end(), "Unsupported compute type [%u].",
node.attr.api.compute_type);
GE_CHK_STATUS_RET(iter->second(node, axes_group), "Gen tiling case failed, compute type [%u].",
static_cast<uint32_t>(node.attr.api.compute_type));
return ge::SUCCESS;
}
}