* -------------------------------------------------------------------------
* This file is part of the MultimodalSDK project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MultimodalSDK is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
* @Description:
* @Version: 1.0
* @Date: 2025-2-10 15:00:00
* @LastEditors: dev
* @LastEditTime: 2025-2-10 15:00:00
*/
#ifndef ACCDATA_SRC_CPP_TENSOR_TENSOR_H_
#define ACCDATA_SRC_CPP_TENSOR_TENSOR_H_
#include <memory>
#include <vector>
#include <utility>
#include <numeric>
#include <unordered_set>
#include "securec.h"
#include "common/check.h"
#include "common/utility.h"
#include "accdata_tensor.h"
namespace acclib {
namespace accdata {
template <typename T>
inline constexpr TensorDataType TensorDataTypeEnum()
{
if constexpr (std::is_same_v<T, uint8_t>) {
return TensorDataType::UINT8;
} else if constexpr (std::is_same_v<T, float>) {
return TensorDataType::FP32;
} else if constexpr (std::is_same_v<T, char>) {
return TensorDataType::CHAR;
} else {
ACCDATA_ERROR("Unsupported data type.");
return TensorDataType::LAST;
}
}
inline int64_t NumElements(const TensorShape& shape, uint32_t dim = 0)
{
if (dim >= shape.size()) {
return 0;
}
return std::accumulate(shape.begin() + dim, shape.end(), 1ULL, std::multiplies<int64_t>());
}
inline bool IsValidShape(const TensorShape& shape)
{
for (auto dim : shape) {
if (dim <= 0) {
return false;
}
}
return true;
}
inline bool IsValidDataType(TensorDataType dataType)
{
static const std::unordered_set<TensorDataType> validDataTypes = {
TensorDataType::CHAR, TensorDataType::FP32, TensorDataType::UINT8
};
return validDataTypes.find(dataType) != validDataTypes.end();
}
inline bool IsValidLayout(TensorLayout layoutType)
{
static const std::unordered_set<TensorLayout> validDataLayouts = {
TensorLayout::NCHW, TensorLayout::NHWC, TensorLayout::PLAIN
};
return validDataLayouts.find(layoutType) != validDataLayouts.end();
}
* @brief Tensor
*/
class Tensor : public AccDataTensor {
public:
Tensor() = default;
Tensor(const Tensor &) = delete;
Tensor &operator = (const Tensor &) = delete;
Tensor(Tensor &&other)
{
*this = std::move(other);
}
Tensor &operator = (Tensor &&other)
{
if (this == &other) {
return *this;
}
mLayout = std::exchange(other.mLayout, TensorLayout::PLAIN);
mDataType = std::exchange(other.mDataType, TensorDataType::FP32);
mShape = std::exchange(other.mShape, {});
mNumBytes = std::exchange(mNumBytes, 0);
mData = std::move(other.mData);
return *this;
}
~Tensor() = default;
AccDataErrorCode Copy(const void *data, const TensorShape &shape, TensorDataType dataType)
{
if (data == nullptr || !IsValidShape(shape) || !IsValidDataType(dataType)) {
ACCDATA_ERROR("Tensor copy failed: data is null, shape is valid, or data type is invalid.");
return AccDataErrorCode::H_TENSOR_ERROR;
}
Resize(shape, dataType);
const uint8_t *src = static_cast<const uint8_t *>(data);
int64_t numBytes = NumElements(shape) * TensorDataTypeSize(dataType);
auto ret = memcpy_s(RawDataPtr<uint8_t>(), mNumBytes, src, numBytes);
if (ret != EOK) {
ACCDATA_ERROR("Memcpy_s failed during copy.");
return AccDataErrorCode::H_TENSOR_ERROR;
}
return AccDataErrorCode::H_OK;
}
template <typename T>
AccDataErrorCode Copy(const T *data, const TensorShape &shape)
{
return Copy(data, shape, TensorDataTypeEnum<T>());
}
AccDataErrorCode Copy(const Tensor &other)
{
SetLayout(other.mLayout);
return Copy(other.mData.get(), other.mShape, other.mDataType);
}
AccDataErrorCode ShareData(const std::shared_ptr<void> &data, const TensorShape &shape, TensorDataType dataType)
{
if (data.get() == nullptr || !IsValidShape(shape) || !IsValidDataType(dataType)) {
ACCDATA_ERROR("Tensor share failed: data is null, shape is valid, or data type is invalid.");
return AccDataErrorCode::H_TENSOR_ERROR;
}
mDataType = dataType;
mShape = shape;
mNumBytes = NumElements(shape) * TensorDataTypeSize(dataType);
mData = data;
return AccDataErrorCode::H_OK;
}
template <typename T>
AccDataErrorCode ShareData(const std::shared_ptr<T> &data, const TensorShape &shape)
{
return ShareData(data, shape, TensorDataTypeEnum<T>());
}
AccDataErrorCode ShareData(const Tensor &other)
{
SetLayout(other.Layout());
return ShareData(other.mData, other.mShape, other.mDataType);
}
void Resize(const TensorShape &shape, TensorDataType dataType)
{
mDataType = dataType;
mShape = shape;
int64_t numBytes = NumElements(shape) * TensorDataTypeSize(dataType);
if (numBytes <= mNumBytes) {
return;
}
mNumBytes = numBytes;
mData = std::shared_ptr<uint8_t>(new (std::align_val_t(ACCDATA_ALIGN_SIZE)) uint8_t[mNumBytes],
[](uint8_t* ptr) {
operator delete[] (ptr, (std::align_val_t(ACCDATA_ALIGN_SIZE)));
}
);
return;
}
template <typename T>
void Resize(const TensorShape &shape)
{
return Resize(shape, TensorDataTypeEnum<T>());
}
template <typename T>
AccDataErrorCode Data(std::shared_ptr<T> &data) const
{
if (!IsDataType<T>()) {
ACCDATA_ERROR("Different datatype.");
return AccDataErrorCode::H_TENSOR_ERROR;
}
data = std::static_pointer_cast<T>(mData);
return AccDataErrorCode::H_OK;
}
template <typename T> T *RawDataPtr() const
{
return static_cast<T *>(mData.get());
}
int64_t GetSize() const
{
return mNumBytes;
}
std::shared_ptr<void> RawDataPtr() const
{
return mData;
}
void Reset()
{
mLayout = TensorLayout::PLAIN;
mDataType = TensorDataType::FP32;
mShape.clear();
mNumBytes = 0;
mData.reset();
return;
}
void SetLayout(TensorLayout layout)
{
mLayout = layout;
return;
}
TensorLayout Layout() const
{
return mLayout;
}
TensorDataType DataType() const
{
return mDataType;
}
const TensorShape &Shape() const
{
return mShape;
}
template <typename T> bool IsDataType() const
{
return mDataType == TensorDataTypeEnum<T>();
}
bool IsValid() const
{
bool valid = (mData != nullptr && mNumBytes > 0 &&
IsValidDataType(mDataType) && IsValidLayout(mLayout) &&
IsValidShape(mShape) && mShape.size() == 4ULL);
if (!valid) {
ACCDATA_ERROR("Tensor is invalid, Info : [ IsValidDataType: " << IsValidDataType(mDataType) <<
", IsValidLayout: " << IsValidLayout(mLayout) <<
", IsValidShape: " << IsValidShape(mShape) && mShape.size() == 4ULL);
}
return valid;
}
private:
TensorLayout mLayout{ TensorLayout::LAST };
TensorDataType mDataType{ TensorDataType::LAST };
TensorShape mShape;
int64_t mNumBytes{ 0 };
std::shared_ptr<void> mData{ nullptr };
};
}
}
#endif