* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GRAPH_GE_TENSOR_IMPL_H_
#define GRAPH_GE_TENSOR_IMPL_H_
#include <string>
#include <vector>
#include <memory>
#include "graph/attr_store.h"
#include "graph/detail/attributes_holder.h"
#include "graph/buffer.h"
#include "graph/ge_error_codes.h"
#include "proto/ge_ir.pb.h"
#include "graph/ge_tensor.h"
#include "graph/types.h"
namespace ge {
class GeTensorDescImpl {
public:
GeTensorDescImpl() = default;
GeTensorDescImpl(const GeShape &shape, const Format format, const DataType dt);
explicit GeTensorDescImpl(proto::TensorDescriptor *const proto_msg);
~GeTensorDescImpl() = default;
GeShape &ShapeReference() const;
GeShape &OriginShapeReference() const;
bool GeTensorDescAttrsAreEqual(const GeTensorDescImpl &other) const;
bool operator==(const GeTensorDescImpl &other) const;
ProtoAttrMap &MutableAttrMap();
ConstProtoAttrMap &GetAttrMap() const;
void SetShape(GeShape &shape) const;
void SetDataType(const DataType dtype);
DataType GetDataType() const;
void SetFormat(const Format format);
Format GetFormat() const;
void SetOriginFormat(const Format format);
Format GetOriginFormat() const;
void SetOriginDataType(const DataType dtype);
DataType GetOriginDataType() const;
void SetName(const std::string &name);
const std::string GetName() const;
bool IsOriginShapeInited() const {
return ext_meta_.IsOriginShapeInited();
}
void SetOriginShapeInited(const bool origin_shape_inited) {
ext_meta_.SetOriginShapeInited(origin_shape_inited);
}
void SetExpandDimsRule(const std::string &expand_dims_rule) {
ext_meta_.SetExpandDimsRule(expand_dims_rule);
}
std::string GetExpandDimsRule() const {
return ext_meta_.GetExpandDimsRule();
}
private:
friend class GeTensorImpl;
friend class TensorUtils;
friend class GeAttrValueImp;
friend class ModelSerializeImp;
friend class GeTensorSerializeUtils;
friend class OnnxUtils;
class ExtMeta {
public:
bool operator==(const ExtMeta& other) const {
return (name == other.name) && (device_type == other.device_type) && (size == other.size) &&
(weight_size == other.weight_size) && (cmps_tab_offset == other.cmps_tab_offset) &&
(reuse_input_index == other.reuse_input_index) && (cmps_tab == other.cmps_tab) &&
(data_offset == other.data_offset) && (cmps_size == other.cmps_size) &&
(real_dim_cnt == other.real_dim_cnt) &&
(other.reuse_input ? reuse_input : !reuse_input) &&
(other.input_tensor ? input_tensor : !input_tensor) &&
(other.output_tensor ? output_tensor : !output_tensor) &&
(other.origin_shape_inited_ ? origin_shape_inited_ : !origin_shape_inited_);
}
std::string GetName() const {
return name;
}
void SetName(const std::string &v) {
name = v;
}
DeviceType GetDeviceType() const {
return device_type;
}
std::string GetDeviceTypeStr() const;
void SetDeviceType(const DeviceType v) {
device_type = v;
}
int64_t GetSize() const {
return size;
}
void SetSize(const int64_t v) {
size = v;
}
int64_t GetWeightSize() const {
return weight_size;
}
void SetWeightSize(const int64_t v) {
weight_size = v;
}
int64_t GetDataOffset() const {
return data_offset;
}
void SetDataOffset(const int64_t v) {
data_offset = v;
}
uint32_t GetRealDimCnt() const {
return real_dim_cnt;
}
void SetRealDimCnt(const uint32_t v) {
real_dim_cnt = v;
}
bool GetInputTensor() const {
return input_tensor;
}
void SetInputTensor(const bool v) {
input_tensor = v;
}
bool GetReuseInput() const {
return reuse_input;
}
void SetReuseInput(const bool v) {
reuse_input = v;
}
uint32_t GetReuseInputIndex() const {
return reuse_input_index;
}
void SetReuseInputIndex(const uint32_t v) {
reuse_input_index = v;
}
bool GetOutputTensor() const {
return output_tensor;
}
void SetOutputTensor(const bool v) {
output_tensor = v;
}
int64_t GetCmpsSize() const {
return cmps_size;
}
void SetCmpsSize(const int64_t v) {
cmps_size = v;
}
std::string GetCmpsTab() const {
return cmps_tab;
}
void SetCmpsTab(const std::string &v) {
cmps_tab = v;
}
int64_t GetCmpsTabOffset() const {
return cmps_tab_offset;
}
void SetCmpsTabOffset(const int64_t v) {
cmps_tab_offset = v;
}
bool IsOriginShapeInited() const {
return origin_shape_inited_;
}
void SetOriginShapeInited(const bool origin_shape_inited) {
origin_shape_inited_ = origin_shape_inited;
}
void SetExpandDimsRule(const std::string &expand_dims_rule) {
expand_dims_rule_ = expand_dims_rule;
}
std::string GetExpandDimsRule() const {
return expand_dims_rule_;
}
private:
int64_t size{0};
int64_t data_offset{0};
int64_t cmps_tab_offset{0};
int64_t cmps_size{0};
int64_t weight_size{0};
uint32_t real_dim_cnt{0U};
uint32_t reuse_input_index{0U};
DeviceType device_type{NPU};
bool input_tensor{false};
bool reuse_input{false};
bool output_tensor{false};
bool origin_shape_inited_{false};
std::string cmps_tab;
std::string name;
std::string expand_dims_rule_;
};
mutable GeShape shape_;
Format format_{FORMAT_ND};
DataType dtype_{DT_FLOAT};
mutable GeShape origin_shape_;
Format origin_format_{FORMAT_ND};
DataType origin_dtype_{DT_UNDEFINED};
ExtMeta ext_meta_;
AttrStore attrs_;
};
class TensorDataImpl {
public:
TensorDataImpl() = default;
TensorDataImpl(const TensorDataImpl &other);
~TensorDataImpl() = default;
TensorDataImpl &operator=(const TensorDataImpl &other);
graphStatus SetData(const uint8_t * const data, const size_t size);
graphStatus SetData(uint8_t * const data, const size_t size, const AlignedPtr::Deleter &delete_fuc);
void SetData(std::shared_ptr<AlignedPtr> aligned_ptr, const size_t size);
graphStatus ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc);
const uint8_t *MallocAlignedPtr(const size_t size);
size_t GetSize() const;
const uint8_t *GetData() const;
uint8_t *GetData();
bool IsTensorDataValid() const;
void clear();
uint8_t operator[](const size_t index) const;
const std::shared_ptr<AlignedPtr> &GetAlignedPtr() const { return aligned_ptr_; }
private:
friend class GeTensorImpl;
friend class TensorUtils;
friend class GeAttrValueImp;
friend class ModelSerializeImp;
friend class GeTensorSerializeUtils;
std::shared_ptr<GeTensorDescImpl> tensor_descriptor_;
std::shared_ptr<AlignedPtr> aligned_ptr_ = nullptr;
size_t length_ = 0UL;
static uint32_t invalid_data_;
};
class GeTensorImpl {
public:
GeTensorImpl();
explicit GeTensorImpl(const GeTensorDesc &tensor_desc);
GeTensorImpl(const GeTensorDesc &tensor_desc, const std::vector<uint8_t> &data);
GeTensorImpl(const GeTensorDesc &tensor_desc, const uint8_t * const data, const size_t size);
GeTensorImpl(GeTensorDesc &&tensor_desc, std::vector<uint8_t> &&data);
GeTensorImpl(const GeTensorDesc &tensor_desc, const Buffer &data);
GeTensorImpl(const GeTensorDesc &tensor_desc, std::shared_ptr<AlignedPtr> aligned_ptr, const size_t size);
GeTensorImpl(const GeTensorDesc &tensor_desc, const size_t size);
GeTensorImpl(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg);
GeTensorImpl(const GeTensorImpl &other);
~GeTensorImpl() = default;
GeTensorImpl &operator=(const GeTensorImpl &other);
GeTensorDesc &DescReference() const;
void BuildAlignerPtrWithProtoData();
graphStatus SetData(std::vector<uint8_t> &&data);
graphStatus SetData(const std::vector<uint8_t> &data);
graphStatus SetData(const uint8_t * const data, size_t const size);
graphStatus SetData(const Buffer &data);
graphStatus SetData(const TensorData &data);
graphStatus SetData(uint8_t * const data, const size_t size, const AlignedPtr::Deleter &delete_fuc);
void ClearData();
void Clone(GeTensorImpl &tensor) const;
std::shared_ptr<AlignedPtr> GetAlignedPtr() const;
const TensorData &GetData() const { return tensor_data_; }
TensorData &MutableData() { return tensor_data_; }
bool IsTensorDataValid() const;
void SetData(std::shared_ptr<AlignedPtr> aligned_ptr, const size_t size) {
tensor_data_.SetData(std::move(aligned_ptr), size);
}
graphStatus ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc);
private:
friend class TensorUtils;
friend class GeAttrValueImp;
friend class ModelSerializeImp;
friend class GeTensorSerializeUtils;
GeIrProtoHelper<proto::TensorDef> tensor_def_;
mutable GeTensorDesc desc_;
TensorData tensor_data_;
};
}
#endif