* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef FAST_GRAPH_FAST_GRAPH_UTILS_H
#define FAST_GRAPH_FAST_GRAPH_UTILS_H
#include <string.h>
#include "graph/anchor.h"
#include "quick_list.h"
#include "graph/fast_graph/execute_graph.h"
#include "graph/utils/tensor_utils.h"
namespace ge {
enum class FastWalkStatus { kNotWalked, kWalking, kWalked };
struct NodeStatus {
size_t size = 0U;
FastWalkStatus status;
};
struct GraphExtendInfo {
bool is_valid_flag_ = false;
};
using QuickNode = ListElement<FastNode>;
using QuickEdge = ListElement<Edge<FastNode>>;
using QuickGraph = ListElement<ExecuteGraph *>;
class FastGraphUtils {
public:
static inline Edge<FastNode> &GetEdge(QuickEdge *const quick_edge) {
return quick_edge->data;
}
static inline FastNode *&GetEdgeSrc(QuickEdge *const quick_edge) {
return quick_edge->data.src;
}
static inline FastNode *const &GetConstEdgeSrc(const QuickEdge *const quick_edge) {
return quick_edge->data.src;
}
static inline FastNode *&GetEdgeDst(QuickEdge *const quick_edge) {
return quick_edge->data.dst;
}
static inline FastNode *const &GetConstEdgeDst(const QuickEdge *const quick_edge) {
return quick_edge->data.dst;
}
static inline int32_t &GetEdgeSrcOutput(QuickEdge *const quick_edge) {
return quick_edge->data.src_output;
}
static inline int32_t GetConstEdgeSrcOutput(const QuickEdge *const quick_edge) {
return quick_edge->data.src_output;
}
static inline int32_t &GetEdgeDstInput(QuickEdge *const quick_edge) {
return quick_edge->data.dst_input;
}
static inline int32_t GetConstEdgeDstInput(const QuickEdge *const quick_edge) {
return quick_edge->data.dst_input;
}
static inline int32_t &GetEdgeInEdgeIndex(QuickEdge *const quick_edge) {
return quick_edge->data.in_edge_index;
}
static inline int32_t &GetEdgeOutEdgeIndex(QuickEdge *const quick_edge) {
return quick_edge->data.out_edge_index;
}
static inline ExecuteGraph *GetGraph(const ListElement<ExecuteGraph *> *const quick_graph) {
return quick_graph->data;
}
static inline ComputeGraph *GetComputeGraph(const ListElement<ComputeGraph *> *const compute_graph) {
return compute_graph->data;
}
static inline FastNode &GetNode(QuickNode *const quick_node) {
return quick_node->data;
}
static inline const FastNode &GetConstNode(const QuickNode *const quick_node) {
return quick_node->data;
}
template <class T>
static inline ListMode &GetMode(ListElement<T> *const list_element) {
return list_element->mode;
}
template <class T>
static inline QuickList<T> *GetOwner(ListElement<T> *const list_element) {
return list_element->owner;
}
static inline QuickNode *GetListElementAddr(const FastNode *const fast_node) {
const auto offset = reinterpret_cast<uintptr_t>(&reinterpret_cast<QuickNode *>(0)->data);
return reinterpret_cast<QuickNode *>(reinterpret_cast<uintptr_t>(fast_node) - offset);
}
static inline QuickEdge *GetListElementAddr(const FastEdge *const edge) {
return reinterpret_cast<QuickEdge *>(reinterpret_cast<uintptr_t>(edge) - offsetof(QuickEdge, data));
}
};
template <class NODE_T>
int64_t GetNodeOutputSize(NODE_T *node, std::vector<NodeStatus> &reverse_dfs_nodes_info) {
int64_t total_size = 0LL;
if ((node == nullptr) || (node->GetOpDescBarePtr() == nullptr)) {
return total_size;
}
NodeStatus &reverse_dfs_node_info = reverse_dfs_nodes_info[static_cast<size_t>(node->GetOpDescBarePtr()->GetId())];
total_size = reverse_dfs_node_info.size;
if (total_size != 0) {
return total_size;
}
for (const auto &out_desc : node->GetOpDescBarePtr()->GetAllOutputsDescPtr()) {
if (out_desc == nullptr) {
continue;
}
int64_t output_size = 0LL;
(void)ge::TensorUtils::CalcTensorMemSize(out_desc->GetShape(), out_desc->GetFormat(), out_desc->GetDataType(),
output_size);
total_size += output_size;
}
if (total_size != 0) {
reverse_dfs_node_info.size = total_size;
}
return total_size;
}
template <class NODE_T>
struct NodeCmp {
explicit NodeCmp(std::vector<NodeStatus> *reverse_dfs_nodes_info) : reverse_dfs_nodes_info_(reverse_dfs_nodes_info) {}
bool operator()(NODE_T *lhs, NODE_T *rhs) const {
const auto lhs_size = GetNodeOutputSize(lhs, *reverse_dfs_nodes_info_);
const auto rhs_size = GetNodeOutputSize(rhs, *reverse_dfs_nodes_info_);
if (lhs_size == rhs_size) {
return strcmp(lhs->GetNamePtr(), rhs->GetNamePtr()) > 0;
}
return lhs_size > rhs_size;
}
std::vector<NodeStatus> *reverse_dfs_nodes_info_;
};
template <class NODE_T>
struct NodeOutInfo {
NodeOutInfo(NODE_T *node, std::vector<NodeStatus> *reverse_dfs_nodes_info)
: num_out_data_nodes(node->GetAllOutEdgesSize()),
output_size(GetNodeOutputSize(node, *reverse_dfs_nodes_info)),
node_name(node->GetName()) {}
bool operator<(const NodeOutInfo &rhs) const {
if (num_out_data_nodes < rhs.num_out_data_nodes) {
return true;
}
if (num_out_data_nodes > rhs.num_out_data_nodes) {
return false;
}
if (output_size < rhs.output_size) {
return true;
}
if (output_size > rhs.output_size) {
return false;
}
return node_name < rhs.node_name;
}
int64_t num_out_data_nodes;
int64_t output_size;
std::string node_name;
};
template <class NODE_T>
class TopoSortStack {
public:
explicit TopoSortStack(std::vector<NodeStatus> *reverse_dfs_nodes_info, const bool is_mem_priority = false,
const bool is_dfs = false, const bool is_reverse_dfs = false)
: is_mem_priority_(is_mem_priority),
is_dfs_(is_dfs),
is_reverse_dfs_(is_reverse_dfs),
reverse_dfs_nodes_info_(reverse_dfs_nodes_info) {}
NODE_T *Pop() {
if (is_mem_priority_ && (!is_reverse_dfs_)) {
const auto &it = mem_priority_stack_.cbegin();
NODE_T *node = it->second;
(void)mem_priority_stack_.erase(it);
return node;
}
NODE_T *node = normal_stack_.back();
normal_stack_.pop_back();
return node;
}
void Push(NODE_T *node) {
if (is_mem_priority_ && (!is_reverse_dfs_)) {
(void)mem_priority_stack_.emplace(NodeOutInfo<NODE_T>(node, reverse_dfs_nodes_info_), node);
return;
}
if (is_dfs_) {
(void)normal_stack_.insert(normal_stack_.end(), node);
} else {
(void)normal_stack_.insert(normal_stack_.begin(), node);
}
}
bool Empty() {
if (is_mem_priority_ && (!is_reverse_dfs_)) {
return mem_priority_stack_.empty();
}
return normal_stack_.empty();
}
private:
bool is_mem_priority_;
bool is_dfs_;
bool is_reverse_dfs_;
std::vector<NodeStatus> *reverse_dfs_nodes_info_;
std::list<NODE_T *> normal_stack_;
std::map<NodeOutInfo<NODE_T>, NODE_T *> mem_priority_stack_;
};
}
#endif