* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GRAPH_NODE_IMPL_H_
#define GRAPH_NODE_IMPL_H_
#include "graph/node.h"
namespace af {
class Node::NodeImpl {
public:
NodeImpl() = default;
NodeImpl(const OpDescPtr &op, const ComputeGraphPtr &owner_graph);
~NodeImpl();
graphStatus Init(const NodePtr &node);
std::string GetName() const;
const char *GetNamePtr() const;
std::string GetType() const;
const char *GetTypePtr() const;
bool NodeMembersAreEqual(const NodeImpl &r_node) const;
bool NodeAnchorIsEqual(const AnchorPtr &left_anchor,
const AnchorPtr &right_anchor,
const size_t i) const;
graphStatus AddLinkFrom(const NodePtr &input_node, const NodePtr &owner_node);
graphStatus AddLinkFrom(const uint32_t &index,
const NodePtr &input_node,
const NodePtr &owner_node);
graphStatus AddLinkFrom(const uint32_t &index, const NodePtr &input_node, const uint32_t &input_node_index,
const NodePtr &owner_node);
graphStatus AddLinkFromForParse(const NodePtr &input_node, const NodePtr &owner_node);
graphStatus AddLinkFrom(const std::string &name, const NodePtr &input_node, const NodePtr &owner_node);
ComputeGraphPtr GetOwnerComputeGraph() const;
ComputeGraph *GetOwnerComputeGraphBarePtr() const;
graphStatus SetOwnerComputeGraph(const ComputeGraphPtr &graph);
graphStatus ClearOwnerGraph(const ComputeGraphPtr &graph);
Node::Vistor<InDataAnchorPtr> GetAllInDataAnchors(const ConstNodePtr &node_ptr) const;
std::vector<InDataAnchor *> GetAllInDataAnchorsPtr() const;
Node::Vistor<OutDataAnchorPtr> GetAllOutDataAnchors(const ConstNodePtr &node_ptr) const;
std::vector<OutDataAnchor *> GetAllOutDataAnchorsPtr() const;
uint32_t GetAllInDataAnchorsSize() const;
uint32_t GetAllOutDataAnchorsSize() const;
Node::Vistor<AnchorPtr> GetAllInAnchors(const ConstNodePtr &owner_node) const;
std::vector<Anchor *> GetAllInAnchorsPtr() const;
Node::Vistor<AnchorPtr> GetAllOutAnchors(const ConstNodePtr &owner_node) const;
std::vector<Anchor *> GetAllOutAnchorsPtr() const;
InDataAnchorPtr GetInDataAnchor(const int32_t idx) const;
AnchorPtr GetInAnchor(const int32_t idx) const;
AnchorPtr GetOutAnchor(const int32_t idx) const;
OutDataAnchorPtr GetOutDataAnchor(const int32_t idx) const;
InControlAnchorPtr GetInControlAnchor() const;
OutControlAnchorPtr GetOutControlAnchor() const;
Node::Vistor<NodePtr> GetInAllNodes(const ConstNodePtr &owner_node) const;
Node::Vistor<NodePtr> GetInNodes(const ConstNodePtr &owner_node) const;
std::vector<Node *> GetInNodesPtr() const;
bool IsAllInNodesSeen(const std::unordered_set<Node *> &nodes_seen) const;
Node::Vistor<NodePtr> GetInDataNodes(const ConstNodePtr &owner_node) const;
Node::Vistor<NodePtr> GetInControlNodes(const ConstNodePtr &owner_node) const;
Node::Vistor<NodePtr> GetOutDataNodes(const ConstNodePtr &owner_node) const;
std::vector<Node *> GetOutDataNodesPtr() const;
uint32_t GetOutDataNodesSize() const;
uint32_t GetOutControlNodesSize() const;
uint32_t GetOutNodesSize() const;
size_t GetInDataNodesSize() const;
size_t GetInControlNodesSize() const;
size_t GetInNodesSize() const;
Node::Vistor<NodePtr> GetOutControlNodes(const ConstNodePtr &owner_node) const;
Node::Vistor<NodePtr> GetOutNodes(const ConstNodePtr &owner_node) const;
Node::Vistor<NodePtr> GetOutAllNodes(const ConstNodePtr &owner_node) const;
std::vector<Node *> GetOutNodesPtr() const;
OpDescPtr GetOpDesc() const;
OpDesc *GetOpDescBarePtr() const;
graphStatus UpdateOpDesc(const OpDescPtr &op_desc);
Node::Vistor<std::pair<NodePtr, OutDataAnchorPtr>>
GetInDataNodesAndAnchors(const ConstNodePtr &owner_node) const;
Node::Vistor<std::pair<NodePtr, InDataAnchorPtr>>
GetOutDataNodesAndAnchors(const ConstNodePtr &owner_node) const;
void AddSendEventId(const uint32_t event_id) { send_event_id_list_.push_back(event_id); }
void AddRecvEventId(const uint32_t event_id) { recv_event_id_list_.push_back(event_id); }
const std::vector<uint32_t> &GetSendEventIdList() const { return send_event_id_list_; }
const std::vector<uint32_t> &GetRecvEventIdList() const { return recv_event_id_list_; }
void GetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) const {
fusion_input_list = fusion_input_dataflow_list_;
}
void GetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) const {
fusion_output_list = fusion_output_dataflow_list_;
}
void SetFusionInputFlowList(const kFusionDataFlowVec_t &fusion_input_list) {
fusion_input_dataflow_list_ = fusion_input_list;
}
void SetFusionOutputFlowList(const kFusionDataFlowVec_t &fusion_output_list) {
fusion_output_dataflow_list_ = fusion_output_list;
}
bool GetHostNode() const { return host_node_; }
void SetHostNode(const bool is_host) { host_node_ = is_host; }
void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; }
NodePtr GetOrigNode() { return orig_node_; }
private:
friend class NodeUtils;
friend class TuningUtils;
friend class OnnxUtils;
OpDescPtr op_;
std::weak_ptr<ComputeGraph> owner_graph_;
ComputeGraph *owner_graph_ptr_ = nullptr;
std::vector<InDataAnchorPtr> in_data_anchors_;
std::vector<OutDataAnchorPtr> out_data_anchors_;
InControlAnchorPtr in_control_anchor_;
OutControlAnchorPtr out_control_anchor_;
bool has_init_{false};
bool host_node_{false};
bool anchor_status_updated_{false};
std::vector<uint32_t> send_event_id_list_;
std::vector<uint32_t> recv_event_id_list_;
kFusionDataFlowVec_t fusion_input_dataflow_list_;
kFusionDataFlowVec_t fusion_output_dataflow_list_;
NodePtr orig_node_{nullptr};
};
}
#endif