* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "node_cache_marker.h"
#include <algorithm>
#include <queue>
#include "schedule_utils.h"
#include "common_utils.h"
#include "axis_type_info.h"
namespace optimize::autoschedule {
af::Status NodeCacheMarker::GetAscNodeInputAttr(const af::NodePtr &node, int32_t idx, af::AscTensorAttr &attr) {
const auto &asc_node = std::dynamic_pointer_cast<af::AscNode>(node);
GE_ASSERT_NOTNULL(asc_node);
GE_ASSERT_TRUE(static_cast<uint32_t>(idx) < asc_node->GetAllInDataAnchorsSize());
GE_ASSERT_NOTNULL(asc_node->GetInDataAnchor(idx));
const auto &pre_out_anchor = asc_node->GetInDataAnchor(idx)->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(pre_out_anchor);
GE_ASSERT_NOTNULL(pre_out_anchor->GetOwnerNode());
const auto &pre_asc_node = std::dynamic_pointer_cast<af::AscNode>(pre_out_anchor->GetOwnerNode());
GE_ASSERT_NOTNULL(pre_asc_node);
GE_ASSERT_TRUE(pre_asc_node->GetOutDataNodesSize() > static_cast<uint32_t>(pre_out_anchor->GetIdx()));
attr = pre_asc_node->outputs[pre_out_anchor->GetIdx()].attr;
return ge::SUCCESS;
}
bool NodeCacheMarker::IsNodeVisited(const af::NodePtr &node) const {
return visited_nodes_.find(node) != visited_nodes_.end();
}
void NodeCacheMarker::VisitNode(const af::NodePtr &node) {
visited_nodes_.insert(node);
}
void NodeCacheMarker::AddToCacheStartSet(const af::NodePtr &node) {
cache_start_nodes_.insert(node);
}
af::ExecuteCondition NodeCacheMarker::DoesNodeNeedCache(const std::vector<int64_t> &in_axis, const std::vector<int64_t> &out_axis,
const std::vector<af::Expression> &in_repeats,
const std::vector<af::Expression> &out_repeats) const {
GE_ASSERT_EQ(in_axis.size(), out_axis.size());
GE_ASSERT_EQ(in_repeats.size(), out_repeats.size());
GE_ASSERT_EQ(in_axis.size(), in_repeats.size());
GE_ASSERT_TRUE(!out_axis.empty());
std::map<int64_t, bool> axis_id_brc_map;
for (size_t i = 0UL; i < out_axis.size(); ++i) {
axis_id_brc_map[out_axis[i]] = !ascgen_utils::ExpressEq(in_repeats[i], out_repeats[i]);
}
std::vector<int64_t> tile_in_axes_idx;
std::vector<int64_t> other_axes_idx;
tile_in_axes_idx.reserve(out_axis.size());
other_axes_idx.reserve(out_axis.size());
bool skip_last_original_axis = true;
for (auto it = out_axis.rbegin(); it != out_axis.rend(); ++it) {
const auto &axis = *it;
const auto &axis_ptr = this->graph_.FindAxis(axis);
GE_ASSERT_NOTNULL(axis_ptr, "Not found axis=(%ld).", axis);
if (axis_ptr->type != af::Axis::kAxisTypeOriginal) {
skip_last_original_axis = false;
}
if (skip_last_original_axis) {
continue;
}
if (axis_ptr->type == af::Axis::kAxisTypeTileInner) {
tile_in_axes_idx.push_back(axis);
} else {
other_axes_idx.push_back(axis);
}
}
bool is_tile_in_brc = std::any_of(tile_in_axes_idx.begin(), tile_in_axes_idx.end(),
[&axis_id_brc_map](const auto axis) { return axis_id_brc_map[axis]; });
bool is_origin_axes_brc = std::all_of(other_axes_idx.begin(), other_axes_idx.end(),
[&axis_id_brc_map](const auto axis) { return axis_id_brc_map[axis]; });
if (is_origin_axes_brc && is_tile_in_brc) {
return af::ExecuteCondition::kCacheBlockSplitOriginBroadcastAxis;
}
if (is_tile_in_brc) {
return af::ExecuteCondition::kCacheBlockSplitFusedBroadcastAxis;
}
return af::ExecuteCondition::kNoCache;
}
af::ExecuteCondition NodeCacheMarker::DoesNodeNeedCache(const af::AscNodePtr &node) const {
GE_ASSERT_TRUE(node->GetInDataNodesSize() > 0UL);
GE_ASSERT_TRUE(node->GetOutDataNodesSize() > 0U);
std::vector<int64_t> in_axis;
std::vector<af::Expression> in_repeats;
if (ScheduleUtils::IsBroadcast(node) && ScheduleUtils::IsScalarLikeNode(node->GetInDataNodes().at(0))) {
GELOGD("Graph [%s], find scalar broadcast [%s]", graph_.GetName().c_str(), node->GetNamePtr());
in_axis = node->outputs[0].attr.axis;
const std::vector<af::Expression> all_one_repeats(in_axis.size(), af::ops::One);
in_repeats = all_one_repeats;
} else {
in_axis = node->inputs[0].attr.axis;
in_repeats = node->inputs[0].attr.repeats;
}
const auto &out_axis = node->outputs[0].attr.axis;
const auto &out_repeats = node->outputs[0].attr.repeats;
return DoesNodeNeedCache(in_axis, out_axis, in_repeats, out_repeats);
}
af::ExecuteCondition NodeCacheMarker::DoesNodeNeedCache(const af::NodePtr &node) const {
const auto &asc_node = std::dynamic_pointer_cast<af::AscNode>(node);
GE_ASSERT_NOTNULL(asc_node);
return DoesNodeNeedCache(asc_node);
}
af::ExecuteCondition NodeCacheMarker::DoesInlineNodeNeedCache(const af::NodePtr &node, int32_t brc_idx) const {
af::AscTensorAttr in_attr;
GE_ASSERT_SUCCESS(GetAscNodeInputAttr(node, brc_idx, in_attr));
af::AscTensorAttr out_attr;
GE_ASSERT_SUCCESS(GetAscNodeInputAttr(node, 1 - brc_idx, out_attr));
const auto &in_axis = in_attr.axis;
const auto &in_repeats = in_attr.repeats;
const auto &out_axis = out_attr.axis;
const auto &out_repeats = out_attr.repeats;
return DoesNodeNeedCache(in_axis, out_axis, in_repeats, out_repeats);
}
void NodeCacheMarker::MarkNodeCacheable(const af::NodePtr &node) {
const auto &asc_node = std::dynamic_pointer_cast<af::AscNode>(node);
if (asc_node != nullptr) {
asc_node->attr.sched.exec_condition = af::ExecuteCondition::kCacheBlockSplitFusedBroadcastAxis;
}
}
* 把node及node的所有父节点标记为可缓存
*/
void NodeCacheMarker::MarkNodesCacheableBottomUp(const af::AscNodePtr &node, [[maybe_unused]] const af::ExecuteCondition condition) {
std::queue<af::NodePtr> queue;
queue.push(node);
while (!queue.empty()) {
const auto &tmp_node = queue.front();
VisitNode(tmp_node);
queue.pop();
const auto &asc_tmp_node = std::dynamic_pointer_cast<af::AscNode>(tmp_node);
if (asc_tmp_node != nullptr && ScheduleUtils::IsLoad(asc_tmp_node)) {
GELOGD("Graph [%s], find Load node [%s] as cache start boundary, truncate to avoid multi-output pollution from Data node.",
graph_.GetName().c_str(), tmp_node->GetNamePtr());
MarkNodeCacheable(tmp_node);
AddToCacheStartSet(tmp_node);
continue;
}
if (tmp_node->GetInDataNodesSize() == 0UL) {
AddToCacheStartSet(tmp_node);
continue;
}
MarkNodeCacheable(tmp_node);
for (const auto &in_node : tmp_node->GetInDataNodes()) {
queue.push(in_node);
}
}
}
void NodeCacheMarker::MarkNodesCacheableBottomUp(const af::NodePtr &node, const af::ExecuteCondition condition) {
const auto &asc_node = std::dynamic_pointer_cast<af::AscNode>(node);
if (asc_node != nullptr) {
MarkNodesCacheableBottomUp(asc_node, condition);
}
}
void NodeCacheMarker::MarkNodesCacheableUpBottom(const af::NodePtr &node) {
std::queue<af::NodePtr> queue;
queue.push(node);
while (!queue.empty()) {
const auto &tmp_node = queue.front();
[[maybe_unused]] const auto exec_condition = ascgen_utils::GetNodeExecCondition(tmp_node);
queue.pop();
for (const auto &out_node : tmp_node->GetOutDataNodes()) {
if (out_node->GetInDataNodesSize() != 1UL || af::ops::IsOps<af::ascir_op::Store>(out_node)) {
continue;
}
if (!ascgen_utils::IsNodeCacheable(out_node)) {
MarkNodeCacheable(out_node);
}
queue.push(out_node);
}
}
}
af::Status NodeCacheMarker::ReverseDfsCacheNode(const af::NodePtr &ge_node) {
if (ScheduleUtils::IsIOBuffer(ge_node) || ge_node->GetInDataNodesSize() == 0UL) {
return ge::SUCCESS;
}
if (IsNodeVisited(ge_node)) {
return ge::SUCCESS;
}
VisitNode(ge_node);
const auto &node = std::dynamic_pointer_cast<af::AscNode>(ge_node);
GE_ASSERT_NOTNULL(node);
if (ScheduleUtils::IsBroadcast(node)) {
const auto condition = DoesNodeNeedCache(node);
if (condition != af::ExecuteCondition::kNoCache) {
MarkNodeCacheable(node);
MarkNodesCacheableBottomUp(node, condition);
GELOGD("Graph(%s) Broadcast(%s) supports brc cache.", graph_.GetName().c_str(), node->GetNamePtr());
return ge::SUCCESS;
}
}
for (const auto &in_node : node->GetInDataNodes()) {
GE_WARN_ASSERT(ReverseDfsCacheNode(in_node) == ge::SUCCESS);
}
return ge::SUCCESS;
}
* 标记每个节点是否需要缓存(输出),逻辑分为两个阶段:\n
* 阶段一:倒序查找
* 1. 倒序深度遍历每个模板的每个节点; \n
* 2. 若遇到RemovePad节点,其输入只有1个且为Broadcast节点,且满足缓存条件,则将RemovePad及其所有父节点标记为可缓存,
* 把RemovePad的root节点保存至StartNodesSet; \n
* 3. 若遇到Broadcast节点,其输入只有1个且满足缓存条件,则将Broadcast及其所有父节点标记为可缓存,
* 把Broadcast的root节点保存至StartNodesSet; \n
* 4.
* 若遇到两输入且其中一路存在隐式广播,则检查广播这一路是否满足缓存条件(记为A),若满足则将A及其所有父节点标记为可缓存,
* 把A的root节点保存至StartNodesSet. \n
* 阶段二:正向刷新
* 1. 遍历StartNodesSet,若其子节点输入个数为1,且未被标记缓存,则标记为可缓存。
*/
af::Status NodeCacheMarker::MarkIfNodeNeedsCache() {
if (ScheduleUtils::HasComputeType(graph_, af::ComputeType::kComputeTranspose)) {
return ge::SUCCESS;
}
std::vector<af::NodePtr> store_nodes;
for (const auto &node : graph_.GetAllNodes()) {
if (ScheduleUtils::IsStore(node)) {
store_nodes.push_back(node);
}
}
if (store_nodes.empty()) {
GELOGD("Graph(%s) has no store node, returning.", graph_.GetName().c_str());
return ge::SUCCESS;
}
visited_nodes_.clear();
cache_start_nodes_.clear();
for (const auto &node : store_nodes) {
GE_WARN_ASSERT(ReverseDfsCacheNode(node) == ge::SUCCESS);
}
for (const auto &node : cache_start_nodes_) {
MarkNodesCacheableUpBottom(node);
}
return ge::SUCCESS;
}
}