* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GRAPH_OP_DESC_IMPL_H_
#define GRAPH_OP_DESC_IMPL_H_
#include <string>
#include <utility>
#include <vector>
#include <set>
#include "graph/types.h"
#include "graph/op_desc.h"
#include "graph/ge_tensor.h"
#include "graph/small_vector.h"
#include "graph/ascend_limits.h"
#include "graph/type/tensor_type_impl.h"
#include "graph/ir/ir_meta.h"
namespace ge {
enum class DataTypeInferStrategy {
kInferFromAttr,
kInferFromInput,
kInferFromOutput,
kInvalidStrategy
};
class OpDescImpl {
public:
OpDescImpl();
OpDescImpl(const std::string &name, const std::string &type);
OpDescImpl(const OpDescImpl &op_desc_impl);
OpDescImpl& operator=(const OpDescImpl &op_desc_impl) &;
explicit OpDescImpl(const ge::proto::OpDef &op_def);
~OpDescImpl() = default;
const char *GetNamePtr() const;
std::string GetName() const;
void SetName(const std::string &name);
const char *GetTypePtr() const;
std::string GetType() const;
void SetType(const std::string &type);
void SetIrRelated(const OpDescImpl *r_op_desc);
graphStatus AddInputDesc(const ge::GeTensorDesc &input_desc);
graphStatus AddInputDesc(const uint32_t index, const ge::GeTensorDesc &input_desc);
graphStatus AddInputDesc(const std::string &name, const ge::GeTensorDesc &input_desc);
graphStatus AddInputDescMiddle(const std::string &name, const uint32_t num, const size_t index);
graphStatus AddOutputDescMiddle(const std::string &name, const uint32_t num, const size_t index);
graphStatus AddInputDescForward(const std::string &name, const uint32_t num);
graphStatus AddOutputDescForward(const std::string &name, const uint32_t num);
graphStatus AddOptionalInputDesc(const std::string &name, const ge::GeTensorDesc &input_desc);
graphStatus UpdateInputDesc(const uint32_t index, const ge::GeTensorDesc &tensor_Desc);
graphStatus UpdateInputDesc(const std::string &name, const ge::GeTensorDesc &tensor_Desc);
bool OpDescMembersAreEqual(const OpDescImpl &r_op_desc) const;
bool OpDescAttrsAreEqual(const OpDescImpl &r_op_desc) const;
bool OpDescGenTensorDescsAreEqual(const OpDescImpl &r_op_desc) const;
bool InputIsSet(const std::string &name) const;
const GeTensorDesc &GetInputDesc(const uint32_t index) const;
const GeTensorDesc &GetInputDesc(const std::string &name) const;
GeTensorDescPtr MutableInputDesc(const uint32_t index) const;
GeTensorDescPtr MutableInputDesc(const std::string &name) const;
OpDesc::Vistor<string> GetAllInputNames(const ConstOpDescPtr &op_desc) const;
void SetOpKernelLibName(const std::string &name);
std::string GetOpKernelLibName() const;
void SetOpEngineName(const std::string &name);
std::string GetOpEngineName() const;
OpDesc::Vistor<GeTensorDesc> GetAllInputsDesc(const ConstOpDescPtr &op_desc) const;
OpDesc::Vistor<GeTensorDescPtr> GetAllInputsDescPtr(const ConstOpDescPtr &op_desc) const;
size_t GetInputsSize() const;
size_t GetIrInputsSize() const;
size_t GetAllInputsSize() const;
graphStatus AddOutputDesc(const ge::GeTensorDesc &output_desc);
graphStatus AddOutputDesc(const std::string &name, const ge::GeTensorDesc &output_desc);
graphStatus UpdateOutputDesc(const uint32_t index, const ge::GeTensorDesc &tensor_Desc);
graphStatus UpdateOutputDesc(const std::string &name, const ge::GeTensorDesc &tensor_Desc);
const GeTensorDesc &GetOutputDesc(const uint32_t index) const;
const GeTensorDesc &GetOutputDesc(const std::string &name) const;
GeTensorDescPtr MutableOutputDesc(const uint32_t index) const;
GeTensorDescPtr MutableOutputDesc(const std::string &name) const;
uint32_t GetAllOutputsDescSize() const;
OpDesc::Vistor<GeTensorDesc> GetAllOutputsDesc(const ConstOpDescPtr &op_desc) const;
OpDesc::Vistor<GeTensorDescPtr> GetAllOutputsDescPtr(const ConstOpDescPtr &op_desc) const;
ConstGeTensorDescPtr GetOutputDescPtr(const uint32_t index) const;
size_t GetOutputsSize() const;
ConstGeTensorDescPtr GetInputDescPtr(const uint32_t index) const;
ConstGeTensorDescPtr GetInputDescPtrDfault(const uint32_t index) const;
ConstGeTensorDescPtr GetInputDescPtr(const std::string &name) const;
graphStatus AddDynamicInputDesc(const std::string &name, const uint32_t num, const bool is_push_back);
graphStatus AddDynamicInputDescByIndex(const std::string &name, const uint32_t num, const size_t index);
graphStatus AddDynamicOutputDesc(const std::string &name, const uint32_t num, const bool is_push_back);
bool IsOptionalInput(const uint32_t index) const;
std::map<std::string, uint32_t> GetAllInputName() const;
std::map<std::string, uint32_t> GetAllOutputName();
std::map<uint32_t, std::string> GetAllOutputIndexToName();
std::map<std::string, uint32_t>& MutableAllInputName();
std::map<std::string, uint32_t>& MutableAllOutputName();
bool UpdateInputName(std::map<std::string, uint32_t> input_name_idx);
bool UpdateOutputName(std::map<std::string, uint32_t> output_name_idx);
std::function<graphStatus(Operator &)> GetInferFunc() const;
std::function<graphStatus(Operator &)> GetVerifyFunc() const;
std::function<graphStatus(Operator &)> GetInferFormatFunc() const;
std::function<graphStatus(Operator &)> GetInferValueRangeFunc() const;
std::function<graphStatus(Operator &)> GetInferDataSliceFunc() const;
void AddInferFunc(const std::function<graphStatus(Operator &)> &func);
void AddInferFormatFunc(const std::function<graphStatus(Operator &)> &func);
void AddInferValueRangeFunc(const std::function<graphStatus(Operator &)> &func);
void AddVerifierFunc(const std::function<graphStatus(Operator &)> &func);
void AddInferDataSliceFunc(const std::function<graphStatus(Operator &)> &func);
bool IsSupportSymbolicInferDataType() const;
graphStatus SymbolicInferDataType(const OpDescPtr &op_desc) const;
graphStatus DefaultInferFormat(const ConstOpDescPtr &op_desc) const;
std::string GetInputNameByIndex(const uint32_t index) const;
int32_t GetInputIndexByName(const std::string &name) const;
graphStatus GetDynamicInputIndexesByName(const std::string &name, std::vector<int32_t> &indexes) const;
std::string GetValidInputNameByIndex(const uint32_t index) const;
std::string GetOutputNameByIndex(const uint32_t index) const;
int32_t GetOutputIndexByName(const std::string &name) const;
graphStatus GetDynamicOutputIndexesByName(const std::string &name, std::vector<int32_t> &indexes) const;
ProtoAttrMap &MutableAttrMap();
ConstProtoAttrMap &GetAttrMap() const;
IRMetaData &MutableIRMeta();
const IRMetaData &GetIRMeta() const;
void SetId(const int64_t id);
int64_t GetId() const;
void SetStreamId(const int64_t stream_id);
int64_t GetStreamId() const;
void SetInputName(const std::vector<std::string> &input_name);
std::vector<std::string> GetInputName() const;
void SetSrcName(const std::vector<std::string> &src_name);
std::vector<std::string> GetSrcName() const;
void SetSrcIndex(const std::vector<int64_t> &src_index);
std::vector<int64_t> GetSrcIndex() const;
void SetInputOffset(const std::vector<int64_t> &input);
std::vector<int64_t> GetInputOffset() const;
void SetOutputOffset(const std::vector<int64_t> &output);
std::vector<int64_t> GetOutputOffset() const;
void SetDstName(const std::vector<std::string> &dst_name);
std::vector<std::string> GetDstName() const;
void SetDstIndex(const std::vector<int64_t> &dst_index);
void SetWorkspace(const std::vector<int64_t> &workspace);
std::vector<int64_t> GetWorkspace() const;
void SetWorkspaceBytes(const std::vector<int64_t> &workspace_bytes);
std::vector<int64_t> GetWorkspaceBytes() const;
void SetIsInputConst(const std::vector<bool> &is_input_const);
std::vector<bool> GetIsInputConst() const;
std::string GetSubgraphInstanceName(const size_t index) const;
const std::vector<std::string> &GetSubgraphInstanceNames() const;
void RemoveSubgraphInstanceName(const std::string &name);
graphStatus AddSubgraphName(const std::string &name);
const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const;
graphStatus SetSubgraphInstanceName(const size_t index, const std::string &name);
graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const;
void *GetTilingFuncInfo() const;
void SetTilingFuncInfo(void *tiling_func_info);
void *GetAtomicTilingFuncInfo() const;
void SetAtomicTilingFuncInfo(void *atomic_tiling_func_info);
private:
void DeSerializeOpDefToMetaData(const proto::OpDef &op_def);
void SerializeMetaDataToOpDef(proto::OpDef * const op_def);
friend class AttrUtils;
friend class OpDescUtils;
friend class ModelSerializeImp;
friend class OnnxUtils;
friend class GraphUtils;
friend class NodeUtils;
friend class FastNodeUtils;
friend class ExecuteGraphUtils;
std::vector<std::string> subgraph_instance_names_;
std::map<std::string, uint32_t> subgraph_names_to_index_;
std::vector<GeTensorDescPtr> inputs_desc_{};
std::map<std::string, uint32_t> input_name_idx_{};
std::vector<GeTensorDescPtr> outputs_desc_{};
std::map<std::string, uint32_t> output_name_idx_{};
std::function<graphStatus(Operator &)> infer_func_ = nullptr;
std::function<graphStatus(Operator &)> infer_format_func_ = nullptr;
std::function<graphStatus(Operator &)> infer_value_range_func_ = nullptr;
std::function<graphStatus(Operator &)> verifier_func_ = nullptr;
std::function<graphStatus(Operator &)> infer_data_slice_func_ = nullptr;
std::string op_kernel_lib_name_;
std::string engine_name_;
OpMetadata meta_data_;
AttrStore attrs_;
void *tiling_func_info_ = nullptr;
void *atomic_tiling_func_info_ = nullptr;
};
}
#endif