* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/utils/ffts_graph_utils.h"
#include <queue>
#include "graph_metadef/graph/debug/ge_util.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/debug/ge_op_types.h"
namespace {
static uint32_t g_fftsPlusSubgraphNum = 0U;
const uint32_t kMaxiumRecursionDepth = 10U;
}
namespace af {
graphStatus FftsGraphUtils::GraphPartition(ComputeGraph &graph, const std::set<NodePtr> &unsupported_nodes) {
if (unsupported_nodes.empty()) {
GELOGI("Graph:%s, no node is unsupported, skip clipping", graph.GetName().c_str());
return SUCCESS;
}
const auto &ffts_plus_graph = GetFftsPlusGraph(graph);
GE_CHECK_NOTNULL(ffts_plus_graph);
std::unordered_set<NodePtr> nodes_need_clip;
std::unordered_set<ComputeGraphPtr> graphs_need_split;
GE_CHK_STATUS_RET(CollectClipNodesAndGraphs(ffts_plus_graph, unsupported_nodes, nodes_need_clip, graphs_need_split),
"[Collect][NeedClip] nodes and subgraphs in graph %s failed", ffts_plus_graph->GetName().c_str());
if (nodes_need_clip.empty() && graphs_need_split.empty()) {
GELOGI("Graph:%s, no node/subgraph need to be clipped, skip", ffts_plus_graph->GetName().c_str());
return SUCCESS;
}
const auto &parent_node = ffts_plus_graph->GetParentNode();
GE_CHECK_NOTNULL(parent_node);
(void)parent_node->GetOpDesc()->DelAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH);
(void)graphs_need_split.emplace(ffts_plus_graph);
for (const auto &subgraph : graphs_need_split) {
if (IsGraphNeedSplit(subgraph, nodes_need_clip)) {
std::vector<std::pair<bool, std::set<NodePtr>>> split_nodes;
GE_CHK_STATUS_RET(SplitNodesWithCheck(subgraph, nodes_need_clip, split_nodes),
"[Split][Nodes] failed, graph:%s", subgraph->GetName().c_str());
GE_CHK_STATUS_RET(SplitSubgraph(subgraph, split_nodes),
"[Split][Subgraph] %s failed", subgraph->GetName().c_str());
} else {
GE_CHK_STATUS_RET(BuildFftsPlusSubgraphWithAllNodes(subgraph),
"[Build][FftsPlusSubgraph] failed, graph:%s", subgraph->GetName().c_str());
}
}
return GRAPH_SUCCESS;
}
graphStatus FftsGraphUtils::CollectClipNodesAndGraphs(const ComputeGraphPtr &graph,
const std::set<NodePtr> &unsupported_nodes,
std::unordered_set<NodePtr> &nodes_need_clip,
std::unordered_set<ComputeGraphPtr> &graphs_need_split) {
for (const auto &node : graph->GetAllNodes()) {
if (unsupported_nodes.count(node) == 0U) {
continue;
}
(void)nodes_need_clip.emplace(node);
ComputeGraphPtr cur_graph = node->GetOwnerComputeGraph();
while (cur_graph != graph) {
const auto &parent_node = cur_graph->GetParentNode();
if (parent_node == nullptr) {
break;
}
(void)nodes_need_clip.emplace(parent_node);
std::vector<ComputeGraphPtr> subgraphs;
GE_CHK_STATUS_RET(NodeUtils::GetDirectSubgraphs(parent_node, subgraphs), "[Get][Subgraphs] failed for node %s",
parent_node->GetName().c_str());
for (const auto &subgraph : subgraphs) {
(void)graphs_need_split.emplace(subgraph);
}
cur_graph = cur_graph->GetParentGraph();
}
}
return GRAPH_SUCCESS;
}
bool FftsGraphUtils::IsGraphNeedSplit(const ComputeGraphPtr &graph,
const std::unordered_set<NodePtr> &nodes_need_clip) {
for (const auto &node : graph->GetDirectNode()) {
if (nodes_need_clip.count(node) > 0U) {
return true;
}
}
return false;
}
graphStatus FftsGraphUtils::SplitNodesWithCheck(const ComputeGraphPtr &graph,
const std::unordered_set<NodePtr> &nodes_need_clip,
std::vector<std::pair<bool, std::set<NodePtr>>> &split_nodes) {
std::set<NodePtr> cur_nodes;
std::set<NodePtr> next_nodes;
for (const auto &node : graph->GetDirectNode()) {
if (node->GetInAllNodes().empty()) {
if (nodes_need_clip.count(node) == 0U) {
(void)cur_nodes.insert(node);
} else {
(void)next_nodes.insert(node);
}
}
}
std::set<NodePtr> calc_nodes;
CollectCalcNodeInSubgraph(graph, calc_nodes);
bool support_flag = false;
std::set<NodePtr> visited_nodes;
while (!(cur_nodes.empty() && next_nodes.empty())) {
const auto &is_cur_stage = [support_flag, nodes_need_clip](const NodePtr &node_ptr) -> bool {
return (support_flag == (nodes_need_clip.count(node_ptr) == 0U));
};
SplitNodes(calc_nodes, is_cur_stage, visited_nodes, cur_nodes, next_nodes);
std::set<NodePtr> cur_split_nodes;
for (const auto &cur_node : cur_nodes) {
if (calc_nodes.count(cur_node) > 0U) {
(void)cur_split_nodes.insert(cur_node);
}
}
if (!cur_split_nodes.empty()) {
split_nodes.emplace_back(support_flag, cur_split_nodes);
}
support_flag = !support_flag;
cur_nodes.clear();
std::swap(cur_nodes, next_nodes);
}
return GRAPH_SUCCESS;
}
void FftsGraphUtils::SplitNodes(const std::set<NodePtr> &calc_nodes,
const std::function<bool(const NodePtr &)> &is_cur_stage,
std::set<NodePtr> &visited_nodes,
std::set<NodePtr> &cur_nodes,
std::set<NodePtr> &next_nodes) {
visited_nodes.insert(cur_nodes.cbegin(), cur_nodes.cend());
std::queue<NodePtr> nodes;
for (const auto &node : cur_nodes) {
nodes.push(node);
}
while (!nodes.empty()) {
const auto &node = nodes.front();
nodes.pop();
if (calc_nodes.find(node) != calc_nodes.end()) {
(void)cur_nodes.insert(node);
} else {
(void)node->GetOpDesc()->DelAttr(ATTR_NAME_THREAD_SCOPE_ID);
}
(void)visited_nodes.insert(node);
for (const auto &out_node : node->GetOutAllNodes()) {
const auto &in_nodes = out_node->GetInAllNodes();
const bool all_in_node_seen = !std::any_of(in_nodes.begin(), in_nodes.end(),
[visited_nodes](const NodePtr &node_ptr) {
return visited_nodes.count(node_ptr) == 0U;
});
if (!all_in_node_seen) {
continue;
}
if (is_cur_stage(out_node)) {
(void)nodes.push(out_node);
} else {
(void)next_nodes.insert(out_node);
}
}
}
}
graphStatus FftsGraphUtils::SplitSubgraph(const ComputeGraphPtr &subgraph,
const std::vector<std::pair<bool, std::set<NodePtr>>> &split_nodes) {
for (const auto &item : split_nodes) {
if ((item.first) && (!item.second.empty())) {
const auto &subgraph_name = "FFTS_Plus_Subgraph_" + std::to_string(g_fftsPlusSubgraphNum++);
const auto &new_subgraph = GraphUtils::BuildSubgraphWithNodes(subgraph, item.second, subgraph_name);
if (new_subgraph == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "Build subgraph %s failed", subgraph_name.c_str());
GELOGE(GRAPH_FAILED, "[Build][Subgraph] %s failed", subgraph_name.c_str());
return GRAPH_FAILED;
}
GE_CHK_STATUS_RET(SetAttrForFftsPlusSubgraph(new_subgraph), "[Set][Attr] failed for ffts+ subgraph");
} else {
for (const auto &node : item.second) {
(void)node->GetOpDesc()->DelAttr(ATTR_NAME_THREAD_SCOPE_ID);
}
}
}
return GRAPH_SUCCESS;
}
graphStatus FftsGraphUtils::BuildFftsPlusSubgraphWithAllNodes(const ComputeGraphPtr &subgraph) {
GE_CHECK_NOTNULL(subgraph);
std::set<NodePtr> calc_nodes;
CollectCalcNodeInSubgraph(subgraph, calc_nodes);
const auto &subgraph_name = "FFTS_Plus_Subgraph_" + std::to_string(g_fftsPlusSubgraphNum++);
const auto &new_subgraph = GraphUtils::BuildSubgraphWithNodes(subgraph, calc_nodes, subgraph_name);
if (new_subgraph == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "Build subgraph %s failed", subgraph_name.c_str());
GELOGE(GRAPH_FAILED, "[Build][Subgraph] %s failed", subgraph_name.c_str());
return GRAPH_FAILED;
}
GE_CHK_STATUS_RET(SetAttrForFftsPlusSubgraph(new_subgraph), "[Set][Attr] failed for ffts+ subgraph");
return GRAPH_SUCCESS;
}
void FftsGraphUtils::CollectCalcNodeInSubgraph(const ComputeGraphPtr &subgraph, std::set<NodePtr> &calc_nodes) {
std::set<NodePtr> edge_nodes;
const std::set<std::string> ctrl_goto_types = { LABELSET, LABELGOTOEX, LABELSWITCHBYINDEX };
CollectEndNodeInSubgraph(subgraph, ctrl_goto_types, edge_nodes);
std::queue<NodePtr> start_nodes;
for (const auto &node : subgraph->GetDirectNode()) {
if ((node->GetType() == DATA) ||
((node->GetInAllNodes().empty()) && (ctrl_goto_types.count(node->GetType()) > 0U))) {
start_nodes.push(node);
}
}
while (!start_nodes.empty()) {
const auto &cur_node = start_nodes.front();
start_nodes.pop();
(void)edge_nodes.insert(cur_node);
for (const auto &out_node : cur_node->GetOutAllNodes()) {
if (ctrl_goto_types.count(out_node->GetType()) > 0U) {
start_nodes.push(out_node);
}
}
}
for (const auto &node : subgraph->GetDirectNode()) {
if (edge_nodes.count(node) == 0U) {
(void)calc_nodes.insert(node);
}
}
}
void FftsGraphUtils::CollectEndNodeInSubgraph(const ComputeGraphPtr &subgraph,
const std::set<std::string> &ctrl_goto_types,
std::set<NodePtr> &edge_nodes) {
const auto &net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT);
if (net_output_node == nullptr) {
return;
}
std::set<NodePtr> out_nodes;
for (const auto &in_node : net_output_node->GetInAllNodes()) {
for (const auto &out_node : in_node->GetOutAllNodes()) {
(void)out_nodes.insert(out_node);
}
}
std::queue<NodePtr> end_nodes;
end_nodes.push(net_output_node);
for (const auto &out_node : out_nodes) {
if (ctrl_goto_types.count(out_node->GetType()) > 0U) {
end_nodes.push(out_node);
}
}
while (!end_nodes.empty()) {
const auto &cur_node = end_nodes.front();
end_nodes.pop();
(void)edge_nodes.insert(cur_node);
for (const auto &out_node : cur_node->GetOutAllNodes()) {
end_nodes.push(out_node);
}
}
}
ComputeGraphPtr FftsGraphUtils::GetFftsPlusGraph(ComputeGraph &graph) {
const auto &parent_node = graph.GetParentNode();
if (parent_node == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "parent node of graph %s is null", graph.GetName().c_str());
GELOGE(GRAPH_FAILED, "[Check][Param] parent node of graph %s is null", graph.GetName().c_str());
return nullptr;
}
std::vector<ComputeGraphPtr> subgraphs;
if (NodeUtils::GetDirectSubgraphs(parent_node, subgraphs) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E18888", "Get subgraph failed, node:%s", parent_node->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Get][Subgraph] failed, node:%s", parent_node->GetName().c_str());
return nullptr;
}
if (subgraphs.size() != 1U) {
REPORT_INNER_ERR_MSG("E18888", "Number of subgraphs in parent_node:%s is %zu, graph:%s",
parent_node->GetName().c_str(), subgraphs.size(), graph.GetName().c_str());
GELOGE(GRAPH_FAILED, "[Check][Param] Number of subgraphs in parent_node:%s is %zu, graph:%s",
parent_node->GetName().c_str(), subgraphs.size(), graph.GetName().c_str());
return nullptr;
}
return subgraphs[0U];
}
graphStatus FftsGraphUtils::SetAttrForFftsPlusSubgraph(const ComputeGraphPtr &subgraph) {
const auto &parent_node = subgraph->GetParentNodeBarePtr();
if (parent_node == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "Parent node of subgraph %s is null", subgraph->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Check][Param] Parent node of subgraph %s is null", subgraph->GetName().c_str());
return GRAPH_FAILED;
}
(void)AttrUtils::SetStr(parent_node->GetOpDesc(), ATTR_NAME_FFTS_PLUS_SUB_GRAPH, subgraph->GetName().c_str());
for (const auto &node : subgraph->GetAllNodes()) {
(void)AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_THREAD_SCOPE_ID, 0);
}
return GRAPH_SUCCESS;
}
graphStatus FftsGraphUtils::GraphPartition(ComputeGraph &graph,
const CalcFunc &calc_func,
const std::vector<uint32_t> &upper_limit) {
if ((calc_func == nullptr) || upper_limit.empty()) {
GELOGI("Graph:%s, calculate function or upper_limit is empty, skip graph partition",
graph.GetName().c_str());
return SUCCESS;
}
const auto &ffts_plus_graph = GetFftsPlusGraph(graph);
GE_CHECK_NOTNULL(ffts_plus_graph);
std::map<NodePtr, std::vector<uint32_t>> node_value;
std::map<ComputeGraphPtr, std::vector<uint32_t>> graph_value;
GE_CHK_STATUS_RET(Calculate(ffts_plus_graph, calc_func, node_value, graph_value),
"[Calculate][Value] failed for graph %s", ffts_plus_graph->GetName().c_str());
if (!IsValueValid(ffts_plus_graph, upper_limit, node_value, graph_value)) {
REPORT_INNER_ERR_MSG("E18888", "Check value invalid");
GELOGE(GRAPH_FAILED, "[Check][Value] invalid");
return GRAPH_FAILED;
}
if ((graph_value.count(ffts_plus_graph) > 0U) && (graph_value[ffts_plus_graph] <= upper_limit)) {
GELOGI("Graph %s not exceed limit, skip graph partition", ffts_plus_graph->GetName().c_str());
return SUCCESS;
}
const auto &parent_node = ffts_plus_graph->GetParentNodeBarePtr();
GE_CHECK_NOTNULL(parent_node);
(void)parent_node->GetOpDesc()->DelAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH);
GE_CHK_STATUS_RET(PartitionGraphWithLimit(ffts_plus_graph, node_value, graph_value, upper_limit),
"[Partition][Graph] failed, graph:%s", ffts_plus_graph->GetName().c_str());
const auto &filter = [](const ComputeGraphPtr &graph_ptr) {
const auto &parent = graph_ptr->GetParentNodeBarePtr();
if ((parent == nullptr) || (parent->GetOpDesc() == nullptr)) {
return false;
}
if ((parent->GetType() != PARTITIONEDCALL) ||
(parent->GetOpDesc()->GetSubgraphInstanceNames().size() != 1U)) {
return false;
}
return !parent->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH);
};
GE_CHK_STATUS_RET(GraphUtils::UnfoldSubgraph(ffts_plus_graph, filter), "[Unfold][Subgraph] failed, graph:%s",
ffts_plus_graph->GetName().c_str());
return GRAPH_SUCCESS;
}
graphStatus FftsGraphUtils::Calculate(const ComputeGraphPtr &graph,
const CalcFunc &calc_func,
std::map<NodePtr, std::vector<uint32_t>> &node_value,
std::map<ComputeGraphPtr, std::vector<uint32_t>> &graph_value,
const uint32_t recursive_depth) {
if (recursive_depth >= kMaxiumRecursionDepth) {
REPORT_INNER_ERR_MSG("E18888", "param depth:%u >= %u(allow max subgraphs)", recursive_depth, kMaxiumRecursionDepth);
GELOGE(GRAPH_FAILED, "[Check][Param]exist too much subgraphs:%u > %u(allow max subgraphs)",
recursive_depth, kMaxiumRecursionDepth);
return GRAPH_FAILED;
}
GE_CHECK_NOTNULL(graph);
std::vector<uint32_t> cur_graph_value;
for (const auto &node : graph->GetDirectNode()) {
std::vector<uint32_t> cur_node_value;
if (node->GetOpDesc()->GetSubgraphInstanceNames().empty()) {
cur_node_value = calc_func(node);
} else {
cur_node_value = Calculate(node, calc_func, node_value, graph_value, recursive_depth);
if (cur_node_value.empty()) {
REPORT_INNER_ERR_MSG("E18888", "Calculate value for func node %s failed", node->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Calculate][Value] for func node %s failed", node->GetName().c_str());
return GRAPH_FAILED;
}
}
node_value[node] = cur_node_value;
if (cur_graph_value.empty()) {
cur_graph_value = cur_node_value;
} else if (cur_graph_value.size() != cur_node_value.size()) {
REPORT_INNER_ERR_MSG("E18888",
"Value size not match, value size of graph %s is %zu, "
"value size of node %s is %zu",
graph->GetName().c_str(), cur_graph_value.size(), node->GetName().c_str(),
cur_node_value.size());
GELOGE(GRAPH_FAILED, "[Check][Param] Value size not match, value size of graph %s is %zu, "
"value size of node %s is %zu", graph->GetName().c_str(), cur_graph_value.size(),
node->GetName().c_str(), cur_node_value.size());
return GRAPH_FAILED;
} else {
(void) std::transform(cur_graph_value.begin(), cur_graph_value.end(), cur_node_value.begin(),
cur_graph_value.begin(), std::plus<uint32_t>());
}
}
graph_value[graph] = cur_graph_value;
return SUCCESS;
}
std::vector<uint32_t> FftsGraphUtils::Calculate(const NodePtr &node, const CalcFunc &calc_func,
std::map<NodePtr, std::vector<uint32_t>> &node_value,
std::map<ComputeGraphPtr, std::vector<uint32_t>> &graph_value,
const uint32_t recursive_depth) {
std::vector<ComputeGraphPtr> subgraphs;
if (NodeUtils::GetDirectSubgraphs(node, subgraphs) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E18888", "Get subgraphs failed");
GELOGE(GRAPH_FAILED, "[Get][Subgraphs] failed");
return {};
}
std::vector<uint32_t> cur_node_value;
for (const auto &subgraph : subgraphs) {
if (graph_value.count(subgraph) == 0U) {
if (Calculate(subgraph, calc_func, node_value, graph_value, recursive_depth + 1U) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E18888", "Calculate value failed, graph %s, parent_node:%s", subgraph->GetName().c_str(),
node->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Calculate][Value] failed, graph %s, parent_node:%s",
subgraph->GetName().c_str(), node->GetName().c_str());
return {};
}
}
if (graph_value.find(subgraph) == graph_value.end()) {
REPORT_INNER_ERR_MSG("E18888", "Find value failed for graph %s", subgraph->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Find][Value] failed for graph %s", subgraph->GetName().c_str());
return {};
}
const auto &subgraph_value = graph_value[subgraph];
if (cur_node_value.empty()) {
cur_node_value = subgraph_value;
} else if (cur_node_value.size() != subgraph_value.size()) {
REPORT_INNER_ERR_MSG("E18888",
"Value size not match, value size of node %s is %zu, value size of subgraph %s "
"is %zu",
node->GetName().c_str(), cur_node_value.size(), subgraph->GetName().c_str(),
subgraph_value.size());
GELOGE(GRAPH_FAILED, "[Check][Param] Value size not match, value size of node %s is %zu, "
"value size of subgraph %s is %zu", node->GetName().c_str(), cur_node_value.size(),
subgraph->GetName().c_str(), subgraph_value.size());
return {};
} else {
(void) std::transform(cur_node_value.begin(), cur_node_value.end(),
subgraph_value.begin(), cur_node_value.begin(), std::plus<uint32_t>());
}
}
return cur_node_value;
}
bool FftsGraphUtils::IsValueValid(const ComputeGraphPtr &graph, const std::vector<uint32_t> &upper_limit,
const std::map<NodePtr, std::vector<uint32_t>> &node_value,
const std::map<ComputeGraphPtr, std::vector<uint32_t>> &graph_value) {
std::vector<ComputeGraphPtr> subgraphs;
if (GraphUtils::GetSubgraphsRecursively(graph, subgraphs) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E18888", "Get subgraphs failed, graph:%s", graph->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Get][Subgraphs] failed, graph:%s", graph->GetName().c_str());
return false;
}
for (const auto &subgraph : subgraphs) {
if (graph_value.count(subgraph) == 0U) {
REPORT_INNER_ERR_MSG("E18888", "Find graph value failed, graph:%s", subgraph->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Check][Param] Find graph value failed, graph:%s", subgraph->GetName().c_str());
return false;
}
std::set<NodePtr> calc_nodes;
CollectCalcNodeInSubgraph(subgraph, calc_nodes);
for (const auto &node : calc_nodes) {
if (node_value.count(node) == 0U) {
REPORT_INNER_ERR_MSG("E18888", "Find node value failed, node:%s", node->GetName().c_str());
GELOGE(GRAPH_FAILED, "[Check][Param] Find node value failed, node:%s", node->GetName().c_str());
return false;
}
}
}
const auto is_node_value_match = [upper_limit](const std::pair<NodePtr, std::vector<uint32_t>> &pair_item) {
return pair_item.second.size() != upper_limit.size();
};
if (std::find_if(node_value.begin(), node_value.end(), is_node_value_match) != node_value.end()) {
REPORT_INNER_ERR_MSG("E18888", "Node value size not match");
GELOGE(GRAPH_FAILED, "[Check][Param] Node value size not match");
return false;
}
const auto is_graph_value_match = [upper_limit](const std::pair<ComputeGraphPtr, std::vector<uint32_t>> &pair_item) {
return pair_item.second.size() != upper_limit.size();
};
if (std::find_if(graph_value.begin(), graph_value.end(), is_graph_value_match) != graph_value.end()) {
REPORT_INNER_ERR_MSG("E18888", "Graph value size not match");
GELOGE(GRAPH_FAILED, "[Check][Param] Graph value size not match");
return false;
}
return true;
}
graphStatus FftsGraphUtils::PartitionGraphWithLimit(const ComputeGraphPtr &graph,
std::map<NodePtr, std::vector<uint32_t>> &node_value,
std::map<ComputeGraphPtr, std::vector<uint32_t>> &graph_value,
const std::vector<uint32_t> &upper_limit,
const uint32_t recursive_depth) {
if (recursive_depth >= kMaxiumRecursionDepth) {
REPORT_INNER_ERR_MSG("E18888", "param depth:%u >= %u(allow max subgraphs)", recursive_depth, kMaxiumRecursionDepth);
GELOGE(GRAPH_FAILED, "[Check][Param]exist too much subgraphs:%u > %u(allow max subgraphs)",
recursive_depth, kMaxiumRecursionDepth);
return GRAPH_FAILED;
}
GE_CHECK_NOTNULL(graph);
std::set<NodePtr> calc_nodes;
CollectCalcNodeInSubgraph(graph, calc_nodes);
uint32_t split_level = 0U;
std::map<uint32_t, std::set<NodePtr>> split_nodes;
std::vector<NodePtr> exceed_single_node;
std::vector<uint32_t> cur_value;
for (const auto &node : graph->GetDirectNode()) {
if (calc_nodes.count(node) == 0U) {
(void)node->GetOpDesc()->DelAttr(ATTR_NAME_THREAD_SCOPE_ID);
continue;
}
std::vector<uint32_t> cur_node_value = node_value[node];
if (cur_value.empty()) {
cur_value = cur_node_value;
} else {
(void)std::transform(cur_value.begin(), cur_value.end(), cur_node_value.begin(), cur_value.begin(),
std::plus<uint32_t>());
}
if (cur_value <= upper_limit) {
(void)split_nodes[split_level].emplace(node);
} else {
++split_level;
if (cur_node_value > upper_limit) {
(void)exceed_single_node.emplace_back(node);
cur_value.clear();
} else {
(void)split_nodes[split_level].emplace(node);
cur_value = cur_node_value;
}
}
}
for (const auto &item : split_nodes) {
const auto &subgraph_name = "FFTS_Plus_Subgraph_" + std::to_string(g_fftsPlusSubgraphNum++);
const auto &subgraph = GraphUtils::BuildSubgraphWithNodes(graph, item.second, subgraph_name);
if (subgraph == nullptr) {
REPORT_INNER_ERR_MSG("E18888", "Build subgraph %s failed", subgraph_name.c_str());
GELOGE(GRAPH_FAILED, "[Build][Subgraph] %s failed", subgraph_name.c_str());
return GRAPH_FAILED;
}
GE_CHK_STATUS_RET(SetAttrForFftsPlusSubgraph(subgraph), "[Set][Attr] failed for ffts+ subgraph");
}
return SplitFuncNode(exceed_single_node, node_value, graph_value, upper_limit, recursive_depth);
}
graphStatus FftsGraphUtils::SplitFuncNode(const std::vector<NodePtr> exceed_single_node,
std::map<NodePtr, std::vector<uint32_t>> &node_value,
std::map<ComputeGraphPtr, std::vector<uint32_t>> &graph_value,
const std::vector<uint32_t> &upper_limit,
const uint32_t recursive_depth) {
for (const auto &node : exceed_single_node) {
(void)node->GetOpDesc()->DelAttr(ATTR_NAME_THREAD_SCOPE_ID);
std::vector<ComputeGraphPtr> subgraphs;
GE_CHK_STATUS_RET(NodeUtils::GetDirectSubgraphs(node, subgraphs), "[Get][Subgraphs] of node %s failed",
node->GetName().c_str());
for (const auto &subgraph : subgraphs) {
if (graph_value[subgraph] <= upper_limit) {
GE_CHK_STATUS_RET(BuildFftsPlusSubgraphWithAllNodes(subgraph), "[Build][FftsPlusSubgraph] failed, graph:%s ",
subgraph->GetName().c_str());
} else {
GE_CHK_STATUS_RET(PartitionGraphWithLimit(subgraph, node_value, graph_value, upper_limit, recursive_depth + 1U),
"[Partition][Subgraph] failed, graph:%s ", subgraph->GetName().c_str());
}
}
}
return GRAPH_SUCCESS;
}
}