* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "data_utils.h"
#include <map>
#include <limits>
#include <stdexcept>
#include "udf_log.h"
namespace FlowFunc {
namespace {
constexpr int32_t kDataTypeSizeBitOffset = 1000;
}
bool CheckAddOverflowUint64(const uint64_t &a, const uint64_t &b) {
if (a > (UINT64_MAX - b)) {
return true;
}
return false;
}
bool CheckMultiplyOverflowInt64(const int64_t &a, const int64_t &b) {
if (a > 0) {
if (b > 0) {
if (a > (std::numeric_limits<int64_t>::max() / b)) {
return true;
}
} else {
if (b < (std::numeric_limits<int64_t>::min() / a)) {
return true;
}
}
} else {
if (b > 0) {
if (a < (std::numeric_limits<int64_t>::min() / b)) {
return true;
}
} else {
if ((a != 0) && (b < (std::numeric_limits<int64_t>::max() / a))) {
return true;
}
}
}
return false;
}
int32_t GetSizeByDataType(TensorDataType data_type) {
static const std::map<TensorDataType, int32_t> sizeMap = {
{TensorDataType::DT_FLOAT, 4},
{TensorDataType::DT_FLOAT16, 2},
{TensorDataType::DT_BF16, 2},
{TensorDataType::DT_INT8, 1},
{TensorDataType::DT_INT16, 2},
{TensorDataType::DT_UINT16, 2},
{TensorDataType::DT_UINT8, 1},
{TensorDataType::DT_INT32, 4},
{TensorDataType::DT_INT64, 8},
{TensorDataType::DT_UINT32, 4},
{TensorDataType::DT_UINT64, 8},
{TensorDataType::DT_BOOL, 1},
{TensorDataType::DT_DOUBLE, 8},
{TensorDataType::DT_QINT8, 1},
{TensorDataType::DT_QINT16, 2},
{TensorDataType::DT_QINT32, 4},
{TensorDataType::DT_QUINT8, 1},
{TensorDataType::DT_QUINT16, 2},
{TensorDataType::DT_DUAL, 5},
{TensorDataType::DT_INT4, kDataTypeSizeBitOffset + 4},
{TensorDataType::DT_UINT1, kDataTypeSizeBitOffset + 1},
{TensorDataType::DT_INT2, kDataTypeSizeBitOffset + 2},
{TensorDataType::DT_UINT2, kDataTypeSizeBitOffset + 2}
};
const auto iter = sizeMap.find(data_type);
if (iter == sizeMap.cend()) {
return -1;
}
return iter->second;
}
int64_t CalcElementCnt(const std::vector<int64_t> &shape) {
int64_t element_cnt = 1;
for (int64_t dim : shape) {
if (dim < 0) {
UDF_LOG_ERROR("dim is negative, not support now, dim=%ld.", dim);
return -1;
}
if (CheckMultiplyOverflowInt64(element_cnt, dim)) {
UDF_LOG_ERROR("CalcElementCnt failed, when multiplying %ld and %ld.", element_cnt, dim);
return -1;
}
element_cnt *= dim;
}
return element_cnt;
}
int64_t CalcDataSize(const std::vector<int64_t> &shape, TensorDataType data_type) {
int32_t type_size = GetSizeByDataType(data_type);
if (type_size < 0) {
UDF_LOG_ERROR("data_type=%d is not support.", static_cast<int32_t>(data_type));
return -1;
}
int64_t element_cnt = CalcElementCnt(shape);
if (element_cnt < 0) {
UDF_LOG_ERROR("CalcElementCnt failed, result=%ld.", element_cnt);
return -1;
}
if (type_size > kDataTypeSizeBitOffset) {
int32_t type_bit_size = type_size - kDataTypeSizeBitOffset;
if (CheckMultiplyOverflowInt64(element_cnt, static_cast<int64_t>(type_bit_size))) {
UDF_LOG_ERROR("CalcDataSize failed, when multiplying %ld and %d.", element_cnt, type_bit_size);
return -1;
}
int64_t data_bit_size = element_cnt * type_bit_size;
constexpr int64_t byteBitSize = 8;
int64_t data_size = data_bit_size / byteBitSize;
if (data_bit_size % byteBitSize != 0) {
data_size += 1;
}
return data_size;
} else {
if (CheckMultiplyOverflowInt64(element_cnt, static_cast<int64_t>(type_size))) {
UDF_LOG_ERROR("CalcDataSize failed, when multiplying %ld and %d.", element_cnt, type_size);
return -1;
}
return element_cnt * type_size;
}
}
bool ConvertToInt32(const std::string &str, int32_t &val) {
try {
val = std::stoi(str);
} catch (std::invalid_argument &) {
UDF_LOG_ERROR("The digit str:%s is invalid", str.c_str());
return false;
} catch (std::out_of_range &) {
UDF_LOG_ERROR("The digit str:%s is out of range int32", str.c_str());
return false;
} catch (...) {
UDF_LOG_ERROR("The digit str:%s cannot change to int", str.c_str());
return false;
}
return true;
}
bool ConvertToInt64(const std::string &str, int64_t &val) {
try {
val = std::stoll(str);
} catch (std::invalid_argument &) {
UDF_LOG_ERROR("The digit str:%s is invalid", str.c_str());
return false;
} catch (std::out_of_range &) {
UDF_LOG_ERROR("The digit str:%s is out of range int64", str.c_str());
return false;
} catch (...) {
UDF_LOG_ERROR("The digit str:%s cannot change to int64", str.c_str());
return false;
}
return true;
}
}