* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef AIR_COMPILER_PNE_DATA_FLOW_GRAPH_DATA_FLOW_GRAPH_H
#define AIR_COMPILER_PNE_DATA_FLOW_GRAPH_DATA_FLOW_GRAPH_H
#include <map>
#include <string>
#include <vector>
#include <mutex>
#include "framework/common/debug/log.h"
#include "graph/compute_graph.h"
#include "dflow/inc/data_flow/model/flow_model.h"
#include "proto/dflow.pb.h"
#include "common/thread_pool/thread_pool.h"
namespace ge {
constexpr uint32_t kDataFlowGraphThreadPoolSize = 8U;
class DataFlowGraph {
public:
explicit DataFlowGraph(const ComputeGraphPtr &compute_graph,
const std::string &data_flow_scope = "",
const bool enable_cache = false,
const bool cache_manual_check = false,
const uint32_t data_flow_depth = 0U)
: root_graph_(compute_graph),
data_flow_scope_(data_flow_scope),
data_flow_depth_(data_flow_depth),
enable_cache_(enable_cache),
cache_manual_check_(cache_manual_check),
thread_pool_("df_pre_load_", kDataFlowGraphThreadPoolSize, true) {}
~DataFlowGraph() = default;
Status Initialize();
const ComputeGraphPtr &GetRootGraph() const {
return root_graph_;
}
const std::string &GetName() const {
return graph_name_;
}
const std::map<std::string, ComputeGraphPtr> &GetAllSubgraphs() const {
return subgraphs_;
}
const std::map<std::string, std::vector<ComputeGraphPtr>> &GetNodeSubgraphs() const {
return node_subgraphs_;
}
const std::map<std::string, std::vector<std::pair<ComputeGraphPtr, uint32_t>>> &GetNodesInputs() const {
return nodes_inputs_;
}
const std::map<std::string, std::vector<std::pair<ComputeGraphPtr, uint32_t>>> &GetNodesOutputs() const {
return nodes_outputs_;
}
const std::vector<std::string> &GetInvokeKeys(const std::string &graph_name) const;
bool InvokedByBuiltIn(const std::string &invoke_key) const;
const std::map<std::string, std::string> &GetGraphBuildOptions(const std::string &graph_name) const;
const std::string &GetInvokedGraphKey(const std::string &graph_name) const;
const std::string &GetInvokedKeyOriginName(const std::string &invoke_key) const;
const std::map<std::string, FlowModelPtr> &GetAllLoadedModels() const {
return loaded_models_;
}
const std::map<std::string, std::vector<FlowModelPtr>> &GetNodeLoadedModels() const {
return node_loaded_models_;
}
Status AddLoadedModel(const std::string &node_name, const std::string &graph_name, const FlowModelPtr &model);
bool IsInvokedGraph(const std::string &graph_name) const;
bool EnableCache() const {
return enable_cache_;
}
bool CacheManualCheck() const {
return cache_manual_check_;
}
bool IsRootDataFlow() const {
return data_flow_scope_.empty();
}
const std::string &GetDataFlowScope() const {
return data_flow_scope_;
}
uint32_t GetDataFlowDepth() const {
return data_flow_depth_;
}
Status CommitPreprocessTask(const std::string &name, std::function<Status()> &task);
Status GetInvokedModelFusionAttrs(const std::vector<std::string> &invoke_keys,
std::string &invoked_model_attrs) const;
private:
friend class ProcessPointLoader;
Status CheckGraph() const;
Status MapNodeInputsAndOutputs(const NodePtr &node, const dataflow::ProcessPoint &process_point);
Status MapNodeInputs(const NodePtr &node, const dataflow::ProcessPoint &process_point);
Status MapNodeOutputs(const NodePtr &node, const dataflow::ProcessPoint &process_point);
void GetInOrOutIndex(const std::vector<std::pair<ComputeGraphPtr, uint32_t>> &vec, size_t &index) const;
Status UpdateInputsFlowAttrs(const NodePtr &node);
Status CheckFlowNode(const NodePtr &node) const;
Status CheckAlignAttrs(bool &align_enable) const;
Status CheckAndFixDataFlowAttrs() const;
Status InitializeFlowNode(const NodePtr &node);
Status WaitPreprocessTaskFinish();
static bool NeedSkip(const std::string &op_type);
ComputeGraphPtr root_graph_;
std::string data_flow_scope_;
uint32_t data_flow_depth_;
std::string graph_name_;
std::map<std::string, ComputeGraphPtr> subgraphs_;
std::map<std::string, std::vector<ComputeGraphPtr>> node_subgraphs_;
std::map<std::string, std::vector<std::pair<ComputeGraphPtr, uint32_t>>> nodes_inputs_;
std::map<std::string, std::vector<std::pair<ComputeGraphPtr, uint32_t>>> nodes_outputs_;
std::map<std::string, std::vector<std::string>> invokes_;
std::map<std::string, std::string> invoked_keys_;
std::map<std::string, std::string> invoked_graphs_;
std::map<std::string, bool> invoked_by_built_in_;
std::map<std::string, std::string> invoke_origins_;
std::map<std::string, std::map<std::string, std::string>> graphs_build_options_;
std::mutex loaded_models_mt_;
std::map<std::string, FlowModelPtr> loaded_models_;
std::map<std::string, std::vector<FlowModelPtr>> node_loaded_models_;
bool enable_cache_;
bool cache_manual_check_;
ThreadPool thread_pool_;
std::map<std::string, std::future<Status>> preprocess_tasks_;
};
}
#endif