* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
* ===================================================================================================================*/
#ifndef INC_EXTERNAL_GRAPH_TENSOR_H_
#define INC_EXTERNAL_GRAPH_TENSOR_H_
#include <atomic>
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include "./ge_error_codes.h"
#include "./types.h"
#include "ascend_string.h"
namespace ge {
class ShapeImpl;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Shape {
public:
Shape();
~Shape() = default;
explicit Shape(const std::vector<int64_t> &dims);
* `GetDimNum()`标识有效的dim的个数,跟`GetDims().size()`不等价,调用方按需选择
* 比如如果dim是[-2], 维度未可知时:
* GetDimNum()会返回0;
* 而GetDims().size()会返回dim的个数,即1;
* 另外如果需要判断是否是标量,推荐使用接口`GetDims.size() == 0U`来判断
* @return
*/
size_t GetDimNum() const;
int64_t GetDim(size_t idx) const;
graphStatus SetDim(size_t idx, int64_t value);
* `GetDims`标识dim的个数,跟`GetDimNum()`不等价,调用方按需选择
* 比如如果dim是[-2], 维度未可知时:
* GetDimNum()会返回0;
* 而GetDims().size()会返回dim的个数,即1;
* 另外如果需要判断是否是标量,推荐使用接口`GetDims.size() == 0U`来判断
* @return
*/
std::vector<int64_t> GetDims() const;
* 获取shape的各个维度的dim值的乘积
* @return
* 如果dim值包含-1或者-2,那么size直接返回-1, 含义是unknown shape
* 如果dim值包含0,那么size直接返回0,含义是空tensor
* 如果dim值的个数为0,那么size直接返回0,含义是标量
* 如果dim值的乘积产生了int64的溢出,那么size直接返回0,含义是乘积溢出
*/
int64_t GetShapeSize() const;
private:
std::shared_ptr<ShapeImpl> impl_;
};
class TensorDescImpl;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc {
public:
TensorDesc();
~TensorDesc() = default;
explicit TensorDesc(Shape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT);
TensorDesc(const TensorDesc &desc);
TensorDesc(TensorDesc &&desc);
TensorDesc &operator=(const TensorDesc &desc);
TensorDesc &operator=(TensorDesc &&desc);
void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT);
Shape GetShape() const;
void SetShape(const Shape &shape);
graphStatus SetUnknownDimNumShape();
graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range);
graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const;
Format GetFormat() const;
void SetFormat(Format format);
Shape GetOriginShape() const;
void SetOriginShape(const Shape &origin_shape);
Format GetOriginFormat() const;
void SetOriginFormat(Format origin_format);
DataType GetDataType() const;
void SetDataType(DataType dt);
ATTRIBUTED_DEPRECATED(graphStatus GetName(AscendString &))
std::string GetName() const;
graphStatus GetName(AscendString &name);
graphStatus GetName(AscendString &name) const;
ATTRIBUTED_DEPRECATED(void SetName(const char_t *))
void SetName(const std::string &name);
void SetName(const char_t *name);
void SetSize(int64_t size);
int64_t GetSize() const;
int64_t GetRealDimCnt() const;
void SetRealDimCnt(const int64_t real_dim_cnt);
void SetPlacement(Placement placement);
Placement GetPlacement() const;
void SetConstData(std::unique_ptr<uint8_t[]> const_data_buffer, const size_t &const_data_len);
bool GetConstData(uint8_t **const_data_buffer, size_t &const_data_len) const;
* 补维类似于ExpandDims算子,在原有shape的基础上,添加一到多个维度,例如原shape[2,2]有两根轴,那么在两根轴中间补两维后的shape为[2,1,1,2]。
* 补维后shape的第0、3根轴被称为原始轴,第1、2根轴被称为补维轴。
*
* 通过1和0描述补维规则,1代表当前轴为补维轴,0代表当前轴为原始轴,从左到右依次代表当前shape每根轴的来源,例如:
* | 补维规则 | 补维前shape | 补维后shape |
* | -------- | ----------- | ------------------------------------------------------------ |
* | 0110 | [2, 2] | [2, 1, 1, 2] |
* | 100 | [2, 3] | [1, 2, 3] |
* | 1000 | [2, 3] | 补维规则与补维前shape不匹配,规则指定原始轴有3根,但原始shape只有2根轴,补维报错。 |
*
*/
void SetExpandDimsRule(const AscendString &expand_dims_rule);
graphStatus GetExpandDimsRule(AscendString &expand_dims_rule) const;
void SetReuseInputIndex(const uint32_t idx);
private:
std::shared_ptr<TensorDescImpl> impl;
friend class TensorAdapter;
};
class TensorImpl;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor {
public:
using DeleteFunc = std::function<void(uint8_t *)>;
Tensor();
~Tensor() = default;
explicit Tensor(const TensorDesc &tensor_desc);
Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data);
Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size);
Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data);
TensorDesc GetTensorDesc() const;
graphStatus SetTensorDesc(const TensorDesc &tensor_desc);
const uint8_t *GetData() const;
uint8_t *GetData();
size_t GetSize() const;
std::unique_ptr<uint8_t[], Tensor::DeleteFunc> ResetData();
graphStatus SetData(std::vector<uint8_t> &&data);
graphStatus SetData(const std::vector<uint8_t> &data);
graphStatus SetData(const uint8_t *data, size_t size);
ATTRIBUTED_DEPRECATED(graphStatus SetData(const char_t *data))
graphStatus SetData(const std::string &data);
graphStatus SetData(const char_t *data);
ATTRIBUTED_DEPRECATED(graphStatus SetData(const std::vector<AscendString> &))
graphStatus SetData(const std::vector<std::string> &data);
graphStatus SetData(const std::vector<AscendString> &datas);
graphStatus SetData(uint8_t *data, size_t size, const Tensor::DeleteFunc &deleter_func);
graphStatus IsValid();
graphStatus SetOriginShapeDimNum(const size_t dim_num);
size_t GetOriginShapeDimNum() const;
graphStatus SetOriginShapeDim(const size_t idx, const int64_t dim_value);
int64_t GetOriginShapeDim(const size_t idx) const;
graphStatus SetOriginFormat(const ge::Format &format);
ge::Format GetOriginFormat() const;
graphStatus SetShapeDimNum(const size_t dim_num);
size_t GetShapeDimNum() const;
graphStatus SetShapeDim(const size_t idx, const int64_t dim_value);
int64_t GetShapeDim(const size_t idx) const;
graphStatus SetFormat(const ge::Format &format);
ge::Format GetFormat() const;
graphStatus SetDataType(const ge::DataType &dtype);
ge::DataType GetDataType() const;
graphStatus SetPlacement(const ge::Placement &placement);
ge::Placement GetPlacement() const;
* 补维类似于ExpandDims算子,在原有shape的基础上,添加一到多个维度,例如原shape[2,2]有两根轴,那么在两根轴中间补两维后的shape为[2,1,1,2]。
* 补维后shape的第0、3根轴被称为原始轴,第1、2根轴被称为补维轴。
*
* 通过1和0描述补维规则,1代表当前轴为补维轴,0代表当前轴为原始轴,从左到右依次代表当前shape每根轴的来源,例如:
* | 补维规则 | 补维前shape | 补维后shape |
* | -------- | ----------- | ------------------------------------------------------------ |
* | 0110 | [2, 2] | [2, 1, 1, 2] |
* | 100 | [2, 3] | [1, 2, 3] |
* | 1000 | [2, 3] | 补维规则与补维前shape不匹配,规则指定原始轴有3根,但原始shape只有2根轴,补维报错。 |
*
*/
graphStatus SetExpandDimsRule(const AscendString &expand_dims_rule);
graphStatus GetExpandDimsRule(AscendString &expand_dims_rule) const;
graphStatus ResetData(uint8_t *data, size_t size, const Tensor::DeleteFunc &deleter_func);
Tensor Clone() const;
private:
std::shared_ptr<TensorImpl> impl;
friend class TensorAdapter;
};
}
#endif