* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_
#define INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_
#include <map>
#include <memory>
#include <string>
#include <set>
#include <unordered_set>
#include <unordered_map>
#include <utility>
#include <vector>
#include "graph/detail/any_map.h"
#include "graph/ge_error_codes.h"
#include "graph/types.h"
#include "graph/attr_store.h"
#include "graph/any_value.h"
namespace google {
namespace protobuf {
class Message;
template<typename Key, typename T>
class Map;
}
}
namespace ge {
namespace proto {
class AttrDef;
class TensorDef;
class TensorDescriptor;
class ShapeDef;
class NamedAttrs;
class ModelDef;
class OpDef;
class GraphDef;
}
using ProtoAttrMap = AttrStore;
using ConstProtoAttrMap = const AttrStore;
using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>;
template<class ProtoType>
class GeIrProtoHelper {
public:
GeIrProtoHelper(const ProtoMsgOwner &protoOwner, ProtoType *const protoMsg)
: protoOwner_(protoOwner), protoMsg_(protoMsg) {}
GeIrProtoHelper() {
protoOwner_ = std::shared_ptr<::google::protobuf::Message>(nullptr);
protoMsg_ = nullptr;
}
virtual ~GeIrProtoHelper() = default;
template<typename T>
GeIrProtoHelper(const GeIrProtoHelper<T> &other) {
protoOwner_ = other.protoOwner_;
protoMsg_ = other.protoMsg_;
}
GeIrProtoHelper(const GeIrProtoHelper<ProtoType> &other) {
protoOwner_ = other.protoOwner_;
protoMsg_ = other.protoMsg_;
}
template<typename T>
GeIrProtoHelper &operator=(const GeIrProtoHelper<T> &other) {
protoOwner_ = other.protoOnwer_;
protoMsg_ = other.protoMsg_;
return *this;
}
GeIrProtoHelper &operator=(const GeIrProtoHelper<ProtoType> &other) {
if (this != &other) {
protoOwner_ = other.protoOwner_;
protoMsg_ = other.protoMsg_;
}
return *this;
}
void InitDefault();
template<typename T>
bool operator==(const GeIrProtoHelper<T> &other) const {
return (protoOwner_ == other.protoOwner_) && (protoMsg_ == other.protoMsg_);
}
inline const ProtoMsgOwner &GetProtoOwner() const {
return protoOwner_;
}
inline ProtoType *GetProtoMsg() const {
return protoMsg_;
}
void CopyValueFrom(const GeIrProtoHelper<const ProtoType> &other) {
if ((other.protoMsg_ != nullptr) && (protoMsg_ != nullptr)) {
*protoMsg_ = *other.protoMsg_;
}
}
void MoveValueFrom(GeIrProtoHelper<ProtoType> &&other) {
if ((other.protoMsg_ != nullptr) && (protoMsg_ != nullptr)) {
*protoMsg_ = std::move(*other.protoMsg_);
}
}
void Swap(GeIrProtoHelper<ProtoType> &other) {
protoOwner_.swap(other.protoOwner_);
ProtoType *const temp = protoMsg_;
protoMsg_ = other.protoMsg_;
other.protoMsg_ = temp;
}
friend class GeIrProtoHelper<typename std::conditional<
std::is_const<ProtoType>::value, typename std::remove_const<ProtoType>::type, const ProtoType>::type>;
friend class ComputerGraphImpl;
friend class GeTensorSerializeUtils;
private:
ProtoMsgOwner protoOwner_ = nullptr;
ProtoType *protoMsg_ = nullptr;
};
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder {
public:
AttrHolder() = default;
virtual ~AttrHolder() = default;
* 对当前AttrHolder对象设置属性,属性名为`name`,属性值为`value`,
* 需要注意的是 如果当前对象已经存在了`name`类型的属性,接口会进行刷新值的操作
* @param name
* @param value
* @return 成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED
*/
graphStatus SetAttr(const std::string &name, const AnyValue &value);
* 尝试对当前AttrHolder对象设置属性,属性名为`name`,属性值为`value`,
* 需要注意的是 如果当前对象已经存在了`name`类型的属性,接口并不会进行刷新值的操作
* @param name
* @param value
* @return 成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED
*/
graphStatus TrySetAttr(const std::string &name, const AnyValue &value);
graphStatus GetAttr(const std::string &name, AnyValue &value) const;
bool HasAttr(const std::string &name) const;
bool HasRequiredAttr(const std::string &name) const {
return required_attrs_and_type_.find(name) != required_attrs_and_type_.end();
}
graphStatus DelAttr(const std::string &name);
void CopyAttrsFrom(const AttrHolder &holder);
void CopyFrom(const AttrHolder &holder);
void SwapBase(AttrHolder &holder) {
required_attrs_and_type_.swap(holder.required_attrs_and_type_);
ext_attrs_.Swap(holder.ext_attrs_);
}
* 对当前对象设置名称为name, 值为value,类型为T的属性,如果对象已经存在了
* 名称name和类型T的属性,那么此接口会刷新属性的值,需要注意的是如果对象已
* 经存在了名称name和非类型T的属性,设置行为会失败告终
* @param name 属性的名称
* @param value 任意类型的属性值
* @return true/false 设置成功返回true, 设置失败返回false
*/
template<class T>
bool SetExtAttr(const std::string &name, const T &value) {
return ext_attrs_.Set(name, value);
}
* 对当前对象尝试获取名称name, 类型为T的属性值,如果对象没有name名称的属性,或者
* 属性的类型不为T, 查询行为失败
* @param name 属性的名称
* @param defaultValue 默认值,用于查询失败时返回这个默认值
* @return 如果查询成功,返回查询到的属性值,如果查询失败,返回传入的默认值
*/
template<class T>
T TryGetExtAttr(const std::string &name, const T defaultValue) const {
T ret(defaultValue);
(void) ext_attrs_.Get(name, ret);
return ret;
}
template<class T>
typename std::enable_if<std::is_base_of<AttrGroupsBase, T>::value, T *>::type GetOrCreateAttrsGroup() {
return MutableAttrMap().GetOrCreateAttrsGroup<T>();
}
template<class T>
typename std::enable_if<std::is_base_of<AttrGroupsBase, T>::value, bool>::type DeleteAttrsGroup() {
return MutableAttrMap().DeleteAttrsGroup<T>();
}
template<typename T, typename... Args>
typename std::enable_if<std::is_base_of<AttrGroupsBase, T>::value, T *>::type CreateAttrsGroup(Args &&... args) {
return MutableAttrMap().CreateAttrsGroup<T>(args...);
}
template<class T>
typename std::enable_if<std::is_base_of<AttrGroupsBase, T>::value, T *>::type GetAttrsGroup() {
return GetAttrMap().GetAttrsGroup<T>();
}
template<class T>
typename std::enable_if<std::is_base_of<AttrGroupsBase, T>::value, const T *>::type GetAttrsGroup() const {
return GetAttrMap().GetAttrsGroup<T>();
}
* 对当前对象尝试获取名称name, 类型为T的属性值,如果对象没有name名称的属性,或者
* 属性的类型不为T, 查询行为失败
* @param name 属性的名称
* @return 如果查询成功,返回查询到的属性值的指针,如果查询失败,返回空指针
*/
template<class T>
const T *GetExtAttr(const std::string &name) const {
return ext_attrs_.Get<T>(name);
}
template<class T>
T *GetExtAttr(const std::string &name) {
return const_cast<T *>(ext_attrs_.Get<T>(name));
}
bool DelExtAttr(const std::string &name) {
return ext_attrs_.Erase(name);
}
graphStatus AddRequiredAttr(const std::string &name);
graphStatus AddRequiredAttrWithType(const std::string &name, const std::string &type);
const std::unordered_map<std::string, std::string> &GetRequiredAttrWithType() const {
return required_attrs_and_type_;
}
protected:
const std::set<std::string> GetAllAttrNames() const;
const std::map<std::string, AnyValue> GetAllAttrs() const;
const std::map<std::string, AnyValue> GetAllAttrsWithFilter(const AttrNameFilter &attr_filter) const;
virtual ProtoAttrMap &MutableAttrMap() = 0;
virtual ConstProtoAttrMap &GetAttrMap() const = 0;
friend class ModelSerializeImp;
friend class AttrUtils;
friend class OpDescUtils;
friend class GraphUtils;
friend class AttrGroupSerialize;
std::unordered_map<std::string, std::string> required_attrs_and_type_;
private:
AnyMap ext_attrs_;
};
}
#endif