* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INC_GRAPH_OP_DESC_H_
#define INC_GRAPH_OP_DESC_H_
#include "graph/range_vistor.h"
#include "graph/ge_tensor.h"
#include "graph/opp_impl_version.h"
namespace ge {
using std::map;
using std::pair;
using std::shared_ptr;
using std::string;
using std::vector;
class Operator;
class OpDescImpl;
class IRMetaData;
using OpDescImplPtr = std::shared_ptr<OpDescImpl>;
enum SubgraphType {
kStatic,
kDynamic,
kSubgraphTypeEnd
};
enum IrInputType {
kIrInputRequired,
kIrInputOptional,
kIrInputDynamic,
kIrInputTypeEnd
};
enum IrOutputType {
kIrOutputRequired,
kIrOutputDynamic,
kIrOutputTypeEnd
};
class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {
public:
using OpDescPtr = std::shared_ptr<OpDesc>;
using ConstOpDescPtr = std::shared_ptr<const OpDesc>;
template <class T>
using Vistor = RangeVistor<T, shared_ptr<const OpDesc>>;
friend class GraphBuilderImpl;
friend class OperatorImpl;
friend class RecoverIrUtils;
OpDesc(const std::string &name, const std::string &type);
OpDesc(const OpDesc &op_desc);
OpDesc(OpDesc &&op_desc);
explicit OpDesc(const ge::proto::OpDef &op_def);
OpDesc();
~OpDesc() override;
bool operator==(const OpDesc &r_op_desc) const;
std::string GetName() const;
const char *GetNamePtr() const;
void SetName(const std::string &name);
void SetNamePtr(const char_t *name);
std::string GetType() const;
const char *GetTypePtr() const;
void SetType(const std::string &type);
void SetIrRelated(const OpDescPtr &op_desc);
graphStatus AddInputDesc(const GeTensorDesc &input_desc);
graphStatus AddInputDesc(const std::string &name, const GeTensorDesc &input_desc);
graphStatus AddInputDesc(const uint32_t index, const ge::GeTensorDesc &input_desc);
graphStatus AddInputDescMiddle(const std::string &name, const uint32_t num, const size_t index);
graphStatus AddOutputDescMiddle(const std::string &name, const uint32_t num, const size_t index);
graphStatus AddOutputDescForward(const std::string &name, const uint32_t num);
graphStatus AddOptionalInputDesc(const std::string &name, const GeTensorDesc &input_desc);
graphStatus UpdateInputDesc(const uint32_t index, const GeTensorDesc &tensor_desc);
graphStatus UpdateInputDesc(const std::string &name, const GeTensorDesc &tensor_desc);
bool InputIsSet(const std::string &name) const;
const GeTensorDesc &GetInputDesc(const uint32_t index) const;
const GeTensorDesc &GetInputDesc(const std::string &name) const;
bool IsOptionalInput(const uint32_t index) const;
Vistor<string> GetAllInputNames() const;
GeTensorDescPtr MutableInputDesc(const uint32_t index) const;
GeTensorDescPtr MutableInputDesc(const std::string &name) const;
* 获取OpDesc的所有输入的GeTensorDesc对象的拷贝,
* 需要注意拷贝行为对性能的影响
* @return
*/
Vistor<GeTensorDesc> GetAllInputsDesc() const;
* 获取OpDesc的所有输入的GeTensorDesc对象的引用,
* 无特殊需求,推荐使用此接口替代GetAllInputsDesc()
* @return
*/
Vistor<GeTensorDescPtr> GetAllInputsDescPtr() const;
size_t GetInputsSize() const;
size_t GetAllInputsSize() const;
graphStatus AddOutputDesc(const GeTensorDesc &output_desc);
graphStatus AddOutputDesc(const std::string &name, const GeTensorDesc &output_desc);
graphStatus UpdateOutputDesc(const uint32_t index, const GeTensorDesc &tensor_desc);
graphStatus UpdateOutputDesc(const std::string &name, const GeTensorDesc &tensor_desc);
const GeTensorDesc &GetOutputDesc(const uint32_t index) const;
const GeTensorDesc &GetOutputDesc(const std::string &name) const;
GeTensorDescPtr MutableOutputDesc(const uint32_t index) const;
GeTensorDescPtr MutableOutputDesc(const std::string &name) const;
uint32_t GetAllOutputsDescSize() const;
Vistor<GeTensorDesc> GetAllOutputsDesc() const;
Vistor<GeTensorDescPtr> GetAllOutputsDescPtr() const;
size_t GetOutputsSize() const;
ConstGeTensorDescPtr GetOutputDescPtr(const uint32_t index) const;
ConstGeTensorDescPtr GetInputDescPtr(const uint32_t index) const;
ConstGeTensorDescPtr GetInputDescPtrDfault(const uint32_t index) const;
ConstGeTensorDescPtr GetInputDescPtr(const std::string &name) const;
graphStatus AddDynamicInputDesc(const std::string &name, const uint32_t num, const bool is_push_back = true);
graphStatus AddDynamicInputDescByIndex(const std::string &name, const uint32_t num, const size_t index);
graphStatus AddDynamicOutputDesc(const std::string &name, const uint32_t num, const bool is_push_back = true);
bool IsOptionalInput(const std::string &name) const;
std::map<std::string, uint32_t> GetAllInputName() const;
std::map<std::string, uint32_t> GetAllOutputName();
std::map<uint32_t, std::string> GetAllOutputIndexToName();
std::map<std::string, uint32_t>& MutableAllInputName();
std::map<std::string, uint32_t>& MutableAllOutputName();
bool UpdateInputName(const std::map<std::string, uint32_t> input_name_idx);
bool UpdateOutputName(const std::map<std::string, uint32_t> output_name_idx);
void *GetTilingFuncInfo() const;
void SetTilingFuncInfo(void *tiling_func_info);
void *GetAtomicTilingFuncInfo() const;
void SetAtomicTilingFuncInfo(void *atomic_tiling_func_info);
bool IsSupportSymbolicInferDataType() const;
graphStatus SymbolicInferDataType();
void AddInferFunc(const std::function<graphStatus(Operator &)> &func);
void AddInferFormatFunc(const std::function<graphStatus(Operator &)> &func);
void AddInferValueRangeFunc(const std::function<graphStatus(Operator &)> &func);
void AddVerifierFunc(const std::function<graphStatus(Operator &)> &func);
void AddInferDataSliceFunc(const std::function<graphStatus(Operator &)> &func);
graphStatus DefaultInferFormat();
std::function<graphStatus(Operator &)> GetInferFunc() const;
std::function<graphStatus(Operator &)> GetVerifyFunc() const;
std::function<graphStatus(Operator &)> GetInferFormatFunc() const;
std::function<graphStatus(Operator &)> GetInferDataSliceFunc() const;
std::function<graphStatus(Operator &)> GetInferValueRangeFunc() const;
graphStatus CommonVerify() const;
graphStatus AddRegisterInputName(const std::string &name);
graphStatus AddRegisterOutputName(const std::string &name);
std::vector<std::string> GetRegisterInputName() const;
std::vector<std::string> GetRegisterOutputName() const;
void AppendIrAttrName(const std::string &name);
const std::vector<std::string> &GetIrAttrNames() const;
void AppendIrInput(std::string name, IrInputType input_type);
const std::vector<std::pair<std::string, IrInputType>> &GetIrInputs() const;
size_t GetIrInputsSize() const;
void AppendIrOutput(std::string name, IrOutputType output_type);
const std::vector<std::pair<std::string, IrOutputType>> &GetIrOutputs() const;
void SetInputDtypeSymbol(const std::string &ir_input, IrInputType type, const std::string &sym_id);
void SetOutputDtypeSymbol(const std::string &ir_output, IrOutputType type, const std::string &sym_id);
void DeclareDtypeSymbol(const std::string &sym_id, const TensorType &type);
void DeclareDtypeSymbol(const std::string &sym_id, const ListTensorType &type);
void DeclareDtypeSymbol(const std::string &sym_id, const Promote &type);
void ShareDtypeSymbolsFrom(const ge::OpDesc &src);
using AttrHolder::AddRequiredAttr;
using AttrHolder::DelAttr;
using AttrHolder::GetAllAttrNames;
using AttrHolder::GetAllAttrs;
using AttrHolder::GetAttr;
using AttrHolder::HasAttr;
using AttrHolder::SetAttr;
void SetId(const int64_t id);
int64_t GetId() const;
void SetStreamId(const int64_t stream_id);
int64_t GetStreamId() const;
void SetAttachedStreamId(const int64_t stream_id);
int64_t GetAttachedStreamId() const;
void SetAttachedStreamIds(const std::vector<int64_t> &stream_ids);
std::vector<int64_t> GetAttachedStreamIds() const;
bool HasValidAttachedStreamId() const;
void SetInputName(const std::vector<std::string> &input_name);
std::vector<std::string> GetInputName() const;
void SetSrcName(const std::vector<std::string> &src_name);
std::vector<std::string> GetSrcName() const;
void SetSrcIndex(const std::vector<int64_t> &src_index);
std::vector<int64_t> GetSrcIndex() const;
void SetInputOffset(const std::vector<int64_t> &input);
std::vector<int64_t> GetInputOffset() const;
void SetOutputOffset(const std::vector<int64_t> &output);
std::vector<int64_t> GetOutputOffset() const;
void SetDstName(const std::vector<std::string> &dst_name);
std::vector<std::string> GetDstName() const;
void SetDstIndex(const std::vector<int64_t> &dst_index);
void SetWorkspace(const std::vector<int64_t> &workspace);
std::vector<int64_t> GetWorkspace() const;
void SetWorkspaceBytes(const std::vector<int64_t> &workspace_bytes);
std::vector<int64_t> GetWorkspaceBytes() const;
void SetIsInputConst(const std::vector<bool> &is_input_const);
std::vector<bool> GetIsInputConst() const;
void SetOpInferDepends(const std::vector<std::string> &depend_names);
std::vector<std::string> GetOpInferDepends() const;
std::string GetInputNameByIndex(const uint32_t index) const;
std::string GetValidInputNameByIndex(const uint32_t index) const;
int32_t GetInputIndexByName(const std::string &name) const;
graphStatus GetDynamicInputIndexesByName(const std::string &name, std::vector<int32_t> &indexes) const;
std::string GetOutputNameByIndex(const uint32_t index) const;
int32_t GetOutputIndexByName(const std::string &name) const;
graphStatus GetDynamicOutputIndexesByName(const std::string &name, std::vector<int32_t> &indexes) const;
void SetOpKernelLibName(const std::string &name);
std::string GetOpKernelLibName() const;
void SetOpEngineName(const std::string &name);
std::string GetOpEngineName() const;
OppImplVersion GetOppImplVersion() const;
void RegisterSubgraphIrName(const std::string &name, const SubgraphType type);
const std::map<std::string, SubgraphType> &GetSubgraphIrNames() const;
* @brief Get subgraph names in IR order
* @return subgraph ir names in IR order
*/
const std::vector<std::pair<std::string, SubgraphType>> &GetOrderedSubgraphIrNames() const;
SubgraphType GetSubgraphTypeByIrName(const std::string &name) const;
graphStatus AddSubgraphName(const std::string &name);
const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const;
std::string GetSubgraphInstanceName(const uint32_t index) const;
const std::vector<std::string> &GetSubgraphInstanceNames() const;
graphStatus SetSubgraphInstanceName(const uint32_t index, const std::string &name);
void RemoveSubgraphInstanceName(const std::string &name);
graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const;
graphStatus GetPromoteIrInputList(std::vector<std::vector<size_t>> &promote_index_list);
protected:
ProtoAttrMap &MutableAttrMap() override;
ConstProtoAttrMap &GetAttrMap() const override;
private:
bool OpDescMembersAreEqual(const OpDesc &r_op_desc) const;
bool OpDescAttrsAreEqual(const OpDesc &r_op_desc) const;
bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const;
OpDescImplPtr impl_;
friend class OpDescUtils;
friend class ModelSerializeImp;
friend class AttrUtils;
friend class GeAttrValueImp;
friend class OnnxUtils;
friend class GraphUtils;
friend class NodeUtils;
friend class FastNodeUtils;
friend class ExecuteGraphUtils;
};
using OpDescPtr = OpDesc::OpDescPtr;
using ConstOpDescPtr = OpDesc::ConstOpDescPtr;
using ConstOpDesc = const OpDesc;
class OpDescBuilder {
public:
OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {}
OpDescBuilder(const OpDescBuilder &) = delete;
OpDescBuilder &operator=(const OpDescBuilder &) = delete;
OpDescBuilder(const OpDescBuilder &&) = delete;
OpDescBuilder &operator=(const OpDescBuilder &&) = delete;
~OpDescBuilder() = default;
* @brief Add input
* @param [in] name
* @return OpDescBuilder
*/
OpDescBuilder &AddInput(const std::string &name);
* @brief Add input
* @param [in] name
* @param [in] tensor
* @return OpDescBuilder
*/
OpDescBuilder &AddInput(const std::string &name, const GeTensorDesc &tensor);
* @brief Add dynamic input
* @param [in] name
* @param [in] num
* @return OpDescBuilder
*/
OpDescBuilder &AddDynamicInput(const std::string &name, const uint32_t num);
* @brief Add dynamic input
* @param [in] name
* @param [in] num
* @param [in] tensor
* @return OpDescBuilder
*/
OpDescBuilder &AddDynamicInput(const std::string &name, const uint32_t num, const GeTensorDesc &tensor);
* @brief Add output
* @param [in] name
* @return OpDescBuilder
*/
OpDescBuilder &AddOutput(const std::string &name);
* @brief Add output
* @param [in] name
* @param [in] tensor
* @return OpDescBuilder
*/
OpDescBuilder &AddOutput(const std::string &name, const GeTensorDesc &tensor);
* @brief Add dynamic output
* @param [in] name
* @param [in] num
* @return OpDescBuilder
*/
OpDescBuilder &AddDynamicOutput(const std::string &name, const uint32_t num);
* @brief Add dynamic output
* @param [in] name
* @param [in] num
* @param [in] tensor
* @return OpDescBuilder
*/
OpDescBuilder &AddDynamicOutput(const std::string &name, const uint32_t num, const GeTensorDesc &tensor);
* @brief Build op_desc
* @return OpDescPtr
*/
OpDescPtr Build();
private:
std::string name_;
std::string type_;
std::vector<std::pair<std::string, GeTensorDesc>> inputs_;
std::vector<std::pair<std::string, GeTensorDesc>> outputs_;
};
}
#endif