* -------------------------------------------------------------------------
* This file is part of the MultimodalSDK project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MultimodalSDK is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
* @Description:
* @Version: 1.0
* @Date: 2025-2-17 17:00:00
* @LastEditors: dev
* @LastEditTime: 2025-2-17 17:00:00
*/
#ifndef ACCDATA_SRC_CPP_PIPELINE_GRAPH_GRAPH_H_
#define ACCDATA_SRC_CPP_PIPELINE_GRAPH_GRAPH_H_
#include <cinttypes>
#include <memory>
#include <vector>
#include <set>
#include "operator/operator.h"
#include "operator/op_spec.h"
namespace acclib {
namespace accdata {
using OpNodeId = uint32_t;
using DataNodeId = uint32_t;
struct OpNode {
std::unique_ptr<Operator> op;
std::set<OpNodeId> parents;
std::set<OpNodeId> children;
std::vector<DataNodeId> inputs;
std::vector<DataNodeId> outputs;
};
struct DataNode {
std::string name;
OpNodeId producer;
std::vector<OpNodeId> consumers;
};
class Graph {
public:
* @brief Add graph node
*/
AccDataErrorCode AddNode(const OpSpec &spec);
* @brief Build based on desired outputs.
*
* @param [in] outputs Output names.
* @param [in] enableFusion Whether enable fusion or not
*/
AccDataErrorCode Build(const std::vector<std::string> &outputs, bool enableFusion);
uint64_t NumOpNode() const
{
return mOpNodes.size();
}
uint64_t NumDataNode() const
{
return mDataNodes.size();
}
* @brief Get the OpNode.
*
* @param [in] id OpNode id, that is in [0, NumOpNode()).
* @return const OpNode&
*/
const OpNode &GetOpNode(OpNodeId id, AccDataErrorCode &errCode) const
{
errCode = AccDataErrorCode::H_PIPELINE_ERROR;
if (id < NumOpNode()) {
errCode = AccDataErrorCode::H_OK;
return mOpNodes[id];
}
return opNode;
}
* @brief Get the DataNode.
*
* @param [in] id OpNode id, that is in [0, NumDataNode()).
* @return const DataNode&
*/
const DataNode &GetDataNode(DataNodeId id, AccDataErrorCode &errCode) const
{
errCode = AccDataErrorCode::H_PIPELINE_ERROR;
if (id < NumDataNode()) {
errCode = AccDataErrorCode::H_OK;
return mDataNodes[id];
}
return dataNode;
}
* @brief Get the DataNode ID by name.
*
* @param name
* @return const DataNode&
*/
AccDataErrorCode GetDataNodeId(const std::string &name, DataNodeId &dataNodeId) const
{
auto it = mName2DataNode.find(name);
if (it == mName2DataNode.end()) {
ACCDATA_ERROR("Name is not a DataNode.");
return AccDataErrorCode::H_PIPELINE_ERROR;
}
dataNodeId = it->second;
return AccDataErrorCode::H_OK;
}
* @brief Get the ids of DataNode that is specified as outputs during Build().
*/
const std::vector<DataNodeId> &GetOutputs() const
{
return mOutputs;
}
* @brief Get stringed graph.
*/
std::string ToString();
private:
* Find the path produces the outputs
* @param [in] outputs outputs names
* @param [in] throwIfErr whether throw exception when encounter error or not.
* @return a vector of OpNodeIds denote all the paths produce the outputs
*/
AccDataErrorCode PathToOutputs(const std::vector<std::string> &outputs, std::vector<OpNodeId> &path);
* @brief Fuse operators.
*/
AccDataErrorCode Fuse(const std::vector<std::string> &outputs, std::vector<OpNodeId> &path);
* Find fusion plan for the pipeline denoted by the parameter originPath
* @param [in] originPath the path that generate the outputs
* @param [in] outputs the name of all outputs
* @return a vector of strings denotes the ops names in the found fusion plan
*/
AccDataErrorCode FindFusePlan(const std::vector<OpNodeId> &originPath, const std::vector<std::string> &outputs,
std::vector<std::string> &fusePlan);
bool ValidateFusePlan(const std::vector<std::string> &plan, const std::vector<OpNodeId> &originPath,
const std::vector<std::string> &outputs);
AccDataErrorCode FuseOps(std::vector<OpNode> &opNodes, const std::string &fuseOpName,
std::vector<OpNodeId> nodeIds);
AccDataErrorCode UpdateGraph(std::vector<OpNode> &opNodes);
* @brief Traverse the path that generate the output.
*
* @param [in] output Output name.
* @param [out] path Path of OpNodes that generates the specified output.
*/
AccDataErrorCode Traverse(const std::string &output, std::vector<OpNodeId> &path);
AccDataErrorCode BuildFromPath(const std::vector<OpNodeId> &path);
private:
AccDataErrorCode LinkInput(const std::string &name, std::unordered_map<std::string, DataNodeId> &name2DataNode,
std::vector<DataNode> &dataNodes, size_t opNodeId, OpNode &opNode);
std::vector<OpNode> mOpNodes{};
std::vector<DataNode> mDataNodes{};
std::unordered_map<std::string, DataNodeId> mName2DataNode{};
std::vector<DataNodeId> mOutputs{};
OpNode opNode{};
DataNode dataNode{};
};
}
}
#endif