* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "att_utils.h"
#include "att_const_values.h"
#include <numeric>
namespace att {
namespace {
bool IsOps(const af::AscNode *node, const std::string &node_type) {
return node->GetType() == node_type;
}
void CollectOrigAxisNames(const SubAxis *dim, std::set<std::string> &orig_names) {
for (auto *orig_axis : dim->orig_axis) {
orig_names.insert(orig_axis->name);
}
}
bool IsReduceAxis(const SubAxis *dim, const Expr &repeat, const Expr &stride,
const std::vector<TensorPtr> &output_tensors) {
for (const auto &output_tensor : output_tensors) {
for (size_t j = 0; j < output_tensor->dim_info.size(); j++) {
if (output_tensor->dim_info[j]->name != dim->name) {
continue;
}
auto output_repeat = output_tensor->repeat[j];
auto output_stride = output_tensor->stride[j];
bool is_reduce = (repeat != af::sym::kSymbolOne && output_repeat == af::sym::kSymbolOne) ||
(stride != af::sym::kSymbolZero && output_stride == af::sym::kSymbolZero);
if (is_reduce) {
return true;
}
}
}
return false;
}
}
bool AttUtils::IsLoadNode(af::AscNode *node) {
GE_ASSERT_NOTNULL(node);
if (IsOps(node, kData) || IsOps(node, kWorkspace)) {
return false;
}
const auto input_size = node->inputs().size();
std::vector<size_t> indices(input_size);
std::iota(indices.begin(), indices.end(), 0U);
bool is_any_input_gm = std::any_of(indices.begin(), indices.end(), [&node](size_t id) {
const auto &input = node->inputs[id];
return (input.attr.mem.hardware == af::MemHardware::kMemHardwareGM);
});
return is_any_input_gm;
}
bool AttUtils::IsStoreNode(af::AscNode *node) {
GE_ASSERT_NOTNULL(node);
if (IsOps(node, kData) || IsOps(node, kWorkspace)) {
return false;
}
const auto output_size = node->outputs().size();
std::vector<size_t> indices(output_size);
std::iota(indices.begin(), indices.end(), 0U);
bool is_any_output_gm = std::any_of(indices.begin(), indices.end(), [&node](size_t id) {
const auto &output = node->outputs[id];
return (output.attr.mem.hardware == af::MemHardware::kMemHardwareGM);
});
return is_any_output_gm;
}
bool AttUtils::IsLoadStoreNode(af::AscNode *node) {
return IsLoadNode(node) || IsStoreNode(node);
}
bool AttUtils::IsTileSplitAxis(const AttAxisPtr &axis) {
return (axis->axis_pos == AxisPosition::INNER) && (!axis->bind_multicore);
}
bool AttUtils::GetLastTileSplitAxisName(af::AscNode &node, const af::AscGraph &graph, std::string &axis_name) {
if (node.outputs().empty()) {
return false;
}
const auto &node_attr = node.outputs[0].attr;
if (node_attr.axis.empty()) {
return false;
}
const auto &last_axis_id = node_attr.axis.back();
for (const auto &axis : graph.GetAllAxis()) {
if (axis->id == last_axis_id) {
axis_name = axis->name;
return true;
}
}
return false;
}
void AttUtils::CollectReduceAxisNames(const NodeInfo &node_info,
std::set<std::string> &reduce_axis_orig_names) {
std::set<std::string> reduce_axis_names;
GELOGD("[DFX] CollectReduceAxisNames: node_name=%s, node_type=%s",
node_info.name.c_str(), node_info.node_type.c_str());
for (const auto &tensor : node_info.inputs) {
size_t dim_size = tensor->dim_info.size();
if (tensor->repeat.size() != dim_size || tensor->stride.size() != dim_size) {
GELOGW("[DFX] CollectReduceAxisNames: tensor=%s has inconsistent sizes", tensor->ToString().c_str());
continue;
}
for (size_t i = 0; i < dim_size; i++) {
auto dim = tensor->dim_info[i];
auto repeat = tensor->repeat[i];
auto stride = tensor->stride[i];
if (IsReduceAxis(dim, repeat, stride, node_info.outputs)) {
CollectOrigAxisNames(dim, reduce_axis_orig_names);
reduce_axis_names.insert(dim->name);
}
}
}
GELOGD("[DFX] Collected reduce_axis_name:%s, reduce_axis_orig_names: %s",
std::accumulate(
reduce_axis_names.begin(), reduce_axis_names.end(), std::string(),
[](const std::string &acc, const std::string &name) { return acc.empty() ? name : acc + "," + name; })
.c_str(),
std::accumulate(
reduce_axis_orig_names.begin(), reduce_axis_orig_names.end(), std::string(),
[](const std::string &acc, const std::string &name) { return acc.empty() ? name : acc + "," + name; })
.c_str());
}
static bool IsBroadcastAxis(const SubAxis *dim, const Expr &repeat, const Expr &stride,
const std::vector<TensorPtr> &output_tensors) {
for (const auto &output_tensor : output_tensors) {
for (size_t j = 0; j < output_tensor->dim_info.size(); j++) {
if (output_tensor->dim_info[j]->name != dim->name) {
continue;
}
auto output_repeat = output_tensor->repeat[j];
auto output_stride = output_tensor->stride[j];
bool is_broadcast = (output_repeat != af::sym::kSymbolOne && repeat == af::sym::kSymbolOne) ||
(output_stride != af::sym::kSymbolZero && stride == af::sym::kSymbolZero);
if (is_broadcast) {
return true;
}
}
}
return false;
}
void AttUtils::CollectBroadcastAxisNames(const NodeInfo &node_info,
std::set<std::string> &broadcast_axis_orig_names) {
std::set<std::string> broadcast_axis_names;
GELOGD("[DFX] CollectBroadcastAxisNames: node_name=%s, node_type=%s",
node_info.name.c_str(), node_info.node_type.c_str());
for (const auto &tensor : node_info.inputs) {
size_t dim_size = tensor->dim_info.size();
if (tensor->repeat.size() != dim_size ||
tensor->stride.size() != dim_size ||
tensor->gm_stride.size() != dim_size) {
GELOGW("[DFX] CollectBroadcastAxisNames: tensor=%s has inconsistent sizes", tensor->ToString().c_str());
continue;
}
for (size_t i = 0; i < dim_size; i++) {
auto dim = tensor->dim_info[i];
auto repeat = tensor->repeat[i];
auto stride = tensor->stride[i];
auto gm_stride = tensor->gm_stride[i];
if (IsBroadcastAxis(dim, repeat, stride, node_info.outputs)) {
CollectOrigAxisNames(dim, broadcast_axis_orig_names);
broadcast_axis_names.insert(dim->name);
}
}
}
GELOGD("[DFX] Collected broadcast_axis_name:%s, broadcast_axis_orig_names: %s",
std::accumulate(
broadcast_axis_names.begin(), broadcast_axis_names.end(), std::string(),
[](const std::string &acc, const std::string &name) { return acc.empty() ? name : acc + "," + name; })
.c_str(),
std::accumulate(
broadcast_axis_orig_names.begin(), broadcast_axis_orig_names.end(), std::string(),
[](const std::string &acc, const std::string &name) { return acc.empty() ? name : acc + "," + name; })
.c_str());
}
}