* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "recompute_case_generator.h"
#include <unordered_set>
#include <queue>
#include "schedule_utils.h"
#include "ascgraph_info_complete.h"
#include "graph/utils/graph_utils.h"
namespace {
constexpr size_t kContinuousOneAxisNum = 2UL;
constexpr size_t kRecomputeNumThreshold = 5UL;
constexpr size_t kRecomputeGapThreshold = 5UL;
const af::Symbol kContinuousOneAxisSizeThreshold = af::Symbol(256);
bool IsSupportedComputeType(af::ComputeType compute_type) {
static std::unordered_set<af::ComputeType> supported_types{
af::ComputeType::kComputeElewise, af::ComputeType::kComputeBroadcast, af::ComputeType::kComputeLoad,
af::ComputeType::kComputeStore};
return supported_types.count(compute_type) > 0UL;
}
ge::Status CopyRecomputeNode(af::Node *ori_node, af::AscGraph &graph, af::AscNodePtr &new_node) {
const auto &op_desc = af::GraphUtils::CopyOpDesc(ori_node->GetOpDesc(), nullptr);
GE_CHECK_NOTNULL(op_desc);
op_desc->SetName(ori_node->GetName() + "_recompute");
af::Operator op = af::OpDescUtils::CreateOperatorFromOpDesc(op_desc);
new_node = graph.AddNode(op);
GE_ASSERT_NOTNULL(new_node);
auto src_asc_node = std::dynamic_pointer_cast<af::AscNode>(ori_node->shared_from_this());
GE_ASSERT_NOTNULL(src_asc_node);
GE_ASSERT_TRUE(af::AscGraph::CopyAscNodeTensorAttr(src_asc_node, new_node),
"DoCopyAscNodeTensorAttr failed, node = %s[%s]", ori_node->GetNamePtr(), ori_node->GetTypePtr());
return ge::SUCCESS;
}
size_t CountComputeNodes(const std::unordered_set<af::Node *> &nodes) {
size_t count = 0UL;
for (const auto &node : nodes) {
auto node_attr = node->GetOpDescBarePtr()->GetOrCreateAttrsGroup<af::AscNodeAttr>();
if (node_attr != nullptr && node_attr->api.type == af::ApiType::kAPITypeCompute) {
++count;
}
}
return count;
}
}
namespace optimize {
Status RecomputeCaseGenerator::GeneratorTask(ascir::HintGraph &hint_graph, std::vector<ScheduleTask> &tasks,
const OptimizerOptions &options) {
(void)options;
auto graph_axis = hint_graph.GetAllAxis();
is_static_graph_ = std::all_of(graph_axis.begin(), graph_axis.end(),
[](const af::AxisPtr &axis) { return axis->size.IsConstExpr(); });
std::vector<af::AscNodePtr> potential_store;
bool has_unsupported_type{false};
for (const auto &node : hint_graph.GetAllNodes()) {
if (ScheduleUtils::IsBuffer(node)) {
continue;
}
if (!IsSupportedComputeType(node->attr.api.compute_type)) {
has_unsupported_type = true;
break;
}
if (ScheduleUtils::IsStore(node) && IsRecomputableNode(hint_graph, node.get())) {
GELOGD("Store node [%s] may be recomputable.", node->GetNamePtr());
potential_store.push_back(node);
}
}
if (has_unsupported_type || potential_store.empty()) {
return GroupPartitionAndGenTasks(hint_graph, tasks);
}
for (const auto &store_node : potential_store) {
GE_ASSERT_SUCCESS(AnalyzeSplittablePath(hint_graph, store_node));
}
MergeAllPaths();
if (result_.merged_path_nodes.empty()) {
return GroupPartitionAndGenTasks(hint_graph, tasks);
}
GE_ASSERT_SUCCESS(DoTaskGenerator(hint_graph, tasks));
return ge::SUCCESS;
}
Status RecomputeCaseGenerator::AnalyzeSplittablePath(ascir::HintGraph &hint_graph,
const af::AscNodePtr &potential_store) {
SplitResultInOnePath single_path_results;
auto out_nodes = potential_store->GetOutDataNodes();
GE_ASSERT_TRUE(out_nodes.size() == 1UL);
auto output_node = out_nodes.at(0UL).get();
single_path_results.valid_path.emplace(output_node);
single_path_results.output_nodes.emplace(output_node);
std::set<af::Node *> visited_nodes;
visited_nodes.emplace(potential_store.get());
std::queue<af::AscNode *> head_nodes;
head_nodes.emplace(potential_store.get());
while (!head_nodes.empty()) {
const auto top_node = head_nodes.front();
head_nodes.pop();
GE_ASSERT_NOTNULL(top_node);
bool current_valid = IsRecomputableNode(hint_graph, top_node);
if (!current_valid) {
GELOGD("Store node [%s] cannot be split for node:[%s].", top_node->GetNamePtr(), potential_store->GetNamePtr());
return ge::SUCCESS;
}
GELOGD("Add path node [%s] in path of split node:[%s].", potential_store->GetNamePtr(), top_node->GetNamePtr());
single_path_results.valid_path.emplace(top_node);
bool is_all_in_path = true;
for (const auto &out_node : top_node->GetOutDataNodes()) {
if (single_path_results.valid_path.count(out_node.get()) == 0UL ||
single_path_results.recompute_nodes.count(out_node.get()) > 0UL) {
is_all_in_path = false;
break;
}
}
if (!is_all_in_path) {
GELOGD("Add recompute node [%s] in path of split node:[%s].", potential_store->GetNamePtr(),
top_node->GetNamePtr());
single_path_results.recompute_nodes.emplace(top_node);
}
for (auto &in_node : top_node->GetInDataNodes()) {
auto asc_node = std::dynamic_pointer_cast<af::AscNode>(in_node);
GE_ASSERT_NOTNULL(asc_node);
if (visited_nodes.insert(asc_node.get()).second) {
head_nodes.push(asc_node.get());
}
}
}
result_.valid_path_results.emplace_back(std::move(single_path_results));
return ge::SUCCESS;
}
void RecomputeCaseGenerator::MergeAllPaths() {
for (const auto &path : result_.valid_path_results) {
result_.merged_output_nodes.insert(path.output_nodes.begin(), path.output_nodes.end());
result_.merged_recompute_nodes.insert(path.recompute_nodes.begin(), path.recompute_nodes.end());
result_.merged_path_nodes.insert(path.valid_path.begin(), path.valid_path.end());
}
auto it = result_.merged_recompute_nodes.begin();
while (it != result_.merged_recompute_nodes.end()) {
bool is_all_in_path = true;
for (const auto &out_node : (*it)->GetOutDataNodes()) {
if (result_.merged_path_nodes.count(out_node.get()) == 0 ||
result_.merged_recompute_nodes.count(out_node.get()) > 0UL) {
is_all_in_path = false;
break;
}
}
if (is_all_in_path) {
GELOGD("Node [%s]'s outputs are all in path, no need to do recompute.", (*it)->GetNamePtr());
it = result_.merged_recompute_nodes.erase(it);
} else {
++it;
}
}
GELOGD("After merging path, [%zu] path nodes, [%zu] recompute_nodes [%zu] output nodes remaining.",
result_.merged_path_nodes.size(), result_.merged_recompute_nodes.size(), result_.merged_output_nodes.size());
}
Status RecomputeCaseGenerator::DoTaskGenerator(ascir::ImplGraph &impl_graph, std::vector<ScheduleTask> &tasks) const {
if (!IsRecomputableAlwaysBetter()) {
ascir::ImplGraph origin_graph(impl_graph.GetName().c_str());
GE_ASSERT_TRUE(origin_graph.CopyFrom(impl_graph), "Failed to copy graph from [%s].", impl_graph.GetName().c_str());
GE_ASSERT_SUCCESS(GroupPartitionAndGenTasks(origin_graph, tasks));
}
auto cg = af::AscGraphUtils::GetComputeGraph(impl_graph);
GE_ASSERT_NOTNULL(cg);
cg->SetName(cg->GetName() + "_recompute_case");
GE_ASSERT_SUCCESS(DoGraphSplit(impl_graph), "Failed to split recompute graph from [%s].",
impl_graph.GetName().c_str());
GE_ASSERT_SUCCESS(GroupPartitionAndGenTasks(impl_graph, tasks));
return ge::SUCCESS;
}
Status RecomputeCaseGenerator::DoGraphSplit(af::AscGraph &graph) const {
std::unordered_map<af::Node *, af::Node *> recompute_to_new;
for (auto recompute_node : result_.merged_recompute_nodes) {
af::AscNodePtr new_node;
GE_ASSERT_SUCCESS(CopyRecomputeNode(recompute_node, graph, new_node), "Failed to copy recompute node [%s].",
recompute_node->GetNamePtr());
recompute_to_new[recompute_node] = new_node.get();
}
std::queue<af::Node *> bfs_queue;
std::unordered_set<af::Node *> visited;
for (auto &out_node : result_.merged_output_nodes) {
bfs_queue.emplace(out_node);
visited.emplace(out_node);
}
while (!bfs_queue.empty()) {
af::Node *current = bfs_queue.front();
bfs_queue.pop();
auto cur_iter = recompute_to_new.find(current);
bool is_cur_recompute = cur_iter != recompute_to_new.end();
for (const auto &in_anchor : current->GetAllInDataAnchors()) {
GE_ASSERT_NOTNULL(in_anchor);
auto peer_out = in_anchor->GetPeerOutAnchor();
if (peer_out == nullptr) {
continue;
}
auto input_node = peer_out->GetOwnerNodeBarePtr();
GE_ASSERT_NOTNULL(input_node);
auto in_iter = recompute_to_new.find(input_node);
bool is_in_recompute = in_iter != recompute_to_new.end();
if (is_cur_recompute) {
GE_ASSERT_TRUE(is_in_recompute);
GE_ASSERT_SUCCESS(af::GraphUtils::AddEdge(in_iter->second->GetOutDataAnchor(peer_out->GetIdx()),
cur_iter->second->GetInDataAnchor(in_anchor->GetIdx())));
} else {
if (is_in_recompute) {
GE_ASSERT_SUCCESS(af::GraphUtils::RemoveEdge(peer_out, in_anchor));
GE_ASSERT_SUCCESS(af::GraphUtils::AddEdge(in_iter->second->GetOutDataAnchor(peer_out->GetIdx()), in_anchor));
}
}
if (visited.count(input_node) == 0UL) {
visited.insert(input_node);
bfs_queue.push(input_node);
}
}
}
return ge::SUCCESS;
}
bool RecomputeCaseGenerator::IsRecomputableNode(ascir::HintGraph &hint_graph, af::AscNode *node) const {
if (node->attr.api.type == af::ApiType::kAPITypeBuffer) {
return true;
}
auto check_static_tensor = [&hint_graph](af::AscTensorAttr &tensor_attr) -> bool {
af::Expression axis_size_product = af::sym::kSymbolOne;
const auto &repeats = tensor_attr.repeats;
const auto &axes = tensor_attr.axis;
for (size_t i = 0UL; i < repeats.size() && i < axes.size(); ++i) {
if (af::SymbolicUtils::StaticCheckNe(repeats[i], af::sym::kSymbolOne) == af::TriBool::kTrue) {
break;
}
auto axis = hint_graph.FindAxis(axes[i]);
if (axis == nullptr || af::SymbolicUtils::StaticCheckEq(axis->size, af::sym::kSymbolOne) == af::TriBool::kTrue) {
break;
}
auto last_threshold = kContinuousOneAxisSizeThreshold / axis->size;
if (af::SymbolicUtils::StaticCheckGt(axis_size_product, last_threshold) == af::TriBool::kTrue) {
axis_size_product = kContinuousOneAxisSizeThreshold + af::sym::kSymbolOne;
break;
}
axis_size_product = axis_size_product * axis->size;
}
return af::SymbolicUtils::StaticCheckGt(axis_size_product, kContinuousOneAxisSizeThreshold) == af::TriBool::kTrue;
};
auto check_dynamic_tensor = [](af::AscTensorAttr &tensor_attr) -> bool {
size_t continuous_ones_count = 0UL;
for (const auto &repeat : tensor_attr.repeats) {
if (af::SymbolicUtils::StaticCheckEq(repeat, af::sym::kSymbolOne) == af::TriBool::kTrue) {
++continuous_ones_count;
} else {
break;
}
}
return continuous_ones_count > kContinuousOneAxisNum;
};
auto output_tensors = node->outputs();
if (is_static_graph_) {
return std::all_of(output_tensors.begin(), output_tensors.end(),
[&](af::AscTensor *&tensor) { return check_static_tensor(tensor->attr); });
}
return std::all_of(output_tensors.begin(), output_tensors.end(),
[&](af::AscTensor *&tensor) { return check_dynamic_tensor(tensor->attr); });
}
bool RecomputeCaseGenerator::IsRecomputableAlwaysBetter() const {
if (!is_static_graph_) {
return false;
}
const size_t recompute_valid_size = CountComputeNodes(result_.merged_recompute_nodes);
if (recompute_valid_size >= kRecomputeNumThreshold) {
return false;
}
const size_t path_valid_size = CountComputeNodes(result_.merged_path_nodes);
const bool is_second_condition_met = (recompute_valid_size * kRecomputeGapThreshold) < path_valid_size;
return is_second_condition_met;
}
}