* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/fast_graph/fast_node.h"
#include <cstddef>
#include <memory>
#include "common/checker.h"
#include "utils/ge_ir_utils.h"
#include "fast_graph_utils.h"
#include "graph/debug/ge_op_types.h"
#include "graph/debug/ge_util.h"
namespace ge {
namespace {
const std::vector<FastEdge *> kEmpty;
}
FastNode::FastNode() {}
FastNode::~FastNode() {}
graphStatus FastNode::Init(const OpDescPtr &op) {
opdesc_ = op;
data_in_num_ = op->GetAllInputsSize();
data_out_num_ = op->GetOutputsSize();
node_token_ = reinterpret_cast<size_t>(this);
return Reset();
}
graphStatus FastNode::Reset() {
if (extend_info_ != nullptr) {
in_data_edges_.clear();
in_control_edges_.clear();
out_data_edges_.clear();
out_control_edges_.clear();
out_data_edges_info_.per_edges_num.clear();
in_data_edges_count_ = 0UL;
in_control_edge_count_ = 0UL;
out_control_edges_count_ = 0UL;
out_data_edges_info_.total_num = 0UL;
extend_info_->Clear();
} else {
extend_info_ = ComGraphMakeUnique<ExtendInfo>();
GE_CHK_BOOL_EXEC(extend_info_ != nullptr,
REPORT_INNER_ERR_MSG("E18888", "Failed to allocate memory for extend information.");
return GRAPH_FAILED, "[Check][Param] Failed to allocate memory for extend information.");
}
extend_info_->UpdateInputSymbols(data_in_num_);
extend_info_->UpdateOutputSymbols(data_out_num_);
UpdateDataForIoNumChange();
return GRAPH_SUCCESS;
}
void FastNode::UpdateDataForIoNumChange() {
if ((out_data_edges_info_.per_edges_num.size() != data_out_num_) || (data_in_num_ != in_data_edges_.size()) ||
(data_out_num_ != out_data_edges_.size())) {
out_data_edges_info_.per_edges_num.resize(data_out_num_, 0UL);
in_data_edges_.resize(data_in_num_, nullptr);
out_data_edges_.resize(data_out_num_);
}
}
OpDescPtr FastNode::GetOpDescPtr() const {
return opdesc_;
}
OpDesc *FastNode::GetOpDescBarePtr() const {
return opdesc_.get();
}
std::string FastNode::GetType() const {
GE_CHK_BOOL_EXEC(opdesc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr");
return std::string(), "[Check][Param] original OpDesc is nullptr");
return opdesc_->GetType();
}
std::string FastNode::GetName() const {
GE_CHK_BOOL_EXEC(opdesc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr");
return std::string(), "[Check][Param] original OpDesc is nullptr");
return opdesc_->GetName();
}
const char *FastNode::GetNamePtr() const {
GE_CHK_BOOL_EXEC(opdesc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr");
return nullptr, "[Check][Param] original OpDesc is nullptr");
return opdesc_->GetNamePtr();
}
const char *FastNode::GetTypePtr() const {
GE_CHK_BOOL_EXEC(opdesc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr");
return nullptr, "[Check][Param] original OpDesc is nullptr");
return opdesc_->GetTypePtr();
}
bool FastNode::operator==(const FastNode &r_node) const {
return (IsEqual(name_, r_node.name_, "node.name") && IsEqual(node_token_, r_node.node_token_, "node.token") &&
IsEqual(opdesc_, r_node.opdesc_, "node.opdesc_") && IsEqual(opdesc_, r_node.opdesc_, "node.opdesc_") &&
IsEqual(self_ptr_, r_node.self_ptr_, "node.self_ptr_") &&
IsEqual(data_in_num_, r_node.data_in_num_, "node.data_in_num_") &&
IsEqual(data_out_num_, r_node.data_out_num_, "node.data_out_num_") &&
IsEqual(in_data_edges_, r_node.in_data_edges_, "node.in_data_edges_") &&
IsEqual(out_data_edges_, r_node.out_data_edges_, "node.out_data_edges_") &&
IsEqual(in_control_edges_, r_node.in_control_edges_, "node.in_control_edges_") &&
IsEqual(out_control_edges_, r_node.out_control_edges_, "node.out_control_edges_") &&
IsEqual(in_data_edges_count_, r_node.in_data_edges_count_, "node.in_data_edges_count_") &&
IsEqual(in_control_edge_count_, in_control_edge_count_, "node.in_control_edge_count_") &&
IsEqual(*extend_info_, *(r_node.extend_info_), "node.extend_info_") &&
IsEqual(out_data_edges_info_.total_num, r_node.out_data_edges_info_.total_num,
"node.out_data_edges_info_.total_num"));
}
graphStatus FastNode::RecordInControlEdge(FastEdge *const edge) {
edge->in_edge_index = in_control_edges_.size();
in_control_edges_.push_back(edge);
in_control_edge_count_++;
return GRAPH_SUCCESS;
}
graphStatus FastNode::RecordOutControlEdge(FastEdge *const edge) {
edge->out_edge_index = out_control_edges_.size();
out_control_edges_.push_back(edge);
out_control_edges_count_++;
return GRAPH_SUCCESS;
}
graphStatus FastNode::RecordInDataEdge(FastEdge *const edge, int32_t index) {
if (!CheckDataIndexIsValid(index, DirectionType::kDirectionInType)) {
REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of in edge.", index, data_in_num_);
GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of in edge.", index, data_in_num_);
return GRAPH_FAILED;
}
if (in_data_edges_[index] != nullptr) {
REPORT_INNER_ERR_MSG("E18888", "Failed to record edge in node [%s] for multiple input.", GetName().c_str());
GELOGE(GRAPH_FAILED, "[Record][Edge] Failed to record edge in node [%s] for multiple input.", GetName().c_str());
return GRAPH_FAILED;
}
in_data_edges_[index] = edge;
in_data_edges_count_++;
return GRAPH_SUCCESS;
}
graphStatus FastNode::RecordOutDataEdge(FastEdge *const edge, int32_t index) {
if (!CheckDataIndexIsValid(index, DirectionType::kDirectionOutType)) {
REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", index, data_out_num_);
GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", index, data_out_num_);
return GRAPH_FAILED;
}
out_data_edges_[index].push_back(edge);
out_data_edges_info_.total_num++;
out_data_edges_info_.per_edges_num[index]++;
return GRAPH_SUCCESS;
}
graphStatus FastNode::RecordEdge(FastEdge *const edge, DirectionType type) {
if (type == DirectionType::kDirectionInType) {
int32_t index = edge->dst_input;
if (index == kControlEdgeIndex) {
return RecordInControlEdge(edge);
}
edge->in_edge_index = 0;
return RecordInDataEdge(edge, index);
}
int32_t index = edge->src_output;
if (index == kControlEdgeIndex) {
return RecordOutControlEdge(edge);
}
GE_ASSERT_TRUE(static_cast<size_t>(index) < out_data_edges_.size());
edge->out_edge_index = out_data_edges_[index].size();
return RecordOutDataEdge(edge, index);
}
graphStatus FastNode::EraseInControlEdge(const FastEdge *const edge) {
GE_ASSERT_TRUE(static_cast<size_t>(edge->in_edge_index) < in_control_edges_.size());
GE_ASSERT_TRUE(in_control_edges_[edge->in_edge_index] == edge);
in_control_edges_[edge->in_edge_index] = nullptr;
in_control_edge_count_--;
return GRAPH_SUCCESS;
}
graphStatus FastNode::EraseOutControlEdge(const FastEdge *const edge) {
GE_ASSERT_TRUE(static_cast<size_t>(edge->out_edge_index) < out_control_edges_.size());
GE_ASSERT_TRUE(out_control_edges_[edge->out_edge_index] == edge);
out_control_edges_[edge->out_edge_index] = nullptr;
out_control_edges_count_--;
return GRAPH_SUCCESS;
}
graphStatus FastNode::EraseInDataEdge(const FastEdge *const edge) {
GE_ASSERT_NOTNULL(edge);
const int32_t index = edge->dst_input;
if (!CheckDataIndexIsValid(index, DirectionType::kDirectionInType)) {
REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", index, data_in_num_);
GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", index, data_in_num_);
return GRAPH_FAILED;
}
GE_ASSERT_TRUE(in_data_edges_[index] == edge);
in_data_edges_[index] = nullptr;
in_data_edges_count_--;
return GRAPH_SUCCESS;
}
graphStatus FastNode::EraseOutDataEdge(const FastEdge *const edge, int32_t index) {
if (!CheckDataIndexIsValid(index, DirectionType::kDirectionOutType)) {
REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", index, data_out_num_);
GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", index, data_out_num_);
return GRAPH_FAILED;
}
GE_ASSERT_TRUE(static_cast<size_t>(edge->out_edge_index) < out_data_edges_[index].size());
GE_ASSERT_TRUE(out_data_edges_[index][edge->out_edge_index] == edge);
out_data_edges_[index][edge->out_edge_index] = nullptr;
out_data_edges_info_.total_num--;
out_data_edges_info_.per_edges_num[index]--;
return GRAPH_SUCCESS;
}
graphStatus FastNode::EraseEdge(const FastEdge *const edge, DirectionType type) {
if (type == DirectionType::kDirectionOutType) {
int32_t index = edge->src_output;
if (index == kControlEdgeIndex) {
return EraseOutControlEdge(edge);
}
return EraseOutDataEdge(edge, index);
}
int32_t index = edge->dst_input;
if (index == kControlEdgeIndex) {
return EraseInControlEdge(edge);
}
return EraseInDataEdge(edge);
}
graphStatus FastNode::CheckAllInputParamter(DirectionType type, int32_t io_idx, int32_t cur_array_index,
int32_t replace_array_index) const {
if (io_idx < -1) {
REPORT_INNER_ERR_MSG("E18888", "The idx[%d] exceed the max capacity of in_edges.", io_idx);
GELOGE(GRAPH_FAILED, "[Check][Param] The idx[%d] exceed the max capacity of in_edges.", io_idx);
return GRAPH_FAILED;
}
size_t io_size = 0UL;
size_t edge_size = 0UL;
if (io_idx != kControlEdgeIndex) {
if (type == DirectionType::kDirectionInType) {
io_size = data_in_num_;
} else if (type == DirectionType::kDirectionOutType) {
io_size = data_out_num_;
}
if (io_size <= static_cast<size_t>(io_idx)) {
REPORT_INNER_ERR_MSG("E18888", "The idx [%d] exceed the max capacity [%zu] of in_edges.", io_idx, io_size);
GELOGE(GRAPH_FAILED, "[Check][Param] The idx [%d] exceed the max capacity [%zu] of in_edges.", io_idx, io_size);
return GRAPH_FAILED;
}
}
if (io_idx == kControlEdgeIndex) {
if (type == DirectionType::kDirectionInType) {
edge_size = in_control_edges_.size();
} else if (type == DirectionType::kDirectionOutType) {
edge_size = out_control_edges_.size();
}
} else {
if (type == DirectionType::kDirectionInType) {
edge_size = 1;
} else if (type == DirectionType::kDirectionOutType) {
edge_size = out_data_edges_[io_idx].size();
}
}
if ((edge_size <= static_cast<size_t>(replace_array_index)) || (edge_size <= static_cast<size_t>(cur_array_index))) {
REPORT_INNER_ERR_MSG("E18888",
"The replace index [%d] or current index [%d] exceed the max capacity [%zu] of in_edges.",
replace_array_index, cur_array_index, edge_size);
GELOGE(GRAPH_FAILED,
"[Check][Param] The replace index [%d] or current index [%d] exceed the max capacity [%zu] of in_edges.",
replace_array_index, cur_array_index, edge_size);
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
graphStatus FastNode::MoveEdge(DirectionType type, int32_t io_idx, int32_t cur_array_index,
int32_t replace_array_index) {
auto ret = CheckAllInputParamter(type, io_idx, cur_array_index, replace_array_index);
if (ret != GRAPH_SUCCESS) {
return ret;
}
if (type == DirectionType::kDirectionInType) {
FastEdge *edge = nullptr;
if (io_idx == kControlEdgeIndex) {
in_control_edges_[replace_array_index] = in_control_edges_[cur_array_index];
in_control_edges_[cur_array_index] = nullptr;
edge = in_control_edges_[replace_array_index];
} else {
}
if (edge != nullptr) {
edge->in_edge_index = replace_array_index;
}
} else if (type == DirectionType::kDirectionOutType) {
FastEdge *edge = nullptr;
if (io_idx == kControlEdgeIndex) {
out_control_edges_[replace_array_index] = out_control_edges_[cur_array_index];
out_control_edges_[cur_array_index] = nullptr;
edge = out_control_edges_[replace_array_index];
} else {
out_data_edges_[io_idx][replace_array_index] = out_data_edges_[io_idx][cur_array_index];
out_data_edges_[io_idx][cur_array_index] = nullptr;
edge = out_data_edges_[io_idx][replace_array_index];
}
if (edge != nullptr) {
edge->out_edge_index = replace_array_index;
}
}
return GRAPH_SUCCESS;
}
size_t FastNode::GetAllInEdgeSize() const {
return in_control_edge_count_ + in_data_edges_count_;
}
const std::vector<Edge<FastNode> *> &FastNode::GetAllInDataEdgesRef() const {
return in_data_edges_;
}
std::vector<Edge<FastNode> *> &FastNode::MutableAllInDataEdges() {
return in_data_edges_;
}
const std::vector<Edge<FastNode> *> &FastNode::GetAllInControlEdgesRef() const {
return in_control_edges_;
}
const std::vector<Edge<FastNode> *> &FastNode::GetAllOutControlEdgesRef() const {
return out_control_edges_;
}
const std::vector<std::vector<Edge<FastNode> *>> &FastNode::GetAllOutDataEdgesRef() const {
return out_data_edges_;
}
std::vector<Edge<FastNode> *> FastNode::GetAllInDataEdges() const {
std::vector<FastEdge *> tmp;
tmp.reserve(in_data_edges_count_);
std::for_each(in_data_edges_.begin(), in_data_edges_.end(), [&tmp](FastEdge *edge) {
if (edge != nullptr) {
tmp.push_back(edge);
}
});
return tmp;
}
std::vector<Edge<FastNode> *> FastNode::GetAllInControlEdges() const {
std::vector<FastEdge *> tmp;
tmp.reserve(in_control_edge_count_);
std::for_each(in_control_edges_.begin(), in_control_edges_.end(), [&tmp](FastEdge *edge) {
if (edge != nullptr) {
tmp.push_back(edge);
}
});
return tmp;
}
std::vector<Edge<FastNode> *> FastNode::GetAllOutControlEdges() const {
std::vector<FastEdge *> tmp;
tmp.reserve(out_control_edges_count_);
std::for_each(out_control_edges_.begin(), out_control_edges_.end(), [&tmp](FastEdge *edge) {
if (edge != nullptr) {
tmp.push_back(edge);
}
});
return tmp;
}
std::vector<Edge<FastNode> *> FastNode::GetAllOutDataEdges() const {
std::vector<FastEdge *> tmp;
tmp.reserve(out_data_edges_info_.total_num);
for (size_t i = 0UL; i < out_data_edges_.size(); i++) {
std::for_each(out_data_edges_[i].begin(), out_data_edges_[i].end(), [&tmp](FastEdge *edge) {
if (edge != nullptr) {
tmp.push_back(edge);
}
});
}
return tmp;
}
inline bool FastNode::CheckDataIndexIsValid(int32_t index, DirectionType type) const {
if (type == DirectionType::kDirectionOutType) {
return ((index >= 0) && (index < static_cast<int32_t>(data_out_num_)));
} else if (type == DirectionType::kDirectionInType) {
return ((index >= 0) && (index < static_cast<int32_t>(data_in_num_)));
}
return false;
}
bool FastNode::OutNodesIsEmpty() const {
return (out_data_edges_info_.total_num + out_control_edges_count_ == 0);
}
size_t FastNode::GetAllOutEdgesSize() const {
return out_control_edges_count_ + out_data_edges_info_.total_num;
}
size_t FastNode::GetAllOutDataEdgesSize() const {
return out_data_edges_info_.total_num;
}
size_t FastNode::GetAllOutControlEdgesSize() const {
return out_control_edges_count_;
}
size_t FastNode::GetAllInDataEdgesSize() const {
return in_data_edges_count_;
}
size_t FastNode::GetAllInControlEdgesSize() const {
return in_control_edge_count_;
}
std::vector<FastNode *> FastNode::GetAllOutNodes() const {
std::vector<FastNode *> tmp;
tmp.reserve(out_control_edges_count_ + out_data_edges_info_.total_num);
for (size_t i = 0UL; i < out_data_edges_.size(); i++) {
std::for_each(out_data_edges_[i].begin(), out_data_edges_[i].end(), [&tmp](FastEdge *edge) {
if (edge != nullptr) {
tmp.push_back(edge->dst);
}
});
}
std::for_each(out_control_edges_.begin(), out_control_edges_.end(), [&tmp](FastEdge *edge) {
if (edge != nullptr) {
tmp.push_back(edge->dst);
}
});
return tmp;
}
std::vector<FastNode *> FastNode::GetAllInNodes() const {
std::vector<FastNode *> tmp;
tmp.reserve(in_control_edge_count_ + in_data_edges_count_);
std::for_each(in_data_edges_.begin(), in_data_edges_.end(), [&tmp](FastEdge *edge) {
if (edge != nullptr) {
tmp.push_back(edge->src);
}
});
std::for_each(in_control_edges_.begin(), in_control_edges_.end(), [&tmp](FastEdge *edge) {
if (edge != nullptr) {
tmp.push_back(edge->src);
}
});
return tmp;
}
std::vector<FastNode *> FastNode::GetInDataNodes() const {
std::vector<FastNode *> in_data_nodes;
in_data_nodes.reserve(in_data_edges_count_);
auto &ref = GetAllInDataEdgesRef();
std::for_each(ref.begin(), ref.end(), [&in_data_nodes](FastEdge *edge) {
if (edge != nullptr) {
in_data_nodes.push_back(edge->src);
}
});
return in_data_nodes;
}
std::vector<FastNode *> FastNode::GetOutDataNodesByIndex(int32_t index) const {
std::vector<FastNode *> out_data_nodes;
if (!CheckDataIndexIsValid(index, DirectionType::kDirectionOutType)) {
REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", index, data_out_num_);
GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", index, data_out_num_);
return out_data_nodes;
}
out_data_nodes.reserve(out_data_edges_info_.per_edges_num[index]);
auto &ref = GetOutEdgesRefByIndex(index);
std::for_each(ref.begin(), ref.end(), [&out_data_nodes](FastEdge *edge) {
if (edge != nullptr) {
out_data_nodes.push_back(edge->dst);
}
});
return out_data_nodes;
}
std::vector<FastNode *> FastNode::GetOutDataNodes() const {
std::vector<FastNode *> out_nodes;
out_nodes.reserve(out_data_edges_info_.total_num);
for (size_t i = 0UL; i < out_data_edges_.size(); i++) {
std::for_each(out_data_edges_[i].begin(), out_data_edges_[i].end(), [&out_nodes](FastEdge *edge) {
if (edge != nullptr) {
out_nodes.push_back(edge->dst);
}
});
}
return out_nodes;
}
std::vector<FastNode *> FastNode::GetOutControlNodes() const {
std::vector<FastNode *> out_ctrl_nodes;
out_ctrl_nodes.reserve(out_control_edges_count_);
for (const auto &edge : out_control_edges_) {
if (edge != nullptr) {
out_ctrl_nodes.push_back(edge->dst);
}
}
return out_ctrl_nodes;
}
std::vector<FastNode *> FastNode::GetInControlNodes() const {
std::vector<FastNode *> in_ctrl_nodes;
in_ctrl_nodes.reserve(in_control_edge_count_);
for (const auto &edge : in_control_edges_) {
if (edge != nullptr) {
in_ctrl_nodes.push_back(edge->src);
}
}
return in_ctrl_nodes;
}
size_t FastNode::GetNodeToken() const {
return node_token_;
}
size_t FastNode::GetInEdgesSizeByIndex(int32_t idx) const {
if (idx == kControlEdgeIndex) {
return in_control_edge_count_;
}
if (!CheckDataIndexIsValid(idx, DirectionType::kDirectionInType)) {
REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", idx, data_in_num_);
GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", idx, data_in_num_);
return 0UL;
}
if (in_data_edges_[idx] != nullptr) {
return 1UL;
}
return 0UL;
}
size_t FastNode::GetOutEdgesSizeByIndex(int32_t idx) const {
if (idx == kControlEdgeIndex) {
return out_control_edges_count_;
}
if (!CheckDataIndexIsValid(idx, DirectionType::kDirectionOutType)) {
REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", idx + 1, data_in_num_);
GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", idx + 1, data_in_num_);
return 0UL;
}
if (out_data_edges_info_.per_edges_num.size() <= static_cast<size_t>(idx)) {
return 0UL;
}
return out_data_edges_info_.per_edges_num[idx];
}
Edge<FastNode> *FastNode::GetInDataEdgeByIndex(int32_t idx) const {
if (!CheckDataIndexIsValid(idx, DirectionType::kDirectionInType)) {
REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of in edge.", idx, data_in_num_);
GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of in edge.", idx, data_in_num_);
return nullptr;
}
return in_data_edges_[idx];
}
bool FastNode::IsDirectlyControlledByNode(FastNode const *node) const {
for (const auto in_ctrl_edge : in_control_edges_) {
if ((in_ctrl_edge != nullptr) && (in_ctrl_edge->src != nullptr) && (in_ctrl_edge->src == node)) {
return true;
}
}
return false;
}
std::vector<Edge<FastNode> *> FastNode::GetOutEdgesByIndex(int32_t idx) const {
if (idx == kControlEdgeIndex) {
return out_control_edges_;
}
if (!CheckDataIndexIsValid(idx, DirectionType::kDirectionOutType)) {
REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", idx, data_out_num_);
GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", idx, data_out_num_);
return std::vector<Edge<FastNode> *>{};
}
std::vector<FastEdge *> tmp;
tmp.reserve(out_data_edges_info_.per_edges_num[idx]);
for (size_t i = 0UL; i < out_data_edges_[idx].size(); i++) {
auto edge = out_data_edges_[idx][i];
if (edge != nullptr) {
tmp.push_back(edge);
}
}
return tmp;
}
const std::vector<Edge<FastNode> *> &FastNode::GetOutEdgesRefByIndex(int32_t idx) const {
if (idx == kControlEdgeIndex) {
return out_control_edges_;
}
if (!CheckDataIndexIsValid(idx, DirectionType::kDirectionOutType)) {
REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", idx, data_out_num_);
GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", idx, data_out_num_);
return kEmpty;
}
return out_data_edges_[idx];
}
graphStatus FastNode::ModifySizeByNodeType(const FastEdge *const fast_edge, size_t &in_edge_size) const {
if ((fast_edge != nullptr) && (fast_edge->src != nullptr)) {
auto type = fast_edge->src->GetType();
if ((strcmp(type.c_str(), NEXTITERATION) == 0) || (strcmp(type.c_str(), REFNEXTITERATION) == 0)) {
GE_IF_BOOL_EXEC(in_edge_size == 0UL,
GELOGE(GRAPH_FAILED, "[Check][Param] If [in_edge_size = 0], the result will be reversed");
return GRAPH_FAILED);
in_edge_size--;
}
}
return GRAPH_SUCCESS;
}
size_t FastNode::GetInEdgeSize() const {
size_t in_edge_size = GetAllInEdgeSize();
auto &edges = GetAllInDataEdgesRef();
for (size_t i = 0UL; i < edges.size(); i++) {
auto edge = edges[i];
if (edge == nullptr) {
continue;
}
auto ret = ModifySizeByNodeType(edge, in_edge_size);
if (ret != GRAPH_SUCCESS) {
return 0;
}
}
return in_edge_size;
}
void FastNode::RemoveAllEdge(std::function<void(Edge<FastNode> *)> const &remove_edge_func) {
for (size_t i = 0UL; i < in_data_edges_.size(); ++i) {
auto edge = in_data_edges_[i];
if (edge != nullptr) {
remove_edge_func(edge);
}
}
for (size_t i = 0UL; i < in_control_edges_.size(); ++i) {
auto edge = in_control_edges_[i];
if (edge != nullptr) {
remove_edge_func(edge);
}
}
for (size_t i = 0UL; i < out_data_edges_.size(); ++i) {
for (size_t j = 0UL; j < out_data_edges_[i].size(); ++j) {
auto edge = out_data_edges_[i][j];
if (edge != nullptr) {
remove_edge_func(edge);
}
}
}
for (size_t i = 0UL; i < out_control_edges_.size(); ++i) {
auto edge = out_control_edges_[i];
if (edge != nullptr) {
remove_edge_func(edge);
}
}
return;
}
size_t FastNode::GetDataInNum() const {
return data_in_num_;
}
size_t FastNode::GetDataOutNum() const {
return data_out_num_;
}
void FastNode::UpdateDataInNum(size_t new_num) {
data_in_num_ = new_num;
UpdateDataForIoNumChange();
extend_info_->UpdateInputSymbols(data_in_num_);
}
void FastNode::UpdateDataOutNum(size_t new_num) {
data_out_num_ = new_num;
UpdateDataForIoNumChange();
extend_info_->UpdateOutputSymbols(data_out_num_);
}
void FastNode::SetNodePtr(const std::shared_ptr<Node> &node) {
self_ptr_ = node;
node_bare_ptr_ = node.get();
}
void FastNode::ClearNodePtr() {
self_ptr_ = nullptr;
}
void FastNode::ClearNodeBarePtr() {
node_bare_ptr_ = nullptr;
}
std::shared_ptr<Node> FastNode::GetNodePtr() const {
if (self_ptr_ != nullptr) {
return self_ptr_;
}
if (node_bare_ptr_ != nullptr) {
return node_bare_ptr_->shared_from_this();
}
return nullptr;
}
Node *FastNode::GetNodeBarePtr() const {
return node_bare_ptr_;
}
void FastNode::UpdateOpDesc(const OpDescPtr &new_opdesc) {
if (new_opdesc == nullptr) {
opdesc_.reset();
return;
}
opdesc_ = new_opdesc;
}
ExtendInfo *FastNode::GetExtendInfo() const {
return extend_info_.get();
}
void ExtendInfo::Clear() {
execute_graph_ = nullptr;
output_index_.clear();
input_index_ = kControlEdgeIndex;
host_node_ = false;
input_symbols_.clear();
output_symbols_.clear();
}
void ExtendInfo::SetInputIndex(int32_t idx) {
input_index_ = idx;
}
int32_t ExtendInfo::GetInputIndex() const {
return input_index_;
}
void ExtendInfo::AddOneOutputIndex(int32_t idx) {
output_index_.push_back(idx);
}
std::vector<int32_t> &ExtendInfo::GetOutputIndex() {
return output_index_;
}
ExecuteGraph *ExtendInfo::GetOwnerGraphBarePtr() const {
return execute_graph_;
}
graphStatus ExtendInfo::SetOwnerGraph(ExecuteGraph *const graph, const FastNode *const fast_node) {
if ((execute_graph_ != nullptr) && (graph != execute_graph_)) {
auto quick_node = FastGraphUtils::GetListElementAddr(fast_node);
auto owner = quick_node->owner;
auto mode = quick_node->mode;
if ((owner != nullptr) && (mode == ListMode::kFreeMode)) {
owner->erase(quick_node);
}
}
execute_graph_ = graph;
return GRAPH_SUCCESS;
}
bool ExtendInfo::operator==(const ExtendInfo &r_info) const {
return (IsEqual(execute_graph_, r_info.execute_graph_, "node.execute_graph_") &&
IsEqual(input_index_, r_info.input_index_, "node.input_index_") &&
IsEqual(output_index_, r_info.output_index_, "node.output_index_"));
}
bool ExtendInfo::GetHostNode() const {
return host_node_;
}
void ExtendInfo::SetHostNode(const bool is_host) {
host_node_ = is_host;
}
void ExtendInfo::UpdateInputSymbols(size_t data_in_num) {
input_symbols_.resize(data_in_num, kInvalidSymbol);
}
void ExtendInfo::UpdateOutputSymbols(size_t data_out_num) {
output_symbols_.resize(data_out_num, kInvalidSymbol);
}
graphStatus ExtendInfo::SetInputSymbol(size_t idx, uint64_t symbol) {
if (!IsDataIndexValid(idx, input_symbols_)) {
return GRAPH_FAILED;
}
input_symbols_[idx] = symbol;
return GRAPH_SUCCESS;
}
graphStatus ExtendInfo::SetOutputSymbol(size_t idx, uint64_t symbol) {
if (!IsDataIndexValid(idx, output_symbols_)) {
return GRAPH_FAILED;
}
output_symbols_[idx] = symbol;
return GRAPH_SUCCESS;
}
uint64_t ExtendInfo::GetInputSymbol(size_t idx) {
if (!IsDataIndexValid(idx, input_symbols_)) {
return kInvalidSymbol;
}
return input_symbols_[idx];
}
uint64_t ExtendInfo::GetOutputSymbol(size_t idx) {
if (!IsDataIndexValid(idx, output_symbols_)) {
return kInvalidSymbol;
}
return output_symbols_[idx];
}
bool ExtendInfo::IsDataIndexValid(size_t idx, const std::vector<uint64_t> &symbols) const {
GE_ASSERT_TRUE(idx < symbols.size(), "The index [%zu] exceeds the size [%zu] of symbols.", idx, symbols.size());
return true;
}
}