* -------------------------------------------------------------------------
* This file is part of the MultimodalSDK project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MultimodalSDK is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
* @Description:
* @Version: 1.0
* @Date: 2025-3-17 9:00:00
* @LastEditors: dev
* @LastEditTime: 2025-3-17 9:00:00
*/
#ifndef ACCDATA_SRC_CPP_INTERFACE_OCKACCDATATENSOR_H_
#define ACCDATA_SRC_CPP_INTERFACE_OCKACCDATATENSOR_H_
#include <memory>
#include <vector>
#include <iostream>
#include <string>
#include <sys/types.h>
#include "accdata_error_code.h"
namespace acclib {
namespace accdata {
* @brief Tensor layout, value is same with APP_ACC
*/
enum class TensorLayout {
NCHW = 0,
NHWC = 1,
PLAIN = 2,
LAST = -1,
};
std::ostream &operator << (std::ostream &os, TensorLayout layout);
* @brief Tensor data type, value is same with CANN
*/
enum class TensorDataType {
FP32 = 0,
UINT8 = 4,
CHAR = 13,
LAST = -1,
};
int32_t TensorDataTypeSize(TensorDataType dataType);
* @brief Tensor shape
*/
using TensorShape = std::vector<size_t>;
std::ostream &operator << (std::ostream &os, TensorDataType layout);
* @class AccDataTensor
* @brief AccData数据格式
*/
class AccDataTensor {
public:
* @brief 复制数据到当前Tensor对象
*
* @note 如果data为空,将返回错误码并终止复制操作
*
* @param data 指向源数据的指针
* @param shape 输入数据的形状
* @param dataType 输入数据的数据类型
*
* @return AccData错误码:
* - H_OK 成功复制Tensor
* - H_TENSOR_ERROR data为空或者Tensor拷贝失败
*/
virtual AccDataErrorCode Copy(const void *data, const TensorShape &shape, TensorDataType dataType) = 0;
* @brief 共享数据到当前Tensor对象
*
* @param data 指向源数据的共享指针
* @param shape 输入数据的形状
* @param dataType 输入数据的数据类型
*/
virtual AccDataErrorCode ShareData(const std::shared_ptr<void> &data, const TensorShape &shape,
TensorDataType dataType) = 0;
* @brief 获取当前Tensor的数据指针
*
* @return 数据指针
*/
virtual std::shared_ptr<void> RawDataPtr() const = 0;
* @brief 获取当前Tensor的数据布局
*
* @return TensorLayout:
* - NHWC
* - NCHW
* - LAST 不支持的布局
*/
virtual TensorLayout Layout() const = 0;
* @brief 获取当前Tensor的数据类型格式
*
* @return TensorDataType:
* - UINT8 1字节无符号整数
* - FP32 单精度浮点数
* - CHAR 字符
* - LAST 不支持的数据类型
*/
virtual TensorDataType DataType() const = 0;
* @brief 获取当前Tensor的数据形状
*
* @return 当前Tensor的数据形状
*/
virtual const TensorShape &Shape() const = 0;
* @brief 设置当前Tensor的数据布局
*
* @param layout 当前Tensor的布局
*/
virtual void SetLayout(TensorLayout layout) = 0;
};
* @class AccDataTensorList
* @brief 非均匀共享的Tensor列表
*
* AccDataTensorList管理一个非均匀共享的Tensor列表,列表中的Tensor可以有不同的形状、数据类型、布局
*/
class AccDataTensorList {
public:
* @brief 创建一个TensorList,包含batchSize个Tensor对象
*
* @param batchSize TensorList中包含的Tensor数量
*
* @return 返回一个std::shared_ptr<AccDataTensorList>,指向新创建的对象
*/
static std::shared_ptr<AccDataTensorList> Create(uint64_t batchSize);
* @brief 获取当前TensorList包含的Tensor数量
*
* @return Tensor数量
*/
virtual uint64_t NumTensors() const = 0;
* @brief 获取指定索引的AccDataTensor引用
*
* 通过下标访问TensorList中的元素。
* 返回指定索引idx位置上的AccDataTensor的引用,使得可以修改该位置上的 Tensor。
*
* @param idx 索引值,表示TensorList中的元素位置。
*
* @return AccDataTensor& 指定索引位置的AccDataTensor引用。
*/
virtual AccDataTensor &operator[](uint64_t idx) = 0;
* @brief 获取指定索引的AccDataTensor常量引用。
*
* 通过下标访问TensorList中的元素,返回一个常量引用。
*
* @param idx 索引值,表示TensorList中的元素位置。
*
* @return const AccDataTensor& 指定索引位置的AccDataTensor常量引用。
*/
virtual const AccDataTensor &operator[](uint64_t idx) const = 0;
};
}
}
#endif