* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INC_GRAPH_GE_TENSOR_H_
#define INC_GRAPH_GE_TENSOR_H_
#include <atomic>
#include <memory>
#include <string>
#include <vector>
#include "detail/attributes_holder.h"
#include "graph/buffer.h"
#include "graph_metadef/graph/aligned_ptr.h"
#include "graph/ge_error_codes.h"
#include "graph/types.h"
#include "graph/small_vector.h"
#include "graph/ascend_limits.h"
namespace ge {
class GeShapeImpl;
using GeShapeImplPtr = std::shared_ptr<GeShapeImpl>;
class TensorDataImpl;
using TensorDataImplPtr = std::shared_ptr<TensorDataImpl>;
class GeTensorDescImpl;
using GeTensorDescImplPtr = std::shared_ptr<GeTensorDescImpl>;
class GeTensorImpl;
using GeTensorImplPtr = std::shared_ptr<GeTensorImpl>;
class GeTensorSerializeUtils;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape {
public:
GeShape();
~GeShape();
explicit GeShape(std::vector<int64_t> s);
* `GetDimNum()`标识有效的dim的个数,跟`GetDims().size()`不等价,调用方按需选择
* 比如如果dim是[-2], 维度未可知时:
* GetDimNum()会返回0;
* 而GetDims().size()会返回dim的个数,即1;
* 另外如果需要判断是否是标量,推荐使用接口`IsScalar`
*/
size_t GetDimNum() const;
void SetDimNum(const size_t dim_num);
void AppendDim(const int64_t dim_size);
bool IsUnknownDimNum() const;
void SetIsUnknownDimNum();
int64_t GetDim(const size_t idx) const;
graphStatus SetDim(const size_t idx, const int64_t value);
* `GetDims`标识dim的个数,跟`GetDimNum()`不等价,调用方按需选择
* 比如如果dim是[-2], 维度未可知时:
* GetDimNum()会返回0;
* 而GetDims().size()会返回dim的个数,即1;
* 另外如果需要判断是否是标量,推荐使用接口`IsScalar`
*/
std::vector<int64_t> GetDims() const;
const SmallVector<int64_t, kDefaultDimsNum> &GetMutableDims() const;
int64_t GetShapeSize() const;
std::string ToString() const;
* 根据tensor的shape的各个维度的dim值判断tensor是否是unknown shape
* @return
* 如果某一维的dim值小于0,那么返回true, 代表是unknown shape
* 如果所有维度的dim值都大于等于0,那么返回false, 代表是known shape
*/
bool IsUnknownShape() const;
* 根据tensor的shape的dim值的个数返回tensor是否是个标量
* @return
* 如果dim的维度是0维,则返回true,代表是标量
* 其他情况返回false, 代表非标量
*/
bool IsScalar() const;
* 根据tensor的shape的dim值是否含0判断tensor是否是个空tensor
* @return
* 如果任一维度的dim值为0,则返回true,代表是空tensor
* 其他情况返回false, 代表非空tensor
*/
bool IsEmptyTensor() const;
GeShape(const GeShape &other);
GeShape(GeShape &&other);
GeShape &operator=(const GeShape &other);
GeShape &operator=(GeShape &&other);
bool operator==(const GeShape &other) const;
private:
GeShapeImplPtr impl_;
friend class GeTensorDesc;
friend class GeTensorDescImpl;
friend class GeTensorSerializeUtils;
friend class ModelSerialize;
GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *const proto_msg);
};
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrHolder {
friend class TensorUtils;
friend class ModelSerialize;
public:
GeTensorDesc();
explicit GeTensorDesc(const GeShape &shape, const Format format = FORMAT_ND, const DataType dt = DT_FLOAT);
GeTensorDesc(const GeTensorDesc &desc);
GeTensorDesc(GeTensorDesc &&desc);
~GeTensorDesc() override;
bool operator==(const GeTensorDesc &r_ge_tensor_desc) const;
void Update(const GeShape &shape, const Format format = FORMAT_ND, const DataType dt = DT_FLOAT);
const GeShape &GetShape() const;
GeShape &MutableShape();
void SetShape(const GeShape &shape);
void SetShape(GeShape &&shape);
void SetUnknownDimNumShape();
graphStatus SetValueRange(const std::vector<std::pair<int64_t, int64_t>> &range);
graphStatus GetValueRange(std::vector<std::pair<int64_t, int64_t>> &range) const;
graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range);
graphStatus SetOriginShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range);
graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const;
graphStatus GetOriginShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const;
const GeShape &GetOriginShape() const;
GeShape &MutableOriginShape() const;
void SetOriginShape(const GeShape &origin_shape);
bool IsOriginShapeInitialized() const;
Format GetFormat() const;
void SetFormat(const Format format);
Format GetOriginFormat() const;
void SetOriginFormat(const Format origin_format);
const std::string GetExpandDimsRule() const;
void SetExpandDimsRule(const std::string &expand_dims_rule);
void SetName(const std::string &name);
const std::string GetName() const;
DataType GetDataType() const;
void SetDataType(const DataType data_type);
DataType GetOriginDataType() const;
void SetOriginDataType(const DataType origin_data_type);
std::vector<uint32_t> GetRefPortIndex() const;
void SetRefPortByIndex(const std::vector<uint32_t> &index);
Placement GetPlacement() const;
void SetPlacement(const Placement placement);
GeTensorDesc Clone() const;
GeTensorDesc &operator=(const GeTensorDesc &desc);
GeTensorDesc &operator=(GeTensorDesc &&desc);
graphStatus IsValid() const;
explicit GeTensorDesc(proto::TensorDescriptor *const proto_msg);
using AttrHolder::DelAttr;
using AttrHolder::GetAllAttrs;
using AttrHolder::GetAttr;
using AttrHolder::HasAttr;
using AttrHolder::SetAttr;
template<class T>
T *GetOrCreateAttrsGroup() {
return MutableAttrMap().GetOrCreateAttrsGroup<T>();
}
protected:
ProtoAttrMap &MutableAttrMap() override;
ConstProtoAttrMap &GetAttrMap() const override;
private:
bool GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const;
friend class GeTensor;
friend class GeTensorImpl;
friend class GeAttrValueImp;
friend class ModelSerializeImp;
friend class GeTensorSerializeUtils;
friend class OnnxUtils;
GeTensorDescImplPtr impl_;
GeShape &ShapeReference() const;
GeShape &OriginShapeReference() const;
};
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorData {
public:
TensorData();
~TensorData();
graphStatus SetData(std::vector<uint8_t> &&data);
graphStatus SetData(const std::vector<uint8_t> &data);
graphStatus SetData(const Buffer &data);
graphStatus SetData(const TensorData &data);
graphStatus SetData(const uint8_t *const data, const size_t size);
graphStatus SetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc);
graphStatus ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc);
const uint8_t *MallocAlignedPtr(const size_t size);
const std::uint8_t *data() const;
std::uint8_t *data();
std::size_t size() const;
void clear();
uint8_t operator[](const size_t index) const;
std::size_t GetSize() const;
const std::uint8_t *GetData() const;
std::uint8_t *GetData();
bool IsTensorDataValid() const;
const std::shared_ptr<AlignedPtr> &GetAlignedPtr();
TensorData &operator=(const TensorData &other);
TensorData(const TensorData &other);
void SetData(std::shared_ptr<AlignedPtr> aligned_ptr, const size_t size);
private:
friend class GeTensor;
friend class GeTensorImpl;
friend class GeAttrValueImp;
friend class ModelSerializeImp;
friend class GeTensorSerializeUtils;
friend class TensorUtils;
TensorDataImplPtr impl_;
};
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor {
public:
GeTensor();
GeTensor(GeTensor &&other) noexcept;
explicit GeTensor(const GeTensorDesc &tensor_desc);
explicit GeTensor(const GeTensorDesc &tensor_desc, const std::vector<uint8_t> &data);
explicit GeTensor(const GeTensorDesc &tensor_desc, const Buffer &data);
explicit GeTensor(const GeTensorDesc &tensor_desc, const uint8_t *const data, const size_t size);
explicit GeTensor(GeTensorDesc &&tensor_desc, std::vector<uint8_t> &&data);
explicit GeTensor(const GeTensorDesc &tensor_desc, const size_t size);
~GeTensor();
const GeTensorDesc &GetTensorDesc() const;
GeTensorDesc &MutableTensorDesc();
void SetTensorDesc(const GeTensorDesc &tensor_desc);
std::shared_ptr<AlignedPtr> GetAlignedPtr();
const TensorData &GetData() const;
TensorData &MutableData();
bool IsTensorDataValid() const;
graphStatus SetData(std::vector<uint8_t> &&data);
graphStatus SetData(const std::vector<uint8_t> &data);
graphStatus SetData(const Buffer &data);
graphStatus SetData(const uint8_t *const data, const size_t size);
graphStatus SetData(const TensorData &data);
graphStatus SetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc);
graphStatus ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc);
void ClearData();
GeTensor Clone() const;
void SetData(std::shared_ptr<AlignedPtr> aligned_ptr, const size_t size);
GeTensor(const GeTensorDesc &tensor_desc, std::shared_ptr<AlignedPtr> aligned_ptr, const size_t size);
GeTensor(const GeTensor &other);
GeTensor &operator=(const GeTensor &other);
GeTensor &operator=(GeTensor &&other);
private:
friend class GeAttrValueImp;
friend class ModelSerializeImp;
friend class GeTensorSerializeUtils;
friend class OnnxUtils;
friend class TensorData;
friend class TensorUtils;
friend class TensorAdapter;
GeTensor(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg);
explicit GeTensor(GeTensorImplPtr impl);
void BuildAlignerPtrWithProtoData();
GeTensorImplPtr impl_;
GeTensorDesc &DescReference() const;
};
class GeTensorSerializeUtils {
public:
static void GeShapeAsProto(const GeShape &shape, proto::ShapeDef *proto);
static void GeTensorDescAsProto(const GeTensorDescImpl &desc, proto::TensorDescriptor *proto);
static void GeTensorDescAsProto(const GeTensorDesc &desc, proto::TensorDescriptor *proto);
static void GeTensorAsProto(const GeTensorImpl &tensor, proto::TensorDef *proto);
static void GeTensorAsProto(const GeTensor &tensor, proto::TensorDef *proto);
static void AssembleGeShapeFromProto(const proto::ShapeDef *proto, GeShape &shape);
static void AssembleGeTensorDescFromProto(const proto::TensorDescriptor *const proto, GeTensorDesc &desc);
static void AssembleGeTensorFromProto(const proto::TensorDef *proto, GeTensor &tensor);
static void NormalizeGeTensorDescProto(proto::TensorDescriptor *proto);
static void GetShapeFromDescProto(const proto::TensorDescriptor *const proto, GeShape &shape);
static void GetOriginShapeFromDescProto(const proto::TensorDescriptor *const proto, GeShape &shape);
static void GetDtypeFromDescProto(const proto::TensorDescriptor *const proto, DataType &dtype);
static void GetOriginDtypeFromDescProto(const proto::TensorDescriptor *const proto, DataType &dtype);
static void GetFormatFromDescProto(const proto::TensorDescriptor *const proto, Format &format);
static void GetOriginFormatFromDescProto(const proto::TensorDescriptor *const proto, Format &format);
};
using GeTensorDescPtr = std::shared_ptr<GeTensorDesc>;
using ConstGeTensorDescPtr = std::shared_ptr<const GeTensorDesc>;
}
#endif