* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GE_HYBRID_MODEL_NODE_ITEM_H_
#define GE_HYBRID_MODEL_NODE_ITEM_H_
#include <mutex>
#include <vector>
#include "framework/memory/memory_api.h"
#include "graph/node.h"
#include "graph/op_desc.h"
#include "graph/utils/node_utils.h"
#include "graph/runtime_inference_context.h"
#include "hybrid/common/tensor_value.h"
#include "hybrid/model/infer/node_shape_infer.h"
namespace ge {
namespace hybrid {
class NodeTask;
class NodeExecutor;
struct GraphExecutionContext;
class SubgraphContext;
struct FusedSubgraph {
std::map<int32_t, std::vector<GeTensorDescPtr>> input_mapping;
std::map<int32_t, GeTensorDescPtr> output_mapping;
std::vector<std::pair<std::pair<int64_t, int32_t>, std::pair<OpDesc *, int32_t>>> data_dependencies;
std::vector<NodePtr> nodes;
std::vector<int64_t> name_to_prof_indexes;
ComputeGraphPtr graph;
};
bool IsControlFlowV2Op(const std::string &op_type);
bool IsFftsGraphNode(const OpDesc &op_desc);
bool IsFftsKernelNode(const OpDesc &op_desc);
class OptionalMutexGuard {
public:
OptionalMutexGuard(std::mutex *const mutex, const std::string &name);
~OptionalMutexGuard();
private:
std::mutex *mu_{nullptr};
std::string name_;
};
struct NodeItem : NodeShapeInfer {
~NodeItem() = default;
static Status Create(const NodePtr &node_ptr, std::unique_ptr<NodeItem> &node_item);
OpDescPtr GetOpDesc() const {
return node->GetOpDesc();
}
GeTensorDescPtr MutableOutputDesc(const int32_t index) const;
Status UpdateInputDesc(const int32_t index, const GeTensorDesc &tensor_desc) const;
GeTensorDescPtr MutableInputDesc(const int32_t index) const;
Status GetInputDesc(const int32_t index, GeTensorDesc &tensor_desc) const;
Status GetCanonicalInputIndex(const uint32_t index, int32_t &canonical_index) const;
bool IsNoOp() const;
bool IsControlFlowV2Op() const {
return is_ctrl_flow_v2_op_;
}
bool IsControlFlowOp() const {
return is_ctrl_flow_op_;
}
bool IsMergeOp() const {
return is_merge_op_;
}
bool IsEnterOp() const {
return kEnterOpTypes.count(node_type) > 0U;
}
bool IsExitOp() const {
return kExitOpTypes.count(node_type) > 0U;
}
bool IsNextIterationOp() const {
return kNextIterationOpTypes.count(node_type) > 0UL;
}
bool IsHcclOp() const;
bool IsFftsSubNode() const {
return is_ffts_sub_node_;
}
bool IsInputEnterFeedNode(const NodePtr &node) const ;
void SetToDynamic();
void SetDataSend(NodeItem *const node_item, int32_t anchor_index);
void SetCtrlSend(NodeItem *const node_item, const uint32_t switch_index);
void SetMergeCtrl(NodeItem *const node_item, const uint32_t merge_index);
size_t GetMergeCtrl(const uint32_t merge_index) const;
OptionalMutexGuard MutexGuard(const std::string &name) const {
return OptionalMutexGuard(copy_mu_.get(), name + "_" + node_name);
}
std::string DebugString() const;
NodePtr node;
int64_t node_id = -1;
int32_t num_inputs = 0;
int32_t num_outputs = 0;
int32_t input_start = -1;
int32_t output_start = -1;
bool has_observer = false;
bool has_optional_inputs = false;
std::vector<MemStorageType> output_mem_types_;
std::vector<ge::NodePtr> dependents_for_shape_inference;
std::vector<ge::NodePtr> dependents_for_execution;
std::set<int32_t> to_const_output_id_list;
std::vector<std::vector<std::pair<int32_t, NodeItem *>>> outputs;
bool is_root_node_ = false;
bool is_ctrl_flow_v2_op_ = false;
bool is_ctrl_flow_op_ = false;
bool is_merge_op_ = false;
bool is_enter_active_ = false;
bool is_ffts_sub_node_ = false;
int64_t frame_index_ = -1;
int64_t parent_frame_ = -1;
std::set<const NodeItem *> root_ctrl_;
std::map<const NodeItem *, std::set<int32_t>> root_data_;
std::set<const NodeItem *> enter_ctrl_;
std::map<const NodeItem *, std::set<int32_t>> enter_data_;
std::set<const NodeItem *> data_send_;
std::map<const NodeItem *, int32_t> data_recv_;
std::set<const NodeItem *> ctrl_send_;
std::set<const NodeItem *> ctrl_recv_;
std::vector<std::set<const NodeItem *>> switch_groups_;
std::shared_ptr<NodeTask> kernel_task;
std::unique_ptr<FusedSubgraph> fused_subgraph;
const NodeExecutor *node_executor = nullptr;
std::map<int32_t, ge::NodePtr> ref_outputs;
std::map<int32_t, int32_t> reuse_inputs;
std::map<int32_t, int32_t> reuse_outputs;
int32_t num_static_input_shapes = 0;
bool is_profiling_report = false;
std::unordered_set<int32_t> skip_sufficiency_of_input_check_;
private:
explicit NodeItem(NodePtr node_ptr);
Status Init();
Status InitInputsAndOutputs();
void ResolveOptionalInputs();
Status ResolveDynamicState();
Status ResolveStaticInputsAndOutputs();
void ResolveUnknownShapeType();
void ResolveFftsPlusStatus();
Status ResolveOutputMemType();
GeTensorDescPtr DoGetInputDesc(const int32_t index) const;
std::vector<uint32_t> input_desc_indices_;
std::shared_ptr<std::mutex> copy_mu_;
mutable std::mutex mu_;
};
}
}
#endif